Fix tuple unapply parsing & try add group by

This commit is contained in:
jilen 2025-07-15 14:29:12 +08:00
parent feeb9cab1e
commit 0b4a6cb0c4
9 changed files with 161 additions and 41 deletions

View file

@ -32,16 +32,16 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = {
+ [x] 基础 Ast 及相关操作
+ [x] 验证 `inline``FromExpr` 是否生效
+ [ ] 验证 `lift` / `liftCaseClass` 如何处理
+ [ ] 验证 `Insert/Update` 实现
+ [x] 验证 `Insert/Update` 实现
+ [ ] DSL
- [x] Map
- [ ] Filter/FlatMap/ConcatMap/Union
- [ ] Join
- [ ] GroupBy/Aggeration
- [x] Filter/FlatMap/ConcatMap/Union
- [x] Join
- [x] GroupBy/Aggeration
+ [ ] 函数解析
- [x] Ident
- [x] Property
- [ ] BinaryOperation
- [ ] UnaryOperation
- [x] BinaryOperation
- [x] UnaryOperation
- [ ] CaseClass
- [ ] Tuple

View file

@ -6,6 +6,8 @@ import minisql.parsing.*
import minisql.util.*
import minisql.ast.{
Ast,
Aggregation,
AggregationOperator,
Entity,
Map,
Property,
@ -14,6 +16,7 @@ import minisql.ast.{
PropertyAlias,
JoinType,
Join,
GroupBy,
given
}
import scala.quoted.*
@ -21,18 +24,25 @@ import scala.deriving.*
import scala.compiletime.*
import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap}
import scala.util.NotGiven
opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted
opaque type Agg[E] <: Quoted = Quoted
opaque type Grouped[G, E] = Quoted
opaque type Action[E] <: Quoted = Quoted
opaque type Update[E] <: Action[Long] = Quoted
opaque type Insert[E] <: Action[Long] = Quoted
opaque type InsertReturning[E] <: Action[E] = Quoted
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Insert {
extension [E](inline insert: Insert[E]) {
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
@ -47,12 +57,8 @@ object Insert {
}
}
opaque type InsertReturning[E] <: Action[E] = Quoted
sealed trait Joined[E1, E2]
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Joined {
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
@ -64,6 +70,20 @@ object Joined {
}
}
trait GroupCollection[E] {
def size: Long = throw new NonQuotedException
}
object Grouped {
extension [G, E](inline g: Grouped[G, E]) {
inline def map[E1](
inline f: (p: (G, GroupCollection[E])) => E1
): Query[E1] = {
transform(g)(f)(Map.apply)
}
}
}
private inline def joinOn[E1, E2](
inline j: Joined[E1, E2],
inline f: (E1, E2) => Boolean
@ -148,6 +168,15 @@ object Query {
transform(e)(f)(Filter.apply)
}
inline def groupBy[G](inline f: E => G): Grouped[G, E] = {
transform(e)(f)(GroupBy.apply)
}
inline def size: Agg[Long] =
Aggregation(AggregationOperator.size, e)
inline def min: Agg[Long] = Aggregation(AggregationOperator.min, e)
inline def max: Agg[Long] = Aggregation(AggregationOperator.min, e)
}
}
@ -308,7 +337,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
q.value match {
compileImpl(q.value, q, idiom, n)
}
private def compileImpl[I <: Idiom, N <: NamingStrategy](
staticAst: Option[Ast],
runtimeAst: Expr[Ast],
idiom: Expr[I],
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
staticAst match {
case Some(ast) =>
val idiom = LoadObject[I].getOrElse(
report.errorAndAbort(s"Idiom not known at compile")
@ -324,7 +363,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
case None =>
report.info("Dynamic Query")
'{
$idiom.translate($q)(using $n)._2
$idiom.translate($runtimeAst)(using $n)._2
}
}

View file

@ -2,6 +2,7 @@ package minisql.ast
import minisql.NamingStrategy
import minisql.ParamEncoder
import minisql.idiom.MirrorIdiom
import scala.quoted.*
@ -21,6 +22,10 @@ sealed trait Ast {
override def apply(a: Ast) =
super.apply(a.neutral)
}.apply(this)
override def toString(): String = {
MirrorIdiom.translate(this)(using minisql.Literal)._2.toString()
}
}
//************************************************************

View file

@ -137,6 +137,13 @@ private given FromExpr[Query] with {
Some(SortBy(b, p, s, o))
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
Some(GroupBy(b, p, body))
case '{
val x: Ast = ${ Expr(b) }
val y: Ident = ${ Expr(p) }
val z: Ast = ${ Expr(body) }
GroupBy(x, y, z)
} =>
Some(GroupBy(b, p, body))
case '{ Distinct(${ Expr(a) }) } =>
Some(Distinct(a))
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
@ -164,6 +171,9 @@ private given FromExpr[Query] with {
Some(FlatJoin(t, a, ia, on))
case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a))
case o =>
import quotes.reflect.*
None
}
}
@ -421,6 +431,14 @@ private given FromExpr[Block] with {
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
case '{ Block(${ Expr(statements) }) } =>
Some(Block(statements))
case '{ Block(List(${ Varargs(stmts) }*)) } =>
stmts.map { x =>
val o = x.asInstanceOf[Expr[Ast]].value
if (o.isEmpty) {
println(s"===================${x.show}")
}
}
None
}
}

View file

