From 59f969a232397412a357b20cd8da243b040e86a9 Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 18 Dec 2024 16:09:08 +0800 Subject: [PATCH] test simple quoted ast --- README.md | 2 +- src/main/scala/minisql/Quoted.scala | 15 + src/main/scala/minisql/ast/FromExprs.scala | 3 +- .../scala/minisql/idiom/MirrorIdiom.scala | 355 ++++++++++++++++++ .../minisql/idiom/StatementInterpolator.scala | 16 +- src/main/scala/minisql/parsing/Parser.scala | 2 +- .../scala/minisql/parsing/QuotedSuite.scala | 35 ++ 7 files changed, 420 insertions(+), 8 deletions(-) create mode 100644 src/main/scala/minisql/idiom/MirrorIdiom.scala create mode 100644 src/test/scala/minisql/parsing/QuotedSuite.scala diff --git a/README.md b/README.md index de83bb9..6f82dea 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ 大部分场景不用在 `macro` 对 Ast 进行复杂模式匹配来分析代码。 -## 核心思路 使用 inline 和 `FromExpr` 代替大部分 parsing 工作 +## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作 `FromExpr` 是 `scala3` 内置的 typeclass,用来获取编译期值 。 diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index c7c9c7f..da7008e 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -52,6 +52,21 @@ def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() class NonQuotedException extends Exception("Cannot be used at runtime") +private[minisql] inline def compileTimeAst(inline q: Quoted): Option[String] = + ${ + compileTimeAstImpl('q) + } + +private def compileTimeAstImpl(e: Expr[Quoted])(using + Quotes +): Expr[Option[String]] = { + import quotes.reflect.* + e.value match { + case Some(v) => '{ Some(${ Expr(v.toString()) }) } + case None => '{ None } + } +} + private[minisql] inline def compile[I <: Idiom, N <: NamingStrategy]( inline q: Quoted, inline idiom: I, diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 26694fb..3541a76 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -130,7 +130,7 @@ private given FromExpr[Query] with { case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } => Some(SortBy(b, p, s, o)) case o => - println(s"Cannot extract ${o.show}") + println(s"Cannot extract ${o}") None } } @@ -146,6 +146,7 @@ private given FromExpr[BinaryOperator] with { case '{ NumericOperator.- } => Some(NumericOperator.-) case '{ NumericOperator.* } => Some(NumericOperator.*) case '{ NumericOperator./ } => Some(NumericOperator./) + case '{ NumericOperator.> } => Some(NumericOperator.>) case '{ StringOperator.split } => Some(StringOperator.split) case '{ StringOperator.startsWith } => Some(StringOperator.startsWith) case '{ StringOperator.concat } => Some(StringOperator.concat) diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala new file mode 100644 index 0000000..88aab8c --- /dev/null +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -0,0 +1,355 @@ +package minisql + +import minisql.ast.Renameable.{ByStrategy, Fixed} +import minisql.ast.Visibility.Hidden +import minisql.ast._ +import minisql.context.CanReturnClause +import minisql.idiom.{Idiom, SetContainsToken, Statement} +import minisql.idiom.StatementInterpolator.* +import minisql.norm.Normalize +import minisql.util.Interleave + +object MirrorIdiom extends MirrorIdiom +class MirrorIdiom extends MirrorIdiomBase with CanReturnClause + +object MirrorIdiomPrinting extends MirrorIdiom { + override def distinguishHidden: Boolean = true +} + +trait MirrorIdiomBase extends Idiom { + + def distinguishHidden: Boolean = false + + override def prepareForProbing(string: String) = string + + override def liftingPlaceholder(index: Int): String = "?" + + override def translate( + ast: Ast + )(implicit naming: NamingStrategy): (Ast, Statement) = { + val normalizedAst = Normalize(ast) + (normalizedAst, stmt"${normalizedAst.token}") + } + + implicit def astTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Ast] = Tokenizer[Ast] { + case ast: Query => ast.token + case ast: Function => ast.token + case ast: Value => ast.token + case ast: Operation => ast.token + case ast: Action => ast.token + case ast: Ident => ast.token + case ast: ExternalIdent => ast.token + case ast: Property => ast.token + case ast: Infix => ast.token + case ast: OptionOperation => ast.token + case ast: IterableOperation => ast.token + case ast: Dynamic => ast.token + case ast: If => ast.token + case ast: Block => ast.token + case ast: Val => ast.token + case ast: Ordering => ast.token + case ast: Lift => ast.token + case ast: Assignment => ast.token + case ast: OnConflict.Excluded => ast.token + case ast: OnConflict.Existing => ast.token + } + + implicit def ifTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[If] = Tokenizer[If] { + case If(a, b, c) => stmt"if(${a.token}) ${b.token} else ${c.token}" + } + + implicit val dynamicTokenizer: Tokenizer[Dynamic] = Tokenizer[Dynamic] { + case Dynamic(tree) => stmt"${tree.toString.token}" + } + + implicit def blockTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Block] = Tokenizer[Block] { + case Block(statements) => stmt"{ ${statements.map(_.token).mkStmt("; ")} }" + } + + implicit def valTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Val] = Tokenizer[Val] { + case Val(name, body) => stmt"val ${name.token} = ${body.token}" + } + + implicit def queryTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Query] = Tokenizer[Query] { + + case Entity.Opinionated(name, Nil, renameable) => + stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token})" + + case Entity.Opinionated(name, prop, renameable) => + val properties = + prop.map(p => stmt"""_.${p.path.mkStmt(".")} -> "${p.alias.token}"""") + stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token}, ${properties.token})" + + case Filter(source, alias, body) => + stmt"${source.token}.filter(${alias.token} => ${body.token})" + + case Map(source, alias, body) => + stmt"${source.token}.map(${alias.token} => ${body.token})" + + case FlatMap(source, alias, body) => + stmt"${source.token}.flatMap(${alias.token} => ${body.token})" + + case ConcatMap(source, alias, body) => + stmt"${source.token}.concatMap(${alias.token} => ${body.token})" + + case SortBy(source, alias, body, ordering) => + stmt"${source.token}.sortBy(${alias.token} => ${body.token})(${ordering.token})" + + case GroupBy(source, alias, body) => + stmt"${source.token}.groupBy(${alias.token} => ${body.token})" + + case Aggregation(op, ast) => + stmt"${scopedTokenizer(ast)}.${op.token}" + + case Take(source, n) => + stmt"${source.token}.take(${n.token})" + + case Drop(source, n) => + stmt"${source.token}.drop(${n.token})" + + case Union(a, b) => + stmt"${a.token}.union(${b.token})" + + case UnionAll(a, b) => + stmt"${a.token}.unionAll(${b.token})" + + case Join(t, a, b, iA, iB, on) => + stmt"${a.token}.${t.token}(${b.token}).on((${iA.token}, ${iB.token}) => ${on.token})" + + case FlatJoin(t, a, iA, on) => + stmt"${a.token}.${t.token}((${iA.token}) => ${on.token})" + + case Distinct(a) => + stmt"${a.token}.distinct" + + case DistinctOn(source, alias, body) => + stmt"${source.token}.distinctOn(${alias.token} => ${body.token})" + + case Nested(a) => + stmt"${a.token}.nested" + } + + implicit val orderingTokenizer: Tokenizer[Ordering] = Tokenizer[Ordering] { + case TupleOrdering(elems) => stmt"Ord(${elems.token})" + case Asc => stmt"Ord.asc" + case Desc => stmt"Ord.desc" + case AscNullsFirst => stmt"Ord.ascNullsFirst" + case DescNullsFirst => stmt"Ord.descNullsFirst" + case AscNullsLast => stmt"Ord.ascNullsLast" + case DescNullsLast => stmt"Ord.descNullsLast" + } + + implicit def optionOperationTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] { + case OptionTableFlatMap(ast, alias, body) => + stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})" + case OptionTableMap(ast, alias, body) => + stmt"${ast.token}.map((${alias.token}) => ${body.token})" + case OptionTableExists(ast, alias, body) => + stmt"${ast.token}.exists((${alias.token}) => ${body.token})" + case OptionTableForall(ast, alias, body) => + stmt"${ast.token}.forall((${alias.token}) => ${body.token})" + case OptionFlatten(ast) => stmt"${ast.token}.flatten" + case OptionGetOrElse(ast, body) => + stmt"${ast.token}.getOrElse(${body.token})" + case OptionFlatMap(ast, alias, body) => + stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})" + case OptionMap(ast, alias, body) => + stmt"${ast.token}.map((${alias.token}) => ${body.token})" + case OptionForall(ast, alias, body) => + stmt"${ast.token}.forall((${alias.token}) => ${body.token})" + case OptionExists(ast, alias, body) => + stmt"${ast.token}.exists((${alias.token}) => ${body.token})" + case OptionContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + case OptionIsEmpty(ast) => stmt"${ast.token}.isEmpty" + case OptionNonEmpty(ast) => stmt"${ast.token}.nonEmpty" + case OptionIsDefined(ast) => stmt"${ast.token}.isDefined" + case OptionSome(ast) => stmt"Some(${ast.token})" + case OptionApply(ast) => stmt"Option(${ast.token})" + case OptionOrNull(ast) => stmt"${ast.token}.orNull" + case OptionGetOrNull(ast) => stmt"${ast.token}.getOrNull" + case OptionNone => stmt"None" + } + + implicit def traversableOperationTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[IterableOperation] = Tokenizer[IterableOperation] { + case MapContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + case SetContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + case ListContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + } + + implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { + case InnerJoin => stmt"join" + case LeftJoin => stmt"leftJoin" + case RightJoin => stmt"rightJoin" + case FullJoin => stmt"fullJoin" + } + + implicit def functionTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Function] = Tokenizer[Function] { + case Function(params, body) => stmt"(${params.token}) => ${body.token}" + } + + implicit def operationTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Operation] = Tokenizer[Operation] { + case UnaryOperation(op: PrefixUnaryOperator, ast) => + stmt"${op.token}${scopedTokenizer(ast)}" + case UnaryOperation(op: PostfixUnaryOperator, ast) => + stmt"${scopedTokenizer(ast)}.${op.token}" + case BinaryOperation(a, op @ SetOperator.`contains`, b) => + SetContainsToken(scopedTokenizer(b), op.token, a.token) + case BinaryOperation(a, op, b) => + stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}" + case FunctionApply(function, values) => + stmt"${scopedTokenizer(function)}.apply(${values.token})" + } + + implicit def operatorTokenizer[T <: Operator]: Tokenizer[T] = Tokenizer[T] { + case o => stmt"${o.toString.token}" + } + + def tokenizeName(name: String, renameable: Renameable) = + renameable match { + case ByStrategy => name + case Fixed => s"`${name}`" + } + + def bracketIfHidden(name: String, visibility: Visibility) = + (distinguishHidden, visibility) match { + case (true, Hidden) => s"[$name]" + case _ => name + } + + implicit def propertyTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Property] = Tokenizer[Property] { + case Property.Opinionated(ExternalIdent(_), name, renameable, visibility) => + stmt"${bracketIfHidden(tokenizeName(name, renameable), visibility).token}" + case Property.Opinionated(ref, name, renameable, visibility) => + stmt"${scopedTokenizer(ref)}.${bracketIfHidden(tokenizeName(name, renameable), visibility).token}" + } + + implicit val valueTokenizer: Tokenizer[Value] = Tokenizer[Value] { + case Constant(v: String) => stmt""""${v.token}"""" + case Constant(()) => stmt"{}" + case Constant(v) => stmt"${v.toString.token}" + case NullValue => stmt"null" + case Tuple(values) => stmt"(${values.token})" + case CaseClass(values) => + stmt"CaseClass(${values.map { case (k, v) => s"${k.token}: ${v.token}" }.mkString(", ").token})" + } + + implicit val identTokenizer: Tokenizer[Ident] = Tokenizer[Ident] { + case Ident.Opinionated(name, visibility) => + stmt"${bracketIfHidden(name, visibility).token}" + } + + implicit val typeTokenizer: Tokenizer[ExternalIdent] = + Tokenizer[ExternalIdent] { + case e => stmt"${e.name.token}" + } + + implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] = + Tokenizer[OnConflict.Excluded] { + case OnConflict.Excluded(ident) => stmt"${ident.token}" + } + + implicit val existingTokenizer: Tokenizer[OnConflict.Existing] = + Tokenizer[OnConflict.Existing] { + case OnConflict.Existing(ident) => stmt"${ident.token}" + } + + implicit def actionTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Action] = Tokenizer[Action] { + case Update(query, assignments) => + stmt"${query.token}.update(${assignments.token})" + case Insert(query, assignments) => + stmt"${query.token}.insert(${assignments.token})" + case Delete(query) => stmt"${query.token}.delete" + case Returning(query, alias, body) => + stmt"${query.token}.returning((${alias.token}) => ${body.token})" + case ReturningGenerated(query, alias, body) => + stmt"${query.token}.returningGenerated((${alias.token}) => ${body.token})" + case Foreach(query, alias, body) => + stmt"${query.token}.foreach((${alias.token}) => ${body.token})" + case c: OnConflict => stmt"${c.token}" + } + + implicit def conflictTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[OnConflict] = { + + def targetProps(l: List[Property]) = l.map(p => + Transform(p) { + case Ident(_) => Ident("_") + } + ) + + implicit val conflictTargetTokenizer: Tokenizer[OnConflict.Target] = + Tokenizer[OnConflict.Target] { + case OnConflict.NoTarget => stmt"" + case OnConflict.Properties(props) => + val listTokens = listTokenizer(astTokenizer).token(props) + stmt"(${listTokens})" + } + + val updateAssignsTokenizer = Tokenizer[Assignment] { + case Assignment(i, p, v) => + stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}" + } + + Tokenizer[OnConflict] { + case OnConflict(i, t, OnConflict.Update(assign)) => + stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})" + case OnConflict(i, t, OnConflict.Ignore) => + stmt"${i.token}.onConflictIgnore${t.token}" + } + } + + implicit def assignmentTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Assignment] = Tokenizer[Assignment] { + case Assignment(ident, property, value) => + stmt"${ident.token} => ${property.token} -> ${value.token}" + } + + implicit def infixTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Infix] = Tokenizer[Infix] { + case Infix(parts, params, _, _) => + def tokenParam(ast: Ast) = + ast match { + case ast: Ident => stmt"$$${ast.token}" + case other => stmt"$${${ast.token}}" + } + + val pt = parts.map(_.token) + val pr = params.map(tokenParam) + val body = Statement(Interleave(pt, pr)) + stmt"""infix"${body.token}"""" + } + + private def scopedTokenizer( + ast: Ast + )(implicit liftTokenizer: Tokenizer[Lift]) = + ast match { + case _: Function => stmt"(${ast.token})" + case _: BinaryOperation => stmt"(${ast.token})" + case other => ast.token + } +} diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala index b732da1..3aa4d26 100644 --- a/src/main/scala/minisql/idiom/StatementInterpolator.scala +++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala @@ -9,19 +9,25 @@ import scala.collection.mutable.ListBuffer object StatementInterpolator { trait Tokenizer[T] { - def token(v: T): Token + extension (v: T) { + def token: Token + } } object Tokenizer { - def apply[T](f: T => Token) = new Tokenizer[T] { - def token(v: T) = f(v) + def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] { + extension (v: T) { + def token: Token = f(v) + } } def withFallback[T]( fallback: Tokenizer[T] => Tokenizer[T] )(pf: PartialFunction[T, Token]) = new Tokenizer[T] { - private val stable = fallback(this) - override def token(v: T) = pf.applyOrElse(v, stable.token) + extension (v: T) { + private def stable = fallback(this) + override def token = pf.applyOrElse(v, stable.token) + } } } diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala index 4bc7265..91bfbc0 100644 --- a/src/main/scala/minisql/parsing/Parser.scala +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -38,7 +38,7 @@ private[minisql] def parseBody[X]( x: Expr[X] )(using Quotes): Expr[Ast] = { import quotes.reflect.* - x.asTerm match { + extractTerm(x.asTerm) match { case Lambda(vals, body) => Parsing.parseExpr(body.asExpr) case o => diff --git a/src/test/scala/minisql/parsing/QuotedSuite.scala b/src/test/scala/minisql/parsing/QuotedSuite.scala new file mode 100644 index 0000000..2bd02eb --- /dev/null +++ b/src/test/scala/minisql/parsing/QuotedSuite.scala @@ -0,0 +1,35 @@ +package minisql + +import minisql.ast.* + +class QuotedSuite extends munit.FunSuite { + private inline def testQuoted(label: String)( + inline x: Quoted, + expect: Ast + ) = test(label) { + assertEquals(compileTimeAst(x), Some(expect.toString())) + } + + case class Foo(id: Long) + + inline def Foos = query[Foo]("foo") + val entityFoo = Entity("foo", Nil) + val idx = Ident("x") + + testQuoted("EntityQuery")(Foos, entityFoo) + + testQuoted("Query/filter")( + Foos.filter(x => x.id > 0), + Filter( + entityFoo, + idx, + BinaryOperation(Property(idx, "id"), NumericOperator.>, Constant(0)) + ) + ) + + testQuoted("Query/map")( + Foos.map(x => x.id), + Map(entityFoo, idx, Property(idx, "id")) + ) + +}