Fix tuple unapply parsing & try add group by
This commit is contained in:
parent
feeb9cab1e
commit
0b4a6cb0c4
9 changed files with 161 additions and 41 deletions
12
README.md
12
README.md
|
@ -32,16 +32,16 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = {
|
||||||
+ [x] 基础 Ast 及相关操作
|
+ [x] 基础 Ast 及相关操作
|
||||||
+ [x] 验证 `inline` 和 `FromExpr` 是否生效
|
+ [x] 验证 `inline` 和 `FromExpr` 是否生效
|
||||||
+ [ ] 验证 `lift` / `liftCaseClass` 如何处理
|
+ [ ] 验证 `lift` / `liftCaseClass` 如何处理
|
||||||
+ [ ] 验证 `Insert/Update` 实现
|
+ [x] 验证 `Insert/Update` 实现
|
||||||
+ [ ] DSL
|
+ [ ] DSL
|
||||||
- [x] Map
|
- [x] Map
|
||||||
- [ ] Filter/FlatMap/ConcatMap/Union
|
- [x] Filter/FlatMap/ConcatMap/Union
|
||||||
- [ ] Join
|
- [x] Join
|
||||||
- [ ] GroupBy/Aggeration
|
- [x] GroupBy/Aggeration
|
||||||
+ [ ] 函数解析
|
+ [ ] 函数解析
|
||||||
- [x] Ident
|
- [x] Ident
|
||||||
- [x] Property
|
- [x] Property
|
||||||
- [ ] BinaryOperation
|
- [x] BinaryOperation
|
||||||
- [ ] UnaryOperation
|
- [x] UnaryOperation
|
||||||
- [ ] CaseClass
|
- [ ] CaseClass
|
||||||
- [ ] Tuple
|
- [ ] Tuple
|
||||||
|
|
|
@ -6,6 +6,8 @@ import minisql.parsing.*
|
||||||
import minisql.util.*
|
import minisql.util.*
|
||||||
import minisql.ast.{
|
import minisql.ast.{
|
||||||
Ast,
|
Ast,
|
||||||
|
Aggregation,
|
||||||
|
AggregationOperator,
|
||||||
Entity,
|
Entity,
|
||||||
Map,
|
Map,
|
||||||
Property,
|
Property,
|
||||||
|
@ -14,6 +16,7 @@ import minisql.ast.{
|
||||||
PropertyAlias,
|
PropertyAlias,
|
||||||
JoinType,
|
JoinType,
|
||||||
Join,
|
Join,
|
||||||
|
GroupBy,
|
||||||
given
|
given
|
||||||
}
|
}
|
||||||
import scala.quoted.*
|
import scala.quoted.*
|
||||||
|
@ -21,18 +24,25 @@ import scala.deriving.*
|
||||||
import scala.compiletime.*
|
import scala.compiletime.*
|
||||||
import scala.compiletime.ops.string.*
|
import scala.compiletime.ops.string.*
|
||||||
import scala.collection.immutable.{Map => IMap}
|
import scala.collection.immutable.{Map => IMap}
|
||||||
import scala.util.NotGiven
|
|
||||||
|
|
||||||
opaque type Quoted <: Ast = Ast
|
opaque type Quoted <: Ast = Ast
|
||||||
|
|
||||||
opaque type Query[E] <: Quoted = Quoted
|
opaque type Query[E] <: Quoted = Quoted
|
||||||
|
|
||||||
|
opaque type Agg[E] <: Quoted = Quoted
|
||||||
|
|
||||||
|
opaque type Grouped[G, E] = Quoted
|
||||||
|
|
||||||
opaque type Action[E] <: Quoted = Quoted
|
opaque type Action[E] <: Quoted = Quoted
|
||||||
|
|
||||||
opaque type Update[E] <: Action[Long] = Quoted
|
opaque type Update[E] <: Action[Long] = Quoted
|
||||||
|
|
||||||
opaque type Insert[E] <: Action[Long] = Quoted
|
opaque type Insert[E] <: Action[Long] = Quoted
|
||||||
|
|
||||||
|
opaque type InsertReturning[E] <: Action[E] = Quoted
|
||||||
|
|
||||||
|
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
|
||||||
|
|
||||||
object Insert {
|
object Insert {
|
||||||
extension [E](inline insert: Insert[E]) {
|
extension [E](inline insert: Insert[E]) {
|
||||||
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
|
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
|
||||||
|
@ -47,12 +57,8 @@ object Insert {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
opaque type InsertReturning[E] <: Action[E] = Quoted
|
|
||||||
|
|
||||||
sealed trait Joined[E1, E2]
|
sealed trait Joined[E1, E2]
|
||||||
|
|
||||||
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
|
|
||||||
|
|
||||||
object Joined {
|
object Joined {
|
||||||
|
|
||||||
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
|
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
|
||||||
|
@ -64,6 +70,20 @@ object Joined {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait GroupCollection[E] {
|
||||||
|
def size: Long = throw new NonQuotedException
|
||||||
|
}
|
||||||
|
|
||||||
|
object Grouped {
|
||||||
|
extension [G, E](inline g: Grouped[G, E]) {
|
||||||
|
inline def map[E1](
|
||||||
|
inline f: (p: (G, GroupCollection[E])) => E1
|
||||||
|
): Query[E1] = {
|
||||||
|
transform(g)(f)(Map.apply)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private inline def joinOn[E1, E2](
|
private inline def joinOn[E1, E2](
|
||||||
inline j: Joined[E1, E2],
|
inline j: Joined[E1, E2],
|
||||||
inline f: (E1, E2) => Boolean
|
inline f: (E1, E2) => Boolean
|
||||||
|
@ -148,6 +168,15 @@ object Query {
|
||||||
transform(e)(f)(Filter.apply)
|
transform(e)(f)(Filter.apply)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline def groupBy[G](inline f: E => G): Grouped[G, E] = {
|
||||||
|
transform(e)(f)(GroupBy.apply)
|
||||||
|
}
|
||||||
|
|
||||||
|
inline def size: Agg[Long] =
|
||||||
|
Aggregation(AggregationOperator.size, e)
|
||||||
|
|
||||||
|
inline def min: Agg[Long] = Aggregation(AggregationOperator.min, e)
|
||||||
|
inline def max: Agg[Long] = Aggregation(AggregationOperator.min, e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,7 +337,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
|
||||||
n: Expr[N]
|
n: Expr[N]
|
||||||
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
|
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
|
||||||
import quotes.reflect.*
|
import quotes.reflect.*
|
||||||
q.value match {
|
compileImpl(q.value, q, idiom, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
private def compileImpl[I <: Idiom, N <: NamingStrategy](
|
||||||
|
staticAst: Option[Ast],
|
||||||
|
runtimeAst: Expr[Ast],
|
||||||
|
idiom: Expr[I],
|
||||||
|
n: Expr[N]
|
||||||
|
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
|
||||||
|
import quotes.reflect.*
|
||||||
|
staticAst match {
|
||||||
case Some(ast) =>
|
case Some(ast) =>
|
||||||
val idiom = LoadObject[I].getOrElse(
|
val idiom = LoadObject[I].getOrElse(
|
||||||
report.errorAndAbort(s"Idiom not known at compile")
|
report.errorAndAbort(s"Idiom not known at compile")
|
||||||
|
@ -324,7 +363,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
|
||||||
case None =>
|
case None =>
|
||||||
report.info("Dynamic Query")
|
report.info("Dynamic Query")
|
||||||
'{
|
'{
|
||||||
$idiom.translate($q)(using $n)._2
|
$idiom.translate($runtimeAst)(using $n)._2
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package minisql.ast
|
||||||
|
|
||||||
import minisql.NamingStrategy
|
import minisql.NamingStrategy
|
||||||
import minisql.ParamEncoder
|
import minisql.ParamEncoder
|
||||||
|
import minisql.idiom.MirrorIdiom
|
||||||
|
|
||||||
import scala.quoted.*
|
import scala.quoted.*
|
||||||
|
|
||||||
|
@ -21,6 +22,10 @@ sealed trait Ast {
|
||||||
override def apply(a: Ast) =
|
override def apply(a: Ast) =
|
||||||
super.apply(a.neutral)
|
super.apply(a.neutral)
|
||||||
}.apply(this)
|
}.apply(this)
|
||||||
|
|
||||||
|
override def toString(): String = {
|
||||||
|
MirrorIdiom.translate(this)(using minisql.Literal)._2.toString()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//************************************************************
|
//************************************************************
|
||||||
|
|
|
@ -137,6 +137,13 @@ private given FromExpr[Query] with {
|
||||||
Some(SortBy(b, p, s, o))
|
Some(SortBy(b, p, s, o))
|
||||||
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
|
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
|
||||||
Some(GroupBy(b, p, body))
|
Some(GroupBy(b, p, body))
|
||||||
|
case '{
|
||||||
|
val x: Ast = ${ Expr(b) }
|
||||||
|
val y: Ident = ${ Expr(p) }
|
||||||
|
val z: Ast = ${ Expr(body) }
|
||||||
|
GroupBy(x, y, z)
|
||||||
|
} =>
|
||||||
|
Some(GroupBy(b, p, body))
|
||||||
case '{ Distinct(${ Expr(a) }) } =>
|
case '{ Distinct(${ Expr(a) }) } =>
|
||||||
Some(Distinct(a))
|
Some(Distinct(a))
|
||||||
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
|
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
|
||||||
|
@ -164,6 +171,9 @@ private given FromExpr[Query] with {
|
||||||
Some(FlatJoin(t, a, ia, on))
|
Some(FlatJoin(t, a, ia, on))
|
||||||
case '{ Nested(${ Expr(a) }) } =>
|
case '{ Nested(${ Expr(a) }) } =>
|
||||||
Some(Nested(a))
|
Some(Nested(a))
|
||||||
|
case o =>
|
||||||
|
import quotes.reflect.*
|
||||||
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -421,6 +431,14 @@ private given FromExpr[Block] with {
|
||||||
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
|
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
|
||||||
case '{ Block(${ Expr(statements) }) } =>
|
case '{ Block(${ Expr(statements) }) } =>
|
||||||
Some(Block(statements))
|
Some(Block(statements))
|
||||||
|
case '{ Block(List(${ Varargs(stmts) }*)) } =>
|
||||||
|
stmts.map { x =>
|
||||||
|
val o = x.asInstanceOf[Expr[Ast]].value
|
||||||
|
if (o.isEmpty) {
|
||||||
|
println(s"===================${x.show}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import minisql.ColumnDecoder
|
||||||
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
|
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
|
||||||
import scala.deriving.*
|
import scala.deriving.*
|
||||||
import scala.compiletime.*
|
import scala.compiletime.*
|
||||||
import scala.util.Try
|
import scala.util.{Try, Success, Failure}
|
||||||
import scala.annotation.targetName
|
import scala.annotation.targetName
|
||||||
|
|
||||||
trait RowExtract[A, Row] {
|
trait RowExtract[A, Row] {
|
||||||
|
@ -113,6 +113,46 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline def io[E](inline q: minisql.Agg[E])(using
|
||||||
|
e: ColumnDecoder.Aux[DBRow, E]
|
||||||
|
): DBIO[E] = {
|
||||||
|
val mapper: Iterable[DBRow] => Try[E] = summonFrom {
|
||||||
|
case _: (E <:< Option[?]) =>
|
||||||
|
(rows: Iterable[DBRow]) =>
|
||||||
|
rows.toVector match {
|
||||||
|
case Vector() => Success(None.asInstanceOf[E])
|
||||||
|
case Vector(r) =>
|
||||||
|
RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E])
|
||||||
|
case o =>
|
||||||
|
Failure(
|
||||||
|
new IllegalStateException(
|
||||||
|
s"Expect agg value, got ${o.size} rows"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case _ =>
|
||||||
|
(rows) =>
|
||||||
|
rows.toVector match {
|
||||||
|
case Vector(r) => RowExtract.single(e).extract(r)
|
||||||
|
case o =>
|
||||||
|
Failure(
|
||||||
|
new IllegalStateException(
|
||||||
|
s"Expect agg value, got ${o.size} rows"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
val lifts = q.liftMap
|
||||||
|
val stmt = minisql.compile[I, N](q, idiom, naming)
|
||||||
|
val (sql, params) = stmt.expand(lifts)
|
||||||
|
(
|
||||||
|
sql = sql,
|
||||||
|
params = params.map(_.value.get.asInstanceOf),
|
||||||
|
mapper = mapper
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
@targetName("ioQuery")
|
@targetName("ioQuery")
|
||||||
inline def io[E](
|
inline def io[E](
|
||||||
inline q: minisql.Query[E]
|
inline q: minisql.Query[E]
|
||||||
|
|
|
@ -62,6 +62,7 @@ private[minisql] object Parsing {
|
||||||
.orElse(ifParser)
|
.orElse(ifParser)
|
||||||
.orElse(traversableOperationParser)
|
.orElse(traversableOperationParser)
|
||||||
.orElse(patMatchParser)
|
.orElse(patMatchParser)
|
||||||
|
.orElse(aggParser)
|
||||||
.orElse {
|
.orElse {
|
||||||
case o =>
|
case o =>
|
||||||
val str = scala.util.Try(o.show).getOrElse("")
|
val str = scala.util.Try(o.show).getOrElse("")
|
||||||
|
@ -104,7 +105,6 @@ private[minisql] object Parsing {
|
||||||
lazy val ifParser: Parser[ast.If] = {
|
lazy val ifParser: Parser[ast.If] = {
|
||||||
case '{ if ($a) $b else $c } =>
|
case '{ if ($a) $b else $c } =>
|
||||||
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
|
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
|
||||||
|
|
||||||
}
|
}
|
||||||
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
|
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
|
||||||
|
|
||||||
|
@ -115,6 +115,11 @@ private[minisql] object Parsing {
|
||||||
astParser
|
astParser
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lazy val aggParser: Parser[ast.Aggregation] = {
|
||||||
|
case '{ ($t: minisql.GroupCollection[t]).size } =>
|
||||||
|
'{ ast.Aggregation(ast.AggregationOperator.size, ${ astParser(t) }) }
|
||||||
|
}
|
||||||
|
|
||||||
astParser(expr)
|
astParser(expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,14 +15,16 @@ private[parsing] def patMatchParsing(
|
||||||
Ident(t),
|
Ident(t),
|
||||||
List(CaseDef(IsTupleUnapply(binds), None, body))
|
List(CaseDef(IsTupleUnapply(binds), None, body))
|
||||||
) =>
|
) =>
|
||||||
val bindStmts = binds.map {
|
val bindStmts = binds.zipWithIndex.map {
|
||||||
case Bind(bn, _) =>
|
case (Bind(bn, _), i) =>
|
||||||
|
val fidx = Expr(s"_${i + 1}")
|
||||||
'{
|
'{
|
||||||
ast.Val(
|
ast.Val(
|
||||||
ast.Ident(${ Expr(bn) }),
|
ast.Ident(${ Expr(bn) }),
|
||||||
ast.Property(ast.Ident(${ Expr(t) }), "_1")
|
ast.Property(ast.Ident(${ Expr(t) }), $fidx)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)
|
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)
|
||||||
|
|
|
@ -97,7 +97,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
|
||||||
|
|
||||||
assertEquals(
|
assertEquals(
|
||||||
o.sql,
|
o.sql,
|
||||||
"SELECT f1.id, f1.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
|
"SELECT f1.id, f2.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,4 +143,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
|
||||||
"UPDATE foo SET name = ?, age = ? WHERE id = 1"
|
"UPDATE foo SET name = ?, age = ? WHERE id = 1"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("GroupBy") {
|
||||||
|
val o = testContext.io(
|
||||||
|
Foos.groupBy(f => f.age).map { case (age, fs) => (age, fs.size) }
|
||||||
|
)
|
||||||
|
assertEquals(
|
||||||
|
o.sql,
|
||||||
|
"SELECT f.age, COUNT(*) FROM foo f GROUP BY f.age"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue