From a1201a67aa6dbd27ab075a76582ef255d5268432 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 29 Jun 2025 16:12:12 +0800 Subject: [PATCH] Add more test case. Expand query elements --- src/main/scala/minisql/Quoted.scala | 24 +++- src/main/scala/minisql/ast/Ast.scala | 15 ++- src/main/scala/minisql/ast/FromExprs.scala | 7 +- src/main/scala/minisql/context/Context.scala | 10 +- .../scala/minisql/ast/FromExprsSuite.scala | 127 ++++++++++++++++++ 5 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 src/test/scala/minisql/ast/FromExprsSuite.scala diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 7abc115..3d4164d 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -52,8 +52,14 @@ private def quotedLiftImpl[X: Type]( object Query { + private[minisql] inline def apply[E](inline ast: Ast): Query[E] = ast + extension [E](inline e: Query[E]) { + private[minisql] inline def expanded: Query[E] = { + Query(expandFields[E](e)) + } + inline def map[E1](inline f: E => E1): Query[E1] = { transform(e)(f)(Map.apply) } @@ -157,10 +163,10 @@ def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() class NonQuotedException extends Exception("Cannot be used at runtime") -private[minisql] inline def compileTimeAst(inline q: Quoted): Option[String] = +private[minisql] inline def compileTimeAst(inline q: Ast): Option[String] = ${ compileTimeAstImpl('q) } -private def compileTimeAstImpl(e: Expr[Quoted])(using +private def compileTimeAstImpl(e: Expr[Ast])(using Quotes ): Expr[Option[String]] = { import quotes.reflect.* @@ -203,3 +209,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( } } + +private inline def expandFields[E](inline base: Ast): Ast = + ${ expandFieldsImpl[E]('base) } + +private def expandFieldsImpl[E](baseExpr: Expr[Ast])(using + Quotes, + Type[E] +): Expr[Ast] = { + import quotes.reflect.* + val values = TypeRepr.of[E].typeSymbol.caseFields.map { f => + '{ Property(ast.Ident("x"), ${ Expr(f.name) }) } + } + '{ Map(${ baseExpr }, ast.Ident("x"), ast.Tuple(${ Expr.ofList(values) })) } +} diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 52446e3..86407e1 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -59,9 +59,9 @@ object Entity { object Opinionated { inline def apply( - name: String, - properties: List[PropertyAlias], - renameableNew: Renameable + inline name: String, + inline properties: List[PropertyAlias], + inline renameableNew: Renameable ): Entity = Entity(name, properties, renameableNew) def unapply(e: Entity) = @@ -154,11 +154,14 @@ case class Ident(name: String, visibility: Visibility) extends Ast { * ExpandNestedQueries phase, needs to be marked invisible. */ object Ident { - def apply(name: String): Ident = Ident(name, Visibility.neutral) - def unapply(p: Ident) = Some((p.name)) + inline def apply(inline name: String): Ident = Ident(name, Visibility.neutral) + def unapply(p: Ident) = Some((p.name)) object Opinionated { - def apply(name: String, visibilityNew: Visibility): Ident = + inline def apply( + inline name: String, + inline visibilityNew: Visibility + ): Ident = Ident(name, visibilityNew) def unapply(p: Ident) = Some((p.name, p.visibility)) diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index e527a6f..c6d48cc 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -52,7 +52,7 @@ private given FromExpr[ScalarValueLift] with { private given FromExpr[Ident] with { def unapply(x: Expr[Ident])(using Quotes): Option[Ident] = x match { - case '{ Ident(${ Expr(n) }) } => Some(Ident(n)) + case '{ Ident(${ Expr(n) }, ${ Expr(v) }) } => Some(Ident(n, v)) } } @@ -136,7 +136,7 @@ private given FromExpr[Query] with { case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } => Some(SortBy(b, p, s, o)) case o => - println(s"Cannot extract ${o}") + println(s"Cannot extract ${o.show}") None } } @@ -274,10 +274,11 @@ private def extractTerm(using Quotes)(x: quotes.reflect.Term) = { } extension (e: Expr[Any]) { - def toTerm(using Quotes) = { + private def toTerm(using Quotes) = { import quotes.reflect.* e.asTerm } + } private def fromBlock(using diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 875d11e..bf945c3 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -118,14 +118,14 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => inline q: minisql.Query[E] ): DBIO[IArray[E]] = { - val extractor = summonFrom { - case e: RowExtract[E, DBRow] => e + val (stmt, extractor) = summonFrom { + case e: RowExtract[E, DBRow] => + minisql.compile[I, N](q.expanded, idiom, naming) -> e case e: ColumnDecoder.Aux[DBRow, E] => - RowExtract.single(e) - } + minisql.compile[I, N](q, idiom, naming) -> RowExtract.single(e) + }: @unchecked val lifts = q.liftMap - val stmt = minisql.compile[I, N](q, idiom, naming) val (sql, params) = stmt.expand(lifts) ( sql = sql, diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala new file mode 100644 index 0000000..ea6d14b --- /dev/null +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -0,0 +1,127 @@ +package minisql.ast + +import munit.FunSuite +import minisql.ast.* +import scala.quoted.* + +class FromExprsSuite extends FunSuite { + + // Helper to test both compile-time and runtime extraction + inline def testFor[A <: Ast](label: String)(inline ast: A) = { + test(label) { + // Test compile-time extraction + val compileTimeResult = minisql.compileTimeAst(ast) + assert(compileTimeResult.contains(ast.toString)) + } + } + + testFor("Ident") { + Ident("test") + } + + testFor("Ident with visibility") { + Ident.Opinionated("test", Visibility.Hidden) + } + + testFor("Property") { + Property(Ident("a"), "b") + } + + testFor("Property with opinions") { + Property.Opinionated(Ident("a"), "b", Renameable.Fixed, Visibility.Visible) + } + + testFor("BinaryOperation") { + BinaryOperation(Ident("a"), EqualityOperator.==, Ident("b")) + } + + testFor("UnaryOperation") { + UnaryOperation(BooleanOperator.!, Ident("flag")) + } + + testFor("ScalarValueLift") { + ScalarValueLift("name", "id", None) + } + + testFor("Ordering") { + Asc + } + + testFor("TupleOrdering") { + TupleOrdering(List(Asc, Desc)) + } + + testFor("Entity") { + Entity("people", Nil) + } + + testFor("Entity with properties") { + Entity("people", List(PropertyAlias(List("name"), "full_name"))) + } + + testFor("Action - Insert") { + Insert( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ) + } + + testFor("Action - Update") { + Update( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ) + } + + testFor("If expression") { + If(Ident("cond"), Ident("then"), Ident("else")) + } + + testFor("Infix") { + Infix( + List("func(", ")"), + List(Ident("param")), + pure = true, + noParen = false + ) + } + + testFor("OptionOperation - OptionMap") { + OptionMap(Ident("opt"), Ident("x"), Ident("x")) + } + + testFor("OptionOperation - OptionFlatMap") { + OptionFlatMap(Ident("opt"), Ident("x"), Ident("x")) + } + + testFor("OptionOperation - OptionGetOrElse") { + OptionGetOrElse(Ident("opt"), Ident("default")) + } + + testFor("Join") { + Join( + InnerJoin, + Ident("a"), + Ident("b"), + Ident("a1"), + Ident("b1"), + BinaryOperation(Ident("a1.id"), EqualityOperator.==, Ident("b1.id")) + ) + } + + testFor("Distinct") { + Distinct(Ident("query")) + } + + testFor("GroupBy") { + GroupBy(Ident("query"), Ident("alias"), Ident("body")) + } + + testFor("Aggregation") { + Aggregation(AggregationOperator.avg, Ident("field")) + } + + testFor("CaseClass") { + CaseClass(List(("name", Ident("value")))) + } +}