@ -2,58 +2,58 @@ package minisql.ast
sealed trait Operator
sealed trait UnaryOperator extends Operator
sealed trait PrefixUnaryOperator extends UnaryOperator
sealed trait UnaryOperator extends Operator
sealed trait PrefixUnaryOperator extends UnaryOperator
sealed trait PostfixUnaryOperator extends UnaryOperator
sealed trait BinaryOperator extends Operator
sealed trait BinaryOperator extends Operator
object EqualityOperator {
case object `==` extends BinaryOperator
case object `!=` extends BinaryOperator
inline def Eq = `==`
inline def Eq = `==`
inline def Neq = `!=`
}
object BooleanOperator {
case object `!` extends PrefixUnaryOperator
case object `!` extends PrefixUnaryOperator
case object `&&` extends BinaryOperator
case object `||` extends BinaryOperator
}
object StringOperator {
case object `concat` extends BinaryOperator
case object `startsWith` extends BinaryOperator
case object `split` extends BinaryOperator
case object `concat` extends BinaryOperator
case object `startsWith` extends BinaryOperator
case object `split` extends BinaryOperator
case object `toUpperCase` extends PostfixUnaryOperator
case object `toLowerCase` extends PostfixUnaryOperator
case object `toLong` extends PostfixUnaryOperator
case object `toInt` extends PostfixUnaryOperator
case object `toLong` extends PostfixUnaryOperator
case object `toInt` extends PostfixUnaryOperator
}
object NumericOperator {
case object `-` extends BinaryOperator with PrefixUnaryOperator
case object `+` extends BinaryOperator
case object `*` extends BinaryOperator
case object `>` extends BinaryOperator
case object `-` extends BinaryOperator with PrefixUnaryOperator
case object `+` extends BinaryOperator
case object `*` extends BinaryOperator
case object `>` extends BinaryOperator
case object `>=` extends BinaryOperator
case object `<` extends BinaryOperator
case object `<` extends BinaryOperator
case object `<=` extends BinaryOperator
case object `/` extends BinaryOperator
case object `%` extends BinaryOperator
case object `/` extends BinaryOperator
case object `%` extends BinaryOperator
}
object SetOperator {
case object `contains` extends BinaryOperator
case object `nonEmpty` extends PostfixUnaryOperator
case object `isEmpty` extends PostfixUnaryOperator
case object `isEmpty` extends PostfixUnaryOperator
}
sealed trait AggregationOperator extends Operator
object AggregationOperator {
case object `min` extends AggregationOperator
case object `max` extends AggregationOperator
case object `avg` extends AggregationOperator
case object `sum` extends AggregationOperator
case object `min` extends AggregationOperator
case object `max` extends AggregationOperator
case object `avg` extends AggregationOperator
case object `sum` extends AggregationOperator
case object `size` extends AggregationOperator
}

View file

@ -7,7 +7,7 @@ import minisql.ColumnDecoder
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
import scala.deriving.*
import scala.compiletime.*
import scala.util.Try
import scala.util.{Try, Success, Failure}
import scala.annotation.targetName
trait RowExtract[A, Row] {
@ -113,6 +113,46 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
)
}
inline def io[E](inline q: minisql.Agg[E])(using
e: ColumnDecoder.Aux[DBRow, E]
): DBIO[E] = {
val mapper: Iterable[DBRow] => Try[E] = summonFrom {
case _: (E <:< Option[?]) =>
(rows: Iterable[DBRow]) =>
rows.toVector match {
case Vector() => Success(None.asInstanceOf[E])
case Vector(r) =>
RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E])
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
case _ =>
(rows) =>
rows.toVector match {
case Vector(r) => RowExtract.single(e).extract(r)
case o =>
Failure(
new IllegalStateException(
s"Expect agg value, got ${o.size} rows"
)
)
}
}
val lifts = q.liftMap
val stmt = minisql.compile[I, N](q, idiom, naming)
val (sql, params) = stmt.expand(lifts)
(
sql = sql,
params = params.map(_.value.get.asInstanceOf),
mapper = mapper
)
}
@targetName("ioQuery")
inline def io[E](
inline q: minisql.Query[E]

View file

@ -62,6 +62,7 @@ private[minisql] object Parsing {
.orElse(ifParser)
.orElse(traversableOperationParser)
.orElse(patMatchParser)
.orElse(aggParser)
.orElse {
case o =>
val str = scala.util.Try(o.show).getOrElse("")
@ -104,7 +105,6 @@ private[minisql] object Parsing {
lazy val ifParser: Parser[ast.If] = {
case '{ if ($a) $b else $c } =>
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
}
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
@ -115,6 +115,11 @@ private[minisql] object Parsing {
astParser
)
lazy val aggParser: Parser[ast.Aggregation] = {
case '{ ($t: minisql.GroupCollection[t]).size } =>
'{ ast.Aggregation(ast.AggregationOperator.size, ${ astParser(t) }) }
}
astParser(expr)
}

View file

@ -15,14 +15,16 @@ private[parsing] def patMatchParsing(
Ident(t),
List(CaseDef(IsTupleUnapply(binds), None, body))
) =>
val bindStmts = binds.map {
case Bind(bn, _) =>
val bindStmts = binds.zipWithIndex.map {
case (Bind(bn, _), i) =>
val fidx = Expr(s"_${i + 1}")
'{
ast.Val(
ast.Ident(${ Expr(bn) }),
ast.Property(ast.Ident(${ Expr(t) }), "_1")
ast.Property(ast.Ident(${ Expr(t) }), $fidx)
)
}
}
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)

View file

@ -97,7 +97,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
assertEquals(
o.sql,
"SELECT f1.id, f1.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
"SELECT f1.id, f2.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
)
}
@ -143,4 +143,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
"UPDATE foo SET name = ?, age = ? WHERE id = 1"
)
}
test("GroupBy") {
val o = testContext.io(
Foos.groupBy(f => f.age).map { case (age, fs) => (age, fs.size) }
)
assertEquals(
o.sql,
"SELECT f.age, COUNT(*) FROM foo f GROUP BY f.age"
)
}
}