unified extractTerm

This commit is contained in:
jilen 2025-06-30 19:33:39 +08:00
parent c1f26a0704
commit adc60400a7
8 changed files with 86 additions and 53 deletions

View file

@ -12,6 +12,8 @@ import minisql.ast.{
Ident,
Filter,
PropertyAlias,
JoinType,
Join,
given
}
import scala.quoted.*
@ -28,6 +30,54 @@ opaque type Action[E] <: Quoted = Quoted
opaque type Insert <: Action[Long] = Quoted
sealed trait Joined[E1, E2]
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Joined {
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
new Joined[E1, E2] {}
extension [E1, E2](inline j: Joined[E1, E2]) {
inline def on(inline f: (E1, E2) => Boolean): JoinQuery[E1, E2] =
joinOn(j, f)
}
}
private inline def joinOn[E1, E2](
inline j: Joined[E1, E2],
inline f: (E1, E2) => Boolean
): JoinQuery[E1, E2] = j.toJoinQuery(f.param0, f.param1, f.body)
extension [E1, E2](inline j: Joined[E1, E2]) {
private inline def toJoinQuery(
inline aliasA: Ident,
inline aliasB: Ident,
inline on: Ast
): Ast = ${ joinQueryOf('j, 'aliasA, 'aliasB, 'on) }
}
private def joinQueryOf[E1, E2](
x: Expr[Joined[E1, E2]],
aliasA: Expr[Ident],
aliasB: Expr[Ident],
on: Expr[Ast]
)(using Quotes, Type[E1], Type[E2]): Expr[Join] = {
import quotes.reflect.*
extractTerm(x.asTerm).asExpr match {
case '{
Joined[E1, E2]($jt, $a, $b)
} =>
'{
Join($jt, $a, $b, $aliasA, $aliasB, $on)
}
case o =>
println("====================---" + o.show)
throw new Exception(s"Fail")
}
}
private inline def quotedLift[X](x: X)(using
e: ParamEncoder[X]
): ast.ScalarValueLift = ${
@ -52,14 +102,15 @@ private def quotedLiftImpl[X: Type](
object Query {
private[minisql] inline def apply[E](inline ast: Ast): Query[E] = ast
extension [E](inline e: Query[E]) {
private[minisql] inline def expanded: Query[E] = {
Query(expandFields[E](e))
expandFields[E](e)
}
inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, E1] =
Joined[E, E1](JoinType.LeftJoin, e, e1)
inline def map[E1](inline f: E => E1): Query[E1] = {
transform(e)(f)(Map.apply)
}

View file

@ -69,9 +69,6 @@ private given FromExpr[Property] with {
)
} =>
Some(Property(a, n, r, v))
case o =>
println(s"Cannot extract ${o.show}")
None
}
}
@ -168,7 +165,7 @@ private given FromExpr[Query] with {
case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a))
case o =>
println(s"Cannot extract ${o.show}")
// println(s"Cannot extract ${o.show}")
None
}
}
@ -406,23 +403,6 @@ private given FromExpr[If] with {
}
}
private def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
val o = unwrapTerm(x)
o
}
extension (e: Expr[Any]) {
private def toTerm(using Quotes) = {
import quotes.reflect.*
@ -434,7 +414,6 @@ extension (e: Expr[Any]) {
private def fromBlock(using
Quotes
)(block: quotes.reflect.Block): Option[Ast] = {
println(s"Show block ${block.show}")
import quotes.reflect.*
val empty: Option[List[Ast]] = Some(Nil)
val stmts = block.statements.foldLeft(empty) { (r, stmt) =>

View file

@ -3,6 +3,7 @@ package minisql.parsing
import minisql.ast
import minisql.ast.Ast
import scala.quoted.*
import minisql.util.*
private[minisql] inline def parseParamAt[F](
inline f: F,

View file

@ -9,7 +9,7 @@ import scala.annotation.tailrec
import minisql.ast.Implicits._
import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.{Hidden, Visible}
import minisql.util.Interleave
import minisql.util.{Interleave, extractTerm}
import scala.quoted.*
type Parser[A] = PartialFunction[Expr[Any], Expr[A]]
@ -29,22 +29,6 @@ private def parser[A](
case e if f.isDefinedAt(e) => f(e)
}
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
unwrapTerm(x)
}
private[minisql] object Parsing {
def parseExpr(
@ -103,7 +87,7 @@ private[minisql] object Parsing {
}
lazy val identParser: Parser[ast.Ident] = termParser {
case x @ Ident(n) if x.symbol.isValDef =>
case x @ Ident(n) =>
'{ ast.Ident(${ Expr(n) }) }
}

View file

@ -16,16 +16,7 @@ private[parsing] def patMatchParsing(
case (Bind(n, ident), idx) =>
n -> Select.unique(t, s"_${idx + 1}")
}.toMap
val tm = new TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term = {
tree match {
case Ident(n) => bm(n)
case o => super.transformTerm(o)(owner)
}
}
}
val newBody = tm.transformTree(body)(e.symbol)
astParser(newBody.asExpr)
blockParsing(astParser)(body.asExpr)
}
}

View file

@ -2,6 +2,7 @@ package minisql
package parsing
import scala.quoted._
import minisql.util.*
private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using
Quotes

View file

@ -22,3 +22,20 @@ private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = {
val fileName = pos.sourceFile.name
s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}"
}
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
val o = unwrapTerm(x)
o
}

View file

@ -34,4 +34,13 @@ class QuotedSuite extends munit.FunSuite {
"INSERT INTO foo (id,name) VALUES (?, ?)"
)
}
test("LeftJoin") {
val o = testContext
.io(Foos.leftJoin(Foos).on((f1, f2) => f1.id == f2.id).map {
case (f1, f2) => (f1.id, f2.id)
})
println(o)
}
}