diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index abb5e5f..3440061 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -12,6 +12,8 @@ import minisql.ast.{ Ident, Filter, PropertyAlias, + JoinType, + Join, given } import scala.quoted.* @@ -28,6 +30,54 @@ opaque type Action[E] <: Quoted = Quoted opaque type Insert <: Action[Long] = Quoted +sealed trait Joined[E1, E2] + +opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted + +object Joined { + + def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] = + new Joined[E1, E2] {} + + extension [E1, E2](inline j: Joined[E1, E2]) { + inline def on(inline f: (E1, E2) => Boolean): JoinQuery[E1, E2] = + joinOn(j, f) + } +} + +private inline def joinOn[E1, E2]( + inline j: Joined[E1, E2], + inline f: (E1, E2) => Boolean +): JoinQuery[E1, E2] = j.toJoinQuery(f.param0, f.param1, f.body) + +extension [E1, E2](inline j: Joined[E1, E2]) { + private inline def toJoinQuery( + inline aliasA: Ident, + inline aliasB: Ident, + inline on: Ast + ): Ast = ${ joinQueryOf('j, 'aliasA, 'aliasB, 'on) } +} + +private def joinQueryOf[E1, E2]( + x: Expr[Joined[E1, E2]], + aliasA: Expr[Ident], + aliasB: Expr[Ident], + on: Expr[Ast] +)(using Quotes, Type[E1], Type[E2]): Expr[Join] = { + import quotes.reflect.* + extractTerm(x.asTerm).asExpr match { + case '{ + Joined[E1, E2]($jt, $a, $b) + } => + '{ + Join($jt, $a, $b, $aliasA, $aliasB, $on) + } + case o => + println("====================---" + o.show) + throw new Exception(s"Fail") + } +} + private inline def quotedLift[X](x: X)(using e: ParamEncoder[X] ): ast.ScalarValueLift = ${ @@ -52,14 +102,15 @@ private def quotedLiftImpl[X: Type]( object Query { - private[minisql] inline def apply[E](inline ast: Ast): Query[E] = ast - extension [E](inline e: Query[E]) { private[minisql] inline def expanded: Query[E] = { - Query(expandFields[E](e)) + expandFields[E](e) } + inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, E1] = + Joined[E, E1](JoinType.LeftJoin, e, e1) + inline def map[E1](inline f: E => E1): Query[E1] = { transform(e)(f)(Map.apply) } diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index f6d0192..8029af8 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -69,9 +69,6 @@ private given FromExpr[Property] with { ) } => Some(Property(a, n, r, v)) - case o => - println(s"Cannot extract ${o.show}") - None } } @@ -168,7 +165,7 @@ private given FromExpr[Query] with { case '{ Nested(${ Expr(a) }) } => Some(Nested(a)) case o => - println(s"Cannot extract ${o.show}") + // println(s"Cannot extract ${o.show}") None } } @@ -406,23 +403,6 @@ private given FromExpr[If] with { } } -private def extractTerm(using Quotes)(x: quotes.reflect.Term) = { - import quotes.reflect.* - def unwrapTerm(t: Term): Term = t match { - case Inlined(_, _, o) => unwrapTerm(o) - case Block(Nil, last) => last - case Typed(t, _) => - unwrapTerm(t) - case Select(t, "$asInstanceOf$") => - unwrapTerm(t) - case TypeApply(t, _) => - unwrapTerm(t) - case o => o - } - val o = unwrapTerm(x) - o -} - extension (e: Expr[Any]) { private def toTerm(using Quotes) = { import quotes.reflect.* @@ -434,7 +414,6 @@ extension (e: Expr[Any]) { private def fromBlock(using Quotes )(block: quotes.reflect.Block): Option[Ast] = { - println(s"Show block ${block.show}") import quotes.reflect.* val empty: Option[List[Ast]] = Some(Nil) val stmts = block.statements.foldLeft(empty) { (r, stmt) => diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala index 91bfbc0..4235398 100644 --- a/src/main/scala/minisql/parsing/Parser.scala +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -3,6 +3,7 @@ package minisql.parsing import minisql.ast import minisql.ast.Ast import scala.quoted.* +import minisql.util.* private[minisql] inline def parseParamAt[F]( inline f: F, diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala index 07da46a..370133b 100644 --- a/src/main/scala/minisql/parsing/Parsing.scala +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -9,7 +9,7 @@ import scala.annotation.tailrec import minisql.ast.Implicits._ import minisql.ast.Renameable.Fixed import minisql.ast.Visibility.{Hidden, Visible} -import minisql.util.Interleave +import minisql.util.{Interleave, extractTerm} import scala.quoted.* type Parser[A] = PartialFunction[Expr[Any], Expr[A]] @@ -29,22 +29,6 @@ private def parser[A]( case e if f.isDefinedAt(e) => f(e) } -private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { - import quotes.reflect.* - def unwrapTerm(t: Term): Term = t match { - case Inlined(_, _, o) => unwrapTerm(o) - case Block(Nil, last) => last - case Typed(t, _) => - unwrapTerm(t) - case Select(t, "$asInstanceOf$") => - unwrapTerm(t) - case TypeApply(t, _) => - unwrapTerm(t) - case o => o - } - unwrapTerm(x) -} - private[minisql] object Parsing { def parseExpr( @@ -103,7 +87,7 @@ private[minisql] object Parsing { } lazy val identParser: Parser[ast.Ident] = termParser { - case x @ Ident(n) if x.symbol.isValDef => + case x @ Ident(n) => '{ ast.Ident(${ Expr(n) }) } } diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala index 2db7652..b8bd924 100644 --- a/src/main/scala/minisql/parsing/PatMatchParsing.scala +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -16,16 +16,7 @@ private[parsing] def patMatchParsing( case (Bind(n, ident), idx) => n -> Select.unique(t, s"_${idx + 1}") }.toMap - val tm = new TreeMap { - override def transformTerm(tree: Term)(owner: Symbol): Term = { - tree match { - case Ident(n) => bm(n) - case o => super.transformTerm(o)(owner) - } - } - } - val newBody = tm.transformTree(body)(e.symbol) - astParser(newBody.asExpr) + blockParsing(astParser)(body.asExpr) } } diff --git a/src/main/scala/minisql/parsing/ValueParsing.scala b/src/main/scala/minisql/parsing/ValueParsing.scala index 6c2fb9e..e7d4ae9 100644 --- a/src/main/scala/minisql/parsing/ValueParsing.scala +++ b/src/main/scala/minisql/parsing/ValueParsing.scala @@ -2,6 +2,7 @@ package minisql package parsing import scala.quoted._ +import minisql.util.* private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using Quotes diff --git a/src/main/scala/minisql/util/QuotesHelper.scala b/src/main/scala/minisql/util/QuotesHelper.scala index 6ecbc76..fe93aa7 100644 --- a/src/main/scala/minisql/util/QuotesHelper.scala +++ b/src/main/scala/minisql/util/QuotesHelper.scala @@ -22,3 +22,20 @@ private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = { val fileName = pos.sourceFile.name s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}" } + +private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { + import quotes.reflect.* + def unwrapTerm(t: Term): Term = t match { + case Inlined(_, _, o) => unwrapTerm(o) + case Block(Nil, last) => last + case Typed(t, _) => + unwrapTerm(t) + case Select(t, "$asInstanceOf$") => + unwrapTerm(t) + case TypeApply(t, _) => + unwrapTerm(t) + case o => o + } + val o = unwrapTerm(x) + o +} diff --git a/src/test/scala/minisql/context/sql/QuotedSuite.scala b/src/test/scala/minisql/context/sql/QuotedSuite.scala index 4ef6d67..17ea26b 100644 --- a/src/test/scala/minisql/context/sql/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/QuotedSuite.scala @@ -34,4 +34,13 @@ class QuotedSuite extends munit.FunSuite { "INSERT INTO foo (id,name) VALUES (?, ?)" ) } + + test("LeftJoin") { + val o = testContext + .io(Foos.leftJoin(Foos).on((f1, f2) => f1.id == f2.id).map { + case (f1, f2) => (f1.id, f2.id) + }) + + println(o) + } }