优化匹配逻辑

This commit is contained in:
jilen 2024-12-09 17:03:38 +08:00
parent bd1dbaa8e2
commit 1adada0c79
3 changed files with 136 additions and 38 deletions

View file

@ -6,16 +6,45 @@ import scala.quoted.*
type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]]
private[minisql] def parseFunction[X](
x: Expr[X]
)(using Quotes): Expr[(List[Ast], Ast)] = {
private[minisql] inline def parseParamAt[A, B](
inline f: A => B,
inline n: Int
): ast.Ident = ${
parseParamAt('f, 'n)
}
private[minisql] inline def parseBody[X](
inline f: X
): ast.Ast = ${
parseBody('f)
}
private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
Quotes
): Expr[ast.Ident] = {
import quotes.reflect.*
x.asTerm match {
case Lambda(vals, body) =>
val paramExprs = vals.map {
val pIdx = n.value.getOrElse(
report.errorAndAbort(s"Param index ${n.show} is not know")
)
extractTerm(f.asTerm) match {
case Lambda(vals, _) =>
vals(pIdx) match {
case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
}
???
}
}
private[minisql] def parseBody[X](
x: Expr[X]
)(using Quotes): Expr[Ast] = {
import quotes.reflect.*
extractTerm(x.asTerm) match {
case Lambda(vals, body) =>
astParser(body.asExpr)
case o =>
report.errorAndAbort(s"Can only parse function")
}
}
private def isNumeric(x: Expr[?])(using Quotes): Boolean = {
@ -34,18 +63,68 @@ private def isNumeric(x: Expr[?])(using Quotes): Boolean = {
case t if t <:< TypeRepr.of[java.math.BigDecimal] => true
case _ => false
}
}
private def identParser(using Quotes): Parser[ast.Ident] = {
import quotes.reflect.*
{ (x: Expr[?]) =>
extractTerm(x.asTerm) match {
case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) })
case _ => None
}
}.unlift
}
private lazy val astParser: Quotes ?=> Parser[Ast] = {
identParser.orElse(propertyParser(astParser))
}
private object IsPropertySelect {
def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = {
import quotes.reflect.*
x.asTerm match {
case Select(x, n) => Some(x.asExpr, n)
case _ => None
}
}
}
def propertyParser(
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.Property] = {
case IsPropertySelect(expr, n) =>
'{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) }
}
def optionOperationParser(
astParser: Parser[Ast]
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.OptionOperation] = {
case '{ ($x: Option[t]).isEmpty } =>
'{ ast.OptionIsEmpty(${ astParser(x) }) }
}
def binaryOperationParser(
astParser: Parser[Ast]
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.BinaryOperation] = {
???
}
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
val o = unwrapTerm(x)
println(s"Before extract ${x.show}")
println(s"After extract ${o.show}")
o
}

View file

@ -93,6 +93,13 @@ private given FromExpr[Query] with {
Some(Entity(n, ps, ren))
case '{ Entity(${ Expr(n) }, ${ Expr(ps) }) } =>
Some(Entity(n, ps, Renameable.neutral))
case '{
val x: Ast = ${ Expr(b) }
val y: Ident = ${ Expr(id) }
val z: Ast = ${ Expr(body) }
Map(x, y, z)
} =>
Some(Map(b, id, body))
case '{ Map(${ Expr(b) }, ${ Expr(id) }, ${ Expr(body) }) } =>
Some(Map(b, id, body))
case '{ Filter(${ Expr(b) }, ${ Expr(id) }, ${ Expr(body) }) } =>
@ -241,8 +248,6 @@ private def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
case o => o
}
val o = unwrapTerm(x)
println(s"From ========== ${x.show}")
println(s"To ========== ${o.show}")
o
}
@ -256,6 +261,7 @@ extension (e: Expr[Any]) {
private def fromBlock(using
Quotes
)(block: quotes.reflect.Block): Option[Ast] = {
println(s"Show block ${block.show}")
import quotes.reflect.*
val empty: Option[List[Ast]] = Some(Nil)
val stmts = block.statements.foldLeft(empty) { (r, stmt) =>
@ -266,7 +272,6 @@ private def fromBlock(using
astList :+ v
}
}
case o =>
None
}
@ -281,12 +286,7 @@ private def fromBlock(using
given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
def unapply(e: Expr[Ast])(using Quotes): Option[Ast] = {
val et = extractTerm(e.toTerm)
et match {
case b: quotes.reflect.Block => fromBlock(b).map(BetaReduction(_))
case b: quotes.reflect.Ident => Some(Ident(b.name))
case o =>
o.asExpr match {
e match {
case '{ $x: Query } => x.value
case '{ $x: ScalarValueLift } => x.value
case '{ $x: Property } => x.value
@ -301,6 +301,4 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
case o => None
}
}
}
}

View file

@ -1,5 +1,6 @@
package minisql.dsl
import minisql.parsing
import minisql.ast.{Ast, Entity, Map, Property, Ident, given}
import scala.quoted.*
import scala.compiletime.*
@ -11,31 +12,47 @@ sealed trait Dsl {
trait Query[E] extends Dsl
case class EntityQuery[E](val ast: Ast) extends Query[E]
case class EntityQuery[E](ast: Ast) extends Query[E]
extension [E](inline e: EntityQuery[E]) {
inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
transform(e.ast)(f)(Map.apply)(EntityQuery.apply[E1])
}
}
inline def mapAst[E1](inline f: Ast => Ast): EntityQuery[E1] =
EntityQuery[E1](f(e.ast))
extension [A, B](inline f1: A => B) {
private inline def param0 = parsing.parseParamAt[A, B](f1, 0)
private inline def body = parsing.parseBody(f1)
}
private inline def transform[D1 <: Dsl, D2 <: Dsl, A, B](inline ast: Ast)(
inline f: A => B
)(inline fast: (Ast, Ident, Ast) => Ast)(inline f2: Ast => D2): D2 = {
f2(fast(ast, f.param0, f.body))
}
given FromExpr[EntityQuery[?]] with {
def unapply(x: Expr[EntityQuery[?]])(using Quotes): Option[EntityQuery[?]] = {
x match {
case '{ val x: Ast = ${ Expr(ast) }; EntityQuery(x) } =>
Some(EntityQuery(ast))
case '{ EntityQuery(${ Expr(ast) }) } =>
Some(EntityQuery(ast))
case _ =>
import quotes.reflect.*
println(s"cannot unlift ${x.asTerm}")
println(s"cannot unlift ${x.show}: ${x.asTerm.getClass}")
None
}
}
}
given FromExpr[Dsl] with {
def unapply(d: Expr[Dsl])(using Quotes): Option[Dsl] = d match {
def unapply(x: Expr[Dsl])(using Quotes): Option[Dsl] = {
import quotes.reflect.*
x match {
case '{ ($x: EntityQuery[?]) } => x.value
}
}
}
inline def query[E](inline table: String) =
EntityQuery[E](Entity(table, Nil))
@ -50,3 +67,7 @@ private def compileImpl(x: Expr[Dsl])(using Quotes): Expr[Option[String]] = {
}
}
case class Foo(id: Long)
inline def queryFooId = query[Foo]("foo").map(_.id)