unified extractTerm

This commit is contained in:
jilen 2025-06-30 19:33:39 +08:00
parent c1f26a0704
commit adc60400a7
8 changed files with 86 additions and 53 deletions

View file

@ -12,6 +12,8 @@ import minisql.ast.{
Ident, Ident,
Filter, Filter,
PropertyAlias, PropertyAlias,
JoinType,
Join,
given given
} }
import scala.quoted.* import scala.quoted.*
@ -28,6 +30,54 @@ opaque type Action[E] <: Quoted = Quoted
opaque type Insert <: Action[Long] = Quoted opaque type Insert <: Action[Long] = Quoted
sealed trait Joined[E1, E2]
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Joined {
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
new Joined[E1, E2] {}
extension [E1, E2](inline j: Joined[E1, E2]) {
inline def on(inline f: (E1, E2) => Boolean): JoinQuery[E1, E2] =
joinOn(j, f)
}
}
private inline def joinOn[E1, E2](
inline j: Joined[E1, E2],
inline f: (E1, E2) => Boolean
): JoinQuery[E1, E2] = j.toJoinQuery(f.param0, f.param1, f.body)
extension [E1, E2](inline j: Joined[E1, E2]) {
private inline def toJoinQuery(
inline aliasA: Ident,
inline aliasB: Ident,
inline on: Ast
): Ast = ${ joinQueryOf('j, 'aliasA, 'aliasB, 'on) }
}
private def joinQueryOf[E1, E2](
x: Expr[Joined[E1, E2]],
aliasA: Expr[Ident],
aliasB: Expr[Ident],
on: Expr[Ast]
)(using Quotes, Type[E1], Type[E2]): Expr[Join] = {
import quotes.reflect.*
extractTerm(x.asTerm).asExpr match {
case '{
Joined[E1, E2]($jt, $a, $b)
} =>
'{
Join($jt, $a, $b, $aliasA, $aliasB, $on)
}
case o =>
println("====================---" + o.show)
throw new Exception(s"Fail")
}
}
private inline def quotedLift[X](x: X)(using private inline def quotedLift[X](x: X)(using
e: ParamEncoder[X] e: ParamEncoder[X]
): ast.ScalarValueLift = ${ ): ast.ScalarValueLift = ${
@ -52,14 +102,15 @@ private def quotedLiftImpl[X: Type](
object Query { object Query {
private[minisql] inline def apply[E](inline ast: Ast): Query[E] = ast
extension [E](inline e: Query[E]) { extension [E](inline e: Query[E]) {
private[minisql] inline def expanded: Query[E] = { private[minisql] inline def expanded: Query[E] = {
Query(expandFields[E](e)) expandFields[E](e)
} }
inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, E1] =
Joined[E, E1](JoinType.LeftJoin, e, e1)
inline def map[E1](inline f: E => E1): Query[E1] = { inline def map[E1](inline f: E => E1): Query[E1] = {
transform(e)(f)(Map.apply) transform(e)(f)(Map.apply)
} }

View file

@ -69,9 +69,6 @@ private given FromExpr[Property] with {
) )
} => } =>
Some(Property(a, n, r, v)) Some(Property(a, n, r, v))
case o =>
println(s"Cannot extract ${o.show}")
None
} }
} }
@ -168,7 +165,7 @@ private given FromExpr[Query] with {
case '{ Nested(${ Expr(a) }) } => case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a)) Some(Nested(a))
case o => case o =>
println(s"Cannot extract ${o.show}") // println(s"Cannot extract ${o.show}")
None None
} }
} }
@ -406,23 +403,6 @@ private given FromExpr[If] with {
} }
} }
private def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
val o = unwrapTerm(x)
o
}
extension (e: Expr[Any]) { extension (e: Expr[Any]) {
private def toTerm(using Quotes) = { private def toTerm(using Quotes) = {
import quotes.reflect.* import quotes.reflect.*
@ -434,7 +414,6 @@ extension (e: Expr[Any]) {
private def fromBlock(using private def fromBlock(using
Quotes Quotes
)(block: quotes.reflect.Block): Option[Ast] = { )(block: quotes.reflect.Block): Option[Ast] = {
println(s"Show block ${block.show}")
import quotes.reflect.* import quotes.reflect.*
val empty: Option[List[Ast]] = Some(Nil) val empty: Option[List[Ast]] = Some(Nil)
val stmts = block.statements.foldLeft(empty) { (r, stmt) => val stmts = block.statements.foldLeft(empty) { (r, stmt) =>

View file

@ -3,6 +3,7 @@ package minisql.parsing
import minisql.ast import minisql.ast
import minisql.ast.Ast import minisql.ast.Ast
import scala.quoted.* import scala.quoted.*
import minisql.util.*
private[minisql] inline def parseParamAt[F]( private[minisql] inline def parseParamAt[F](
inline f: F, inline f: F,

View file

@ -9,7 +9,7 @@ import scala.annotation.tailrec
import minisql.ast.Implicits._ import minisql.ast.Implicits._
import minisql.ast.Renameable.Fixed import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.{Hidden, Visible} import minisql.ast.Visibility.{Hidden, Visible}
import minisql.util.Interleave import minisql.util.{Interleave, extractTerm}
import scala.quoted.* import scala.quoted.*
type Parser[A] = PartialFunction[Expr[Any], Expr[A]] type Parser[A] = PartialFunction[Expr[Any], Expr[A]]
@ -29,22 +29,6 @@ private def parser[A](
case e if f.isDefinedAt(e) => f(e) case e if f.isDefinedAt(e) => f(e)
} }
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
unwrapTerm(x)
}
private[minisql] object Parsing { private[minisql] object Parsing {
def parseExpr( def parseExpr(
@ -103,7 +87,7 @@ private[minisql] object Parsing {
} }
lazy val identParser: Parser[ast.Ident] = termParser { lazy val identParser: Parser[ast.Ident] = termParser {
case x @ Ident(n) if x.symbol.isValDef => case x @ Ident(n) =>
'{ ast.Ident(${ Expr(n) }) } '{ ast.Ident(${ Expr(n) }) }
} }

View file

@ -16,16 +16,7 @@ private[parsing] def patMatchParsing(
case (Bind(n, ident), idx) => case (Bind(n, ident), idx) =>
n -> Select.unique(t, s"_${idx + 1}") n -> Select.unique(t, s"_${idx + 1}")
}.toMap }.toMap
val tm = new TreeMap { blockParsing(astParser)(body.asExpr)
override def transformTerm(tree: Term)(owner: Symbol): Term = {
tree match {
case Ident(n) => bm(n)
case o => super.transformTerm(o)(owner)
}
}
}
val newBody = tm.transformTree(body)(e.symbol)
astParser(newBody.asExpr)
} }
} }

View file

@ -2,6 +2,7 @@ package minisql
package parsing package parsing
import scala.quoted._ import scala.quoted._
import minisql.util.*
private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using
Quotes Quotes

View file

@ -22,3 +22,20 @@ private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = {
val fileName = pos.sourceFile.name val fileName = pos.sourceFile.name
s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}" s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}"
} }
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
val o = unwrapTerm(x)
o
}

View file

@ -34,4 +34,13 @@ class QuotedSuite extends munit.FunSuite {
"INSERT INTO foo (id,name) VALUES (?, ?)" "INSERT INTO foo (id,name) VALUES (?, ?)"
) )
} }
test("LeftJoin") {
val o = testContext
.io(Foos.leftJoin(Foos).on((f1, f2) => f1.id == f2.id).map {
case (f1, f2) => (f1.id, f2.id)
})
println(o)
}
} }