Fix tuple unapply parsing & try add group by
This commit is contained in:
parent
feeb9cab1e
commit
0b4a6cb0c4
9 changed files with 161 additions and 41 deletions
12
README.md
12
README.md
|
@ -32,16 +32,16 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = {
|
|||
+ [x] 基础 Ast 及相关操作
|
||||
+ [x] 验证 `inline` 和 `FromExpr` 是否生效
|
||||
+ [ ] 验证 `lift` / `liftCaseClass` 如何处理
|
||||
+ [ ] 验证 `Insert/Update` 实现
|
||||
+ [x] 验证 `Insert/Update` 实现
|
||||
+ [ ] DSL
|
||||
- [x] Map
|
||||
- [ ] Filter/FlatMap/ConcatMap/Union
|
||||
- [ ] Join
|
||||
- [ ] GroupBy/Aggeration
|
||||
- [x] Filter/FlatMap/ConcatMap/Union
|
||||
- [x] Join
|
||||
- [x] GroupBy/Aggeration
|
||||
+ [ ] 函数解析
|
||||
- [x] Ident
|
||||
- [x] Property
|
||||
- [ ] BinaryOperation
|
||||
- [ ] UnaryOperation
|
||||
- [x] BinaryOperation
|
||||
- [x] UnaryOperation
|
||||
- [ ] CaseClass
|
||||
- [ ] Tuple
|
||||
|
|
|
@ -6,6 +6,8 @@ import minisql.parsing.*
|
|||
import minisql.util.*
|
||||
import minisql.ast.{
|
||||
Ast,
|
||||
Aggregation,
|
||||
AggregationOperator,
|
||||
Entity,
|
||||
Map,
|
||||
Property,
|
||||
|
@ -14,6 +16,7 @@ import minisql.ast.{
|
|||
PropertyAlias,
|
||||
JoinType,
|
||||
Join,
|
||||
GroupBy,
|
||||
given
|
||||
}
|
||||
import scala.quoted.*
|
||||
|
@ -21,18 +24,25 @@ import scala.deriving.*
|
|||
import scala.compiletime.*
|
||||
import scala.compiletime.ops.string.*
|
||||
import scala.collection.immutable.{Map => IMap}
|
||||
import scala.util.NotGiven
|
||||
|
||||
opaque type Quoted <: Ast = Ast
|
||||
|
||||
opaque type Query[E] <: Quoted = Quoted
|
||||
|
||||
opaque type Agg[E] <: Quoted = Quoted
|
||||
|
||||
opaque type Grouped[G, E] = Quoted
|
||||
|
||||
opaque type Action[E] <: Quoted = Quoted
|
||||
|
||||
opaque type Update[E] <: Action[Long] = Quoted
|
||||
|
||||
opaque type Insert[E] <: Action[Long] = Quoted
|
||||
|
||||
opaque type InsertReturning[E] <: Action[E] = Quoted
|
||||
|
||||
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
|
||||
|
||||
object Insert {
|
||||
extension [E](inline insert: Insert[E]) {
|
||||
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
|
||||
|
@ -47,12 +57,8 @@ object Insert {
|
|||
}
|
||||
}
|
||||
|
||||
opaque type InsertReturning[E] <: Action[E] = Quoted
|
||||
|
||||
sealed trait Joined[E1, E2]
|
||||
|
||||
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
|
||||
|
||||
object Joined {
|
||||
|
||||
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
|
||||
|
@ -64,6 +70,20 @@ object Joined {
|
|||
}
|
||||
}
|
||||
|
||||
trait GroupCollection[E] {
|
||||
def size: Long = throw new NonQuotedException
|
||||
}
|
||||
|
||||
object Grouped {
|
||||
extension [G, E](inline g: Grouped[G, E]) {
|
||||
inline def map[E1](
|
||||
inline f: (p: (G, GroupCollection[E])) => E1
|
||||
): Query[E1] = {
|
||||
transform(g)(f)(Map.apply)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private inline def joinOn[E1, E2](
|
||||
inline j: Joined[E1, E2],
|
||||
inline f: (E1, E2) => Boolean
|
||||
|
@ -148,6 +168,15 @@ object Query {
|
|||
transform(e)(f)(Filter.apply)
|
||||
}
|
||||
|
||||
inline def groupBy[G](inline f: E => G): Grouped[G, E] = {
|
||||
transform(e)(f)(GroupBy.apply)
|
||||
}
|
||||
|
||||
inline def size: Agg[Long] =
|
||||
Aggregation(AggregationOperator.size, e)
|
||||
|
||||
inline def min: Agg[Long] = Aggregation(AggregationOperator.min, e)
|
||||
inline def max: Agg[Long] = Aggregation(AggregationOperator.min, e)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -308,7 +337,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
|
|||
n: Expr[N]
|
||||
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
|
||||
import quotes.reflect.*
|
||||
q.value match {
|
||||
compileImpl(q.value, q, idiom, n)
|
||||
}
|
||||
|
||||
private def compileImpl[I <: Idiom, N <: NamingStrategy](
|
||||
staticAst: Option[Ast],
|
||||
runtimeAst: Expr[Ast],
|
||||
idiom: Expr[I],
|
||||
n: Expr[N]
|
||||
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
|
||||
import quotes.reflect.*
|
||||
staticAst match {
|
||||
case Some(ast) =>
|
||||
val idiom = LoadObject[I].getOrElse(
|
||||
report.errorAndAbort(s"Idiom not known at compile")
|
||||
|
@ -324,7 +363,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy](
|
|||
case None =>
|
||||
report.info("Dynamic Query")
|
||||
'{
|
||||
$idiom.translate($q)(using $n)._2
|
||||
$idiom.translate($runtimeAst)(using $n)._2
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package minisql.ast
|
|||
|
||||
import minisql.NamingStrategy
|
||||
import minisql.ParamEncoder
|
||||
import minisql.idiom.MirrorIdiom
|
||||
|
||||
import scala.quoted.*
|
||||
|
||||
|
@ -21,6 +22,10 @@ sealed trait Ast {
|
|||
override def apply(a: Ast) =
|
||||
super.apply(a.neutral)
|
||||
}.apply(this)
|
||||
|
||||
override def toString(): String = {
|
||||
MirrorIdiom.translate(this)(using minisql.Literal)._2.toString()
|
||||
}
|
||||
}
|
||||
|
||||
//************************************************************
|
||||
|
|
|
@ -137,6 +137,13 @@ private given FromExpr[Query] with {
|
|||
Some(SortBy(b, p, s, o))
|
||||
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
|
||||
Some(GroupBy(b, p, body))
|
||||
case '{
|
||||
val x: Ast = ${ Expr(b) }
|
||||
val y: Ident = ${ Expr(p) }
|
||||
val z: Ast = ${ Expr(body) }
|
||||
GroupBy(x, y, z)
|
||||
} =>
|
||||
Some(GroupBy(b, p, body))
|
||||
case '{ Distinct(${ Expr(a) }) } =>
|
||||
Some(Distinct(a))
|
||||
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
|
||||
|
@ -164,6 +171,9 @@ private given FromExpr[Query] with {
|
|||
Some(FlatJoin(t, a, ia, on))
|
||||
case '{ Nested(${ Expr(a) }) } =>
|
||||
Some(Nested(a))
|
||||
case o =>
|
||||
import quotes.reflect.*
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -421,6 +431,14 @@ private given FromExpr[Block] with {
|
|||
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
|
||||
case '{ Block(${ Expr(statements) }) } =>
|
||||
Some(Block(statements))
|
||||
case '{ Block(List(${ Varargs(stmts) }*)) } =>
|
||||
stmts.map { x =>
|
||||
val o = x.asInstanceOf[Expr[Ast]].value
|
||||
if (o.isEmpty) {
|
||||
println(s"===================${x.show}")
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,58 +2,58 @@ package minisql.ast
|
|||
|
||||
sealed trait Operator
|
||||
|
||||
sealed trait UnaryOperator extends Operator
|
||||
sealed trait PrefixUnaryOperator extends UnaryOperator
|
||||
sealed trait UnaryOperator extends Operator
|
||||
sealed trait PrefixUnaryOperator extends UnaryOperator
|
||||
sealed trait PostfixUnaryOperator extends UnaryOperator
|
||||
sealed trait BinaryOperator extends Operator
|
||||
sealed trait BinaryOperator extends Operator
|
||||
|
||||
object EqualityOperator {
|
||||
case object `==` extends BinaryOperator
|
||||
case object `!=` extends BinaryOperator
|
||||
inline def Eq = `==`
|
||||
inline def Eq = `==`
|
||||
inline def Neq = `!=`
|
||||
}
|
||||
|
||||
object BooleanOperator {
|
||||
case object `!` extends PrefixUnaryOperator
|
||||
case object `!` extends PrefixUnaryOperator
|
||||
case object `&&` extends BinaryOperator
|
||||
case object `||` extends BinaryOperator
|
||||
}
|
||||
|
||||
object StringOperator {
|
||||
case object `concat` extends BinaryOperator
|
||||
case object `startsWith` extends BinaryOperator
|
||||
case object `split` extends BinaryOperator
|
||||
case object `concat` extends BinaryOperator
|
||||
case object `startsWith` extends BinaryOperator
|
||||
case object `split` extends BinaryOperator
|
||||
case object `toUpperCase` extends PostfixUnaryOperator
|
||||
case object `toLowerCase` extends PostfixUnaryOperator
|
||||
case object `toLong` extends PostfixUnaryOperator
|
||||
case object `toInt` extends PostfixUnaryOperator
|
||||
case object `toLong` extends PostfixUnaryOperator
|
||||
case object `toInt` extends PostfixUnaryOperator
|
||||
}
|
||||
|
||||
object NumericOperator {
|
||||
case object `-` extends BinaryOperator with PrefixUnaryOperator
|
||||
case object `+` extends BinaryOperator
|
||||
case object `*` extends BinaryOperator
|
||||
case object `>` extends BinaryOperator
|
||||
case object `-` extends BinaryOperator with PrefixUnaryOperator
|
||||
case object `+` extends BinaryOperator
|
||||
case object `*` extends BinaryOperator
|
||||
case object `>` extends BinaryOperator
|
||||
case object `>=` extends BinaryOperator
|
||||
case object `<` extends BinaryOperator
|
||||
case object `<` extends BinaryOperator
|
||||
case object `<=` extends BinaryOperator
|
||||
case object `/` extends BinaryOperator
|
||||
case object `%` extends BinaryOperator
|
||||
case object `/` extends BinaryOperator
|
||||
case object `%` extends BinaryOperator
|
||||
}
|
||||
|
||||
object SetOperator {
|
||||
case object `contains` extends BinaryOperator
|
||||
case object `nonEmpty` extends PostfixUnaryOperator
|
||||
case object `isEmpty` extends PostfixUnaryOperator
|
||||
case object `isEmpty` extends PostfixUnaryOperator
|
||||
}
|
||||
|
||||
sealed trait AggregationOperator extends Operator
|
||||
|
||||
object AggregationOperator {
|
||||
case object `min` extends AggregationOperator
|
||||
case object `max` extends AggregationOperator
|
||||
case object `avg` extends AggregationOperator
|
||||
case object `sum` extends AggregationOperator
|
||||
case object `min` extends AggregationOperator
|
||||
case object `max` extends AggregationOperator
|
||||
case object `avg` extends AggregationOperator
|
||||
case object `sum` extends AggregationOperator
|
||||
case object `size` extends AggregationOperator
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import minisql.ColumnDecoder
|
|||
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
|
||||
import scala.deriving.*
|
||||
import scala.compiletime.*
|
||||
import scala.util.Try
|
||||
import scala.util.{Try, Success, Failure}
|
||||
import scala.annotation.targetName
|
||||
|
||||
trait RowExtract[A, Row] {
|
||||
|
@ -113,6 +113,46 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
|
|||
)
|
||||
}
|
||||
|
||||
inline def io[E](inline q: minisql.Agg[E])(using
|
||||
e: ColumnDecoder.Aux[DBRow, E]
|
||||
): DBIO[E] = {
|
||||
val mapper: Iterable[DBRow] => Try[E] = summonFrom {
|
||||
case _: (E <:< Option[?]) =>
|
||||
(rows: Iterable[DBRow]) =>
|
||||
rows.toVector match {
|
||||
case Vector() => Success(None.asInstanceOf[E])
|
||||
case Vector(r) =>
|
||||
RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E])
|
||||
case o =>
|
||||
Failure(
|
||||
new IllegalStateException(
|
||||
s"Expect agg value, got ${o.size} rows"
|
||||
)
|
||||
)
|
||||
}
|
||||
case _ =>
|
||||
(rows) =>
|
||||
rows.toVector match {
|
||||
case Vector(r) => RowExtract.single(e).extract(r)
|
||||
case o =>
|
||||
Failure(
|
||||
new IllegalStateException(
|
||||
s"Expect agg value, got ${o.size} rows"
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
val lifts = q.liftMap
|
||||
val stmt = minisql.compile[I, N](q, idiom, naming)
|
||||
val (sql, params) = stmt.expand(lifts)
|
||||
(
|
||||
sql = sql,
|
||||
params = params.map(_.value.get.asInstanceOf),
|
||||
mapper = mapper
|
||||
)
|
||||
}
|
||||
|
||||
@targetName("ioQuery")
|
||||
inline def io[E](
|
||||
inline q: minisql.Query[E]
|
||||
|
|
|
@ -62,6 +62,7 @@ private[minisql] object Parsing {
|
|||
.orElse(ifParser)
|
||||
.orElse(traversableOperationParser)
|
||||
.orElse(patMatchParser)
|
||||
.orElse(aggParser)
|
||||
.orElse {
|
||||
case o =>
|
||||
val str = scala.util.Try(o.show).getOrElse("")
|
||||
|
@ -104,7 +105,6 @@ private[minisql] object Parsing {
|
|||
lazy val ifParser: Parser[ast.If] = {
|
||||
case '{ if ($a) $b else $c } =>
|
||||
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
|
||||
|
||||
}
|
||||
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
|
||||
|
||||
|
@ -115,6 +115,11 @@ private[minisql] object Parsing {
|
|||
astParser
|
||||
)
|
||||
|
||||
lazy val aggParser: Parser[ast.Aggregation] = {
|
||||
case '{ ($t: minisql.GroupCollection[t]).size } =>
|
||||
'{ ast.Aggregation(ast.AggregationOperator.size, ${ astParser(t) }) }
|
||||
}
|
||||
|
||||
astParser(expr)
|
||||
}
|
||||
|
||||
|
|
|
@ -15,14 +15,16 @@ private[parsing] def patMatchParsing(
|
|||
Ident(t),
|
||||
List(CaseDef(IsTupleUnapply(binds), None, body))
|
||||
) =>
|
||||
val bindStmts = binds.map {
|
||||
case Bind(bn, _) =>
|
||||
val bindStmts = binds.zipWithIndex.map {
|
||||
case (Bind(bn, _), i) =>
|
||||
val fidx = Expr(s"_${i + 1}")
|
||||
'{
|
||||
ast.Val(
|
||||
ast.Ident(${ Expr(bn) }),
|
||||
ast.Property(ast.Ident(${ Expr(t) }), "_1")
|
||||
ast.Property(ast.Ident(${ Expr(t) }), $fidx)
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)
|
||||
|
|
|
@ -97,7 +97,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
|
|||
|
||||
assertEquals(
|
||||
o.sql,
|
||||
"SELECT f1.id, f1.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
|
||||
"SELECT f1.id, f2.id FROM foo f1 INNER JOIN foo f2 ON f1.id = f2.id"
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -143,4 +143,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
|
|||
"UPDATE foo SET name = ?, age = ? WHERE id = 1"
|
||||
)
|
||||
}
|
||||
|
||||
test("GroupBy") {
|
||||
val o = testContext.io(
|
||||
Foos.groupBy(f => f.age).map { case (age, fs) => (age, fs.size) }
|
||||
)
|
||||
assertEquals(
|
||||
o.sql,
|
||||
"SELECT f.age, COUNT(*) FROM foo f GROUP BY f.age"
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue