增加解析 case (x, y) => 函数定义

This commit is contained in:
jilen 2025-07-02 10:33:47 +08:00
parent adc60400a7
commit f5e43657b3
9 changed files with 63 additions and 53 deletions

View file

@ -3,10 +3,9 @@ name := "minisql"
scalaVersion := "3.7.1"
libraryDependencies ++= Seq(
"org.scalameta" %% "munit" % "1.0.3" % Test
"org.scalameta" %% "munit" % "1.1.1" % Test
)
javaOptions ++= Seq("-Xss16m")
scalacOptions ++= Seq(
"-deprecation",

View file

@ -1 +1 @@
sbt.version=1.10.10
sbt.version=1.11.2

View file

@ -72,9 +72,6 @@ private def joinQueryOf[E1, E2](
'{
Join($jt, $a, $b, $aliasA, $aliasB, $on)
}
case o =>
println("====================---" + o.show)
throw new Exception(s"Fail")
}
}

View file

@ -164,9 +164,6 @@ private given FromExpr[Query] with {
Some(FlatJoin(t, a, ia, on))
case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a))
case o =>
// println(s"Cannot extract ${o.show}")
None
}
}
@ -403,6 +400,20 @@ private given FromExpr[If] with {
}
}
private given FromExpr[Block] with {
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
case '{ Block(${ Expr(statements) }) } =>
Some(Block(statements))
}
}
private given FromExpr[Val] with {
def unapply(x: Expr[Val])(using Quotes): Option[Val] = x match {
case '{ Val(${ Expr(n) }, ${ Expr(b) }) } =>
Some(Val(n, b))
}
}
extension (e: Expr[Any]) {
private def toTerm(using Quotes) = {
import quotes.reflect.*
@ -411,30 +422,6 @@ extension (e: Expr[Any]) {
}
private def fromBlock(using
Quotes
)(block: quotes.reflect.Block): Option[Ast] = {
import quotes.reflect.*
val empty: Option[List[Ast]] = Some(Nil)
val stmts = block.statements.foldLeft(empty) { (r, stmt) =>
stmt match {
case ValDef(n, _, Some(body)) =>
r.flatMap { astList =>
body.asExprOf[Ast].value.map { v =>
astList :+ v
}
}
case o =>
None
}
}
stmts.flatMap { stmts =>
block.expr.asExprOf[Ast].value.map { last =>
minisql.ast.Block(stmts :+ last)
}
}
}
given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
def unapply(e: Expr[Ast])(using Quotes): Option[Ast] = {
@ -444,6 +431,7 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
case '{ $x: ScalarValueLift } => x.value
case '{ $x: Property } => x.value
case '{ $x: Ident } => x.value
case '{ $x: Val } => x.value
case '{ $x: Tuple } => x.value
case '{ $x: Value } => x.value
case '{ $x: Operation } => x.value
@ -453,6 +441,7 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
case '{ $x: Infix } => x.value
case '{ $x: CaseClass } => x.value
case '{ $x: OptionOperation } => x.value
case '{ $x: Block } => x.value
case o =>
import quotes.reflect.*
report.warning(s"Cannot get value from ${o.show}", o.asTerm.pos)

View file

@ -22,6 +22,25 @@ private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using
valDefParser
}
private[parsing] def parseBlockList(
astParser: => Parser[ast.Ast],
e: Expr[Any]
)(using Quotes): List[Expr[ast.Ast]] = {
import quotes.reflect.*
lazy val statementParser = statementParsing(astParser)
e.asTerm match {
case Block(st, t) =>
(st :+ t).map {
case e if e.isExpr => astParser(e.asExpr)
case `statementParser`(x) => x
case o =>
report.errorAndAbort(s"Cannot parse statement: ${o.show}")
}
}
}
private[parsing] def blockParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Ast] = {

View file

@ -1,12 +1 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private[parsing] def infixParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Infix] = {
import quotes.reflect.*
???
}

View file

@ -59,7 +59,6 @@ private[minisql] object Parsing {
.orElse(ifParser)
.orElse(traversableOperationParser)
.orElse(patMatchParser)
// .orElse(infixParser)
.orElse {
case o =>
val str = scala.util.Try(o.show).getOrElse("")
@ -106,8 +105,6 @@ private[minisql] object Parsing {
}
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
// lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser)
lazy val traversableOperationParser: Parser[ast.IterableOperation] =
traversableOperationParsing(astParser)

View file

@ -11,12 +11,22 @@ private[parsing] def patMatchParsing(
termParser {
// Val defs that showd pattern variables will cause error
case e @ Match(t, List(CaseDef(IsTupleUnapply(binds), None, body))) =>
val bm = binds.zipWithIndex.map {
case (Bind(n, ident), idx) =>
n -> Select.unique(t, s"_${idx + 1}")
}.toMap
blockParsing(astParser)(body.asExpr)
case e @ Match(
Ident(t),
List(CaseDef(IsTupleUnapply(binds), None, body))
) =>
val bindStmts = binds.map {
case Bind(bn, _) =>
'{
ast.Val(
ast.Ident(${ Expr(bn) }),
ast.Property(ast.Ident(${ Expr(t) }), "_1")
)
}
}
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)
'{ ast.Block(${ Expr.ofList(allStmts.toList) }) }
}
}

View file

@ -124,4 +124,14 @@ class FromExprsSuite extends FunSuite {
testFor("CaseClass") {
CaseClass(List(("name", Ident("value"))))
}
testFor("Block") { // Also tested Val
Block(
List(
Val(Ident("x"), Constant(1)),
Val(Ident("y"), Constant(2)),
BinaryOperation(Ident("x"), NumericOperator.+, Ident("y"))
)
)
}
}