Fix tuple unapply parsing & try add group by

This commit is contained in:
jilen 2025-07-15 14:29:12 +08:00
parent feeb9cab1e
commit 0b4a6cb0c4
9 changed files with 161 additions and 41 deletions

View file

@ -32,16 +32,16 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = {
+ [x] 基础 Ast 及相关操作 + [x] 基础 Ast 及相关操作
+ [x] 验证 `inline``FromExpr` 是否生效 + [x] 验证 `inline``FromExpr` 是否生效
+ [ ] 验证 `lift` / `liftCaseClass` 如何处理 + [ ] 验证 `lift` / `liftCaseClass` 如何处理
+ [ ] 验证 `Insert/Update` 实现 + [x] 验证 `Insert/Update` 实现
+ [ ] DSL + [ ] DSL
- [x] Map - [x] Map
- [ ] Filter/FlatMap/ConcatMap/Union - [x] Filter/FlatMap/ConcatMap/Union
- [ ] Join - [x] Join
- [ ] GroupBy/Aggeration - [x] GroupBy/Aggeration
+ [ ] 函数解析 + [ ] 函数解析
- [x] Ident - [x] Ident
- [x] Property - [x] Property
- [ ] BinaryOperation - [x] BinaryOperation
- [ ] UnaryOperation - [x] UnaryOperation
- [ ] CaseClass - [ ] CaseClass
- [ ] Tuple - [ ] Tuple

View file

@ -6,6 +6,8 @@ import minisql.parsing.*
import minisql.util.* import minisql.util.*
import minisql.ast.{ import minisql.ast.{
Ast, Ast,
Aggregation,
AggregationOperator,
Entity, Entity,
Map, Map,
Property, Property,
@ -14,6 +16,7 @@ import minisql.ast.{
PropertyAlias, PropertyAlias,
JoinType, JoinType,
Join, Join,
GroupBy,
given given
} }
import scala.quoted.* import scala.quoted.*
@ -21,18 +24,25 @@ import scala.deriving.*
import scala.compiletime.* import scala.compiletime.*
import scala.compiletime.ops.string.* import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap} import scala.collection.immutable.{Map => IMap}
import scala.util.NotGiven
opaque type Quoted <: Ast = Ast opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted opaque type Query[E] <: Quoted = Quoted
opaque type Agg[E] <: Quoted = Quoted
opaque type Grouped[G, E] = Quoted
opaque type Action[E] <: Quoted = Quoted opaque type Action[E] <: Quoted = Quoted
opaque type Update[E] <: Action[Long] = Quoted opaque type Update[E] <: Action[Long] = Quoted
opaque type Insert[E] <: Action[Long] = Quoted opaque type Insert[E] <: Action[Long] = Quoted
opaque type InsertReturning[E] <: Action[E] = Quoted
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Insert { object Insert {
extension [E](inline insert: Insert[E]) { extension [E](inline insert: Insert[E]) {
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = { inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
@ -47,12 +57,8 @@ object Insert {
} }
} }
opaque type InsertReturning[E] <: Action[E] = Quoted
sealed trait Joined[E1, E2] sealed trait Joined[E1, E2]
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Joined { object Joined {
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] = def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
@ -64,6 +70,20 @@ object Joined {
} }
} }
trait GroupCollection[E] {
def size: Long = throw new NonQuotedException
}
object Grouped {
extension [G, E](inline g: Grouped[G, E]) {
inline def map[E1](
inline f: (p: (G, GroupCollection[E])) => E1
): Query[E1] = {
transform(g)(f)(Map.apply)
}
}
}
private inline def joinOn[E1, E2]( private inline def joinOn[E1, E2](
inline j: Joined[E1, E2], inline j: Joined[E1, E2],
inline f: (E1, E2) => Boolean inline f: (E1, E2) => Boolean
@ -148,6 +168,15 @@ object Query {
transform(e)(f)(Filter.apply) transform(e)(f)(Filter.apply)
} }
inline def groupBy[G](inline f: E => G): Grouped[G, E] = {
transform(e)(f)(GroupBy.apply)
}
inline def size: Agg[Long] =
Aggregation(AggregationOperator.size, e)
inline def min: Agg[Long] = Aggregation(AggregationOperator.min, e)
inline def max: Agg[Long] = Aggregation(AggregationOperator.min, e)
} }
} }
@ -308,7 +337,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
n: Expr[N] n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = { )(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.* import quotes.reflect.*
q.value match { compileImpl(q.value, q, idiom, n)
}
private def compileImpl[I <: Idiom, N <: NamingStrategy](
staticAst: Option[Ast],
runtimeAst: Expr[Ast],
idiom: Expr[I],
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
staticAst match {
case Some(ast) => case Some(ast) =>
val idiom = LoadObject[I].getOrElse( val idiom = LoadObject[I].getOrElse(
report.errorAndAbort(s"Idiom not known at compile") report.errorAndAbort(s"Idiom not known at compile")
@ -324,7 +363,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
case None => case None =>
report.info("Dynamic Query") report.info("Dynamic Query")
'{ '{
$idiom.translate($q)(using $n)._2 $idiom.translate($runtimeAst)(using $n)._2
} }
} }

View file

@ -2,6 +2,7 @@ package minisql.ast
import minisql.NamingStrategy import minisql.NamingStrategy
import minisql.ParamEncoder import minisql.ParamEncoder
import minisql.idiom.MirrorIdiom
import scala.quoted.* import scala.quoted.*
@ -21,6 +22,10 @@ sealed trait Ast {
override def apply(a: Ast) = override def apply(a: Ast) =
super.apply(a.neutral) super.apply(a.neutral)
}.apply(this) }.apply(this)
override def toString(): String = {
MirrorIdiom.translate(this)(using minisql.Literal)._2.toString()
}
} }
//************************************************************ //************************************************************

