Fix tuple unapply parsing & try add group by

This commit is contained in:
jilen 2025-07-15 14:29:12 +08:00
parent feeb9cab1e
commit 0b4a6cb0c4
9 changed files with 161 additions and 41 deletions

View file

@ -32,16 +32,16 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = {
+ [x] 基础 Ast 及相关操作
+ [x] 验证 `inline``FromExpr` 是否生效
+ [ ] 验证 `lift` / `liftCaseClass` 如何处理
+ [ ] 验证 `Insert/Update` 实现
+ [x] 验证 `Insert/Update` 实现
+ [ ] DSL
- [x] Map
- [ ] Filter/FlatMap/ConcatMap/Union
- [ ] Join
- [ ] GroupBy/Aggeration
- [x] Filter/FlatMap/ConcatMap/Union
- [x] Join
- [x] GroupBy/Aggeration
+ [ ] 函数解析
- [x] Ident
- [x] Property
- [ ] BinaryOperation
- [ ] UnaryOperation
- [x] BinaryOperation
- [x] UnaryOperation
- [ ] CaseClass
- [ ] Tuple

View file

@ -6,6 +6,8 @@ import minisql.parsing.*
import minisql.util.*
import minisql.ast.{
Ast,
Aggregation,
AggregationOperator,
Entity,
Map,
Property,
@ -14,6 +16,7 @@ import minisql.ast.{
PropertyAlias,
JoinType,
Join,
GroupBy,
given
}
import scala.quoted.*
@ -21,18 +24,25 @@ import scala.deriving.*
import scala.compiletime.*
import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap}
import scala.util.NotGiven
opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted
opaque type Agg[E] <: Quoted = Quoted
opaque type Grouped[G, E] = Quoted
opaque type Action[E] <: Quoted = Quoted
opaque type Update[E] <: Action[Long] = Quoted
opaque type Insert[E] <: Action[Long] = Quoted
opaque type InsertReturning[E] <: Action[E] = Quoted
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Insert {
extension [E](inline insert: Insert[E]) {
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
@ -47,12 +57,8 @@ object Insert {
}
}
opaque type InsertReturning[E] <: Action[E] = Quoted
sealed trait Joined[E1, E2]
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Joined {
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
@ -64,6 +70,20 @@ object Joined {
}
}
trait GroupCollection[E] {
def size: Long = throw new NonQuotedException
}
object Grouped {
extension [G, E](inline g: Grouped[G, E]) {
inline def map[E1](
inline f: (p: (G, GroupCollection[E])) => E1
): Query[E1] = {
transform(g)(f)(Map.apply)
}
}
}
private inline def joinOn[E1, E2](
inline j: Joined[E1, E2],
inline f: (E1, E2) => Boolean
@ -148,6 +168,15 @@ object Query {
transform(e)(f)(Filter.apply)
}
inline def groupBy[G](inline f: E => G): Grouped[G, E] = {
transform(e)(f)(GroupBy.apply)
}
inline def size: Agg[Long] =
Aggregation(AggregationOperator.size, e)
inline def min: Agg[Long] = Aggregation(AggregationOperator.min, e)
inline def max: Agg[Long] = Aggregation(AggregationOperator.min, e)
}
}
@ -308,7 +337,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
q.value match {
compileImpl(q.value, q, idiom, n)
}
private def compileImpl[I <: Idiom, N <: NamingStrategy](
staticAst: Option[Ast],
runtimeAst: Expr[Ast],
idiom: Expr[I],
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
staticAst match {
case Some(ast) =>
val idiom = LoadObject[I].getOrElse(
report.errorAndAbort(s"Idiom not known at compile")
@ -324,7 +363,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
case None =>
report.info("Dynamic Query")
'{
$idiom.translate($q)(using $n)._2
$idiom.translate($runtimeAst)(using $n)._2
}
}

View file

@ -2,6 +2,7 @@ package minisql.ast
import minisql.NamingStrategy
import minisql.ParamEncoder
import minisql.idiom.MirrorIdiom
import scala.quoted.*
@ -21,6 +22,10 @@ sealed trait Ast {
override def apply(a: Ast) =
super.apply(a.neutral)
}.apply(this)
override def toString(): String = {
MirrorIdiom.translate(this)(using minisql.Literal)._2.toString()
}
}
//************************************************************

View file

@ -137,6 +137,13 @@ private given FromExpr[Query] with {
Some(SortBy(b, p, s, o))
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
Some(GroupBy(b, p, body))
case '{
val x: Ast = ${ Expr(b) }
val y: Ident = ${ Expr(p) }
val z: Ast = ${ Expr(body) }
GroupBy(x, y, z)
} =>
Some(GroupBy(b, p, body))
case '{ Distinct(${ Expr(a) }) } =>
Some(Distinct(a))
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
@ -164,6 +171,9 @@ private given FromExpr[Query] with {
Some(FlatJoin(t, a, ia, on))
case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a))
case o =>
import quotes.reflect.*
None
}
}
@ -421,6 +431,14 @@ private given FromExpr[Block] with {
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
case '{ Block(${ Expr(statements) }) } =>
Some(Block(statements))
case '{ Block(List(${ Varargs(stmts) }*)) } =>
stmts.map { x =>
val o = x.asInstanceOf[Expr[Ast]].value
if (o.isEmpty) {
println(s"===================${x.show}")
}
}
None
}
}

