diff --git a/README.md b/README.md index 6f82dea..377f016 100644 --- a/README.md +++ b/README.md @@ -32,16 +32,16 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = { + [x] 基础 Ast 及相关操作 + [x] 验证 `inline` 和 `FromExpr` 是否生效 + [ ] 验证 `lift` / `liftCaseClass` 如何处理 -+ [ ] 验证 `Insert/Update` 实现 ++ [x] 验证 `Insert/Update` 实现 + [ ] DSL - [x] Map - - [ ] Filter/FlatMap/ConcatMap/Union - - [ ] Join - - [ ] GroupBy/Aggeration + - [x] Filter/FlatMap/ConcatMap/Union + - [x] Join + - [x] GroupBy/Aggeration + [ ] 函数解析 - [x] Ident - [x] Property - - [ ] BinaryOperation - - [ ] UnaryOperation + - [x] BinaryOperation + - [x] UnaryOperation - [ ] CaseClass - [ ] Tuple diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 37fd684..416b0fd 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -6,6 +6,8 @@ import minisql.parsing.* import minisql.util.* import minisql.ast.{ Ast, + Aggregation, + AggregationOperator, Entity, Map, Property, @@ -14,6 +16,7 @@ import minisql.ast.{ PropertyAlias, JoinType, Join, + GroupBy, given } import scala.quoted.* @@ -21,18 +24,25 @@ import scala.deriving.* import scala.compiletime.* import scala.compiletime.ops.string.* import scala.collection.immutable.{Map => IMap} -import scala.util.NotGiven opaque type Quoted <: Ast = Ast opaque type Query[E] <: Quoted = Quoted +opaque type Agg[E] <: Quoted = Quoted + +opaque type Grouped[G, E] = Quoted + opaque type Action[E] <: Quoted = Quoted opaque type Update[E] <: Action[Long] = Quoted opaque type Insert[E] <: Action[Long] = Quoted +opaque type InsertReturning[E] <: Action[E] = Quoted + +opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted + object Insert { extension [E](inline insert: Insert[E]) { inline def returning[E1](inline f: E => E1): InsertReturning[E1] = { @@ -47,12 +57,8 @@ object Insert { } } -opaque type InsertReturning[E] <: Action[E] = Quoted - sealed trait Joined[E1, E2] -opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted - object Joined { def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] = @@ -64,6 +70,20 @@ object Joined { } } +trait GroupCollection[E] { + def size: Long = throw new NonQuotedException +} + +object Grouped { + extension [G, E](inline g: Grouped[G, E]) { + inline def map[E1]( + inline f: (p: (G, GroupCollection[E])) => E1 + ): Query[E1] = { + transform(g)(f)(Map.apply) + } + } +} + private inline def joinOn[E1, E2]( inline j: Joined[E1, E2], inline f: (E1, E2) => Boolean @@ -148,6 +168,15 @@ object Query { transform(e)(f)(Filter.apply) } + inline def groupBy[G](inline f: E => G): Grouped[G, E] = { + transform(e)(f)(GroupBy.apply) + } + + inline def size: Agg[Long] = + Aggregation(AggregationOperator.size, e) + + inline def min: Agg[Long] = Aggregation(AggregationOperator.min, e) + inline def max: Agg[Long] = Aggregation(AggregationOperator.min, e) } } @@ -308,7 +337,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( n: Expr[N] )(using Quotes, Type[I], Type[N]): Expr[Statement] = { import quotes.reflect.* - q.value match { + compileImpl(q.value, q, idiom, n) +} + +private def compileImpl[I <: Idiom, N <: NamingStrategy]( + staticAst: Option[Ast], + runtimeAst: Expr[Ast], + idiom: Expr[I], + n: Expr[N] +)(using Quotes, Type[I], Type[N]): Expr[Statement] = { + import quotes.reflect.* + staticAst match { case Some(ast) => val idiom = LoadObject[I].getOrElse( report.errorAndAbort(s"Idiom not known at compile") @@ -324,7 +363,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( case None => report.info("Dynamic Query") '{ - $idiom.translate($q)(using $n)._2 + $idiom.translate($runtimeAst)(using $n)._2 } } diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 42386c7..80e413d 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -2,6 +2,7 @@ package minisql.ast import minisql.NamingStrategy import minisql.ParamEncoder +import minisql.idiom.MirrorIdiom import scala.quoted.* @@ -21,6 +22,10 @@ sealed trait Ast { override def apply(a: Ast) = super.apply(a.neutral) }.apply(this) + + override def toString(): String = { + MirrorIdiom.translate(this)(using minisql.Literal)._2.toString() + } } //************************************************************ diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 08686b3..6bb7f40 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -137,6 +137,13 @@ private given FromExpr[Query] with { Some(SortBy(b, p, s, o)) case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } => Some(GroupBy(b, p, body)) + case '{ + val x: Ast = ${ Expr(b) } + val y: Ident = ${ Expr(p) } + val z: Ast = ${ Expr(body) } + GroupBy(x, y, z) + } => + Some(GroupBy(b, p, body)) case '{ Distinct(${ Expr(a) }) } => Some(Distinct(a)) case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } => @@ -164,6 +171,9 @@ private given FromExpr[Query] with { Some(FlatJoin(t, a, ia, on)) case '{ Nested(${ Expr(a) }) } => Some(Nested(a)) + case o => + import quotes.reflect.* + None } } @@ -421,6 +431,14 @@ private given FromExpr[Block] with { def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match { case '{ Block(${ Expr(statements) }) } => Some(Block(statements)) + case '{ Block(List(${ Varargs(stmts) }*)) } => + stmts.map { x => + val o = x.asInstanceOf[Expr[Ast]].value + if (o.isEmpty) { + println(s"===================${x.show}") + } + } + None } } diff --git a/src/main/scala/minisql/ast/Operator.scala b/src/main/scala/minisql/ast/Operator.scala index 09489fa..d042bb0 100644 --- a/src/main/scala/minisql/ast/Operator.scala +++ b/src/main/scala/minisql/ast/Operator.scala @@ -2,58 +2,58 @@ package minisql.ast sealed trait Operator -sealed trait UnaryOperator extends Operator -sealed trait PrefixUnaryOperator extends UnaryOperator +sealed trait UnaryOperator extends Operator +sealed trait PrefixUnaryOperator extends UnaryOperator sealed trait PostfixUnaryOperator extends UnaryOperator -sealed trait BinaryOperator extends Operator +sealed trait BinaryOperator extends Operator object EqualityOperator { case object `==` extends BinaryOperator case object `!=` extends BinaryOperator - inline def Eq = `==` + inline def Eq = `==` inline def Neq = `!=` } object BooleanOperator { - case object `!` extends PrefixUnaryOperator + case object `!` extends PrefixUnaryOperator case object `&&` extends BinaryOperator case object `||` extends BinaryOperator } object StringOperator { - case object `concat` extends BinaryOperator - case object `startsWith` extends BinaryOperator - case object `split` extends BinaryOperator + case object `concat` extends BinaryOperator + case object `startsWith` extends BinaryOperator + case object `split` extends BinaryOperator case object `toUpperCase` extends PostfixUnaryOperator case object `toLowerCase` extends PostfixUnaryOperator - case object `toLong` extends PostfixUnaryOperator - case object `toInt` extends PostfixUnaryOperator + case object `toLong` extends PostfixUnaryOperator + case object `toInt` extends PostfixUnaryOperator } object NumericOperator { - case object `-` extends BinaryOperator with PrefixUnaryOperator - case object `+` extends BinaryOperator - case object `*` extends BinaryOperator - case object `>` extends BinaryOperator + case object `-` extends BinaryOperator with PrefixUnaryOperator + case object `+` extends BinaryOperator + case object `*` extends BinaryOperator + case object `>` extends BinaryOperator case object `>=` extends BinaryOperator - case object `<` extends BinaryOperator + case object `<` extends BinaryOperator case object `<=` extends BinaryOperator - case object `/` extends BinaryOperator - case object `%` extends BinaryOperator + case object `/` extends BinaryOperator + case object `%` extends BinaryOperator } object SetOperator { case object `contains` extends BinaryOperator case object `nonEmpty` extends PostfixUnaryOperator - case object `isEmpty` extends PostfixUnaryOperator + case object `isEmpty` extends PostfixUnaryOperator } sealed trait AggregationOperator extends Operator object AggregationOperator { - case object `min` extends AggregationOperator - case object `max` extends AggregationOperator - case object `avg` extends AggregationOperator - case object `sum` extends AggregationOperator + case object `min` extends AggregationOperator + case object `max` extends AggregationOperator + case object `avg` extends AggregationOperator + case object `sum` extends AggregationOperator case object `size` extends AggregationOperator } diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index bf945c3..90bb8e7 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -7,7 +7,7 @@ import minisql.ColumnDecoder import minisql.ast.{Ast, ScalarValueLift, CollectAst} import scala.deriving.* import scala.compiletime.* -import scala.util.Try +import scala.util.{Try, Success, Failure} import scala.annotation.targetName trait RowExtract[A, Row] { @@ -113,6 +113,46 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => ) } + inline def io[E](inline q: minisql.Agg[E])(using + e: ColumnDecoder.Aux[DBRow, E] + ): DBIO[E] = { + val mapper: Iterable[DBRow] => Try[E] = summonFrom { + case _: (E <:< Option[?]) => + (rows: Iterable[DBRow]) => + rows.toVector match { + case Vector() => Success(None.asInstanceOf[E]) + case Vector(r) => + RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E]) + case o => + Failure( + new IllegalStateException( + s"Expect agg value, got ${o.size} rows" + ) + ) + } + case _ => + (rows) => + rows.toVector match { + case Vector(r) => RowExtract.single(e).extract(r) + case o => + Failure( + new IllegalStateException( + s"Expect agg value, got ${o.size} rows" + ) + ) + } + + } + val lifts = q.liftMap + val stmt = minisql.compile[I, N](q, idiom, naming) + val (sql, params) = stmt.expand(lifts) + ( + sql = sql, + params = params.map(_.value.get.asInstanceOf), + mapper = mapper + ) + } + @targetName("ioQuery") inline def io[E]( inline q: minisql.Query[E] diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala index 8794e7f..3e9b583 100644 --- a/src/main/scala/minisql/parsing/Parsing.scala +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -62,6 +62,7 @@ private[minisql] object Parsing { .orElse(ifParser) .orElse(traversableOperationParser) .orElse(patMatchParser) + .orElse(aggParser) .orElse { case o => val str = scala.util.Try(o.show).getOrElse("") @@ -104,7 +105,6 @@ private[minisql] object Parsing { lazy val ifParser: Parser[ast.If] = { case '{ if ($a) $b else $c } => '{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) } - } lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) @@ -115,6 +115,11 @@ private[minisql] object Parsing { astParser ) + lazy val aggParser: Parser[ast.Aggregation] = { + case '{ ($t: minisql.GroupCollection[t]).size } => + '{ ast.Aggregation(ast.AggregationOperator.size, ${ astParser(t) }) } + } + astParser(expr) } diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala index 9a7686c..51b5b3e 100644 --- a/src/main/scala/minisql/parsing/PatMatchParsing.scala +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -15,14 +15,16 @@ private[parsing] def patMatchParsing( Ident(t), List(CaseDef(IsTupleUnapply(binds), None, body)) ) => - val bindStmts = binds.map { - case Bind(bn, _) => + val bindStmts = binds.zipWithIndex.map { + case (Bind(bn, _), i) => + val fidx = Expr(s"_${i + 1}") '{ ast.Val( ast.Ident(${ Expr(bn) }), - ast.Property(ast.Ident(${ Expr(t) }), "_1") + ast.Property(ast.Ident(${ Expr(t) }), $fidx) ) } + } val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr) diff --git a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala index 3edde53..a356236 100644 --- a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala +++ b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala @@ -97,7 +97,7 @@ class MirrorSqlContextSuite extends munit.FunSuite { assertEquals( o.sql, - "SELECT f1.id, f1.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id" + "SELECT f1.id, f2.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id" ) } @@ -143,4 +143,15 @@ class MirrorSqlContextSuite extends munit.FunSuite { "UPDATE foo SET name = ?, age = ? WHERE id = 1" ) } + + test("GroupBy") { + val o = testContext.io( + Foos.groupBy(f => f.age).map { case (age, fs) => (age, fs.size) } + ) + assertEquals( + o.sql, + "SELECT f.age, COUNT(*) FROM foo f GROUP BY f.age" + ) + } + }