View file

@ -137,6 +137,13 @@ private given FromExpr[Query] with {
Some(SortBy(b, p, s, o)) Some(SortBy(b, p, s, o))
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } => case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
Some(GroupBy(b, p, body)) Some(GroupBy(b, p, body))
case '{
val x: Ast = ${ Expr(b) }
val y: Ident = ${ Expr(p) }
val z: Ast = ${ Expr(body) }
GroupBy(x, y, z)
} =>
Some(GroupBy(b, p, body))
case '{ Distinct(${ Expr(a) }) } => case '{ Distinct(${ Expr(a) }) } =>
Some(Distinct(a)) Some(Distinct(a))
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } => case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
@ -164,6 +171,9 @@ private given FromExpr[Query] with {
Some(FlatJoin(t, a, ia, on)) Some(FlatJoin(t, a, ia, on))
case '{ Nested(${ Expr(a) }) } => case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a)) Some(Nested(a))
case o =>
import quotes.reflect.*
None
} }
} }
@ -421,6 +431,14 @@ private given FromExpr[Block] with {
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match { def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
case '{ Block(${ Expr(statements) }) } => case '{ Block(${ Expr(statements) }) } =>
Some(Block(statements)) Some(Block(statements))
case '{ Block(List(${ Varargs(stmts) }*)) } =>
stmts.map { x =>
val o = x.asInstanceOf[Expr[Ast]].value
if (o.isEmpty) {
println(s"===================${x.show}")
}
}
None
} }
} }

View file

@ -2,58 +2,58 @@ package minisql.ast
sealed trait Operator sealed trait Operator
sealed trait UnaryOperator extends Operator sealed trait UnaryOperator extends Operator
sealed trait PrefixUnaryOperator extends UnaryOperator sealed trait PrefixUnaryOperator extends UnaryOperator
sealed trait PostfixUnaryOperator extends UnaryOperator sealed trait PostfixUnaryOperator extends UnaryOperator
sealed trait BinaryOperator extends Operator sealed trait BinaryOperator extends Operator
object EqualityOperator { object EqualityOperator {
case object `==` extends BinaryOperator case object `==` extends BinaryOperator
case object `!=` extends BinaryOperator case object `!=` extends BinaryOperator
inline def Eq = `==` inline def Eq = `==`
inline def Neq = `!=` inline def Neq = `!=`
} }
object BooleanOperator { object BooleanOperator {
case object `!` extends PrefixUnaryOperator case object `!` extends PrefixUnaryOperator
case object `&&` extends BinaryOperator case object `&&` extends BinaryOperator
case object `||` extends BinaryOperator case object `||` extends BinaryOperator
} }
object StringOperator { object StringOperator {
case object `concat` extends BinaryOperator case object `concat` extends BinaryOperator
case object `startsWith` extends BinaryOperator case object `startsWith` extends BinaryOperator
case object `split` extends BinaryOperator case object `split` extends BinaryOperator
case object `toUpperCase` extends PostfixUnaryOperator case object `toUpperCase` extends PostfixUnaryOperator
case object `toLowerCase` extends PostfixUnaryOperator case object `toLowerCase` extends PostfixUnaryOperator
case object `toLong` extends PostfixUnaryOperator case object `toLong` extends PostfixUnaryOperator
case object `toInt` extends PostfixUnaryOperator case object `toInt` extends PostfixUnaryOperator
} }
object NumericOperator { object NumericOperator {
case object `-` extends BinaryOperator with PrefixUnaryOperator case object `-` extends BinaryOperator with PrefixUnaryOperator
case object `+` extends BinaryOperator case object `+` extends BinaryOperator
case object `*` extends BinaryOperator case object `*` extends BinaryOperator
case object `>` extends BinaryOperator case object `>` extends BinaryOperator
case object `>=` extends BinaryOperator case object `>=` extends BinaryOperator
case object `<` extends BinaryOperator case object `<` extends BinaryOperator
case object `<=` extends BinaryOperator case object `<=` extends BinaryOperator
case object `/` extends BinaryOperator case object `/` extends BinaryOperator
case object `%` extends BinaryOperator case object `%` extends BinaryOperator
} }
object SetOperator { object SetOperator {
case object `contains` extends BinaryOperator case object `contains` extends BinaryOperator
case object `nonEmpty` extends PostfixUnaryOperator case object `nonEmpty` extends PostfixUnaryOperator
case object `isEmpty` extends PostfixUnaryOperator case object `isEmpty` extends PostfixUnaryOperator
} }
sealed trait AggregationOperator extends Operator sealed trait AggregationOperator extends Operator
object AggregationOperator { object AggregationOperator {
case object `min` extends AggregationOperator case object `min` extends AggregationOperator
case object `max` extends AggregationOperator case object `max` extends AggregationOperator
case object `avg` extends AggregationOperator case object `avg` extends AggregationOperator
case object `sum` extends AggregationOperator case object `sum` extends AggregationOperator
case object `size` extends AggregationOperator case object `size` extends AggregationOperator
} }