View file

@ -7,7 +7,7 @@ import minisql.ColumnDecoder
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
import scala.deriving.*
import scala.compiletime.*
import scala.util.Try
import scala.util.{Try, Success, Failure}
import scala.annotation.targetName
trait RowExtract[A, Row] {
@ -113,6 +113,46 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
)
}
inline def io[E](inline q: minisql.Agg[E])(using
e: ColumnDecoder.Aux[DBRow, E]
): DBIO[E] = {
val mapper: Iterable[DBRow] => Try[E] = summonFrom {
case _: (E <:< Option[?]) =>
(rows: Iterable[DBRow]) =>
rows.toVector match {
case Vector() => Success(None.asInstanceOf[E])
case Vector(r) =>
RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E])
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
case _ =>
(rows) =>
rows.toVector match {
case Vector(r) => RowExtract.single(e).extract(r)
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
}
val lifts = q.liftMap
val stmt = minisql.compile[I, N](q, idiom, naming)
val (sql, params) = stmt.expand(lifts)
(
sql = sql,
params = params.map(_.value.get.asInstanceOf),
mapper = mapper
)
}
@targetName("ioQuery")
inline def io[E](
inline q: minisql.Query[E]

View file

@ -62,6 +62,7 @@ private[minisql] object Parsing {
.orElse(ifParser)
.orElse(traversableOperationParser)
.orElse(patMatchParser)
.orElse(aggParser)
.orElse {
case o =>
val str = scala.util.Try(o.show).getOrElse("")
@ -104,7 +105,6 @@ private[minisql] object Parsing {
lazy val ifParser: Parser[ast.If] = {
case '{ if ($a) $b else $c } =>
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
}
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
@ -115,6 +115,11 @@ private[minisql] object Parsing {
astParser
)
lazy val aggParser: Parser[ast.Aggregation] = {
case '{ ($t: minisql.GroupCollection[t]).size } =>
'{ ast.Aggregation(ast.AggregationOperator.size, ${ astParser(t) }) }
}
astParser(expr)
}

View file

@ -15,14 +15,16 @@ private[parsing] def patMatchParsing(
Ident(t),
List(CaseDef(IsTupleUnapply(binds), None, body))
) =>
val bindStmts = binds.map {
case Bind(bn, _) =>
val bindStmts = binds.zipWithIndex.map {
case (Bind(bn, _), i) =>
val fidx = Expr(s"_${i + 1}")
'{
ast.Val(
ast.Ident(${ Expr(bn) }),
ast.Property(ast.Ident(${ Expr(t) }), "_1")
ast.Property(ast.Ident(${ Expr(t) }), $fidx)
)
}
}
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)

View file

@ -97,7 +97,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
assertEquals(
o.sql,
"SELECT f1.id, f1.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
"SELECT f1.id, f2.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
)
}
@ -143,4 +143,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
"UPDATE foo SET name = ?, age = ? WHERE id = 1"
)
}
test("GroupBy") {
val o = testContext.io(
Foos.groupBy(f => f.age).map { case (age, fs) => (age, fs.size) }
)
assertEquals(
o.sql,
"SELECT f.age, COUNT(*) FROM foo f GROUP BY f.age"
)
}
}