Fix tuple unapply parsing & try add group by

This commit is contained in:
jilen 2025-07-15 14:29:12 +08:00
parent feeb9cab1e
commit 0b4a6cb0c4
9 changed files with 161 additions and 41 deletions

View file

@ -32,16 +32,16 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = {
+ [x] 基础 Ast 及相关操作 + [x] 基础 Ast 及相关操作
+ [x] 验证 `inline``FromExpr` 是否生效 + [x] 验证 `inline``FromExpr` 是否生效
+ [ ] 验证 `lift` / `liftCaseClass` 如何处理 + [ ] 验证 `lift` / `liftCaseClass` 如何处理
+ [ ] 验证 `Insert/Update` 实现 + [x] 验证 `Insert/Update` 实现
+ [ ] DSL + [ ] DSL
- [x] Map - [x] Map
- [ ] Filter/FlatMap/ConcatMap/Union - [x] Filter/FlatMap/ConcatMap/Union
- [ ] Join - [x] Join
- [ ] GroupBy/Aggeration - [x] GroupBy/Aggeration
+ [ ] 函数解析 + [ ] 函数解析
- [x] Ident - [x] Ident
- [x] Property - [x] Property
- [ ] BinaryOperation - [x] BinaryOperation
- [ ] UnaryOperation - [x] UnaryOperation
- [ ] CaseClass - [ ] CaseClass
- [ ] Tuple - [ ] Tuple

View file

@ -6,6 +6,8 @@ import minisql.parsing.*
import minisql.util.* import minisql.util.*
import minisql.ast.{ import minisql.ast.{
Ast, Ast,
Aggregation,
AggregationOperator,
Entity, Entity,
Map, Map,
Property, Property,
@ -14,6 +16,7 @@ import minisql.ast.{
PropertyAlias, PropertyAlias,
JoinType, JoinType,
Join, Join,
GroupBy,
given given
} }
import scala.quoted.* import scala.quoted.*
@ -21,18 +24,25 @@ import scala.deriving.*
import scala.compiletime.* import scala.compiletime.*
import scala.compiletime.ops.string.* import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap} import scala.collection.immutable.{Map => IMap}
import scala.util.NotGiven
opaque type Quoted <: Ast = Ast opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted opaque type Query[E] <: Quoted = Quoted
opaque type Agg[E] <: Quoted = Quoted
opaque type Grouped[G, E] = Quoted
opaque type Action[E] <: Quoted = Quoted opaque type Action[E] <: Quoted = Quoted
opaque type Update[E] <: Action[Long] = Quoted opaque type Update[E] <: Action[Long] = Quoted
opaque type Insert[E] <: Action[Long] = Quoted opaque type Insert[E] <: Action[Long] = Quoted
opaque type InsertReturning[E] <: Action[E] = Quoted
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Insert { object Insert {
extension [E](inline insert: Insert[E]) { extension [E](inline insert: Insert[E]) {
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = { inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
@ -47,12 +57,8 @@ object Insert {
} }
} }
opaque type InsertReturning[E] <: Action[E] = Quoted
sealed trait Joined[E1, E2] sealed trait Joined[E1, E2]
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Joined { object Joined {
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] = def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
@ -64,6 +70,20 @@ object Joined {
} }
} }
trait GroupCollection[E] {
def size: Long = throw new NonQuotedException
}
object Grouped {
extension [G, E](inline g: Grouped[G, E]) {
inline def map[E1](
inline f: (p: (G, GroupCollection[E])) => E1
): Query[E1] = {
transform(g)(f)(Map.apply)
}
}
}
private inline def joinOn[E1, E2]( private inline def joinOn[E1, E2](
inline j: Joined[E1, E2], inline j: Joined[E1, E2],
inline f: (E1, E2) => Boolean inline f: (E1, E2) => Boolean
@ -148,6 +168,15 @@ object Query {
transform(e)(f)(Filter.apply) transform(e)(f)(Filter.apply)
} }
inline def groupBy[G](inline f: E => G): Grouped[G, E] = {
transform(e)(f)(GroupBy.apply)
}
inline def size: Agg[Long] =
Aggregation(AggregationOperator.size, e)
inline def min: Agg[Long] = Aggregation(AggregationOperator.min, e)
inline def max: Agg[Long] = Aggregation(AggregationOperator.min, e)
} }
} }
@ -308,7 +337,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
n: Expr[N] n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = { )(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.* import quotes.reflect.*
q.value match { compileImpl(q.value, q, idiom, n)
}
private def compileImpl[I <: Idiom, N <: NamingStrategy](
staticAst: Option[Ast],
runtimeAst: Expr[Ast],
idiom: Expr[I],
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
staticAst match {
case Some(ast) => case Some(ast) =>
val idiom = LoadObject[I].getOrElse( val idiom = LoadObject[I].getOrElse(
report.errorAndAbort(s"Idiom not known at compile") report.errorAndAbort(s"Idiom not known at compile")
@ -324,7 +363,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
case None => case None =>
report.info("Dynamic Query") report.info("Dynamic Query")
'{ '{
$idiom.translate($q)(using $n)._2 $idiom.translate($runtimeAst)(using $n)._2
} }
} }

View file

@ -2,6 +2,7 @@ package minisql.ast
import minisql.NamingStrategy import minisql.NamingStrategy
import minisql.ParamEncoder import minisql.ParamEncoder
import minisql.idiom.MirrorIdiom
import scala.quoted.* import scala.quoted.*
@ -21,6 +22,10 @@ sealed trait Ast {
override def apply(a: Ast) = override def apply(a: Ast) =
super.apply(a.neutral) super.apply(a.neutral)
}.apply(this) }.apply(this)
override def toString(): String = {
MirrorIdiom.translate(this)(using minisql.Literal)._2.toString()
}
} }
//************************************************************ //************************************************************