View file

@ -7,7 +7,7 @@ import minisql.ColumnDecoder
import minisql.ast.{Ast, ScalarValueLift, CollectAst} import minisql.ast.{Ast, ScalarValueLift, CollectAst}
import scala.deriving.* import scala.deriving.*
import scala.compiletime.* import scala.compiletime.*
import scala.util.Try import scala.util.{Try, Success, Failure}
import scala.annotation.targetName import scala.annotation.targetName
trait RowExtract[A, Row] { trait RowExtract[A, Row] {
@ -113,6 +113,46 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
) )
} }
inline def io[E](inline q: minisql.Agg[E])(using
e: ColumnDecoder.Aux[DBRow, E]
): DBIO[E] = {
val mapper: Iterable[DBRow] => Try[E] = summonFrom {
case _: (E <:< Option[?]) =>
(rows: Iterable[DBRow]) =>
rows.toVector match {
case Vector() => Success(None.asInstanceOf[E])
case Vector(r) =>
RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E])
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
case _ =>
(rows) =>
rows.toVector match {
case Vector(r) => RowExtract.single(e).extract(r)
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
}
val lifts = q.liftMap
val stmt = minisql.compile[I, N](q, idiom, naming)
val (sql, params) = stmt.expand(lifts)
(
sql = sql,
params = params.map(_.value.get.asInstanceOf),
mapper = mapper
)
}
@targetName("ioQuery") @targetName("ioQuery")
inline def io[E]( inline def io[E](
inline q: minisql.Query[E] inline q: minisql.Query[E]

View file

@ -62,6 +62,7 @@ private[minisql] object Parsing {
.orElse(ifParser) .orElse(ifParser)
.orElse(traversableOperationParser) .orElse(traversableOperationParser)
.orElse(patMatchParser) .orElse(patMatchParser)
.orElse(aggParser)
.orElse { .orElse {
case o => case o =>
val str = scala.util.Try(o.show).getOrElse("") val str = scala.util.Try(o.show).getOrElse("")
@ -104,7 +105,6 @@ private[minisql] object Parsing {
lazy val ifParser: Parser[ast.If] = { lazy val ifParser: Parser[ast.If] = {
case '{ if ($a) $b else $c } => case '{ if ($a) $b else $c } =>
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) } '{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
} }
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
@ -115,6 +115,11 @@ private[minisql] object Parsing {
astParser astParser
) )
lazy val aggParser: Parser[ast.Aggregation] = {
case '{ ($t: minisql.GroupCollection[t]).size } =>
'{ ast.Aggregation(ast.AggregationOperator.size, ${ astParser(t) }) }
}
astParser(expr) astParser(expr)
} }

View file

@ -15,14 +15,16 @@ private[parsing] def patMatchParsing(
Ident(t), Ident(t),
List(CaseDef(IsTupleUnapply(binds), None, body)) List(CaseDef(IsTupleUnapply(binds), None, body))
) => ) =>
val bindStmts = binds.map { val bindStmts = binds.zipWithIndex.map {
case Bind(bn, _) => case (Bind(bn, _), i) =>
val fidx = Expr(s"_${i + 1}")
'{ '{
ast.Val( ast.Val(
ast.Ident(${ Expr(bn) }), ast.Ident(${ Expr(bn) }),
ast.Property(ast.Ident(${ Expr(t) }), "_1") ast.Property(ast.Ident(${ Expr(t) }), $fidx)
) )
} }
} }
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr) val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)

View file

@ -97,7 +97,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
assertEquals( assertEquals(
o.sql, o.sql,
"SELECT f1.id, f1.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id" "SELECT f1.id, f2.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
) )
} }
@ -143,4 +143,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
"UPDATE foo SET name = ?, age = ? WHERE id = 1" "UPDATE foo SET name = ?, age = ? WHERE id = 1"
) )
} }
test("GroupBy") {
val o = testContext.io(
Foos.groupBy(f => f.age).map { case (age, fs) => (age, fs.size) }
)
assertEquals(
o.sql,
"SELECT f.age, COUNT(*) FROM foo f GROUP BY f.age"
)
}
} }