View file

@ -137,6 +137,13 @@ private given FromExpr[Query] with {
Some(SortBy(b, p, s, o)) Some(SortBy(b, p, s, o))
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } => case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
Some(GroupBy(b, p, body)) Some(GroupBy(b, p, body))
case '{
val x: Ast = ${ Expr(b) }
val y: Ident = ${ Expr(p) }
val z: Ast = ${ Expr(body) }
GroupBy(x, y, z)
} =>
Some(GroupBy(b, p, body))
case '{ Distinct(${ Expr(a) }) } => case '{ Distinct(${ Expr(a) }) } =>
Some(Distinct(a)) Some(Distinct(a))
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } => case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
@ -164,6 +171,9 @@ private given FromExpr[Query] with {
Some(FlatJoin(t, a, ia, on)) Some(FlatJoin(t, a, ia, on))
case '{ Nested(${ Expr(a) }) } => case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a)) Some(Nested(a))
case o =>
import quotes.reflect.*
None
} }
} }
@ -421,6 +431,14 @@ private given FromExpr[Block] with {
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match { def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
case '{ Block(${ Expr(statements) }) } => case '{ Block(${ Expr(statements) }) } =>
Some(Block(statements)) Some(Block(statements))
case '{ Block(List(${ Varargs(stmts) }*)) } =>
stmts.map { x =>
val o = x.asInstanceOf[Expr[Ast]].value
if (o.isEmpty) {
println(s"===================${x.show}")
}
}
None
} }
} }

View file

@ -7,7 +7,7 @@ import minisql.ColumnDecoder
import minisql.ast.{Ast, ScalarValueLift, CollectAst} import minisql.ast.{Ast, ScalarValueLift, CollectAst}
import scala.deriving.* import scala.deriving.*
import scala.compiletime.* import scala.compiletime.*
import scala.util.Try import scala.util.{Try, Success, Failure}
import scala.annotation.targetName import scala.annotation.targetName
trait RowExtract[A, Row] { trait RowExtract[A, Row] {
@ -113,6 +113,46 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
) )
} }
inline def io[E](inline q: minisql.Agg[E])(using
e: ColumnDecoder.Aux[DBRow, E]
): DBIO[E] = {
val mapper: Iterable[DBRow] => Try[E] = summonFrom {
case _: (E <:< Option[?]) =>
(rows: Iterable[DBRow]) =>
rows.toVector match {
case Vector() => Success(None.asInstanceOf[E])
case Vector(r) =>
RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E])
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
case _ =>
(rows) =>
rows.toVector match {
case Vector(r) => RowExtract.single(e).extract(r)
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
}
val lifts = q.liftMap
val stmt = minisql.compile[I, N](q, idiom, naming)
val (sql, params) = stmt.expand(lifts)
(
sql = sql,
params = params.map(_.value.get.asInstanceOf),
mapper = mapper
)
}
@targetName("ioQuery") @targetName("ioQuery")
inline def io[E]( inline def io[E](
inline q: minisql.Query[E] inline q: minisql.Query[E]

View file

@ -62,6 +62,7 @@ private[minisql] object Parsing {
.orElse(ifParser) .orElse(ifParser)
.orElse(traversableOperationParser) .orElse(traversableOperationParser)
.orElse(patMatchParser) .orElse(patMatchParser)
.orElse(aggParser)
.orElse { .orElse {
case o => case o =>
val str = scala.util.Try(o.show).getOrElse("") val str = scala.util.Try(o.show).getOrElse("")
@ -104,7 +105,6 @@ private[minisql] object Parsing {
lazy val ifParser: Parser[ast.If] = { lazy val ifParser: Parser[ast.If] = {
case '{ if ($a) $b else $c } => case '{ if ($a) $b else $c } =>
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) } '{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
} }
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
@ -115,6 +115,11 @@ private[minisql] object Parsing {
astParser astParser
) )
lazy val aggParser: Parser[ast.Aggregation] = {
case '{ ($t: minisql.GroupCollection[t]).size } =>
'{ ast.Aggregation(ast.AggregationOperator.size, ${ astParser(t) }) }
}
astParser(expr) astParser(expr)
} }

View file

@ -15,14 +15,16 @@ private[parsing] def patMatchParsing(
Ident(t), Ident(t),
List(CaseDef(IsTupleUnapply(binds), None, body)) List(CaseDef(IsTupleUnapply(binds), None, body))
) => ) =>
val bindStmts = binds.map { val bindStmts = binds.zipWithIndex.map {
case Bind(bn, _) => case (Bind(bn, _), i) =>
val fidx = Expr(s"_${i + 1}")
'{ '{
ast.Val( ast.Val(
ast.Ident(${ Expr(bn) }), ast.Ident(${ Expr(bn) }),
ast.Property(ast.Ident(${ Expr(t) }), "_1") ast.Property(ast.Ident(${ Expr(t) }), $fidx)
) )
} }
} }
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr) val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)

View file

@ -97,7 +97,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
assertEquals( assertEquals(
o.sql, o.sql,
"SELECT f1.id, f1.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id" "SELECT f1.id, f2.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
) )
} }
@ -143,4 +143,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
"UPDATE foo SET name = ?, age = ? WHERE id = 1" "UPDATE foo SET name = ?, age = ? WHERE id = 1"
) )
} }
test("GroupBy") {
val o = testContext.io(
Foos.groupBy(f => f.age).map { case (age, fs) => (age, fs.size) }
)
assertEquals(
o.sql,
"SELECT f.age, COUNT(*) FROM foo f GROUP BY f.age"
)
}
} }