From 8103d4517890be42196eac54d012377251e39c00 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 15 Dec 2024 20:51:38 +0800 Subject: [PATCH 01/26] try parsing function body --- src/main/scala/minisql/Parser.scala | 127 ----- src/main/scala/minisql/ReturnAction.scala | 7 + src/main/scala/minisql/ast/Ast.scala | 13 +- src/main/scala/minisql/ast/AstOps.scala | 94 ++++ src/main/scala/minisql/ast/FromExprs.scala | 5 +- .../context/ReturnFieldCapability.scala | 64 +++ src/main/scala/minisql/dsl.scala | 18 +- src/main/scala/minisql/idiom/Idiom.scala | 23 + src/main/scala/minisql/idiom/LoadNaming.scala | 28 + .../scala/minisql/idiom/ReifyStatement.scala | 68 +++ src/main/scala/minisql/idiom/Statement.scala | 47 ++ .../minisql/idiom/StatementInterpolator.scala | 146 ++++++ .../scala/minisql/norm/AdHocReduction.scala | 52 ++ src/main/scala/minisql/norm/ApplyMap.scala | 160 ++++++ .../scala/minisql/norm/AttachToEntity.scala | 48 ++ .../{util => norm}/BetaReduction.scala | 4 +- .../scala/minisql/norm/ConcatBehavior.scala | 7 + .../scala/minisql/norm/EqualityBehavior.scala | 7 + .../scala/minisql/norm/ExpandReturning.scala | 74 +++ .../minisql/norm/FlattenOptionOperation.scala | 108 ++++ .../minisql/norm/NestImpureMappedInfix.scala | 76 +++ src/main/scala/minisql/norm/Normalize.scala | 51 ++ .../norm/NormalizeAggregationIdent.scala | 29 ++ .../norm/NormalizeNestedStructures.scala | 47 ++ .../minisql/norm/NormalizeReturning.scala | 154 ++++++ src/main/scala/minisql/norm/OrderTerms.scala | 29 ++ .../scala/minisql/norm/RenameProperties.scala | 491 ++++++++++++++++++ .../minisql/{util => norm}/Replacements.scala | 2 +- .../minisql/norm/SimplifyNullChecks.scala | 124 +++++ .../minisql/norm/SymbolicReduction.scala | 38 ++ .../norm/capture/AvoidAliasConflict.scala | 174 +++++++ .../minisql/norm/capture/AvoidCapture.scala | 9 + .../scala/minisql/norm/capture/Dealias.scala | 72 +++ .../capture/DemarcateExternalAliases.scala | 98 ++++ .../scala/minisql/parsing/BlockParsing.scala | 47 ++ .../scala/minisql/parsing/BoxingParsing.scala | 31 ++ .../scala/minisql/parsing/InfixParsing.scala | 13 + .../scala/minisql/parsing/LiftParsing.scala | 16 + .../minisql/parsing/OperationParsing.scala | 113 ++++ src/main/scala/minisql/parsing/Parser.scala | 47 ++ src/main/scala/minisql/parsing/Parsing.scala | 139 +++++ .../minisql/parsing/PatMatchParsing.scala | 49 ++ .../minisql/parsing/PropertyParsing.scala | 30 ++ .../parsing/TraversableOperationParsing.scala | 16 + .../scala/minisql/parsing/ValueParsing.scala | 71 +++ .../scala/minisql/util/Interpolator.scala | 15 +- src/main/scala/minisql/util/Message.scala | 75 +++ 47 files changed, 3000 insertions(+), 156 deletions(-) delete mode 100644 src/main/scala/minisql/Parser.scala create mode 100644 src/main/scala/minisql/ReturnAction.scala create mode 100644 src/main/scala/minisql/ast/AstOps.scala create mode 100644 src/main/scala/minisql/context/ReturnFieldCapability.scala create mode 100644 src/main/scala/minisql/idiom/Idiom.scala create mode 100644 src/main/scala/minisql/idiom/LoadNaming.scala create mode 100644 src/main/scala/minisql/idiom/ReifyStatement.scala create mode 100644 src/main/scala/minisql/idiom/Statement.scala create mode 100644 src/main/scala/minisql/idiom/StatementInterpolator.scala create mode 100644 src/main/scala/minisql/norm/AdHocReduction.scala create mode 100644 src/main/scala/minisql/norm/ApplyMap.scala create mode 100644 src/main/scala/minisql/norm/AttachToEntity.scala rename src/main/scala/minisql/{util => norm}/BetaReduction.scala (99%) create mode 100644 src/main/scala/minisql/norm/ConcatBehavior.scala create mode 100644 src/main/scala/minisql/norm/EqualityBehavior.scala create mode 100644 src/main/scala/minisql/norm/ExpandReturning.scala create mode 100644 src/main/scala/minisql/norm/FlattenOptionOperation.scala create mode 100644 src/main/scala/minisql/norm/NestImpureMappedInfix.scala create mode 100644 src/main/scala/minisql/norm/Normalize.scala create mode 100644 src/main/scala/minisql/norm/NormalizeAggregationIdent.scala create mode 100644 src/main/scala/minisql/norm/NormalizeNestedStructures.scala create mode 100644 src/main/scala/minisql/norm/NormalizeReturning.scala create mode 100644 src/main/scala/minisql/norm/OrderTerms.scala create mode 100644 src/main/scala/minisql/norm/RenameProperties.scala rename src/main/scala/minisql/{util => norm}/Replacements.scala (98%) create mode 100644 src/main/scala/minisql/norm/SimplifyNullChecks.scala create mode 100644 src/main/scala/minisql/norm/SymbolicReduction.scala create mode 100644 src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala create mode 100644 src/main/scala/minisql/norm/capture/AvoidCapture.scala create mode 100644 src/main/scala/minisql/norm/capture/Dealias.scala create mode 100644 src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala create mode 100644 src/main/scala/minisql/parsing/BlockParsing.scala create mode 100644 src/main/scala/minisql/parsing/BoxingParsing.scala create mode 100644 src/main/scala/minisql/parsing/InfixParsing.scala create mode 100644 src/main/scala/minisql/parsing/LiftParsing.scala create mode 100644 src/main/scala/minisql/parsing/OperationParsing.scala create mode 100644 src/main/scala/minisql/parsing/Parser.scala create mode 100644 src/main/scala/minisql/parsing/Parsing.scala create mode 100644 src/main/scala/minisql/parsing/PatMatchParsing.scala create mode 100644 src/main/scala/minisql/parsing/PropertyParsing.scala create mode 100644 src/main/scala/minisql/parsing/TraversableOperationParsing.scala create mode 100644 src/main/scala/minisql/parsing/ValueParsing.scala create mode 100644 src/main/scala/minisql/util/Message.scala diff --git a/src/main/scala/minisql/Parser.scala b/src/main/scala/minisql/Parser.scala deleted file mode 100644 index 3b21d90..0000000 --- a/src/main/scala/minisql/Parser.scala +++ /dev/null @@ -1,127 +0,0 @@ -package minisql.parsing - -import minisql.ast -import minisql.ast.Ast -import scala.quoted.* - -type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]] - -private[minisql] inline def parseParamAt[F]( - inline f: F, - inline n: Int -): ast.Ident = ${ - parseParamAt('f, 'n) -} - -private[minisql] inline def parseBody[X]( - inline f: X -): ast.Ast = ${ - parseBody('f) -} - -private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using - Quotes -): Expr[ast.Ident] = { - - import quotes.reflect.* - - val pIdx = n.value.getOrElse( - report.errorAndAbort(s"Param index ${n.show} is not know") - ) - extractTerm(f.asTerm) match { - case Lambda(vals, _) => - vals(pIdx) match { - case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) } - } - } -} - -private[minisql] def parseBody[X]( - x: Expr[X] -)(using Quotes): Expr[Ast] = { - import quotes.reflect.* - extractTerm(x.asTerm) match { - case Lambda(vals, body) => - astParser(body.asExpr) - case o => - report.errorAndAbort(s"Can only parse function") - } -} -private def isNumeric(x: Expr[?])(using Quotes): Boolean = { - import quotes.reflect.* - x.asTerm.tpe match { - case t if t <:< TypeRepr.of[Int] => true - case t if t <:< TypeRepr.of[Long] => true - case t if t <:< TypeRepr.of[Float] => true - case t if t <:< TypeRepr.of[Double] => true - case t if t <:< TypeRepr.of[BigDecimal] => true - case t if t <:< TypeRepr.of[BigInt] => true - case t if t <:< TypeRepr.of[java.lang.Integer] => true - case t if t <:< TypeRepr.of[java.lang.Long] => true - case t if t <:< TypeRepr.of[java.lang.Float] => true - case t if t <:< TypeRepr.of[java.lang.Double] => true - case t if t <:< TypeRepr.of[java.math.BigDecimal] => true - case _ => false - } -} - -private def identParser(using Quotes): Parser[ast.Ident] = { - import quotes.reflect.* - { (x: Expr[?]) => - extractTerm(x.asTerm) match { - case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) }) - case _ => None - } - }.unlift - -} - -private lazy val astParser: Quotes ?=> Parser[Ast] = { - identParser.orElse(propertyParser(astParser)) -} - -private object IsPropertySelect { - def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = { - import quotes.reflect.* - x.asTerm match { - case Select(x, n) => Some(x.asExpr, n) - case _ => None - } - } -} - -def propertyParser( - astParser: => Parser[Ast] -)(using Quotes): Parser[ast.Property] = { - case IsPropertySelect(expr, n) => - '{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) } -} - -def optionOperationParser( - astParser: => Parser[Ast] -)(using Quotes): Parser[ast.OptionOperation] = { - case '{ ($x: Option[t]).isEmpty } => - '{ ast.OptionIsEmpty(${ astParser(x) }) } -} - -def binaryOperationParser( - astParser: => Parser[Ast] -)(using Quotes): Parser[ast.BinaryOperation] = { - ??? -} - -private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { - import quotes.reflect.* - def unwrapTerm(t: Term): Term = t match { - case Inlined(_, _, o) => unwrapTerm(o) - case Block(Nil, last) => last - case Typed(t, _) => - unwrapTerm(t) - case Select(t, "$asInstanceOf$") => - unwrapTerm(t) - case TypeApply(t, _) => - unwrapTerm(t) - case o => o - } - unwrapTerm(x) -} diff --git a/src/main/scala/minisql/ReturnAction.scala b/src/main/scala/minisql/ReturnAction.scala new file mode 100644 index 0000000..63df7a0 --- /dev/null +++ b/src/main/scala/minisql/ReturnAction.scala @@ -0,0 +1,7 @@ +package minisql + +enum ReturnAction { + case ReturnNothing + case ReturnColumns(columns: List[String]) + case ReturnRecord +} diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 35feb70..48e3fae 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -1,6 +1,7 @@ package minisql.ast import minisql.NamingStrategy +import minisql.ParamEncoder import scala.quoted.* @@ -378,14 +379,20 @@ sealed trait ScalarLift extends Lift case class ScalarValueLift( name: String, - liftId: String + liftId: String, + value: Option[(Any, ParamEncoder[?])] ) extends ScalarLift +case class ScalarQueryLift( + name: String, + liftId: String +) extends ScalarLift {} + object ScalarLift { given ToExpr[ScalarLift] with { def apply(l: ScalarLift)(using Quotes) = l match { - case ScalarValueLift(n, id) => - '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } + case ScalarValueLift(n, id, v) => + '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, None) } } } } diff --git a/src/main/scala/minisql/ast/AstOps.scala b/src/main/scala/minisql/ast/AstOps.scala new file mode 100644 index 0000000..e0069ae --- /dev/null +++ b/src/main/scala/minisql/ast/AstOps.scala @@ -0,0 +1,94 @@ +package minisql.ast +object Implicits { + implicit class AstOps(body: Ast) { + private[minisql] def +||+(other: Ast) = + BinaryOperation(body, BooleanOperator.`||`, other) + private[minisql] def +&&+(other: Ast) = + BinaryOperation(body, BooleanOperator.`&&`, other) + private[minisql] def +==+(other: Ast) = + BinaryOperation(body, EqualityOperator.`==`, other) + private[minisql] def +!=+(other: Ast) = + BinaryOperation(body, EqualityOperator.`!=`, other) + } +} + +object +||+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, BooleanOperator.`||`, two) => Some((one, two)) + case _ => None + } + } +} + +object +&&+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, BooleanOperator.`&&`, two) => Some((one, two)) + case _ => None + } + } +} + +val EqOp = EqualityOperator.`==` +val NeqOp = EqualityOperator.`!=` + +object +==+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, EqOp, two) => Some((one, two)) + case _ => None + } + } +} + +object +!=+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, NeqOp, two) => Some((one, two)) + case _ => None + } + } +} + +object IsNotNullCheck { + def apply(ast: Ast) = BinaryOperation(ast, EqualityOperator.`!=`, NullValue) + + def unapply(ast: Ast): Option[Ast] = { + ast match { + case BinaryOperation(cond, NeqOp, NullValue) => Some(cond) + case _ => None + } + } +} + +object IsNullCheck { + def apply(ast: Ast) = BinaryOperation(ast, EqOp, NullValue) + + def unapply(ast: Ast): Option[Ast] = { + ast match { + case BinaryOperation(cond, EqOp, NullValue) => Some(cond) + case _ => None + } + } +} + +object IfExistElseNull { + def apply(exists: Ast, `then`: Ast) = + If(IsNotNullCheck(exists), `then`, NullValue) + + def unapply(ast: Ast) = ast match { + case If(IsNotNullCheck(exists), t, NullValue) => Some((exists, t)) + case _ => None + } +} + +object IfExist { + def apply(exists: Ast, `then`: Ast, otherwise: Ast) = + If(IsNotNullCheck(exists), `then`, otherwise) + + def unapply(ast: Ast) = ast match { + case If(IsNotNullCheck(exists), t, e) => Some((exists, t, e)) + case _ => None + } +} diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index b66e4b9..26694fb 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -45,8 +45,9 @@ private given FromExpr[Infix] with { private given FromExpr[ScalarValueLift] with { def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] = x match { - case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } => - Some(ScalarValueLift(n, id)) + case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } => + // don't cared about value here, a little tricky + Some(ScalarValueLift(n, id, null)) } } diff --git a/src/main/scala/minisql/context/ReturnFieldCapability.scala b/src/main/scala/minisql/context/ReturnFieldCapability.scala new file mode 100644 index 0000000..cadbb78 --- /dev/null +++ b/src/main/scala/minisql/context/ReturnFieldCapability.scala @@ -0,0 +1,64 @@ +package minisql.context + +sealed trait ReturningCapability + +/** + * Data cannot be returned Insert/Update/etc... clauses in the target database. + */ +sealed trait ReturningNotSupported extends ReturningCapability + +/** + * Returning a single field from Insert/Update/etc... clauses is supported. This + * is the most common databases e.g. MySQL, Sqlite, and H2 (although as of + * h2database/h2database#1972 this may change. See #1496 regarding this. + * Typically this needs to be setup in the JDBC + * `connection.prepareStatement(sql, Array("returnColumn"))`. + */ +sealed trait ReturningSingleFieldSupported extends ReturningCapability + +/** + * Returning multiple columns from Insert/Update/etc... clauses is supported. + * This generally means that columns besides auto-incrementing ones can be + * returned. This is supported by Oracle. In JDBC, the following is done: + * `connection.prepareStatement(sql, Array("column1, column2, ..."))`. + */ +sealed trait ReturningMultipleFieldSupported extends ReturningCapability + +/** + * An actual `RETURNING` clause is supported in the SQL dialect of the specified + * database e.g. Postgres. this typically means that columns returned from + * Insert/Update/etc... clauses can have other database operations done on them + * such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. In + * JDBC, the following is done: `connection.prepareStatement(sql, + * Statement.RETURN_GENERATED_KEYS))`. + */ +sealed trait ReturningClauseSupported extends ReturningCapability + +object ReturningNotSupported extends ReturningNotSupported +object ReturningSingleFieldSupported extends ReturningSingleFieldSupported +object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported +object ReturningClauseSupported extends ReturningClauseSupported + +trait Capabilities { + def idiomReturningCapability: ReturningCapability +} + +trait CanReturnClause extends Capabilities { + override def idiomReturningCapability: ReturningClauseSupported = + ReturningClauseSupported +} + +trait CanReturnField extends Capabilities { + override def idiomReturningCapability: ReturningSingleFieldSupported = + ReturningSingleFieldSupported +} + +trait CanReturnMultiField extends Capabilities { + override def idiomReturningCapability: ReturningMultipleFieldSupported = + ReturningMultipleFieldSupported +} + +trait CannotReturn extends Capabilities { + override def idiomReturningCapability: ReturningNotSupported = + ReturningNotSupported +} diff --git a/src/main/scala/minisql/dsl.scala b/src/main/scala/minisql/dsl.scala index 43cc9c0..ace3d8f 100644 --- a/src/main/scala/minisql/dsl.scala +++ b/src/main/scala/minisql/dsl.scala @@ -29,20 +29,6 @@ private inline def transform[A, B](inline q1: Quoted)( inline def query[E](inline table: String): EntityQuery[E] = Entity(table, Nil) -inline def compile(inline x: Ast): Option[String] = ${ - compileImpl('{ x }) -} - -private def compileImpl( - x: Expr[Ast] -)(using Quotes): Expr[Option[String]] = { - import quotes.reflect.* - x.value match { - case Some(xv) => '{ Some(${ Expr(xv.toString()) }) } - case None => '{ None } - } -} - extension [A, B](inline f1: A => B) { private inline def param0 = parsing.parseParamAt(f1, 0) private inline def body = parsing.parseBody(f1) @@ -54,6 +40,6 @@ extension [A1, A2, B](inline f1: (A1, A2) => B) { private inline def body = parsing.parseBody(f1) } -case class Foo(id: Int) +def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() -inline def queryFooId = query[Foo]("foo").map(_.id) +class NonQuotedException extends Exception("Cannot be used at runtime") diff --git a/src/main/scala/minisql/idiom/Idiom.scala b/src/main/scala/minisql/idiom/Idiom.scala new file mode 100644 index 0000000..7e0cd01 --- /dev/null +++ b/src/main/scala/minisql/idiom/Idiom.scala @@ -0,0 +1,23 @@ +package minisql.idiom + +import minisql.NamingStrategy +import minisql.ast._ +import minisql.context.Capabilities + +trait Idiom extends Capabilities { + + def emptySetContainsToken(field: Token): Token = StringToken("FALSE") + + def defaultAutoGeneratedToken(field: Token): Token = StringToken( + "DEFAULT VALUES" + ) + + def liftingPlaceholder(index: Int): String + + def translate(ast: Ast)(implicit naming: NamingStrategy): (Ast, Statement) + + def format(queryString: String): String = queryString + + def prepareForProbing(string: String): String + +} diff --git a/src/main/scala/minisql/idiom/LoadNaming.scala b/src/main/scala/minisql/idiom/LoadNaming.scala new file mode 100644 index 0000000..2405080 --- /dev/null +++ b/src/main/scala/minisql/idiom/LoadNaming.scala @@ -0,0 +1,28 @@ +package minisql.idiom + +import scala.util.Try +import scala.quoted._ +import minisql.NamingStrategy +import minisql.util.CollectTry +import minisql.util.LoadObject +import minisql.CompositeNamingStrategy + +object LoadNaming { + + def static[C](using Quotes, Type[C]): Try[NamingStrategy] = CollectTry { + strategies[C].map(LoadObject[NamingStrategy](_)) + }.map(NamingStrategy(_)) + + private def strategies[C](using Quotes, Type[C]) = { + import quotes.reflect.* + val isComposite = TypeRepr.of[C] <:< TypeRepr.of[CompositeNamingStrategy] + val ct = TypeRepr.of[C] + if (isComposite) { + ct.typeArgs.filterNot { t => + t =:= TypeRepr.of[NamingStrategy] && t =:= TypeRepr.of[Nothing] + } + } else { + List(ct) + } + } +} diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala new file mode 100644 index 0000000..a3e8902 --- /dev/null +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -0,0 +1,68 @@ +package minisql.idiom + +import minisql.ast._ +import minisql.util.Interleave +import minisql.idiom.StatementInterpolator._ +import scala.annotation.tailrec +import scala.collection.immutable.{Map => SMap} + +object ReifyStatement { + + def apply( + liftingPlaceholder: Int => String, + emptySetContainsToken: Token => Token, + statement: Statement, + liftMap: SMap[String, (Any, Any)] + ): (String, List[ScalarValueLift]) = { + val expanded = expandLiftings(statement, emptySetContainsToken, liftMap) + token2string(expanded, liftingPlaceholder) + } + + private def token2string( + token: Token, + liftingPlaceholder: Int => String + ): (String, List[ScalarValueLift]) = { + + val liftBuilder = List.newBuilder[ScalarValueLift] + val sqlBuilder = StringBuilder() + @tailrec + def loop( + workList: Seq[Token], + liftingSize: Int + ): Unit = workList match { + case Seq() => () + case head +: tail => + head match { + case StringToken(s2) => + sqlBuilder ++= s2 + loop(tail, liftingSize) + case SetContainsToken(a, op, b) => + loop( + stmt"$a $op ($b)" +: tail, + liftingSize + ) + case ScalarLiftToken(lift: ScalarValueLift) => + sqlBuilder ++= liftingPlaceholder(liftingSize) + liftBuilder += lift + loop(tail, liftingSize + 1) + case ScalarLiftToken(o) => + throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}") + case Statement(tokens) => + loop( + tokens.foldRight(tail)(_ +: _), + liftingSize + ) + } + } + loop(Vector(token), 0) + sqlBuilder.toString() -> liftBuilder.result() + } + + private def expandLiftings( + statement: Statement, + emptySetContainsToken: Token => Token, + liftMap: SMap[String, (Any, Any)] + ): (Token) = { + statement + } +} diff --git a/src/main/scala/minisql/idiom/Statement.scala b/src/main/scala/minisql/idiom/Statement.scala new file mode 100644 index 0000000..987ee2f --- /dev/null +++ b/src/main/scala/minisql/idiom/Statement.scala @@ -0,0 +1,47 @@ +package minisql.idiom + +import scala.quoted._ +import minisql.ast._ + +sealed trait Token + +case class StringToken(string: String) extends Token { + override def toString = string +} + +case class ScalarLiftToken(lift: ScalarLift) extends Token { + override def toString = s"lift(${lift.name})" +} + +case class Statement(tokens: List[Token]) extends Token { + override def toString = tokens.mkString +} + +case class SetContainsToken(a: Token, op: Token, b: Token) extends Token { + override def toString = s"${a.toString} ${op.toString} (${b.toString})" +} + +object Statement { + given ToExpr[Statement] with { + def apply(t: Statement)(using Quotes): Expr[Statement] = { + '{ Statement(${ Expr(t.tokens) }) } + } + } +} + +object Token { + + given ToExpr[Token] with { + def apply(t: Token)(using Quotes): Expr[Token] = { + t match { + case StringToken(s) => + '{ StringToken(${ Expr(s) }) } + case ScalarLiftToken(l) => + '{ ScalarLiftToken(${ Expr(l) }) } + case SetContainsToken(a, op, b) => + '{ SetContainsToken(${ Expr(a) }, ${ Expr(op) }, ${ Expr(b) }) } + case s: Statement => Expr(s) + } + } + } +} diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala new file mode 100644 index 0000000..b732da1 --- /dev/null +++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala @@ -0,0 +1,146 @@ +package minisql.idiom + +import minisql.ast._ +import minisql.util.Interleave +import minisql.util.Messages._ + +import scala.collection.mutable.ListBuffer + +object StatementInterpolator { + + trait Tokenizer[T] { + def token(v: T): Token + } + + object Tokenizer { + def apply[T](f: T => Token) = new Tokenizer[T] { + def token(v: T) = f(v) + } + def withFallback[T]( + fallback: Tokenizer[T] => Tokenizer[T] + )(pf: PartialFunction[T, Token]) = + new Tokenizer[T] { + private val stable = fallback(this) + override def token(v: T) = pf.applyOrElse(v, stable.token) + } + } + + implicit class TokenImplicit[T](v: T)(implicit tokenizer: Tokenizer[T]) { + def token = tokenizer.token(v) + } + + implicit def stringTokenizer: Tokenizer[String] = + Tokenizer[String] { + case string => StringToken(string) + } + + implicit def liftTokenizer: Tokenizer[Lift] = + Tokenizer[Lift] { + case lift: ScalarLift => ScalarLiftToken(lift) + case lift => + fail( + s"Can't tokenize a non-scalar lifting. ${lift.name}\n" + + s"\n" + + s"This might happen because:\n" + + s"* You are trying to insert or update an `Option[A]` field, but Scala infers the type\n" + + s" to `Some[A]` or `None.type`. For example:\n" + + s" run(query[Users].update(_.optionalField -> lift(Some(value))))" + + s" In that case, make sure the type is `Option`:\n" + + s" run(query[Users].update(_.optionalField -> lift(Some(value): Option[Int])))\n" + + s" or\n" + + s" run(query[Users].update(_.optionalField -> lift(Option(value))))\n" + + s"\n" + + s"* You are trying to insert or update whole Embedded case class. For example:\n" + + s" run(query[Users].update(_.embeddedCaseClass -> lift(someInstance)))\n" + + s" In that case, make sure you are updating individual columns, for example:\n" + + s" run(query[Users].update(\n" + + s" _.embeddedCaseClass.a -> lift(someInstance.a),\n" + + s" _.embeddedCaseClass.b -> lift(someInstance.b)\n" + + s" ))" + ) + } + + implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity) + implicit def statementTokenizer: Tokenizer[Statement] = + Tokenizer[Statement](identity) + implicit def stringTokenTokenizer: Tokenizer[StringToken] = + Tokenizer[StringToken](identity) + implicit def liftingTokenTokenizer: Tokenizer[ScalarLiftToken] = + Tokenizer[ScalarLiftToken](identity) + + extension [T](list: List[T]) { + def mkStmt(sep: String = ", ")(implicit tokenize: Tokenizer[T]) = { + val l1 = list.map(_.token) + val l2 = List.fill(l1.size - 1)(StringToken(sep)) + Statement(Interleave(l1, l2)) + } + } + + implicit def listTokenizer[T](implicit + tokenize: Tokenizer[T] + ): Tokenizer[List[T]] = + Tokenizer[List[T]] { + case list => list.mkStmt() + } + + extension (sc: StringContext) { + + def flatten(tokens: List[Token]): List[Token] = { + + def unestStatements(tokens: List[Token]): List[Token] = { + tokens.flatMap { + case Statement(innerTokens) => unestStatements(innerTokens) + case token => token :: Nil + } + } + + def mergeStringTokens(tokens: List[Token]): List[Token] = { + val (resultBuilder, leftTokens) = + tokens.foldLeft((new ListBuffer[Token], new ListBuffer[String])) { + case ((builder, acc), stringToken: StringToken) => + val str = stringToken.string + if (str.nonEmpty) + acc += stringToken.string + (builder, acc) + case ((builder, prev), b) if prev.isEmpty => + (builder += b.token, prev) + case ((builder, prev), b) /* if prev.nonEmpty */ => + builder += StringToken(prev.result().mkString) + builder += b.token + (builder, new ListBuffer[String]) + } + if (leftTokens.nonEmpty) + resultBuilder += StringToken(leftTokens.result().mkString) + resultBuilder.result() + } + + (unestStatements) + .andThen(mergeStringTokens) + .apply(tokens) + } + + def checkLengths( + args: scala.collection.Seq[Any], + parts: Seq[String] + ): Unit = + if (parts.length != args.length + 1) + throw new IllegalArgumentException( + "wrong number of arguments (" + args.length + + ") for interpolated string with " + parts.length + " parts" + ) + + def stmt(args: Token*): Statement = { + checkLengths(args, sc.parts) + val partsIterator = sc.parts.iterator + val argsIterator = args.iterator + val bldr = List.newBuilder[Token] + bldr += StringToken(partsIterator.next()) + while (argsIterator.hasNext) { + bldr += argsIterator.next() + bldr += StringToken(partsIterator.next()) + } + val tokens = flatten(bldr.result()) + Statement(tokens) + } + } +} diff --git a/src/main/scala/minisql/norm/AdHocReduction.scala b/src/main/scala/minisql/norm/AdHocReduction.scala new file mode 100644 index 0000000..10ad9a8 --- /dev/null +++ b/src/main/scala/minisql/norm/AdHocReduction.scala @@ -0,0 +1,52 @@ +package minisql.norm + +import minisql.ast.BinaryOperation +import minisql.ast.BooleanOperator +import minisql.ast.Filter +import minisql.ast.FlatMap +import minisql.ast.Map +import minisql.ast.Query +import minisql.ast.Union +import minisql.ast.UnionAll + +object AdHocReduction { + + def unapply(q: Query) = + q match { + + // --------------------------- + // *.filter + + // a.filter(b => c).filter(d => e) => + // a.filter(b => c && e[d := b]) + case Filter(Filter(a, b, c), d, e) => + val er = BetaReduction(e, d -> b) + Some(Filter(a, b, BinaryOperation(c, BooleanOperator.`&&`, er))) + + // --------------------------- + // flatMap.* + + // a.flatMap(b => c).map(d => e) => + // a.flatMap(b => c.map(d => e)) + case Map(FlatMap(a, b, c), d, e) => + Some(FlatMap(a, b, Map(c, d, e))) + + // a.flatMap(b => c).filter(d => e) => + // a.flatMap(b => c.filter(d => e)) + case Filter(FlatMap(a, b, c), d, e) => + Some(FlatMap(a, b, Filter(c, d, e))) + + // a.flatMap(b => c.union(d)) + // a.flatMap(b => c).union(a.flatMap(b => d)) + case FlatMap(a, b, Union(c, d)) => + Some(Union(FlatMap(a, b, c), FlatMap(a, b, d))) + + // a.flatMap(b => c.unionAll(d)) + // a.flatMap(b => c).unionAll(a.flatMap(b => d)) + case FlatMap(a, b, UnionAll(c, d)) => + Some(UnionAll(FlatMap(a, b, c), FlatMap(a, b, d))) + + case other => None + } + +} diff --git a/src/main/scala/minisql/norm/ApplyMap.scala b/src/main/scala/minisql/norm/ApplyMap.scala new file mode 100644 index 0000000..e5ddb0c --- /dev/null +++ b/src/main/scala/minisql/norm/ApplyMap.scala @@ -0,0 +1,160 @@ +package minisql.norm + +import minisql.ast._ + +object ApplyMap { + + private def isomorphic(e: Ast, c: Ast, alias: Ident) = + BetaReduction(e, alias -> c) == c + + object InfixedTailOperation { + + def hasImpureInfix(ast: Ast) = + CollectAst(ast) { + case i @ Infix(_, _, false, _) => i + }.nonEmpty + + def unapply(ast: Ast): Option[Ast] = + ast match { + case cc: CaseClass if hasImpureInfix(cc) => Some(cc) + case tup: Tuple if hasImpureInfix(tup) => Some(tup) + case p: Property if hasImpureInfix(p) => Some(p) + case b: BinaryOperation if hasImpureInfix(b) => Some(b) + case u: UnaryOperation if hasImpureInfix(u) => Some(u) + case i @ Infix(_, _, false, _) => Some(i) + case _ => None + } + } + + object MapWithoutInfixes { + def unapply(ast: Ast): Option[(Ast, Ident, Ast)] = + ast match { + case Map(a, b, InfixedTailOperation(c)) => None + case Map(a, b, c) => Some((a, b, c)) + case _ => None + } + } + + object DetachableMap { + def unapply(ast: Ast): Option[(Ast, Ident, Ast)] = + ast match { + case Map(a: GroupBy, b, c) => None + case Map(a, b, InfixedTailOperation(c)) => None + case Map(a, b, c) => Some((a, b, c)) + case _ => None + } + } + + def unapply(q: Query): Option[Query] = + q match { + + case Map(a: GroupBy, b, c) if (b == c) => None + case Map(a: DistinctOn, b, c) => None + case Map(a: Nested, b, c) if (b == c) => None + case Nested(DetachableMap(a: Join, b, c)) => None + + // map(i => (i.i, i.l)).distinct.map(x => (x._1, x._2)) => + // map(i => (i.i, i.l)).distinct + case Map(Distinct(DetachableMap(a, b, c)), d, e) if isomorphic(e, c, d) => + Some(Distinct(Map(a, b, c))) + + // a.map(b => c).map(d => e) => + // a.map(b => e[d := c]) + case before @ Map(MapWithoutInfixes(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + Some(Map(a, b, er)) + + // a.map(b => b) => + // a + case Map(a: Query, b, c) if (b == c) => + Some(a) + + // a.map(b => c).flatMap(d => e) => + // a.flatMap(b => e[d := c]) + case FlatMap(DetachableMap(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + Some(FlatMap(a, b, er)) + + // a.map(b => c).filter(d => e) => + // a.filter(b => e[d := c]).map(b => c) + case Filter(DetachableMap(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + Some(Map(Filter(a, b, er), b, c)) + + // a.map(b => c).sortBy(d => e) => + // a.sortBy(b => e[d := c]).map(b => c) + case SortBy(DetachableMap(a, b, c), d, e, f) => + val er = BetaReduction(e, d -> c) + Some(Map(SortBy(a, b, er, f), b, c)) + + // a.map(b => c).sortBy(d => e).distinct => + // a.sortBy(b => e[d := c]).map(b => c).distinct + case SortBy(Distinct(DetachableMap(a, b, c)), d, e, f) => + val er = BetaReduction(e, d -> c) + Some(Distinct(Map(SortBy(a, b, er, f), b, c))) + + // a.map(b => c).groupBy(d => e) => + // a.groupBy(b => e[d := c]).map(x => (x._1, x._2.map(b => c))) + case GroupBy(DetachableMap(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + val x = Ident("x") + val x1 = Property( + Ident("x"), + "_1" + ) // These introduced property should not be renamed + val x2 = Property(Ident("x"), "_2") // due to any naming convention. + val body = Tuple(List(x1, Map(x2, b, c))) + Some(Map(GroupBy(a, b, er), x, body)) + + // a.map(b => c).drop(d) => + // a.drop(d).map(b => c) + case Drop(DetachableMap(a, b, c), d) => + Some(Map(Drop(a, d), b, c)) + + // a.map(b => c).take(d) => + // a.drop(d).map(b => c) + case Take(DetachableMap(a, b, c), d) => + Some(Map(Take(a, d), b, c)) + + // a.map(b => c).nested => + // a.nested.map(b => c) + case Nested(DetachableMap(a, b, c)) => + Some(Map(Nested(a), b, c)) + + // a.map(b => c).*join(d.map(e => f)).on((iA, iB) => on) + // a.*join(d).on((b, e) => on[iA := c, iB := f]).map(t => (c[b := t._1], f[e := t._2])) + case Join( + tpe, + DetachableMap(a, b, c), + DetachableMap(d, e, f), + iA, + iB, + on + ) => + val onr = BetaReduction(on, iA -> c, iB -> f) + val t = Ident("t") + val t1 = BetaReduction(c, b -> Property(t, "_1")) + val t2 = BetaReduction(f, e -> Property(t, "_2")) + Some(Map(Join(tpe, a, d, b, e, onr), t, Tuple(List(t1, t2)))) + + // a.*join(b.map(c => d)).on((iA, iB) => on) + // a.*join(b).on((iA, c) => on[iB := d]).map(t => (t._1, d[c := t._2])) + case Join(tpe, a, DetachableMap(b, c, d), iA, iB, on) => + val onr = BetaReduction(on, iB -> d) + val t = Ident("t") + val t1 = Property(t, "_1") + val t2 = BetaReduction(d, c -> Property(t, "_2")) + Some(Map(Join(tpe, a, b, iA, c, onr), t, Tuple(List(t1, t2)))) + + // a.map(b => c).*join(d).on((iA, iB) => on) + // a.*join(d).on((b, iB) => on[iA := c]).map(t => (c[b := t._1], t._2)) + case Join(tpe, DetachableMap(a, b, c), d, iA, iB, on) => + val onr = BetaReduction(on, iA -> c) + val t = Ident("t") + val t1 = BetaReduction(c, b -> Property(t, "_1")) + val t2 = Property(t, "_2") + Some(Map(Join(tpe, a, d, b, iB, onr), t, Tuple(List(t1, t2)))) + + case other => None + } +} diff --git a/src/main/scala/minisql/norm/AttachToEntity.scala b/src/main/scala/minisql/norm/AttachToEntity.scala new file mode 100644 index 0000000..e5608e9 --- /dev/null +++ b/src/main/scala/minisql/norm/AttachToEntity.scala @@ -0,0 +1,48 @@ +package minisql.norm + +import minisql.util.Messages.fail +import minisql.ast._ + +object AttachToEntity { + + private object IsEntity { + def unapply(q: Ast): Option[Ast] = + q match { + case q: Entity => Some(q) + case q: Infix => Some(q) + case _ => None + } + } + + def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)( + q: Ast + ): Ast = + q match { + + case Map(IsEntity(a), b, c) => Map(f(a, b), b, c) + case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c) + case ConcatMap(IsEntity(a), b, c) => ConcatMap(f(a, b), b, c) + case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c) + case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d) + case DistinctOn(IsEntity(a), b, c) => DistinctOn(f(a, b), b, c) + + case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join | + _: FlatJoin => + f(q, alias.getOrElse(Ident("x"))) + + case Map(a: Query, b, c) => Map(apply(f, Some(b))(a), b, c) + case FlatMap(a: Query, b, c) => FlatMap(apply(f, Some(b))(a), b, c) + case ConcatMap(a: Query, b, c) => ConcatMap(apply(f, Some(b))(a), b, c) + case Filter(a: Query, b, c) => Filter(apply(f, Some(b))(a), b, c) + case SortBy(a: Query, b, c, d) => SortBy(apply(f, Some(b))(a), b, c, d) + case Take(a: Query, b) => Take(apply(f, alias)(a), b) + case Drop(a: Query, b) => Drop(apply(f, alias)(a), b) + case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a)) + case Distinct(a: Query) => Distinct(apply(f, alias)(a)) + case DistinctOn(a: Query, b, c) => DistinctOn(apply(f, Some(b))(a), b, c) + + case IsEntity(q) => f(q, alias.getOrElse(Ident("x"))) + + case other => fail(s"Can't find an 'Entity' in '$q'") + } +} diff --git a/src/main/scala/minisql/util/BetaReduction.scala b/src/main/scala/minisql/norm/BetaReduction.scala similarity index 99% rename from src/main/scala/minisql/util/BetaReduction.scala rename to src/main/scala/minisql/norm/BetaReduction.scala index 0940564..868b021 100644 --- a/src/main/scala/minisql/util/BetaReduction.scala +++ b/src/main/scala/minisql/norm/BetaReduction.scala @@ -1,6 +1,6 @@ -package minisql.util +package minisql.norm -import minisql.ast.* +import minisql.ast._ import scala.collection.immutable.{Map => IMap} case class BetaReduction(replacements: Replacements) diff --git a/src/main/scala/minisql/norm/ConcatBehavior.scala b/src/main/scala/minisql/norm/ConcatBehavior.scala new file mode 100644 index 0000000..3547b1d --- /dev/null +++ b/src/main/scala/minisql/norm/ConcatBehavior.scala @@ -0,0 +1,7 @@ +package minisql.norm + +trait ConcatBehavior +object ConcatBehavior { + case object AnsiConcat extends ConcatBehavior + case object NonAnsiConcat extends ConcatBehavior +} diff --git a/src/main/scala/minisql/norm/EqualityBehavior.scala b/src/main/scala/minisql/norm/EqualityBehavior.scala new file mode 100644 index 0000000..ee1ff38 --- /dev/null +++ b/src/main/scala/minisql/norm/EqualityBehavior.scala @@ -0,0 +1,7 @@ +package minisql.norm + +trait EqualityBehavior +object EqualityBehavior { + case object AnsiEquality extends EqualityBehavior + case object NonAnsiEquality extends EqualityBehavior +} diff --git a/src/main/scala/minisql/norm/ExpandReturning.scala b/src/main/scala/minisql/norm/ExpandReturning.scala new file mode 100644 index 0000000..32a886f --- /dev/null +++ b/src/main/scala/minisql/norm/ExpandReturning.scala @@ -0,0 +1,74 @@ +package minisql.norm + +import minisql.ReturnAction.ReturnColumns +import minisql.{NamingStrategy, ReturnAction} +import minisql.ast._ +import minisql.context.{ + ReturningClauseSupported, + ReturningMultipleFieldSupported, + ReturningNotSupported, + ReturningSingleFieldSupported +} +import minisql.idiom.{Idiom, Statement} + +/** + * Take the `.returning` part in a query that contains it and return the array + * of columns representing of the returning seccovtion with any other operations + * etc... that they might contain. + */ +object ExpandReturning { + + def applyMap( + returning: ReturningAction + )(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = { + val initialExpand = ExpandReturning.apply(returning)(idiom, naming) + + idiom.idiomReturningCapability match { + case ReturningClauseSupported => + ReturnAction.ReturnRecord + case ReturningMultipleFieldSupported => + ReturnColumns(initialExpand.map { + case (ast, statement) => f(ast, statement) + }) + case ReturningSingleFieldSupported => + if (initialExpand.length == 1) + ReturnColumns(initialExpand.map { + case (ast, statement) => f(ast, statement) + }) + else + throw new IllegalArgumentException( + s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified." + ) + case ReturningNotSupported => + throw new IllegalArgumentException( + s"RETURNING columns are not allowed in the ${idiom} dialect." + ) + } + } + + def apply( + returning: ReturningAction + )(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = { + val ReturningAction(_, alias, properties) = returning: @unchecked + + // Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1)))) + // => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))) + val dePropertized = + Transform(properties) { + case `alias` => ExternalIdent(alias.name) + } + + val aliasName = alias.name + + // Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))) + // => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))) + val deTuplified = dePropertized match { + case Tuple(values) => values + case CaseClass(values) => values.map(_._2) + case other => List(other) + } + + implicit val namingStrategy: NamingStrategy = naming + deTuplified.map(v => idiom.translate(v)) + } +} diff --git a/src/main/scala/minisql/norm/FlattenOptionOperation.scala b/src/main/scala/minisql/norm/FlattenOptionOperation.scala new file mode 100644 index 0000000..50ba857 --- /dev/null +++ b/src/main/scala/minisql/norm/FlattenOptionOperation.scala @@ -0,0 +1,108 @@ +package minisql.norm + +import minisql.ast.* +import minisql.ast.Implicits.* +import minisql.norm.ConcatBehavior.NonAnsiConcat + +class FlattenOptionOperation(concatBehavior: ConcatBehavior) + extends StatelessTransformer { + + private def emptyOrNot(b: Boolean, ast: Ast) = + if (b) OptionIsEmpty(ast) else OptionNonEmpty(ast) + + def uncheckedReduction(ast: Ast, alias: Ident, body: Ast) = + apply(BetaReduction(body, alias -> ast)) + + def uncheckedForall(ast: Ast, alias: Ident, body: Ast) = { + val reduced = BetaReduction(body, alias -> ast) + apply((IsNullCheck(ast) +||+ reduced): Ast) + } + + def containsNonFallthroughElement(ast: Ast) = + CollectAst(ast) { + case If(_, _, _) => true + case Infix(_, _, _, _) => true + case BinaryOperation(_, StringOperator.`concat`, _) + if (concatBehavior == NonAnsiConcat) => + true + }.nonEmpty + + override def apply(ast: Ast): Ast = + ast match { + + case OptionTableFlatMap(ast, alias, body) => + uncheckedReduction(ast, alias, body) + + case OptionTableMap(ast, alias, body) => + uncheckedReduction(ast, alias, body) + + case OptionTableExists(ast, alias, body) => + uncheckedReduction(ast, alias, body) + + case OptionTableForall(ast, alias, body) => + uncheckedForall(ast, alias, body) + + case OptionFlatten(ast) => + apply(ast) + + case OptionSome(ast) => + apply(ast) + + case OptionApply(ast) => + apply(ast) + + case OptionOrNull(ast) => + apply(ast) + + case OptionGetOrNull(ast) => + apply(ast) + + case OptionNone => NullValue + + case OptionGetOrElse(OptionMap(ast, alias, body), Constant(b: Boolean)) => + apply((BetaReduction(body, alias -> ast) +||+ emptyOrNot(b, ast)): Ast) + + case OptionGetOrElse(ast, body) => + apply(If(IsNotNullCheck(ast), ast, body)) + + case OptionFlatMap(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduced = BetaReduction(body, alias -> ast) + apply(IfExistElseNull(ast, reduced)) + } else { + uncheckedReduction(ast, alias, body) + } + + case OptionMap(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduced = BetaReduction(body, alias -> ast) + apply(IfExistElseNull(ast, reduced)) + } else { + uncheckedReduction(ast, alias, body) + } + + case OptionForall(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduction = BetaReduction(body, alias -> ast) + apply( + (IsNullCheck(ast) +||+ (IsNotNullCheck(ast) +&&+ reduction)): Ast + ) + } else { + uncheckedForall(ast, alias, body) + } + + case OptionExists(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduction = BetaReduction(body, alias -> ast) + apply((IsNotNullCheck(ast) +&&+ reduction): Ast) + } else { + uncheckedReduction(ast, alias, body) + } + + case OptionContains(ast, body) => + apply((ast +==+ body): Ast) + + case other => + super.apply(other) + } +} diff --git a/src/main/scala/minisql/norm/NestImpureMappedInfix.scala b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala new file mode 100644 index 0000000..789efba --- /dev/null +++ b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala @@ -0,0 +1,76 @@ +package minisql.norm + +import minisql.ast._ + +/** + * A problem occurred in the original way infixes were done in that it was + * assumed that infix clauses represented pure functions. While this is true of + * many UDFs (e.g. `CONCAT`, `GETDATE`) it is certainly not true of many others + * e.g. `RAND()`, and most importantly `RANK()`. For this reason, the operations + * that are done in `ApplyMap` on standard AST `Map` clauses cannot be done + * therefore additional safety checks were introduced there in order to assure + * this does not happen. In addition to this however, it is necessary to add + * this normalization step which inserts `Nested` AST elements in every map that + * contains impure infix. See more information and examples in #1534. + */ +object NestImpureMappedInfix extends StatelessTransformer { + + // Are there any impure infixes that exist inside the specified ASTs + def hasInfix(asts: Ast*): Boolean = + asts.exists(ast => + CollectAst(ast) { + case i @ Infix(_, _, false, _) => i + }.nonEmpty + ) + + // Continue exploring into the Map to see if there are additional impure infix clauses inside. + private def applyInside(m: Map) = + Map(apply(m.query), m.alias, m.body) + + override def apply(ast: Ast): Ast = + ast match { + // If there is already a nested clause inside the map, there is no reason to insert another one + case Nested(Map(inner, a, b)) => + Nested(Map(apply(inner), a, b)) + + case m @ Map(_, x, cc @ CaseClass(values)) if hasInfix(cc) => // Nested(m) + Map( + Nested(applyInside(m)), + x, + CaseClass(values.map { + case (name, _) => + ( + name, + Property(x, name) + ) // mappings of nested-query case class properties should not be renamed + }) + ) + + case m @ Map(_, x, tup @ Tuple(values)) if hasInfix(tup) => + Map( + Nested(applyInside(m)), + x, + Tuple(values.zipWithIndex.map { + case (_, i) => + Property( + x, + s"_${i + 1}" + ) // mappings of nested-query tuple properties should not be renamed + }) + ) + + case m @ Map(_, x, i @ Infix(_, _, false, _)) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case m @ Map(_, x, Property(prop, _)) if hasInfix(prop) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case m @ Map(_, x, BinaryOperation(a, _, b)) if hasInfix(a, b) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case m @ Map(_, x, UnaryOperation(_, a)) if hasInfix(a) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case other => super.apply(other) + } +} diff --git a/src/main/scala/minisql/norm/Normalize.scala b/src/main/scala/minisql/norm/Normalize.scala new file mode 100644 index 0000000..d3d8d67 --- /dev/null +++ b/src/main/scala/minisql/norm/Normalize.scala @@ -0,0 +1,51 @@ +package minisql.norm + +import minisql.ast.Ast +import minisql.ast.Query +import minisql.ast.StatelessTransformer +import minisql.norm.capture.AvoidCapture +import minisql.ast.Action +import minisql.util.Messages.trace +import minisql.util.Messages.TraceType.Normalizations + +import scala.annotation.tailrec + +object Normalize extends StatelessTransformer { + + override def apply(q: Ast): Ast = + super.apply(BetaReduction(q)) + + override def apply(q: Action): Action = + NormalizeReturning(super.apply(q)) + + override def apply(q: Query): Query = + norm(AvoidCapture(q)) + + private def traceNorm[T](label: String) = + trace[T](s"${label} (Normalize)", 1, Normalizations) + + @tailrec + private def norm(q: Query): Query = + q match { + case NormalizeNestedStructures(query) => + traceNorm("NormalizeNestedStructures")(query) + norm(query) + case ApplyMap(query) => + traceNorm("ApplyMap")(query) + norm(query) + case SymbolicReduction(query) => + traceNorm("SymbolicReduction")(query) + norm(query) + case AdHocReduction(query) => + traceNorm("AdHocReduction")(query) + norm(query) + case OrderTerms(query) => + traceNorm("OrderTerms")(query) + norm(query) + case NormalizeAggregationIdent(query) => + traceNorm("NormalizeAggregationIdent")(query) + norm(query) + case other => + other + } +} diff --git a/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala new file mode 100644 index 0000000..c451ee4 --- /dev/null +++ b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala @@ -0,0 +1,29 @@ +package minisql.norm + +import minisql.ast._ + +object NormalizeAggregationIdent { + + def unapply(q: Query) = + q match { + + // a => a.b.map(x => x.c).agg => + // a => a.b.map(a => a.c).agg + case Aggregation( + op, + Map( + p @ Property(i: Ident, _), + mi, + Property.Opinionated(_: Ident, n, renameable, visibility) + ) + ) if i != mi => + Some( + Aggregation( + op, + Map(p, i, Property.Opinionated(i, n, renameable, visibility)) + ) + ) // in example aove, if c in x.c is fixed c in a.c should also be + + case _ => None + } +} diff --git a/src/main/scala/minisql/norm/NormalizeNestedStructures.scala b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala new file mode 100644 index 0000000..603c411 --- /dev/null +++ b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala @@ -0,0 +1,47 @@ +package minisql.norm + +import minisql.ast._ + +object NormalizeNestedStructures { + + def unapply(q: Query): Option[Query] = + q match { + case e: Entity => None + case Map(a, b, c) => apply(a, c)(Map(_, b, _)) + case FlatMap(a, b, c) => apply(a, c)(FlatMap(_, b, _)) + case ConcatMap(a, b, c) => apply(a, c)(ConcatMap(_, b, _)) + case Filter(a, b, c) => apply(a, c)(Filter(_, b, _)) + case SortBy(a, b, c, d) => apply(a, c)(SortBy(_, b, _, d)) + case GroupBy(a, b, c) => apply(a, c)(GroupBy(_, b, _)) + case Aggregation(a, b) => apply(b)(Aggregation(a, _)) + case Take(a, b) => apply(a, b)(Take.apply) + case Drop(a, b) => apply(a, b)(Drop.apply) + case Union(a, b) => apply(a, b)(Union.apply) + case UnionAll(a, b) => apply(a, b)(UnionAll.apply) + case Distinct(a) => apply(a)(Distinct.apply) + case DistinctOn(a, b, c) => apply(a, c)(DistinctOn(_, b, _)) + case Nested(a) => apply(a)(Nested.apply) + case FlatJoin(t, a, iA, on) => + (Normalize(a), Normalize(on)) match { + case (`a`, `on`) => None + case (a, on) => Some(FlatJoin(t, a, iA, on)) + } + case Join(t, a, b, iA, iB, on) => + (Normalize(a), Normalize(b), Normalize(on)) match { + case (`a`, `b`, `on`) => None + case (a, b, on) => Some(Join(t, a, b, iA, iB, on)) + } + } + + private def apply(a: Ast)(f: Ast => Query) = + (Normalize(a)) match { + case (`a`) => None + case (a) => Some(f(a)) + } + + private def apply(a: Ast, b: Ast)(f: (Ast, Ast) => Query) = + (Normalize(a), Normalize(b)) match { + case (`a`, `b`) => None + case (a, b) => Some(f(a, b)) + } +} diff --git a/src/main/scala/minisql/norm/NormalizeReturning.scala b/src/main/scala/minisql/norm/NormalizeReturning.scala new file mode 100644 index 0000000..43fc241 --- /dev/null +++ b/src/main/scala/minisql/norm/NormalizeReturning.scala @@ -0,0 +1,154 @@ +package minisql.norm + +import minisql.ast._ +import minisql.norm.capture.AvoidAliasConflict + +/** + * When actions are used with a `.returning` clause, remove the columns used in + * the returning clause from the action. E.g. for `insert(Person(id, + * name)).returning(_.id)` remove the `id` column from the original insert. + */ +object NormalizeReturning { + + def apply(e: Action): Action = { + e match { + case ReturningGenerated(a: Action, alias, body) => + // De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action. + // This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres. + // For example: + // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max)) + // Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally + // be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case. + // In order to fix this, we need to change `t` into a different alias. + val newBody = dealiasBody(body, alias) + ReturningGenerated(apply(a, newBody, alias), alias, newBody) + + // For a regular return clause, do not need to exclude assignments from insertion however, we still + // need to de-alias the Action body in case conflicts result. For example the following query: + // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max)) + // would incorrectly be interpreted as: + // INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t + // whereas it should be: + // INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1) + case Returning(a: Action, alias, body) => + val newBody = dealiasBody(body, alias) + Returning(a, alias, newBody) + + case _ => e + } + } + + /** + * In some situations, a query can exist inside of a `returning` clause. In + * this case, we need to rename if the aliases used in that query override the + * alias used in the `returning` clause otherwise they will be treated as + * returning-clause aliases ExpandReturning (i.e. they will become + * ExternalAlias instances) and later be tokenized incorrectly. + */ + private def dealiasBody(body: Ast, alias: Ident): Ast = + Transform(body) { + case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias)) + } + + private def apply(e: Action, body: Ast, returningIdent: Ident): Action = + e match { + case Insert(query, assignments) => + Insert(query, filterReturnedColumn(assignments, body, returningIdent)) + case Update(query, assignments) => + Update(query, filterReturnedColumn(assignments, body, returningIdent)) + case OnConflict(a: Action, target, act) => + OnConflict(apply(a, body, returningIdent), target, act) + case _ => e + } + + private def filterReturnedColumn( + assignments: List[Assignment], + column: Ast, + returningIdent: Ident + ): List[Assignment] = + assignments.flatMap(filterReturnedColumn(_, column, returningIdent)) + + /** + * In situations like Property(Property(ident, foo), bar) pull out the + * inner-most ident + */ + object NestedProperty { + def unapply(ast: Property): Option[Ast] = { + ast match { + case p @ Property(subAst, _) => Some(innerMost(subAst)) + } + } + + private def innerMost(ast: Ast): Ast = ast match { + case Property(inner, _) => innerMost(inner) + case other => other + } + } + + /** + * Remove the specified column from the assignment. For example, in a query + * like `insert(Person(id, name)).returning(r => r.id)` we need to remove the + * `id` column from the insertion. The value of the `column:Ast` in this case + * will be `Property(Ident(r), id)` and the values fo the assignment `p1` + * property will typically be `v.id` and `v.name` (the `v` variable is a + * default used for `insert` queries). + */ + private def filterReturnedColumn( + assignment: Assignment, + body: Ast, + returningIdent: Ident + ): Option[Assignment] = + assignment match { + case Assignment(_, p1: Property, _) => { + // Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but + // if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2)) + // or it can even be a query since queries are allowed to be in return sections e.g: + // query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max)) + // In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment + // in order to know what to exclude. + val matchedProps = + CollectAst(body) { + // case prop @ NestedProperty(`returningIdent`) => prop + case prop @ NestedProperty(Ident(name)) + if (name == returningIdent.name) => + prop + case prop @ NestedProperty(ExternalIdent(name)) + if (name == returningIdent.name) => + prop + } + + if ( + matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp)) + ) + None + else + Some(assignment) + } + case assignment => Some(assignment) + } + + object SomeIdent { + def unapply(ast: Ast): Option[Ast] = + ast match { + case id: Ident => Some(id) + case id: ExternalIdent => Some(id) + case _ => None + } + } + + /** + * Is it the same property (but possibly of a different identity). E.g. + * `p.foo.bar` and `v.foo.bar` + */ + private def isSameProperties(p1: Property, p2: Property): Boolean = + (p1.ast, p2.ast) match { + case (SomeIdent(_), SomeIdent(_)) => + p1.name == p2.name + // If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the + // outer properties are the same before moving on to the inner ones. + case (pp1: Property, pp2: Property) if (p1.name == p2.name) => + isSameProperties(pp1, pp2) + case _ => + false + } +} diff --git a/src/main/scala/minisql/norm/OrderTerms.scala b/src/main/scala/minisql/norm/OrderTerms.scala new file mode 100644 index 0000000..22422fa --- /dev/null +++ b/src/main/scala/minisql/norm/OrderTerms.scala @@ -0,0 +1,29 @@ +package minisql.norm + +import minisql.ast._ + +object OrderTerms { + + def unapply(q: Query) = + q match { + + case Take(Map(a: GroupBy, b, c), d) => None + + // a.sortBy(b => c).filter(d => e) => + // a.filter(d => e).sortBy(b => c) + case Filter(SortBy(a, b, c, d), e, f) => + Some(SortBy(Filter(a, e, f), b, c, d)) + + // a.flatMap(b => c).take(n).map(d => e) => + // a.flatMap(b => c).map(d => e).take(n) + case Map(Take(fm: FlatMap, n), ma, mb) => + Some(Take(Map(fm, ma, mb), n)) + + // a.flatMap(b => c).drop(n).map(d => e) => + // a.flatMap(b => c).map(d => e).drop(n) + case Map(Drop(fm: FlatMap, n), ma, mb) => + Some(Drop(Map(fm, ma, mb), n)) + + case other => None + } +} diff --git a/src/main/scala/minisql/norm/RenameProperties.scala b/src/main/scala/minisql/norm/RenameProperties.scala new file mode 100644 index 0000000..53d135c --- /dev/null +++ b/src/main/scala/minisql/norm/RenameProperties.scala @@ -0,0 +1,491 @@ +package minisql.norm + +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.Visible +import minisql.ast._ +import minisql.util.Interpolator + +object RenameProperties extends StatelessTransformer { + val interp = new Interpolator(3) + import interp._ + def traceDifferent(one: Any, two: Any) = + if (one != two) + trace"Replaced $one with $two".andLog() + else + trace"Replacements did not match".andLog() + + override def apply(q: Query): Query = + applySchemaOnly(q) + + override def apply(q: Action): Action = + applySchema(q) match { + case (q, schema) => q + } + + override def apply(e: Operation): Operation = + e match { + case UnaryOperation(o, c: Query) => + UnaryOperation(o, applySchemaOnly(apply(c))) + case _ => super.apply(e) + } + + private def applySchema(q: Action): (Action, Schema) = + q match { + case Insert(q: Query, assignments) => + applySchema(q, assignments, Insert.apply) + case Update(q: Query, assignments) => + applySchema(q, assignments, Update.apply) + case Delete(q: Query) => + applySchema(q) match { + case (q, schema) => (Delete(q), schema) + } + case Returning(action: Action, alias, body) => + applySchema(action) match { + case (action, schema) => + val replace = + trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn` + replacements(alias, schema) + val bodyr = BetaReduction(body, replace*) + traceDifferent(body, bodyr) + (Returning(action, alias, bodyr), schema) + } + case ReturningGenerated(action: Action, alias, body) => + applySchema(action) match { + case (action, schema) => + val replace = + trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn` + replacements(alias, schema) + val bodyr = BetaReduction(body, replace*) + traceDifferent(body, bodyr) + (ReturningGenerated(action, alias, bodyr), schema) + } + case OnConflict(a: Action, target, act) => + applySchema(a) match { + case (action, schema) => + val targetR = target match { + case OnConflict.Properties(props) => + val propsR = props.map { prop => + val replace = + trace"Finding Replacements for $props inside ${prop.ast} using schema $schema:" `andReturn` + replacements( + prop.ast, + schema + ) // A BetaReduction on a Property will always give back a Property + BetaReduction(prop, replace*).asInstanceOf[Property] + } + traceDifferent(props, propsR) + OnConflict.Properties(propsR) + case OnConflict.NoTarget => target + } + val actR = act match { + case OnConflict.Update(assignments) => + OnConflict.Update(replaceAssignments(assignments, schema)) + case _ => act + } + (OnConflict(action, targetR, actR), schema) + } + case q => (q, TupleSchema.empty) + } + + private def replaceAssignments( + a: List[Assignment], + schema: Schema + ): List[Assignment] = + a.map { + case Assignment(alias, prop, value) => + val replace = + trace"Finding Replacements for $prop inside $alias using schema $schema:" `andReturn` + replacements(alias, schema) + val propR = BetaReduction(prop, replace*) + traceDifferent(prop, propR) + val valueR = BetaReduction(value, replace*) + traceDifferent(value, valueR) + Assignment(alias, propR, valueR) + } + + private def applySchema( + q: Query, + a: List[Assignment], + f: (Query, List[Assignment]) => Action + ): (Action, Schema) = + applySchema(q) match { + case (q, schema) => + (f(q, replaceAssignments(a, schema)), schema) + } + + private def applySchemaOnly(q: Query): Query = + applySchema(q) match { + case (q, _) => q + } + + object TupleIndex { + def unapply(s: String): Option[Int] = + if (s.matches("_[0-9]*")) + Some(s.drop(1).toInt - 1) + else + None + } + + sealed trait Schema { + def lookup(property: List[String]): Option[Schema] = + (property, this) match { + case (Nil, schema) => + trace"Nil at $property returning " `andReturn` + Some(schema) + case (path, e @ EntitySchema(_)) => + trace"Entity at $path returning " `andReturn` + Some(e.subSchemaOrEmpty(path)) + case (head :: tail, CaseClassSchema(props)) if (props.contains(head)) => + trace"Case class at $property returning " `andReturn` + props(head).lookup(tail) + case (TupleIndex(idx) :: tail, TupleSchema(values)) + if values.contains(idx) => + trace"Tuple at at $property returning " `andReturn` + values(idx).lookup(tail) + case _ => + trace"Nothing found at $property returning " `andReturn` + None + } + } + + // Represents a nested property path to an identity i.e. Property(Property(... Ident(), ...)) + object PropertyMatroshka { + + def traverse(initial: Property): Option[(Ident, List[String])] = + initial match { + // If it's a nested-property walk inside and append the name to the result (if something is returned) + case Property(inner: Property, name) => + traverse(inner).map { case (id, list) => (id, list :+ name) } + // If it's a property with ident in the core, return that + case Property(id: Ident, name) => + Some((id, List(name))) + // Otherwise an ident property is not inside so don't return anything + case _ => + None + } + + def unapply(ast: Ast): Option[(Ident, List[String])] = + ast match { + case p: Property => traverse(p) + case _ => None + } + + } + + def protractSchema( + body: Ast, + ident: Ident, + schema: Schema + ): Option[Schema] = { + + def protractSchemaRecurse(body: Ast, schema: Schema): Option[Schema] = + body match { + // if any values yield a sub-schema which is not an entity, recurse into that + case cc @ CaseClass(values) => + trace"Protracting CaseClass $cc into new schema:" `andReturn` + CaseClassSchema( + values.collect { + case (name, innerBody @ HierarchicalAstEntity()) => + (name, protractSchemaRecurse(innerBody, schema)) + // pass the schema into a recursive call an extract from it when we non tuple/caseclass element + case (name, innerBody @ PropertyMatroshka(`ident`, path)) => + (name, protractSchemaRecurse(innerBody, schema)) + // we have reached an ident i.e. recurse to pass the current schema into the case class + case (name, `ident`) => + (name, protractSchemaRecurse(ident, schema)) + }.collect { + case (name, Some(subSchema)) => (name, subSchema) + } + ).notEmpty + case tup @ Tuple(values) => + trace"Protracting Tuple $tup into new schema:" `andReturn` + TupleSchema + .fromIndexes( + values.zipWithIndex.collect { + case (innerBody @ HierarchicalAstEntity(), index) => + (index, protractSchemaRecurse(innerBody, schema)) + // pass the schema into a recursive call an extract from it when we non tuple/caseclass element + case (innerBody @ PropertyMatroshka(`ident`, path), index) => + (index, protractSchemaRecurse(innerBody, schema)) + // we have reached an ident i.e. recurse to pass the current schema into the tuple + case (`ident`, index) => + (index, protractSchemaRecurse(ident, schema)) + }.collect { + case (index, Some(subSchema)) => (index, subSchema) + } + ) + .notEmpty + + case prop @ PropertyMatroshka(`ident`, path) => + trace"Protraction completed schema path $prop at the schema $schema pointing to:" `andReturn` + schema match { + // case e: EntitySchema => Some(e) + case _ => schema.lookup(path) + } + case `ident` => + trace"Protraction completed with the mapping identity $ident at the schema:" `andReturn` + Some(schema) + case other => + trace"Protraction DID NOT find a sub schema, it completed with $other at the schema:" `andReturn` + Some(schema) + } + + protractSchemaRecurse(body, schema) + } + + case object EmptySchema extends Schema + case class EntitySchema(e: Entity) extends Schema { + def noAliases = e.properties.isEmpty + + private def subSchema(path: List[String]) = + EntitySchema( + Entity( + s"sub-${e.name}", + e.properties.flatMap { + case PropertyAlias(aliasPath, alias) => + if (aliasPath == path) + List(PropertyAlias(aliasPath, alias)) + else if (aliasPath.startsWith(path)) + List(PropertyAlias(aliasPath.diff(path), alias)) + else + List() + } + ) + ) + + def subSchemaOrEmpty(path: List[String]): Schema = + trace"Creating sub-schema for entity $e at path $path will be" andReturn { + val sub = subSchema(path) + if (sub.noAliases) EmptySchema else sub + } + + } + case class TupleSchema(m: collection.Map[Int, Schema] /* Zero Indexed */ ) + extends Schema { + def list = m.toList.sortBy(_._1) + def notEmpty = + if (this.m.nonEmpty) Some(this) else None + } + case class CaseClassSchema(m: collection.Map[String, Schema]) extends Schema { + def list = m.toList + def notEmpty = + if (this.m.nonEmpty) Some(this) else None + } + object CaseClassSchema { + def apply(property: String, value: Schema): CaseClassSchema = + CaseClassSchema(collection.Map(property -> value)) + def apply(list: List[(String, Schema)]): CaseClassSchema = + CaseClassSchema(list.toMap) + } + + object TupleSchema { + def fromIndexes(schemas: List[(Int, Schema)]): TupleSchema = + TupleSchema(schemas.toMap) + + def apply(schemas: List[Schema]): TupleSchema = + TupleSchema(schemas.zipWithIndex.map(_.swap).toMap) + + def apply(index: Int, schema: Schema): TupleSchema = + TupleSchema(collection.Map(index -> schema)) + + def empty: TupleSchema = TupleSchema(List.empty) + } + + object HierarchicalAstEntity { + def unapply(ast: Ast): Boolean = + ast match { + case cc: CaseClass => true + case tup: Tuple => true + case _ => false + } + } + + private def applySchema(q: Query): (Query, Schema) = { + q match { + + // Don't understand why this is needed.... + case Map(q: Query, x, p) => + applySchema(q) match { + case (q, subSchema) => + val replace = + trace"Looking for possible replacements for $p inside $x using schema $subSchema:" `andReturn` + replacements(x, subSchema) + val pr = BetaReduction(p, replace*) + traceDifferent(p, pr) + val prr = apply(pr) + traceDifferent(pr, prr) + + val schema = + trace"Protracting Hierarchical Entity $prr into sub-schema: $subSchema" `andReturn` { + protractSchema(prr, x, subSchema) + }.getOrElse(EmptySchema) + + (Map(q, x, prr), schema) + } + + case e: Entity => (e, EntitySchema(e)) + case Filter(q: Query, x, p) => applySchema(q, x, p, Filter.apply) + case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o)) + case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy.apply) + case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _)) + case Take(q: Query, n) => applySchema(q, Take(_, n)) + case Drop(q: Query, n) => applySchema(q, Drop(_, n)) + case Nested(q: Query) => applySchema(q, Nested.apply) + case Distinct(q: Query) => applySchema(q, Distinct.apply) + case DistinctOn(q: Query, iA, on) => applySchema(q, DistinctOn(_, iA, on)) + + case FlatMap(q: Query, x, p) => + applySchema(q, x, p, FlatMap.apply) match { + case (FlatMap(q, x, p: Query), oldSchema) => + val (pr, newSchema) = applySchema(p) + (FlatMap(q, x, pr), newSchema) + case (flatMap, oldSchema) => + (flatMap, TupleSchema.empty) + } + + case ConcatMap(q: Query, x, p) => + applySchema(q, x, p, ConcatMap.apply) match { + case (ConcatMap(q, x, p: Query), oldSchema) => + val (pr, newSchema) = applySchema(p) + (ConcatMap(q, x, pr), newSchema) + case (concatMap, oldSchema) => + (concatMap, TupleSchema.empty) + } + + case Join(typ, a: Query, b: Query, iA, iB, on) => + (applySchema(a), applySchema(b)) match { + case ((a, schemaA), (b, schemaB)) => + val combinedReplacements = + trace"Finding Replacements for $on inside ${(iA, iB)} using schemas ${(schemaA, schemaB)}:" andReturn { + val replaceA = replacements(iA, schemaA) + val replaceB = replacements(iB, schemaB) + replaceA ++ replaceB + } + val onr = BetaReduction(on, combinedReplacements*) + traceDifferent(on, onr) + (Join(typ, a, b, iA, iB, onr), TupleSchema(List(schemaA, schemaB))) + } + + case FlatJoin(typ, a: Query, iA, on) => + applySchema(a) match { + case (a, schemaA) => + val replaceA = + trace"Finding Replacements for $on inside $iA using schema $schemaA:" `andReturn` + replacements(iA, schemaA) + val onr = BetaReduction(on, replaceA*) + traceDifferent(on, onr) + (FlatJoin(typ, a, iA, onr), schemaA) + } + + case Map(q: Operation, x, p) if x == p => + (Map(apply(q), x, p), TupleSchema.empty) + + case Map(Infix(parts, params, pure, paren), x, p) => + val transformed = + params.map { + case q: Query => + val (qr, schema) = applySchema(q) + traceDifferent(q, qr) + (qr, Some(schema)) + case q => + (q, None) + } + + val schema = + transformed.collect { + case (_, Some(schema)) => schema + } match { + case e :: Nil => e + case ls => TupleSchema(ls) + } + val replace = + trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn` + replacements(x, schema) + val pr = BetaReduction(p, replace*) + traceDifferent(p, pr) + val prr = apply(pr) + traceDifferent(pr, prr) + + (Map(Infix(parts, transformed.map(_._1), pure, paren), x, prr), schema) + + case q => + (q, TupleSchema.empty) + } + } + + private def applySchema(ast: Query, f: Ast => Query): (Query, Schema) = + applySchema(ast) match { + case (ast, schema) => + (f(ast), schema) + } + + private def applySchema[T]( + q: Query, + x: Ident, + p: Ast, + f: (Ast, Ident, Ast) => T + ): (T, Schema) = + applySchema(q) match { + case (q, schema) => + val replace = + trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn` + replacements(x, schema) + val pr = BetaReduction(p, replace*) + traceDifferent(p, pr) + val prr = apply(pr) + traceDifferent(pr, prr) + (f(q, x, prr), schema) + } + + private def replacements(base: Ast, schema: Schema): Seq[(Ast, Ast)] = + schema match { + // The entity renameable property should already have been marked as Fixed + case EntitySchema(Entity(entity, properties)) => + // trace"%4 Entity Schema: " andReturn + properties.flatMap { + // A property alias means that there was either a querySchema(tableName, _.propertyName -> PropertyAlias) + // or a schemaMeta (which ultimately gets turned into a querySchema) which is the same thing but implicit. + // In this case, we want to rename the properties based on the property aliases as well as mark + // them Fixed since they should not be renamed based on + // the naming strategy wherever they are tokenized (e.g. in SqlIdiom) + case PropertyAlias(path, alias) => + def apply(base: Ast, path: List[String]): Ast = + path match { + case Nil => base + case head :: tail => apply(Property(base, head), tail) + } + List( + apply(base, path) -> Property.Opinionated( + base, + alias, + Fixed, + Visible + ) // Hidden properties cannot be renamed + ) + } + case tup: TupleSchema => + // trace"%4 Tuple Schema: " andReturn + tup.list.flatMap { + case (idx, value) => + replacements( + // Should not matter whether property is fixed or variable here + // since beta reduction ignores that + Property(base, s"_${idx + 1}"), + value + ) + } + case cc: CaseClassSchema => + // trace"%4 CaseClass Schema: " andReturn + cc.list.flatMap { + case (property, value) => + replacements( + // Should not matter whether property is fixed or variable here + // since beta reduction ignores that + Property(base, property), + value + ) + } + // Do nothing if it is an empty schema + case EmptySchema => List() + } +} diff --git a/src/main/scala/minisql/util/Replacements.scala b/src/main/scala/minisql/norm/Replacements.scala similarity index 98% rename from src/main/scala/minisql/util/Replacements.scala rename to src/main/scala/minisql/norm/Replacements.scala index f0982e2..4b4a955 100644 --- a/src/main/scala/minisql/util/Replacements.scala +++ b/src/main/scala/minisql/norm/Replacements.scala @@ -1,4 +1,4 @@ -package minisql.util +package minisql.norm import minisql.ast.Ast import scala.collection.immutable.Map diff --git a/src/main/scala/minisql/norm/SimplifyNullChecks.scala b/src/main/scala/minisql/norm/SimplifyNullChecks.scala new file mode 100644 index 0000000..a49b949 --- /dev/null +++ b/src/main/scala/minisql/norm/SimplifyNullChecks.scala @@ -0,0 +1,124 @@ +package minisql.norm + +import minisql.ast.* +import minisql.norm.EqualityBehavior.AnsiEquality + +/** + * Due to the introduction of null checks in `map`, `flatMap`, and `exists`, in + * `FlattenOptionOperation` in order to resolve #1053, as well as to support + * non-ansi compliant string concatenation as outlined in #1295, large + * conditional composites became common. For example:
 case class
+ * Holder(value:Option[String])
+ *
+ * // The following statement query[Holder].map(h => h.value.map(_ + "foo")) //
+ * Will yield the following result SELECT CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END FROM Holder h 
Now, let's add a + * getOrElse statement to the clause that requires an additional + * wrapped null check. We cannot rely on there being a map call + * beforehand since we could be reading value as a nullable field + * directly from the database).
 // The following statement
+ * query[Holder].map(h => h.value.map(_ + "foo").getOrElse("bar")) // Yields the
+ * following result: SELECT CASE WHEN CASE WHEN h.value IS NOT NULL THEN h.value
+ * \|| 'foo' ELSE null END IS NOT NULL THEN CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END ELSE 'bar' END FROM Holder h 
+ * This of course is highly redundant and can be reduced to simply:
+ * SELECT CASE WHEN h.value IS NOT NULL AND (h.value || 'foo') IS NOT NULL THEN
+ * h.value || 'foo' ELSE 'bar' END FROM Holder h 
This reduction is + * done by the "Center Rule." There are some other simplification rules as well. + * Note how we are force to null-check both `h.value` as well as `(h.value || + * 'foo')` because a user may use `Option[T].flatMap` and explicitly transform a + * particular value to `null`. + */ +class SimplifyNullChecks(equalityBehavior: EqualityBehavior) + extends StatelessTransformer { + + override def apply(ast: Ast): Ast = { + import minisql.ast.Implicits.* + ast match { + // Center rule + case IfExist( + IfExistElseNull(condA, thenA), + IfExistElseNull(condB, thenB), + otherwise + ) if (condA == condB && thenA == thenB) => + apply( + If(IsNotNullCheck(condA) +&&+ IsNotNullCheck(thenA), thenA, otherwise) + ) + + // Left hand rule + case IfExist(IfExistElseNull(check, affirm), value, otherwise) => + apply( + If( + IsNotNullCheck(check) +&&+ IsNotNullCheck(affirm), + value, + otherwise + ) + ) + + // Right hand rule + case IfExistElseNull(cond, IfExistElseNull(innerCond, innerThen)) => + apply( + If( + IsNotNullCheck(cond) +&&+ IsNotNullCheck(innerCond), + innerThen, + NullValue + ) + ) + + case OptionIsDefined(Optional(a)) +&&+ OptionIsDefined( + Optional(b) + ) +&&+ (exp @ (Optional(a1) `== or !=` Optional(b1))) + if (a == a1 && b == b1 && equalityBehavior == AnsiEquality) => + apply(exp) + + case OptionIsDefined(Optional(a)) +&&+ (exp @ (Optional( + a1 + ) `== or !=` Optional(_))) + if (a == a1 && equalityBehavior == AnsiEquality) => + apply(exp) + case OptionIsDefined(Optional(b)) +&&+ (exp @ (Optional( + _ + ) `== or !=` Optional(b1))) + if (b == b1 && equalityBehavior == AnsiEquality) => + apply(exp) + + case (left +&&+ OptionIsEmpty(Optional(Constant(_)))) +||+ other => + apply(other) + case (OptionIsEmpty(Optional(Constant(_))) +&&+ right) +||+ other => + apply(other) + case other +||+ (left +&&+ OptionIsEmpty(Optional(Constant(_)))) => + apply(other) + case other +||+ (OptionIsEmpty(Optional(Constant(_))) +&&+ right) => + apply(other) + + case (left +&&+ OptionIsDefined(Optional(Constant(_)))) => apply(left) + case (OptionIsDefined(Optional(Constant(_))) +&&+ right) => apply(right) + case (left +||+ OptionIsEmpty(Optional(Constant(_)))) => apply(left) + case (OptionIsEmpty(OptionSome(Optional(_))) +||+ right) => apply(right) + + case other => + super.apply(other) + } + } + + object `== or !=` { + def unapply(ast: Ast): Option[(Ast, Ast)] = ast match { + case a +==+ b => Some((a, b)) + case a +!=+ b => Some((a, b)) + case _ => None + } + } + + /** + * Simple extractor that looks inside of an optional values to see if the + * thing inside can be pulled out. If not, it just returns whatever element it + * can find. + */ + object Optional { + def unapply(a: Ast): Option[Ast] = a match { + case OptionApply(value) => Some(value) + case OptionSome(value) => Some(value) + case value => Some(value) + } + } +} diff --git a/src/main/scala/minisql/norm/SymbolicReduction.scala b/src/main/scala/minisql/norm/SymbolicReduction.scala new file mode 100644 index 0000000..d7e8965 --- /dev/null +++ b/src/main/scala/minisql/norm/SymbolicReduction.scala @@ -0,0 +1,38 @@ +package minisql.norm + +import minisql.ast.Filter +import minisql.ast.FlatMap +import minisql.ast.Query +import minisql.ast.Union +import minisql.ast.UnionAll + +object SymbolicReduction { + + def unapply(q: Query) = + q match { + + // a.filter(b => c).flatMap(d => e.$) => + // a.flatMap(d => e.filter(_ => c[b := d]).$) + case FlatMap(Filter(a, b, c), d, e: Query) => + val cr = BetaReduction(c, b -> d) + val er = AttachToEntity(Filter(_, _, cr))(e) + Some(FlatMap(a, d, er)) + + // a.flatMap(b => c).flatMap(d => e) => + // a.flatMap(b => c.flatMap(d => e)) + case FlatMap(FlatMap(a, b, c), d, e) => + Some(FlatMap(a, b, FlatMap(c, d, e))) + + // a.union(b).flatMap(c => d) + // a.flatMap(c => d).union(b.flatMap(c => d)) + case FlatMap(Union(a, b), c, d) => + Some(Union(FlatMap(a, c, d), FlatMap(b, c, d))) + + // a.unionAll(b).flatMap(c => d) + // a.flatMap(c => d).unionAll(b.flatMap(c => d)) + case FlatMap(UnionAll(a, b), c, d) => + Some(UnionAll(FlatMap(a, c, d), FlatMap(b, c, d))) + + case other => None + } +} diff --git a/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala new file mode 100644 index 0000000..ed45c62 --- /dev/null +++ b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala @@ -0,0 +1,174 @@ +package minisql.norm.capture + +import minisql.ast.{ + Entity, + Filter, + FlatJoin, + FlatMap, + GroupBy, + Ident, + Join, + Map, + Query, + SortBy, + StatefulTransformer, + _ +} +import minisql.norm.{BetaReduction, Normalize} +import scala.collection.immutable.Set + +private[minisql] case class AvoidAliasConflict(state: Set[Ident]) + extends StatefulTransformer[Set[Ident]] { + + object Unaliased { + + private def isUnaliased(q: Ast): Boolean = + q match { + case Nested(q: Query) => isUnaliased(q) + case Take(q: Query, _) => isUnaliased(q) + case Drop(q: Query, _) => isUnaliased(q) + case Aggregation(_, q: Query) => isUnaliased(q) + case Distinct(q: Query) => isUnaliased(q) + case _: Entity | _: Infix => true + case _ => false + } + + def unapply(q: Ast): Option[Ast] = + q match { + case q if (isUnaliased(q)) => Some(q) + case _ => None + } + } + + override def apply(q: Query): (Query, StatefulTransformer[Set[Ident]]) = + q match { + + case FlatMap(Unaliased(q), x, p) => + apply(x, p)(FlatMap(q, _, _)) + + case ConcatMap(Unaliased(q), x, p) => + apply(x, p)(ConcatMap(q, _, _)) + + case Map(Unaliased(q), x, p) => + apply(x, p)(Map(q, _, _)) + + case Filter(Unaliased(q), x, p) => + apply(x, p)(Filter(q, _, _)) + + case SortBy(Unaliased(q), x, p, o) => + apply(x, p)(SortBy(q, _, _, o)) + + case GroupBy(Unaliased(q), x, p) => + apply(x, p)(GroupBy(q, _, _)) + + case DistinctOn(Unaliased(q), x, p) => + apply(x, p)(DistinctOn(q, _, _)) + + case Join(t, a, b, iA, iB, o) => + val (ar, art) = apply(a) + val (br, brt) = art.apply(b) + val freshA = freshIdent(iA, brt.state) + val freshB = freshIdent(iB, brt.state + freshA) + val or = BetaReduction(o, iA -> freshA, iB -> freshB) + val (orr, orrt) = AvoidAliasConflict(brt.state + freshA + freshB)(or) + (Join(t, ar, br, freshA, freshB, orr), orrt) + + case FlatJoin(t, a, iA, o) => + val (ar, art) = apply(a) + val freshA = freshIdent(iA) + val or = BetaReduction(o, iA -> freshA) + val (orr, orrt) = AvoidAliasConflict(art.state + freshA)(or) + (FlatJoin(t, ar, freshA, orr), orrt) + + case _: Entity | _: FlatMap | _: ConcatMap | _: Map | _: Filter | + _: SortBy | _: GroupBy | _: Aggregation | _: Take | _: Drop | + _: Union | _: UnionAll | _: Distinct | _: DistinctOn | _: Nested => + super.apply(q) + } + + private def apply(x: Ident, p: Ast)( + f: (Ident, Ast) => Query + ): (Query, StatefulTransformer[Set[Ident]]) = { + val fresh = freshIdent(x) + val pr = BetaReduction(p, x -> fresh) + val (prr, t) = AvoidAliasConflict(state + fresh)(pr) + (f(fresh, prr), t) + } + + private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = { + def loop(x: Ident, n: Int): Ident = { + val fresh = Ident(s"${x.name}$n") + if (!state.contains(fresh)) + fresh + else + loop(x, n + 1) + } + if (!state.contains(x)) + x + else + loop(x, 1) + } + + /** + * Sometimes we need to change the variables in a function because they will + * might conflict with some variable further up in the macro. Right now, this + * only happens when you do something like this: val q = quote { (v: + * Foo) => query[Foo].insert(v) } run(q(lift(v))) Since 'v' is used by + * actionMeta in order to map keys to values for insertion, using it as a + * function argument messes up the output SQL like so: INSERT INTO + * MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?) + * Therefore, we need to have a method to remove such conflicting variables + * from Function ASTs + */ + private def applyFunction(f: Function): Function = { + val (newBody, _, newParams) = + f.params.foldLeft((f.body, state, List[Ident]())) { + case ((body, state, newParams), param) => { + val fresh = freshIdent(param) + val pr = BetaReduction(body, param -> fresh) + val (prr, t) = AvoidAliasConflict(state + fresh)(pr) + (prr, t.state, newParams :+ fresh) + } + } + Function(newParams, newBody) + } + + private def applyForeach(f: Foreach): Foreach = { + val fresh = freshIdent(f.alias) + val pr = BetaReduction(f.body, f.alias -> fresh) + val (prr, _) = AvoidAliasConflict(state + fresh)(pr) + Foreach(f.query, fresh, prr) + } +} + +private[minisql] object AvoidAliasConflict { + + def apply(q: Query): Query = + AvoidAliasConflict(Set[Ident]())(q) match { + case (q, _) => q + } + + /** + * Make sure query parameters do not collide with paramters of a AST function. + * Do this by walkning through the function's subtree and transforming and + * queries encountered. + */ + def sanitizeVariables( + f: Function, + dangerousVariables: Set[Ident] + ): Function = { + AvoidAliasConflict(dangerousVariables).applyFunction(f) + } + + /** Same is `sanitizeVariables` but for Foreach * */ + def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = { + AvoidAliasConflict(dangerousVariables).applyForeach(f) + } + + def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = { + AvoidAliasConflict(dangerousVariables).apply(q) match { + // Propagate aliasing changes to the rest of the query + case (q, _) => Normalize(q) + } + } +} diff --git a/src/main/scala/minisql/norm/capture/AvoidCapture.scala b/src/main/scala/minisql/norm/capture/AvoidCapture.scala new file mode 100644 index 0000000..788b551 --- /dev/null +++ b/src/main/scala/minisql/norm/capture/AvoidCapture.scala @@ -0,0 +1,9 @@ +package minisql.norm.capture + +import minisql.ast.Query + +object AvoidCapture { + + def apply(q: Query): Query = + Dealias(AvoidAliasConflict(q)) +} diff --git a/src/main/scala/minisql/norm/capture/Dealias.scala b/src/main/scala/minisql/norm/capture/Dealias.scala new file mode 100644 index 0000000..64a56ce --- /dev/null +++ b/src/main/scala/minisql/norm/capture/Dealias.scala @@ -0,0 +1,72 @@ +package minisql.norm.capture + +import minisql.ast._ +import minisql.norm.BetaReduction + +case class Dealias(state: Option[Ident]) + extends StatefulTransformer[Option[Ident]] { + + override def apply(q: Query): (Query, StatefulTransformer[Option[Ident]]) = + q match { + case FlatMap(a, b, c) => + dealias(a, b, c)(FlatMap.apply) match { + case (FlatMap(a, b, c), _) => + val (cn, cnt) = apply(c) + (FlatMap(a, b, cn), cnt) + } + case ConcatMap(a, b, c) => + dealias(a, b, c)(ConcatMap.apply) match { + case (ConcatMap(a, b, c), _) => + val (cn, cnt) = apply(c) + (ConcatMap(a, b, cn), cnt) + } + case Map(a, b, c) => + dealias(a, b, c)(Map.apply) + case Filter(a, b, c) => + dealias(a, b, c)(Filter.apply) + case SortBy(a, b, c, d) => + dealias(a, b, c)(SortBy(_, _, _, d)) + case GroupBy(a, b, c) => + dealias(a, b, c)(GroupBy.apply) + case DistinctOn(a, b, c) => + dealias(a, b, c)(DistinctOn.apply) + case Take(a, b) => + val (an, ant) = apply(a) + (Take(an, b), ant) + case Drop(a, b) => + val (an, ant) = apply(a) + (Drop(an, b), ant) + case Union(a, b) => + val (an, _) = apply(a) + val (bn, _) = apply(b) + (Union(an, bn), Dealias(None)) + case UnionAll(a, b) => + val (an, _) = apply(a) + val (bn, _) = apply(b) + (UnionAll(an, bn), Dealias(None)) + case Join(t, a, b, iA, iB, o) => + val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _)) + val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _)) + (Join(t, an, bn, iAn, iBn, onn), Dealias(None)) + case FlatJoin(t, a, iA, o) => + val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _)) + (FlatJoin(t, an, iAn, on), Dealias(Some(iA))) + case _: Entity | _: Distinct | _: Aggregation | _: Nested => + (q, Dealias(None)) + } + + private def dealias[T](a: Ast, b: Ident, c: Ast)(f: (Ast, Ident, Ast) => T) = + apply(a) match { + case (an, t @ Dealias(Some(alias))) => + (f(an, alias, BetaReduction(c, b -> alias)), t) + case other => + (f(a, b, c), Dealias(Some(b))) + } +} + +object Dealias { + def apply(query: Query) = + new Dealias(None)(query) match { + case (q, _) => q + } +} diff --git a/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala new file mode 100644 index 0000000..7c95f7c --- /dev/null +++ b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala @@ -0,0 +1,98 @@ +package minisql.norm.capture + +import minisql.ast.* + +/** + * Walk through any Queries that a returning clause has and replace Ident of the + * returning variable with ExternalIdent so that in later steps involving filter + * simplification, it will not be mistakenly dealiased with a potential shadow. + * Take this query for instance:
 query[TestEntity]
+ * .insert(lift(TestEntity("s", 0, 1L, None))) .returningGenerated( r =>
+ * (query[Dummy].filter(r => r.i == r.i).filter(d => d.i == r.i).max) ) 
+ * The returning clause has an alias `Ident("r")` as well as the first filter + * clause. These two filters will be combined into one at which point the + * meaning of `r.i` in the 2nd filter will be confused for the first filter's + * alias (i.e. the `r` in `filter(r => ...)`. Therefore, we need to change this + * vunerable `r.i` in the second filter clause to an `ExternalIdent` before any + * of the simplifications are done. + * + * Note that we only want to do this for Queries inside of a `Returning` clause + * body. Other places where this needs to be done (e.g. in a Tuple that + * `Returning` returns) are done in `ExpandReturning`. + */ +private[minisql] case class DemarcateExternalAliases(externalIdent: Ident) + extends StatelessTransformer { + + def applyNonOverride(idents: Ident*)(ast: Ast) = + if (idents.forall(_ != externalIdent)) apply(ast) + else ast + + override def apply(ast: Ast): Ast = ast match { + + case FlatMap(q, i, b) => + FlatMap(apply(q), i, applyNonOverride(i)(b)) + + case ConcatMap(q, i, b) => + ConcatMap(apply(q), i, applyNonOverride(i)(b)) + + case Map(q, i, b) => + Map(apply(q), i, applyNonOverride(i)(b)) + + case Filter(q, i, b) => + Filter(apply(q), i, applyNonOverride(i)(b)) + + case SortBy(q, i, p, o) => + SortBy(apply(q), i, applyNonOverride(i)(p), o) + + case GroupBy(q, i, b) => + GroupBy(apply(q), i, applyNonOverride(i)(b)) + + case DistinctOn(q, i, b) => + DistinctOn(apply(q), i, applyNonOverride(i)(b)) + + case Join(t, a, b, iA, iB, o) => + Join(t, a, b, iA, iB, applyNonOverride(iA, iB)(o)) + + case FlatJoin(t, a, iA, o) => + FlatJoin(t, a, iA, applyNonOverride(iA)(o)) + + case p @ Property.Opinionated( + id @ Ident(_), + value, + renameable, + visibility + ) => + if (id == externalIdent) + Property.Opinionated( + ExternalIdent(externalIdent.name), + value, + renameable, + visibility + ) + else + p + + case other => + super.apply(other) + } +} + +object DemarcateExternalAliases { + + private def demarcateQueriesInBody(id: Ident, body: Ast) = + Transform(body) { + // Apply to the AST defined apply method about, not to the superclass method that takes Query + case q: Query => + new DemarcateExternalAliases(id).apply(q.asInstanceOf[Ast]) + } + + def apply(ast: Ast): Ast = ast match { + case Returning(a, id, body) => + Returning(a, id, demarcateQueriesInBody(id, body)) + case ReturningGenerated(a, id, body) => + val d = demarcateQueriesInBody(id, body) + ReturningGenerated(a, id, demarcateQueriesInBody(id, body)) + case other => + other + } +} diff --git a/src/main/scala/minisql/parsing/BlockParsing.scala b/src/main/scala/minisql/parsing/BlockParsing.scala new file mode 100644 index 0000000..ae8722c --- /dev/null +++ b/src/main/scala/minisql/parsing/BlockParsing.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +type SParser[X] = + (q: Quotes) ?=> PartialFunction[q.reflect.Statement, Expr[X]] + +private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using + Quotes +): SParser[ast.Ast] = { + + import quotes.reflect.* + + @annotation.nowarn + lazy val valDefParser: SParser[ast.Val] = { + case ValDef(n, _, Some(b)) => + val body = astParser(b.asExpr) + '{ ast.Val(ast.Ident(${ Expr(n) }), $body) } + + } + valDefParser +} + +private[parsing] def blockParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + lazy val statementParser = statementParsing(astParser) + + termParser { + case Block(Nil, t) => astParser(t.asExpr) + case b @ Block(st, t) => + val asts = (st :+ t).map { + case e if e.isExpr => astParser(e.asExpr) + case `statementParser`(x) => x + case o => + report.errorAndAbort(s"Cannot parse statement: ${o.show}") + } + if (asts.size > 1) { + '{ ast.Block(${ Expr.ofList(asts) }) } + } else asts(0) + + } +} diff --git a/src/main/scala/minisql/parsing/BoxingParsing.scala b/src/main/scala/minisql/parsing/BoxingParsing.scala new file mode 100644 index 0000000..d7b7f31 --- /dev/null +++ b/src/main/scala/minisql/parsing/BoxingParsing.scala @@ -0,0 +1,31 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def boxingParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + case '{ BigDecimal.int2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.long2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.double2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.javaBigDecimal2bigDecimal($v) } => astParser(v) + case '{ Predef.byte2Byte($v) } => astParser(v) + case '{ Predef.short2Short($v) } => astParser(v) + case '{ Predef.char2Character($v) } => astParser(v) + case '{ Predef.int2Integer($v) } => astParser(v) + case '{ Predef.long2Long($v) } => astParser(v) + case '{ Predef.float2Float($v) } => astParser(v) + case '{ Predef.double2Double($v) } => astParser(v) + case '{ Predef.boolean2Boolean($v) } => astParser(v) + case '{ Predef.augmentString($v) } => astParser(v) + case '{ Predef.Byte2byte($v) } => astParser(v) + case '{ Predef.Short2short($v) } => astParser(v) + case '{ Predef.Character2char($v) } => astParser(v) + case '{ Predef.Integer2int($v) } => astParser(v) + case '{ Predef.Long2long($v) } => astParser(v) + case '{ Predef.Float2float($v) } => astParser(v) + case '{ Predef.Double2double($v) } => astParser(v) + case '{ Predef.Boolean2boolean($v) } => astParser(v) + +} diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala new file mode 100644 index 0000000..7b173db --- /dev/null +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -0,0 +1,13 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted.* + +private[parsing] def infixParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Infix] = { + + import quotes.reflect.* + ??? +} diff --git a/src/main/scala/minisql/parsing/LiftParsing.scala b/src/main/scala/minisql/parsing/LiftParsing.scala new file mode 100644 index 0000000..c01df89 --- /dev/null +++ b/src/main/scala/minisql/parsing/LiftParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import scala.quoted.* +import minisql.ParamEncoder +import minisql.ast +import minisql.dsl.* + +private[parsing] def liftParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Lift] = { + case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } => + import quotes.reflect.* + val name = x.asTerm.symbol.fullName + val liftId = x.asTerm.symbol.owner.fullName + "@" + name + '{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) } +} diff --git a/src/main/scala/minisql/parsing/OperationParsing.scala b/src/main/scala/minisql/parsing/OperationParsing.scala new file mode 100644 index 0000000..df93010 --- /dev/null +++ b/src/main/scala/minisql/parsing/OperationParsing.scala @@ -0,0 +1,113 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.{ + EqualityOperator, + StringOperator, + NumericOperator, + BooleanOperator +} +import minisql.dsl.* +import scala.quoted._ + +private[parsing] def operationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Operation] = { + import quotes.reflect.* + + def isNumeric(t: TypeRepr) = { + t <:< TypeRepr.of[Int] + || t <:< TypeRepr.of[Long] + || t <:< TypeRepr.of[Byte] + || t <:< TypeRepr.of[Float] + || t <:< TypeRepr.of[Double] + || t <:< TypeRepr.of[java.math.BigDecimal] + || t <:< TypeRepr.of[scala.math.BigDecimal] + } + + def parseBinary( + left: Expr[Any], + right: Expr[Any], + op: Expr[ast.BinaryOperator] + ) = { + val leftE = astParser(left) + val rightE = astParser(right) + '{ ast.BinaryOperation(${ leftE }, ${ op }, ${ rightE }) } + } + + def parseUnary(expr: Expr[Any], op: Expr[ast.UnaryOperator]) = { + val base = astParser(expr) + '{ ast.UnaryOperation($op, $base) } + + } + + val universalOpParser: Parser[ast.BinaryOperation] = termParser { + case Apply(Select(leftT, UniversalOp(op)), List(rightT)) => + parseBinary(leftT.asExpr, rightT.asExpr, op) + } + + val stringOpParser: Parser[ast.Operation] = { + case '{ ($x: String) + ($y: String) } => + parseBinary(x, y, '{ StringOperator.concat }) + case '{ ($x: String).startsWith($y) } => + parseBinary(x, y, '{ StringOperator.startsWith }) + case '{ ($x: String).split($y) } => + parseBinary(x, y, '{ StringOperator.split }) + case '{ ($x: String).toUpperCase } => + parseUnary(x, '{ StringOperator.toUpperCase }) + case '{ ($x: String).toLowerCase } => + parseUnary(x, '{ StringOperator.toLowerCase }) + case '{ ($x: String).toLong } => + parseUnary(x, '{ StringOperator.toLong }) + case '{ ($x: String).toInt } => + parseUnary(x, '{ StringOperator.toInt }) + } + + val numericOpParser = termParser { + case (Apply(Select(lt, NumericOp(op)), List(rt))) if isNumeric(lt.tpe) => + parseBinary(lt.asExpr, rt.asExpr, op) + case Select(leftTerm, "unary_-") if isNumeric(leftTerm.tpe) => + val leftExpr = astParser(leftTerm.asExpr) + '{ ast.UnaryOperation(NumericOperator.-, ${ leftExpr }) } + + } + + val booleanOpParser: Parser[ast.Operation] = { + case '{ ($x: Boolean) && $y } => + parseBinary(x, y, '{ BooleanOperator.&& }) + case '{ ($x: Boolean) || $y } => + parseBinary(x, y, '{ BooleanOperator.|| }) + case '{ !($x: Boolean) } => + parseUnary(x, '{ BooleanOperator.! }) + } + + universalOpParser + .orElse(stringOpParser) + .orElse(numericOpParser) + .orElse(booleanOpParser) +} + +private object UniversalOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "==" | "equals" => Some('{ EqualityOperator.== }) + case "!=" => Some('{ EqualityOperator.!= }) + case _ => None + } +} + +private object NumericOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "+" => Some('{ NumericOperator.+ }) + case "-" => Some('{ NumericOperator.- }) + case "*" => Some('{ NumericOperator.* }) + case "/" => Some('{ NumericOperator./ }) + case ">" => Some('{ NumericOperator.> }) + case ">=" => Some('{ NumericOperator.>= }) + case "<" => Some('{ NumericOperator.< }) + case "<=" => Some('{ NumericOperator.<= }) + case "%" => Some('{ NumericOperator.% }) + case _ => None + } +} diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala new file mode 100644 index 0000000..4bc7265 --- /dev/null +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.Ast +import scala.quoted.* + +private[minisql] inline def parseParamAt[F]( + inline f: F, + inline n: Int +): ast.Ident = ${ + parseParamAt('f, 'n) +} + +private[minisql] inline def parseBody[X]( + inline f: X +): ast.Ast = ${ + parseBody('f) +} + +private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using + Quotes +): Expr[ast.Ident] = { + + import quotes.reflect.* + + val pIdx = n.value.getOrElse( + report.errorAndAbort(s"Param index ${n.show} is not know") + ) + extractTerm(f.asTerm) match { + case Lambda(vals, _) => + vals(pIdx) match { + case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) } + } + } +} + +private[minisql] def parseBody[X]( + x: Expr[X] +)(using Quotes): Expr[Ast] = { + import quotes.reflect.* + x.asTerm match { + case Lambda(vals, body) => + Parsing.parseExpr(body.asExpr) + case o => + report.errorAndAbort(s"Can only parse function") + } +} diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala new file mode 100644 index 0000000..463112f --- /dev/null +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -0,0 +1,139 @@ +package minisql.parsing + +import minisql.ast +import minisql.context.{ReturningMultipleFieldSupported, _} +import minisql.norm.BetaReduction +import minisql.norm.capture.AvoidAliasConflict +import minisql.idiom.Idiom +import scala.annotation.tailrec +import minisql.ast.Implicits._ +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.{Hidden, Visible} +import minisql.util.Interleave +import scala.quoted.* + +type Parser[A] = PartialFunction[Expr[Any], Expr[A]] + +private def termParser[A](using q: Quotes)( + pf: PartialFunction[q.reflect.Term, Expr[A]] +): Parser[A] = { + import quotes.reflect._ + { + case e if pf.isDefinedAt(e.asTerm) => pf(e.asTerm) + } +} + +private def parser[A]( + f: PartialFunction[Expr[Any], Expr[A]] +)(using Quotes): Parser[A] = { + case e if f.isDefinedAt(e) => f(e) +} + +private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { + import quotes.reflect.* + def unwrapTerm(t: Term): Term = t match { + case Inlined(_, _, o) => unwrapTerm(o) + case Block(Nil, last) => last + case Typed(t, _) => + unwrapTerm(t) + case Select(t, "$asInstanceOf$") => + unwrapTerm(t) + case TypeApply(t, _) => + unwrapTerm(t) + case o => o + } + unwrapTerm(x) +} + +private[minisql] object Parsing { + + def parseExpr( + expr: Expr[?] + )(using q: Quotes): Expr[ast.Ast] = { + + import q.reflect._ + + def unwrapped( + f: Parser[ast.Ast] + ): Parser[ast.Ast] = { + case expr => + val t = expr.asTerm + f(extractTerm(t).asExpr) + } + + lazy val astParser: Parser[ast.Ast] = + unwrapped { + typedParser + .orElse(propertyParser) + .orElse(liftParser) + .orElse(identParser) + .orElse(valueParser) + .orElse(operationParser) + .orElse(constantParser) + .orElse(blockParser) + .orElse(boxingParser) + .orElse(ifParser) + .orElse(traversableOperationParser) + .orElse(patMatchParser) + .orElse(infixParser) + .orElse { + case o => + val str = scala.util.Try(o.show).getOrElse("") + report.errorAndAbort( + s"cannot parse ${str}", + o.asTerm.pos + ) + } + } + + lazy val typedParser: Parser[ast.Ast] = termParser { + case (Typed(t, _)) => + astParser(t.asExpr) + } + + lazy val blockParser: Parser[ast.Ast] = blockParsing(astParser) + + lazy val valueParser: Parser[ast.Value] = valueParsing(astParser) + + lazy val liftParser: Parser[ast.Lift] = liftParsing(astParser) + + lazy val constantParser: Parser[ast.Constant] = termParser { + case Literal(x) => + '{ ast.Constant(${ Literal(x).asExpr }) } + } + + lazy val identParser: Parser[ast.Ident] = termParser { + case x @ Ident(n) if x.symbol.isValDef => + '{ ast.Ident(${ Expr(n) }) } + } + + lazy val propertyParser: Parser[ast.Property] = propertyParsing(astParser) + + lazy val operationParser: Parser[ast.Operation] = operationParsing( + astParser + ) + + lazy val boxingParser: Parser[ast.Ast] = boxingParsing(astParser) + + lazy val ifParser: Parser[ast.If] = { + case '{ if ($a) $b else $c } => + '{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) } + + } + lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) + + lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser) + + lazy val traversableOperationParser: Parser[ast.IterableOperation] = + traversableOperationParsing(astParser) + + astParser(expr) + } + + private[minisql] inline def parse[A]( + inline a: A + ): ast.Ast = ${ + parseExpr('a) + } + +} diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala new file mode 100644 index 0000000..2db7652 --- /dev/null +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -0,0 +1,49 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def patMatchParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + termParser { + // Val defs that showd pattern variables will cause error + case e @ Match(t, List(CaseDef(IsTupleUnapply(binds), None, body))) => + val bm = binds.zipWithIndex.map { + case (Bind(n, ident), idx) => + n -> Select.unique(t, s"_${idx + 1}") + }.toMap + val tm = new TreeMap { + override def transformTerm(tree: Term)(owner: Symbol): Term = { + tree match { + case Ident(n) => bm(n) + case o => super.transformTerm(o)(owner) + } + } + } + val newBody = tm.transformTree(body)(e.symbol) + astParser(newBody.asExpr) + } + +} + +object IsTupleUnapply { + + def unapply(using + Quotes + )(t: quotes.reflect.Tree): Option[List[quotes.reflect.Tree]] = { + import quotes.reflect.* + def isTupleNUnapply(x: Term) = { + val fn = x.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.unapply") + } + t match { + case Unapply(m, _, binds) if isTupleNUnapply(m) => + Some(binds) + case _ => None + } + } +} diff --git a/src/main/scala/minisql/parsing/PropertyParsing.scala b/src/main/scala/minisql/parsing/PropertyParsing.scala new file mode 100644 index 0000000..292f281 --- /dev/null +++ b/src/main/scala/minisql/parsing/PropertyParsing.scala @@ -0,0 +1,30 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted._ + +private[parsing] def propertyParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Property] = { + import quotes.reflect.* + + def isAccessor(s: Select) = { + s.qualifier.tpe.typeSymbol.caseFields.exists(cf => cf.name == s.name) + } + + val parseApply: Parser[ast.Property] = termParser { + case m @ Select(base, n) if isAccessor(m) => + val obj = astParser(base.asExpr) + '{ ast.Property($obj, ${ Expr(n) }) } + } + + val parseOptionGet: Parser[ast.Property] = { + case '{ ($e: Option[t]).get } => + report.errorAndAbort( + "Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead." + ) + } + parseApply.orElse(parseOptionGet) + +} diff --git a/src/main/scala/minisql/parsing/TraversableOperationParsing.scala b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala new file mode 100644 index 0000000..8f50098 --- /dev/null +++ b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private def traversableOperationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.IterableOperation] = { + case '{ type k; type v; (${ m }: Map[`k`, `v`]).contains($key) } => + '{ ast.MapContains(${ astParser(m) }, ${ astParser(key) }) } + case '{ ($s: Set[e]).contains($i) } => + '{ ast.SetContains(${ astParser(s) }, ${ astParser(i) }) } + case '{ ($s: Seq[e]).contains($i) } => + '{ ast.ListContains(${ astParser(s) }, ${ astParser(i) }) } + +} diff --git a/src/main/scala/minisql/parsing/ValueParsing.scala b/src/main/scala/minisql/parsing/ValueParsing.scala new file mode 100644 index 0000000..6c2fb9e --- /dev/null +++ b/src/main/scala/minisql/parsing/ValueParsing.scala @@ -0,0 +1,71 @@ +package minisql +package parsing + +import scala.quoted._ + +private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using + Quotes +): Parser[ast.Value] = { + + import quotes.reflect.* + + val parseTupleApply: Parser[ast.Tuple] = { + case IsTupleApply(args) => + val t = args.map(astParser) + '{ ast.Tuple(${ Expr.ofList(t) }) } + } + + val parseAssocTuple: Parser[ast.Tuple] = { + case '{ ($x: tx) -> ($y: ty) } => + '{ ast.Tuple(List(${ astParser(x) }, ${ astParser(y) })) } + } + + parseTupleApply.orElse(parseAssocTuple) +} + +private[minisql] object IsTupleApply { + + def unapply(e: Expr[Any])(using Quotes): Option[Seq[Expr[Any]]] = { + + import quotes.reflect.* + + def isTupleNApply(t: Term) = { + val fn = t.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.apply") + } + + def isTupleXXLApply(t: Term) = { + t.symbol.fullName == "scala.runtime.TupleXXL$.apply" + } + + extractTerm(e.asTerm) match { + // TupleN(0-22).apply + case Apply(b, args) if isTupleNApply(b) => + Some(args.map(_.asExpr)) + case TypeApply(Select(t, "$asInstanceOf$"), tt) if isTupleXXLApply(t) => + t.asExpr match { + case '{ scala.runtime.TupleXXL.apply(${ Varargs(args) }*) } => + Some(args) + } + case o => + None + } + } +} + +private[parsing] object IsTuple2 { + + def unapply(using + Quotes + )(t: Expr[Any]): Option[(Expr[Any], Expr[Any])] = + t match { + case '{ scala.Tuple2.apply($x1, $x2) } => Some((x1, x2)) + case '{ ($x1: t1) -> ($x2: t2) } => Some((x1, x2)) + case _ => None + } + + def unapply(using + Quotes + )(t: quotes.reflect.Term): Option[(Expr[Any], Expr[Any])] = + unapply(t.asExpr) +} diff --git a/src/main/scala/minisql/util/Interpolator.scala b/src/main/scala/minisql/util/Interpolator.scala index 55e32cd..c63e984 100644 --- a/src/main/scala/minisql/util/Interpolator.scala +++ b/src/main/scala/minisql/util/Interpolator.scala @@ -11,20 +11,25 @@ import scala.util.matching.Regex class Interpolator( defaultIndent: Int = 0, qprint: AstPrinter = AstPrinter(), - out: PrintStream = System.out, + out: PrintStream = System.out ) { + + extension (sc: StringContext) { + def trace(elements: Any*) = new Traceable(sc, elements) + } + class Traceable(sc: StringContext, elementsSeq: Seq[Any]) { private val elementPrefix = "| " private sealed trait PrintElement private case class Str(str: String, first: Boolean) extends PrintElement - private case class Elem(value: String) extends PrintElement - private case object Separator extends PrintElement + private case class Elem(value: String) extends PrintElement + private case object Separator extends PrintElement private def generateStringForCommand(value: Any, indent: Int) = { val objectString = qprint(value) - val oneLine = objectString.fitsOnOneLine + val oneLine = objectString.fitsOnOneLine oneLine match { case true => s"${indent.prefix}> ${objectString}" case false => @@ -42,7 +47,7 @@ class Interpolator( private def readBuffers() = { def orZero(i: Int): Int = if (i < 0) 0 else i - val parts = sc.parts.iterator.toList + val parts = sc.parts.iterator.toList val elements = elementsSeq.toList.map(qprint(_)) val (firstStr, explicitIndent) = readFirst(parts.head) diff --git a/src/main/scala/minisql/util/Message.scala b/src/main/scala/minisql/util/Message.scala new file mode 100644 index 0000000..f748456 --- /dev/null +++ b/src/main/scala/minisql/util/Message.scala @@ -0,0 +1,75 @@ +package minisql.util + +import minisql.AstPrinter +import minisql.idiom.Idiom +import minisql.util.IndentUtil._ + +object Messages { + + private def variable(propName: String, envName: String, default: String) = + Option(System.getProperty(propName)) + .orElse(sys.env.get(envName)) + .getOrElse(default) + + private[util] val prettyPrint = + variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean + private[util] val debugEnabled = + variable("quill.macro.log", "quill_macro_log", "true").toBoolean + private[util] val traceEnabled = + variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean + private[util] val traceColors = + variable("quill.trace.color", "quill_trace_color,", "false").toBoolean + private[util] val traceOpinions = + variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean + private[util] val traceAstSimple = variable( + "quill.trace.ast.simple", + "quill_trace_ast_simple", + "false" + ).toBoolean + private[minisql] val cacheDynamicQueries = variable( + "quill.query.cacheDaynamic", + "query_query_cacheDaynamic", + "true" + ).toBoolean + private[util] val traces: List[TraceType] = + variable("quill.trace.types", "quill_trace_types", "standard") + .split(",") + .toList + .map(_.trim) + .flatMap(trace => + TraceType.values.filter(traceType => trace == traceType.value) + ) + + def tracesEnabled(tt: TraceType) = + traceEnabled && traces.contains(tt) + + sealed trait TraceType { def value: String } + object TraceType { + case object Normalizations extends TraceType { val value = "norm" } + case object Standard extends TraceType { val value = "standard" } + case object NestedQueryExpansion extends TraceType { val value = "nest" } + + def values: List[TraceType] = + List(Standard, Normalizations, NestedQueryExpansion) + } + + val qprint = AstPrinter() + + def fail(msg: String) = + throw new IllegalStateException(msg) + + def trace[T]( + label: String, + numIndent: Int = 0, + traceType: TraceType = TraceType.Standard + ) = + (v: T) => { + val indent = (0 to numIndent).map(_ => "").mkString(" ") + if (tracesEnabled(traceType)) + println(s"$indent$label\n${{ + if (traceColors) qprint.apply(v) + else qprint.apply(v) + }.split("\n").map(s"$indent " + _).mkString("\n")}") + v + } +} From a0ceea91a9bb6bfa0aba0f2e7d3bc91df4ff7124 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 15 Dec 2024 21:11:14 +0800 Subject: [PATCH 02/26] add one test case --- build.sbt | 3 ++- src/main/scala/minisql/parsing/Parsing.scala | 4 ++-- src/test/scala/minisql/parsing/ParsingSuite.scala | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 3 deletions(-) create mode 100644 src/test/scala/minisql/parsing/ParsingSuite.scala diff --git a/build.sbt b/build.sbt index 0a952af..2dc3b9c 100644 --- a/build.sbt +++ b/build.sbt @@ -1,8 +1,9 @@ name := "minisql" -scalaVersion := "3.6.2" +scalaVersion := "3.5.2" libraryDependencies ++= Seq( + "org.scalameta" %% "munit" % "1.0.3" % Test ) scalacOptions ++= Seq("-experimental", "-language:experimental.namedTuples") diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala index 463112f..07da46a 100644 --- a/src/main/scala/minisql/parsing/Parsing.scala +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -75,7 +75,7 @@ private[minisql] object Parsing { .orElse(ifParser) .orElse(traversableOperationParser) .orElse(patMatchParser) - .orElse(infixParser) + // .orElse(infixParser) .orElse { case o => val str = scala.util.Try(o.show).getOrElse("") @@ -122,7 +122,7 @@ private[minisql] object Parsing { } lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) - lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser) + // lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser) lazy val traversableOperationParser: Parser[ast.IterableOperation] = traversableOperationParsing(astParser) diff --git a/src/test/scala/minisql/parsing/ParsingSuite.scala b/src/test/scala/minisql/parsing/ParsingSuite.scala new file mode 100644 index 0000000..043346d --- /dev/null +++ b/src/test/scala/minisql/parsing/ParsingSuite.scala @@ -0,0 +1,14 @@ +package minisql.parsing + +import minisql.ast.* + +class ParsingSuite extends munit.FunSuite { + + inline def testParseInline(inline x: Any, ast: Ast) = { + assertEquals(Parsing.parse(x), Ident("x")) + } + + test("Ident") { + val x = 1 + } +} From 2e7e7df4a3ff178ee634c1835037cb357dcfcf85 Mon Sep 17 00:00:00 2001 From: jilen Date: Tue, 17 Dec 2024 19:51:19 +0800 Subject: [PATCH 03/26] move package --- src/main/scala/minisql/Meta.scala | 3 + src/main/scala/minisql/ParamEncoder.scala | 15 ++++ src/main/scala/minisql/Quoted.scala | 86 +++++++++++++++++++ src/main/scala/minisql/context/Context.scala | 66 ++++++++++++++ src/main/scala/minisql/dsl.scala | 45 ---------- src/main/scala/minisql/idiom/Idiom.scala | 2 +- .../scala/minisql/idiom/ReifyStatement.scala | 2 +- .../scala/minisql/parsing/LiftParsing.scala | 2 +- .../minisql/parsing/OperationParsing.scala | 2 +- src/main/scala/minisql/util/CollectTry.scala | 20 ++++- src/main/scala/minisql/util/LoadObject.scala | 7 +- .../scala/minisql/parsing/ParsingSuite.scala | 30 ++++++- .../scala/minisql/parsing/QuerySuite.scala | 1 + 13 files changed, 227 insertions(+), 54 deletions(-) create mode 100644 src/main/scala/minisql/Meta.scala create mode 100644 src/main/scala/minisql/Quoted.scala create mode 100644 src/main/scala/minisql/context/Context.scala delete mode 100644 src/main/scala/minisql/dsl.scala create mode 100644 src/test/scala/minisql/parsing/QuerySuite.scala diff --git a/src/main/scala/minisql/Meta.scala b/src/main/scala/minisql/Meta.scala new file mode 100644 index 0000000..38d8fb9 --- /dev/null +++ b/src/main/scala/minisql/Meta.scala @@ -0,0 +1,3 @@ +package minisql + +type QueryMeta diff --git a/src/main/scala/minisql/ParamEncoder.scala b/src/main/scala/minisql/ParamEncoder.scala index a55c0a3..05ef348 100644 --- a/src/main/scala/minisql/ParamEncoder.scala +++ b/src/main/scala/minisql/ParamEncoder.scala @@ -1,8 +1,23 @@ package minisql +import scala.util.Try + trait ParamEncoder[E] { type Stmt def setParam(s: Stmt, idx: Int, v: E): Unit } + +trait ColumnDecoder[X] { + + type DBRow + + def decode(row: DBRow, idx: Int): Try[X] +} + +object ColumnDecoder { + type Aux[R, X] = ColumnDecoder[X] { + type DBRow = R + } +} diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala new file mode 100644 index 0000000..c7c9c7f --- /dev/null +++ b/src/main/scala/minisql/Quoted.scala @@ -0,0 +1,86 @@ +package minisql + +import minisql.* +import minisql.idiom.* +import minisql.parsing.* +import minisql.util.* +import minisql.ast.{Ast, Entity, Map, Property, Ident, Filter, given} +import scala.quoted.* +import scala.compiletime.* +import scala.compiletime.ops.string.* +import scala.collection.immutable.{Map => IMap} + +opaque type Quoted <: Ast = Ast + +opaque type Query[E] <: Quoted = Quoted + +opaque type EntityQuery[E] <: Query[E] = Query[E] + +object EntityQuery { + extension [E](inline e: EntityQuery[E]) { + inline def map[E1](inline f: E => E1): EntityQuery[E1] = { + transform(e)(f)(Map.apply) + } + + inline def filter(inline f: E => Boolean): EntityQuery[E] = { + transform(e)(f)(Filter.apply) + } + } +} + +private inline def transform[A, B](inline q1: Quoted)( + inline f: A => B +)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = { + fast(q1, f.param0, f.body) +} + +inline def query[E](inline table: String): EntityQuery[E] = + Entity(table, Nil) + +extension [A, B](inline f1: A => B) { + private inline def param0 = parsing.parseParamAt(f1, 0) + private inline def body = parsing.parseBody(f1) +} + +extension [A1, A2, B](inline f1: (A1, A2) => B) { + private inline def param0 = parsing.parseParamAt(f1, 0) + private inline def param1 = parsing.parseParamAt(f1, 1) + private inline def body = parsing.parseBody(f1) +} + +def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() + +class NonQuotedException extends Exception("Cannot be used at runtime") + +private[minisql] inline def compile[I <: Idiom, N <: NamingStrategy]( + inline q: Quoted, + inline idiom: I, + inline naming: N +): Statement = ${ compileImpl[I, N]('q, 'idiom, 'naming) } + +private def compileImpl[I <: Idiom, N <: NamingStrategy]( + q: Expr[Quoted], + idiom: Expr[I], + n: Expr[N] +)(using Quotes, Type[I], Type[N]): Expr[Statement] = { + import quotes.reflect.* + q.value match { + case Some(ast) => + val idiom = LoadObject[I].getOrElse( + report.errorAndAbort(s"Idiom not known at compile") + ) + + val naming = LoadNaming + .static[N] + .getOrElse(report.errorAndAbort(s"NamingStrategy not known at compile")) + + val stmt = idiom.translate(ast)(using naming) + Expr(stmt._2) + case None => + report.info("Dynamic Query") + '{ + $idiom.translate($q)(using $n)._2 + } + + } +} diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala new file mode 100644 index 0000000..47f5f2e --- /dev/null +++ b/src/main/scala/minisql/context/Context.scala @@ -0,0 +1,66 @@ +package minisql.context + +import scala.deriving.* +import scala.compiletime.* +import scala.util.Try +import minisql.util.* +import minisql.idiom.{Idiom, Statement} +import minisql.{NamingStrategy, ParamEncoder} +import minisql.ColumnDecoder + +trait Context[I <: Idiom, N <: NamingStrategy] { selft => + + val idiom: I + val naming: NamingStrategy + + type DBStatement + type DBRow + type DBResultSet + + trait RowExtract[A] { + def extract(row: DBRow): Try[A] + } + + object RowExtract { + + private class ExtractorImpl[A]( + decoders: IArray[Any], + m: Mirror.ProductOf[A] + ) extends RowExtract[A] { + def extract(row: DBRow): Try[A] = { + val decodedFields = decoders.zipWithIndex.traverse { + case (d, i) => + d.asInstanceOf[Decoder[?]].decode(row, i) + } + decodedFields.map { vs => + m.fromProduct(Tuple.fromIArray(vs)) + } + } + } + + inline given [P <: Product](using m: Mirror.ProductOf[P]): RowExtract[P] = { + val decoders = summonAll[Tuple.Map[m.MirroredElemTypes, Decoder]] + ExtractorImpl(decoders.toIArray.asInstanceOf, m) + } + } + + type Encoder[X] = ParamEncoder[X] { + type Stmt = DBStatement + } + + type Decoder[X] = ColumnDecoder.Aux[DBRow, X] + + type DBIO[X] = ( + statement: Statement, + params: (Any, Encoder[?]), + extract: RowExtract[X] + ) + + inline def io[E]( + inline q: minisql.Query[E] + )(using r: RowExtract[E]): DBIO[Seq[E]] = { + val statement = minisql.compile(q, idiom, naming) + ??? + } + +} diff --git a/src/main/scala/minisql/dsl.scala b/src/main/scala/minisql/dsl.scala deleted file mode 100644 index ace3d8f..0000000 --- a/src/main/scala/minisql/dsl.scala +++ /dev/null @@ -1,45 +0,0 @@ -package minisql.dsl - -import minisql.* -import minisql.parsing.* -import minisql.ast.{Ast, Entity, Map, Property, Ident, given} -import scala.quoted.* -import scala.compiletime.* -import scala.compiletime.ops.string.* -import scala.collection.immutable.{Map => IMap} - -opaque type Quoted <: Ast = Ast - -opaque type Query[E] <: Quoted = Quoted - -opaque type EntityQuery[E] <: Query[E] = Query[E] - -extension [E](inline e: EntityQuery[E]) { - inline def map[E1](inline f: E => E1): EntityQuery[E1] = { - transform(e)(f)(Map.apply) - } -} - -private inline def transform[A, B](inline q1: Quoted)( - inline f: A => B -)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = { - fast(q1, f.param0, f.body) -} - -inline def query[E](inline table: String): EntityQuery[E] = - Entity(table, Nil) - -extension [A, B](inline f1: A => B) { - private inline def param0 = parsing.parseParamAt(f1, 0) - private inline def body = parsing.parseBody(f1) -} - -extension [A1, A2, B](inline f1: (A1, A2) => B) { - private inline def param0 = parsing.parseParamAt(f1, 0) - private inline def param1 = parsing.parseParamAt(f1, 1) - private inline def body = parsing.parseBody(f1) -} - -def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() - -class NonQuotedException extends Exception("Cannot be used at runtime") diff --git a/src/main/scala/minisql/idiom/Idiom.scala b/src/main/scala/minisql/idiom/Idiom.scala index 7e0cd01..43b6110 100644 --- a/src/main/scala/minisql/idiom/Idiom.scala +++ b/src/main/scala/minisql/idiom/Idiom.scala @@ -14,7 +14,7 @@ trait Idiom extends Capabilities { def liftingPlaceholder(index: Int): String - def translate(ast: Ast)(implicit naming: NamingStrategy): (Ast, Statement) + def translate(ast: Ast)(using naming: NamingStrategy): (Ast, Statement) def format(queryString: String): String = queryString diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala index a3e8902..aea8322 100644 --- a/src/main/scala/minisql/idiom/ReifyStatement.scala +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -63,6 +63,6 @@ object ReifyStatement { emptySetContainsToken: Token => Token, liftMap: SMap[String, (Any, Any)] ): (Token) = { - statement + ??? } } diff --git a/src/main/scala/minisql/parsing/LiftParsing.scala b/src/main/scala/minisql/parsing/LiftParsing.scala index c01df89..9a0f32b 100644 --- a/src/main/scala/minisql/parsing/LiftParsing.scala +++ b/src/main/scala/minisql/parsing/LiftParsing.scala @@ -3,7 +3,7 @@ package minisql.parsing import scala.quoted.* import minisql.ParamEncoder import minisql.ast -import minisql.dsl.* +import minisql.* private[parsing] def liftParsing( astParser: => Parser[ast.Ast] diff --git a/src/main/scala/minisql/parsing/OperationParsing.scala b/src/main/scala/minisql/parsing/OperationParsing.scala index df93010..4425dd5 100644 --- a/src/main/scala/minisql/parsing/OperationParsing.scala +++ b/src/main/scala/minisql/parsing/OperationParsing.scala @@ -7,7 +7,7 @@ import minisql.ast.{ NumericOperator, BooleanOperator } -import minisql.dsl.* +import minisql.* import scala.quoted._ private[parsing] def operationParsing( diff --git a/src/main/scala/minisql/util/CollectTry.scala b/src/main/scala/minisql/util/CollectTry.scala index f0ee506..74a6984 100644 --- a/src/main/scala/minisql/util/CollectTry.scala +++ b/src/main/scala/minisql/util/CollectTry.scala @@ -1,6 +1,24 @@ package minisql.util -import scala.util.Try +import scala.util.* + +extension [A](xs: IArray[A]) { + private[minisql] def traverse[B](f: A => Try[B]): Try[IArray[B]] = { + val out = IArray.newBuilder[Any] + var left: Option[Throwable] = None + xs.foreach { (v) => + if (!left.isDefined) { + f(v) match { + case Failure(e) => + left = Some(e) + case Success(r) => + out += r + } + } + } + left.toLeft(out.result().asInstanceOf).toTry + } +} object CollectTry { def apply[T](list: List[Try[T]]): Try[List[T]] = diff --git a/src/main/scala/minisql/util/LoadObject.scala b/src/main/scala/minisql/util/LoadObject.scala index 83bbec0..8c13c8f 100644 --- a/src/main/scala/minisql/util/LoadObject.scala +++ b/src/main/scala/minisql/util/LoadObject.scala @@ -5,10 +5,15 @@ import scala.util.Try object LoadObject { + def apply[T](using Quotes, Type[T]): Try[T] = { + import quotes.reflect.* + apply(TypeRepr.of[T]) + } + def apply[T](using Quotes)(ot: quotes.reflect.TypeRepr): Try[T] = Try { import quotes.reflect.* val moduleClsName = ot.typeSymbol.companionModule.moduleClass.fullName - val moduleCls = Class.forName(moduleClsName) + val moduleCls = Class.forName(moduleClsName) val field = moduleCls .getFields() .find { f => diff --git a/src/test/scala/minisql/parsing/ParsingSuite.scala b/src/test/scala/minisql/parsing/ParsingSuite.scala index 043346d..b90f9f8 100644 --- a/src/test/scala/minisql/parsing/ParsingSuite.scala +++ b/src/test/scala/minisql/parsing/ParsingSuite.scala @@ -4,11 +4,35 @@ import minisql.ast.* class ParsingSuite extends munit.FunSuite { - inline def testParseInline(inline x: Any, ast: Ast) = { + test("Ident") { + val x = 1 assertEquals(Parsing.parse(x), Ident("x")) } - test("Ident") { - val x = 1 + test("NumericOperator.+") { + val a = 1 + val b = 2 + assertEquals( + Parsing.parse(a + b), + BinaryOperation(Ident("a"), NumericOperator.+, Ident("b")) + ) + } + + test("NumericOperator.-") { + val a = 1 + val b = 2 + assertEquals( + Parsing.parse(a - b), + BinaryOperation(Ident("a"), NumericOperator.-, Ident("b")) + ) + } + + test("NumericOperator.*") { + val a = 1 + val b = 2 + assertEquals( + Parsing.parse(a * b), + BinaryOperation(Ident("a"), NumericOperator.*, Ident("b")) + ) } } diff --git a/src/test/scala/minisql/parsing/QuerySuite.scala b/src/test/scala/minisql/parsing/QuerySuite.scala new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/test/scala/minisql/parsing/QuerySuite.scala @@ -0,0 +1 @@ + From 59f969a232397412a357b20cd8da243b040e86a9 Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 18 Dec 2024 16:09:08 +0800 Subject: [PATCH 04/26] test simple quoted ast --- README.md | 2 +- src/main/scala/minisql/Quoted.scala | 15 + src/main/scala/minisql/ast/FromExprs.scala | 3 +- .../scala/minisql/idiom/MirrorIdiom.scala | 355 ++++++++++++++++++ .../minisql/idiom/StatementInterpolator.scala | 16 +- src/main/scala/minisql/parsing/Parser.scala | 2 +- .../scala/minisql/parsing/QuotedSuite.scala | 35 ++ 7 files changed, 420 insertions(+), 8 deletions(-) create mode 100644 src/main/scala/minisql/idiom/MirrorIdiom.scala create mode 100644 src/test/scala/minisql/parsing/QuotedSuite.scala diff --git a/README.md b/README.md index de83bb9..6f82dea 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ 大部分场景不用在 `macro` 对 Ast 进行复杂模式匹配来分析代码。 -## 核心思路 使用 inline 和 `FromExpr` 代替大部分 parsing 工作 +## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作 `FromExpr` 是 `scala3` 内置的 typeclass,用来获取编译期值 。 diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index c7c9c7f..da7008e 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -52,6 +52,21 @@ def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() class NonQuotedException extends Exception("Cannot be used at runtime") +private[minisql] inline def compileTimeAst(inline q: Quoted): Option[String] = + ${ + compileTimeAstImpl('q) + } + +private def compileTimeAstImpl(e: Expr[Quoted])(using + Quotes +): Expr[Option[String]] = { + import quotes.reflect.* + e.value match { + case Some(v) => '{ Some(${ Expr(v.toString()) }) } + case None => '{ None } + } +} + private[minisql] inline def compile[I <: Idiom, N <: NamingStrategy]( inline q: Quoted, inline idiom: I, diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 26694fb..3541a76 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -130,7 +130,7 @@ private given FromExpr[Query] with { case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } => Some(SortBy(b, p, s, o)) case o => - println(s"Cannot extract ${o.show}") + println(s"Cannot extract ${o}") None } } @@ -146,6 +146,7 @@ private given FromExpr[BinaryOperator] with { case '{ NumericOperator.- } => Some(NumericOperator.-) case '{ NumericOperator.* } => Some(NumericOperator.*) case '{ NumericOperator./ } => Some(NumericOperator./) + case '{ NumericOperator.> } => Some(NumericOperator.>) case '{ StringOperator.split } => Some(StringOperator.split) case '{ StringOperator.startsWith } => Some(StringOperator.startsWith) case '{ StringOperator.concat } => Some(StringOperator.concat) diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala new file mode 100644 index 0000000..88aab8c --- /dev/null +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -0,0 +1,355 @@ +package minisql + +import minisql.ast.Renameable.{ByStrategy, Fixed} +import minisql.ast.Visibility.Hidden +import minisql.ast._ +import minisql.context.CanReturnClause +import minisql.idiom.{Idiom, SetContainsToken, Statement} +import minisql.idiom.StatementInterpolator.* +import minisql.norm.Normalize +import minisql.util.Interleave + +object MirrorIdiom extends MirrorIdiom +class MirrorIdiom extends MirrorIdiomBase with CanReturnClause + +object MirrorIdiomPrinting extends MirrorIdiom { + override def distinguishHidden: Boolean = true +} + +trait MirrorIdiomBase extends Idiom { + + def distinguishHidden: Boolean = false + + override def prepareForProbing(string: String) = string + + override def liftingPlaceholder(index: Int): String = "?" + + override def translate( + ast: Ast + )(implicit naming: NamingStrategy): (Ast, Statement) = { + val normalizedAst = Normalize(ast) + (normalizedAst, stmt"${normalizedAst.token}") + } + + implicit def astTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Ast] = Tokenizer[Ast] { + case ast: Query => ast.token + case ast: Function => ast.token + case ast: Value => ast.token + case ast: Operation => ast.token + case ast: Action => ast.token + case ast: Ident => ast.token + case ast: ExternalIdent => ast.token + case ast: Property => ast.token + case ast: Infix => ast.token + case ast: OptionOperation => ast.token + case ast: IterableOperation => ast.token + case ast: Dynamic => ast.token + case ast: If => ast.token + case ast: Block => ast.token + case ast: Val => ast.token + case ast: Ordering => ast.token + case ast: Lift => ast.token + case ast: Assignment => ast.token + case ast: OnConflict.Excluded => ast.token + case ast: OnConflict.Existing => ast.token + } + + implicit def ifTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[If] = Tokenizer[If] { + case If(a, b, c) => stmt"if(${a.token}) ${b.token} else ${c.token}" + } + + implicit val dynamicTokenizer: Tokenizer[Dynamic] = Tokenizer[Dynamic] { + case Dynamic(tree) => stmt"${tree.toString.token}" + } + + implicit def blockTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Block] = Tokenizer[Block] { + case Block(statements) => stmt"{ ${statements.map(_.token).mkStmt("; ")} }" + } + + implicit def valTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Val] = Tokenizer[Val] { + case Val(name, body) => stmt"val ${name.token} = ${body.token}" + } + + implicit def queryTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Query] = Tokenizer[Query] { + + case Entity.Opinionated(name, Nil, renameable) => + stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token})" + + case Entity.Opinionated(name, prop, renameable) => + val properties = + prop.map(p => stmt"""_.${p.path.mkStmt(".")} -> "${p.alias.token}"""") + stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token}, ${properties.token})" + + case Filter(source, alias, body) => + stmt"${source.token}.filter(${alias.token} => ${body.token})" + + case Map(source, alias, body) => + stmt"${source.token}.map(${alias.token} => ${body.token})" + + case FlatMap(source, alias, body) => + stmt"${source.token}.flatMap(${alias.token} => ${body.token})" + + case ConcatMap(source, alias, body) => + stmt"${source.token}.concatMap(${alias.token} => ${body.token})" + + case SortBy(source, alias, body, ordering) => + stmt"${source.token}.sortBy(${alias.token} => ${body.token})(${ordering.token})" + + case GroupBy(source, alias, body) => + stmt"${source.token}.groupBy(${alias.token} => ${body.token})" + + case Aggregation(op, ast) => + stmt"${scopedTokenizer(ast)}.${op.token}" + + case Take(source, n) => + stmt"${source.token}.take(${n.token})" + + case Drop(source, n) => + stmt"${source.token}.drop(${n.token})" + + case Union(a, b) => + stmt"${a.token}.union(${b.token})" + + case UnionAll(a, b) => + stmt"${a.token}.unionAll(${b.token})" + + case Join(t, a, b, iA, iB, on) => + stmt"${a.token}.${t.token}(${b.token}).on((${iA.token}, ${iB.token}) => ${on.token})" + + case FlatJoin(t, a, iA, on) => + stmt"${a.token}.${t.token}((${iA.token}) => ${on.token})" + + case Distinct(a) => + stmt"${a.token}.distinct" + + case DistinctOn(source, alias, body) => + stmt"${source.token}.distinctOn(${alias.token} => ${body.token})" + + case Nested(a) => + stmt"${a.token}.nested" + } + + implicit val orderingTokenizer: Tokenizer[Ordering] = Tokenizer[Ordering] { + case TupleOrdering(elems) => stmt"Ord(${elems.token})" + case Asc => stmt"Ord.asc" + case Desc => stmt"Ord.desc" + case AscNullsFirst => stmt"Ord.ascNullsFirst" + case DescNullsFirst => stmt"Ord.descNullsFirst" + case AscNullsLast => stmt"Ord.ascNullsLast" + case DescNullsLast => stmt"Ord.descNullsLast" + } + + implicit def optionOperationTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] { + case OptionTableFlatMap(ast, alias, body) => + stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})" + case OptionTableMap(ast, alias, body) => + stmt"${ast.token}.map((${alias.token}) => ${body.token})" + case OptionTableExists(ast, alias, body) => + stmt"${ast.token}.exists((${alias.token}) => ${body.token})" + case OptionTableForall(ast, alias, body) => + stmt"${ast.token}.forall((${alias.token}) => ${body.token})" + case OptionFlatten(ast) => stmt"${ast.token}.flatten" + case OptionGetOrElse(ast, body) => + stmt"${ast.token}.getOrElse(${body.token})" + case OptionFlatMap(ast, alias, body) => + stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})" + case OptionMap(ast, alias, body) => + stmt"${ast.token}.map((${alias.token}) => ${body.token})" + case OptionForall(ast, alias, body) => + stmt"${ast.token}.forall((${alias.token}) => ${body.token})" + case OptionExists(ast, alias, body) => + stmt"${ast.token}.exists((${alias.token}) => ${body.token})" + case OptionContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + case OptionIsEmpty(ast) => stmt"${ast.token}.isEmpty" + case OptionNonEmpty(ast) => stmt"${ast.token}.nonEmpty" + case OptionIsDefined(ast) => stmt"${ast.token}.isDefined" + case OptionSome(ast) => stmt"Some(${ast.token})" + case OptionApply(ast) => stmt"Option(${ast.token})" + case OptionOrNull(ast) => stmt"${ast.token}.orNull" + case OptionGetOrNull(ast) => stmt"${ast.token}.getOrNull" + case OptionNone => stmt"None" + } + + implicit def traversableOperationTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[IterableOperation] = Tokenizer[IterableOperation] { + case MapContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + case SetContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + case ListContains(ast, body) => stmt"${ast.token}.contains(${body.token})" + } + + implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { + case InnerJoin => stmt"join" + case LeftJoin => stmt"leftJoin" + case RightJoin => stmt"rightJoin" + case FullJoin => stmt"fullJoin" + } + + implicit def functionTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Function] = Tokenizer[Function] { + case Function(params, body) => stmt"(${params.token}) => ${body.token}" + } + + implicit def operationTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Operation] = Tokenizer[Operation] { + case UnaryOperation(op: PrefixUnaryOperator, ast) => + stmt"${op.token}${scopedTokenizer(ast)}" + case UnaryOperation(op: PostfixUnaryOperator, ast) => + stmt"${scopedTokenizer(ast)}.${op.token}" + case BinaryOperation(a, op @ SetOperator.`contains`, b) => + SetContainsToken(scopedTokenizer(b), op.token, a.token) + case BinaryOperation(a, op, b) => + stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}" + case FunctionApply(function, values) => + stmt"${scopedTokenizer(function)}.apply(${values.token})" + } + + implicit def operatorTokenizer[T <: Operator]: Tokenizer[T] = Tokenizer[T] { + case o => stmt"${o.toString.token}" + } + + def tokenizeName(name: String, renameable: Renameable) = + renameable match { + case ByStrategy => name + case Fixed => s"`${name}`" + } + + def bracketIfHidden(name: String, visibility: Visibility) = + (distinguishHidden, visibility) match { + case (true, Hidden) => s"[$name]" + case _ => name + } + + implicit def propertyTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Property] = Tokenizer[Property] { + case Property.Opinionated(ExternalIdent(_), name, renameable, visibility) => + stmt"${bracketIfHidden(tokenizeName(name, renameable), visibility).token}" + case Property.Opinionated(ref, name, renameable, visibility) => + stmt"${scopedTokenizer(ref)}.${bracketIfHidden(tokenizeName(name, renameable), visibility).token}" + } + + implicit val valueTokenizer: Tokenizer[Value] = Tokenizer[Value] { + case Constant(v: String) => stmt""""${v.token}"""" + case Constant(()) => stmt"{}" + case Constant(v) => stmt"${v.toString.token}" + case NullValue => stmt"null" + case Tuple(values) => stmt"(${values.token})" + case CaseClass(values) => + stmt"CaseClass(${values.map { case (k, v) => s"${k.token}: ${v.token}" }.mkString(", ").token})" + } + + implicit val identTokenizer: Tokenizer[Ident] = Tokenizer[Ident] { + case Ident.Opinionated(name, visibility) => + stmt"${bracketIfHidden(name, visibility).token}" + } + + implicit val typeTokenizer: Tokenizer[ExternalIdent] = + Tokenizer[ExternalIdent] { + case e => stmt"${e.name.token}" + } + + implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] = + Tokenizer[OnConflict.Excluded] { + case OnConflict.Excluded(ident) => stmt"${ident.token}" + } + + implicit val existingTokenizer: Tokenizer[OnConflict.Existing] = + Tokenizer[OnConflict.Existing] { + case OnConflict.Existing(ident) => stmt"${ident.token}" + } + + implicit def actionTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Action] = Tokenizer[Action] { + case Update(query, assignments) => + stmt"${query.token}.update(${assignments.token})" + case Insert(query, assignments) => + stmt"${query.token}.insert(${assignments.token})" + case Delete(query) => stmt"${query.token}.delete" + case Returning(query, alias, body) => + stmt"${query.token}.returning((${alias.token}) => ${body.token})" + case ReturningGenerated(query, alias, body) => + stmt"${query.token}.returningGenerated((${alias.token}) => ${body.token})" + case Foreach(query, alias, body) => + stmt"${query.token}.foreach((${alias.token}) => ${body.token})" + case c: OnConflict => stmt"${c.token}" + } + + implicit def conflictTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[OnConflict] = { + + def targetProps(l: List[Property]) = l.map(p => + Transform(p) { + case Ident(_) => Ident("_") + } + ) + + implicit val conflictTargetTokenizer: Tokenizer[OnConflict.Target] = + Tokenizer[OnConflict.Target] { + case OnConflict.NoTarget => stmt"" + case OnConflict.Properties(props) => + val listTokens = listTokenizer(astTokenizer).token(props) + stmt"(${listTokens})" + } + + val updateAssignsTokenizer = Tokenizer[Assignment] { + case Assignment(i, p, v) => + stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}" + } + + Tokenizer[OnConflict] { + case OnConflict(i, t, OnConflict.Update(assign)) => + stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})" + case OnConflict(i, t, OnConflict.Ignore) => + stmt"${i.token}.onConflictIgnore${t.token}" + } + } + + implicit def assignmentTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Assignment] = Tokenizer[Assignment] { + case Assignment(ident, property, value) => + stmt"${ident.token} => ${property.token} -> ${value.token}" + } + + implicit def infixTokenizer(implicit + liftTokenizer: Tokenizer[Lift] + ): Tokenizer[Infix] = Tokenizer[Infix] { + case Infix(parts, params, _, _) => + def tokenParam(ast: Ast) = + ast match { + case ast: Ident => stmt"$$${ast.token}" + case other => stmt"$${${ast.token}}" + } + + val pt = parts.map(_.token) + val pr = params.map(tokenParam) + val body = Statement(Interleave(pt, pr)) + stmt"""infix"${body.token}"""" + } + + private def scopedTokenizer( + ast: Ast + )(implicit liftTokenizer: Tokenizer[Lift]) = + ast match { + case _: Function => stmt"(${ast.token})" + case _: BinaryOperation => stmt"(${ast.token})" + case other => ast.token + } +} diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala index b732da1..3aa4d26 100644 --- a/src/main/scala/minisql/idiom/StatementInterpolator.scala +++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala @@ -9,19 +9,25 @@ import scala.collection.mutable.ListBuffer object StatementInterpolator { trait Tokenizer[T] { - def token(v: T): Token + extension (v: T) { + def token: Token + } } object Tokenizer { - def apply[T](f: T => Token) = new Tokenizer[T] { - def token(v: T) = f(v) + def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] { + extension (v: T) { + def token: Token = f(v) + } } def withFallback[T]( fallback: Tokenizer[T] => Tokenizer[T] )(pf: PartialFunction[T, Token]) = new Tokenizer[T] { - private val stable = fallback(this) - override def token(v: T) = pf.applyOrElse(v, stable.token) + extension (v: T) { + private def stable = fallback(this) + override def token = pf.applyOrElse(v, stable.token) + } } } diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala index 4bc7265..91bfbc0 100644 --- a/src/main/scala/minisql/parsing/Parser.scala +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -38,7 +38,7 @@ private[minisql] def parseBody[X]( x: Expr[X] )(using Quotes): Expr[Ast] = { import quotes.reflect.* - x.asTerm match { + extractTerm(x.asTerm) match { case Lambda(vals, body) => Parsing.parseExpr(body.asExpr) case o => diff --git a/src/test/scala/minisql/parsing/QuotedSuite.scala b/src/test/scala/minisql/parsing/QuotedSuite.scala new file mode 100644 index 0000000..2bd02eb --- /dev/null +++ b/src/test/scala/minisql/parsing/QuotedSuite.scala @@ -0,0 +1,35 @@ +package minisql + +import minisql.ast.* + +class QuotedSuite extends munit.FunSuite { + private inline def testQuoted(label: String)( + inline x: Quoted, + expect: Ast + ) = test(label) { + assertEquals(compileTimeAst(x), Some(expect.toString())) + } + + case class Foo(id: Long) + + inline def Foos = query[Foo]("foo") + val entityFoo = Entity("foo", Nil) + val idx = Ident("x") + + testQuoted("EntityQuery")(Foos, entityFoo) + + testQuoted("Query/filter")( + Foos.filter(x => x.id > 0), + Filter( + entityFoo, + idx, + BinaryOperation(Property(idx, "id"), NumericOperator.>, Constant(0)) + ) + ) + + testQuoted("Query/map")( + Foos.map(x => x.id), + Map(entityFoo, idx, Property(idx, "id")) + ) + +} From 6be96aba2cedc191aa1cad7dafe9c092f937a7c3 Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 18 Dec 2024 19:28:43 +0800 Subject: [PATCH 05/26] save --- src/main/scala/minisql/ast/FromExprs.scala | 7 +++++++ src/main/scala/minisql/context/Context.scala | 4 ++++ src/main/scala/minisql/idiom/ReifyStatement.scala | 2 +- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 3541a76..9f70b0d 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -123,6 +123,13 @@ private given FromExpr[Query] with { Some(FlatMap(b, id, body)) case '{ ConcatMap(${ Expr(b) }, ${ Expr(id) }, ${ Expr(body) }) } => Some(ConcatMap(b, id, body)) + case '{ + val x: Ast = ${ Expr(b) } + val y: Ident = ${ Expr(id) } + val z: Ast = ${ Expr(body) } + ConcatMap(x, y, z) + } => + Some(ConcatMap(b, id, body)) case '{ Drop(${ Expr(b) }, ${ Expr(n) }) } => Some(Drop(b, n)) case '{ Take(${ Expr(b) }, ${ Expr[Ast](n) }) } => diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 47f5f2e..8b3b96f 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -56,6 +56,10 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => extract: RowExtract[X] ) + extension (ast: Ast) { + extractParams + } + inline def io[E]( inline q: minisql.Query[E] )(using r: RowExtract[E]): DBIO[Seq[E]] = { diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala index aea8322..7c5b9a8 100644 --- a/src/main/scala/minisql/idiom/ReifyStatement.scala +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -62,7 +62,7 @@ object ReifyStatement { statement: Statement, emptySetContainsToken: Token => Token, liftMap: SMap[String, (Any, Any)] - ): (Token) = { + ): Token = { ??? } } From 87f1b70b274b554977e4f9921a161b6c6ddb6fe5 Mon Sep 17 00:00:00 2001 From: jilen Date: Thu, 19 Dec 2024 11:42:23 +0800 Subject: [PATCH 06/26] try implement context --- src/main/scala/minisql/ast/Ast.scala | 5 ++- src/main/scala/minisql/context/Context.scala | 38 ++++++++++++---- .../scala/minisql/idiom/ReifyStatement.scala | 44 ++++++++++++++++--- src/main/scala/minisql/util/CollectTry.scala | 2 +- 4 files changed, 71 insertions(+), 18 deletions(-) diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 48e3fae..52446e3 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -385,8 +385,9 @@ case class ScalarValueLift( case class ScalarQueryLift( name: String, - liftId: String -) extends ScalarLift {} + liftId: String, + value: Option[(Seq[Any], ParamEncoder[?])] +) extends ScalarLift object ScalarLift { given ToExpr[ScalarLift] with { diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 8b3b96f..af33d80 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -4,9 +4,10 @@ import scala.deriving.* import scala.compiletime.* import scala.util.Try import minisql.util.* -import minisql.idiom.{Idiom, Statement} +import minisql.idiom.{Idiom, Statement, ReifyStatement} import minisql.{NamingStrategy, ParamEncoder} import minisql.ColumnDecoder +import minisql.ast.{Ast, ScalarValueLift, CollectAst} trait Context[I <: Idiom, N <: NamingStrategy] { selft => @@ -50,21 +51,40 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => type Decoder[X] = ColumnDecoder.Aux[DBRow, X] - type DBIO[X] = ( - statement: Statement, - params: (Any, Encoder[?]), - extract: RowExtract[X] + type DBIO[E] = ( + sql: String, + params: List[(Any, Encoder[?])], + mapper: Iterable[DBRow] => Try[E] ) extension (ast: Ast) { - extractParams + private def liftMap = { + val lifts = CollectAst.byType[ScalarValueLift](ast) + lifts.map(l => l.liftId -> l.value.get).toMap + } + } + + extension (stmt: Statement) { + def expand(liftMap: Map[String, (Any, ParamEncoder[?])]) = + ReifyStatement( + idiom.liftingPlaceholder, + idiom.emptySetContainsToken, + stmt, + liftMap + ) } inline def io[E]( inline q: minisql.Query[E] - )(using r: RowExtract[E]): DBIO[Seq[E]] = { - val statement = minisql.compile(q, idiom, naming) - ??? + )(using r: RowExtract[E]): DBIO[IArray[E]] = { + val lifts = q.liftMap + val stmt = minisql.compile(q, idiom, naming) + val (sql, params) = stmt.expand(lifts) + ( + sql = sql, + params = params.map(_.value.get.asInstanceOf), + mapper = (rows) => rows.traverse(r.extract) + ) } } diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala index 7c5b9a8..7a4a07a 100644 --- a/src/main/scala/minisql/idiom/ReifyStatement.scala +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -1,8 +1,9 @@ package minisql.idiom -import minisql.ast._ +import minisql.ParamEncoder +import minisql.ast.* import minisql.util.Interleave -import minisql.idiom.StatementInterpolator._ +import minisql.idiom.StatementInterpolator.* import scala.annotation.tailrec import scala.collection.immutable.{Map => SMap} @@ -12,7 +13,7 @@ object ReifyStatement { liftingPlaceholder: Int => String, emptySetContainsToken: Token => Token, statement: Statement, - liftMap: SMap[String, (Any, Any)] + liftMap: SMap[String, (Any, ParamEncoder[?])] ): (String, List[ScalarValueLift]) = { val expanded = expandLiftings(statement, emptySetContainsToken, liftMap) token2string(expanded, liftingPlaceholder) @@ -61,8 +62,39 @@ object ReifyStatement { private def expandLiftings( statement: Statement, emptySetContainsToken: Token => Token, - liftMap: SMap[String, (Any, Any)] - ): Token = { - ??? + liftMap: SMap[String, (Any, ParamEncoder[?])] + ): (Token) = { + Statement { + val lb = List.newBuilder[Token] + statement.tokens.foldLeft(lb) { + case ( + tokens, + SetContainsToken(a, op, ScalarLiftToken(lift: ScalarQueryLift)) + ) => + val (lv, le) = liftMap(lift.liftId) + lv.asInstanceOf[Iterable[Any]].toVector match { + case Vector() => tokens += emptySetContainsToken(a) + case values => + val liftings = values.zipWithIndex.map { + case (v, i) => + ScalarLiftToken( + ScalarValueLift( + s"${lift.name}[${i}]", + s"${lift.liftId}[${i}]", + Some(v -> le) + ) + ) + } + val separators = Vector.fill(liftings.size - 1)(StringToken(", ")) + (tokens += stmt"$a $op (") ++= Interleave( + liftings, + separators + ) += StringToken(")") + } + case (tokens, token) => + tokens += token + } + lb.result() + } } } diff --git a/src/main/scala/minisql/util/CollectTry.scala b/src/main/scala/minisql/util/CollectTry.scala index 74a6984..6b466be 100644 --- a/src/main/scala/minisql/util/CollectTry.scala +++ b/src/main/scala/minisql/util/CollectTry.scala @@ -2,7 +2,7 @@ package minisql.util import scala.util.* -extension [A](xs: IArray[A]) { +extension [A](xs: Iterable[A]) { private[minisql] def traverse[B](f: A => Try[B]): Try[IArray[B]] = { val out = IArray.newBuilder[Any] var left: Option[Throwable] = None From 7f5092c3963c4b9a889be8ed7fdd54a7553eb05f Mon Sep 17 00:00:00 2001 From: jilen Date: Thu, 19 Dec 2024 12:36:44 +0800 Subject: [PATCH 07/26] add mirror context --- .../scala/minisql/context/MirrorContext.scala | 15 ++++++++ src/main/scala/minisql/context/mirror.scala | 35 +++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 src/main/scala/minisql/context/MirrorContext.scala create mode 100644 src/main/scala/minisql/context/mirror.scala diff --git a/src/main/scala/minisql/context/MirrorContext.scala b/src/main/scala/minisql/context/MirrorContext.scala new file mode 100644 index 0000000..c1be053 --- /dev/null +++ b/src/main/scala/minisql/context/MirrorContext.scala @@ -0,0 +1,15 @@ +package minisql + +import minisql.context.mirror.* + +class MirrorContext[Idiom <: idiom.Idiom, Naming <: NamingStrategy]( + val idiom: Idiom, + val naming: Naming +) extends context.Context[Idiom, Naming] { + + type DBRow = Row + + type DBResultSet = Iterable[DBRow] + + type DBStatement = IArray[Any] +} diff --git a/src/main/scala/minisql/context/mirror.scala b/src/main/scala/minisql/context/mirror.scala new file mode 100644 index 0000000..cd3b725 --- /dev/null +++ b/src/main/scala/minisql/context/mirror.scala @@ -0,0 +1,35 @@ +package minisql.context.mirror + +import minisql.{MirrorContext, NamingStrategy} +import minisql.idiom.Idiom +import minisql.util.Messages.fail +import scala.reflect.ClassTag + +/** +* No extra class defined +*/ +opaque type Row = IArray[Any] *: EmptyTuple + +extension (r: Row) { + + def data: IArray[Any] = r._1 + + def add(value: Any): Row = (r.data :+ value) *: EmptyTuple + + def apply[T](idx: Int)(using t: ClassTag[T]): T = { + r.data(idx) match { + case v: T => v + case other => + fail( + s"Invalid column type. Expected '${t.runtimeClass}', but got '$other'" + ) + } + } +} + +trait MirrorCodecs[I <: Idiom, N <: NamingStrategy] { + this: MirrorContext[I, N] => + + given byteEncoder: Encoder[Byte] + +} From 47cf808e8f67c2aaca1b01734fd4a59c8cdf9726 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 29 Dec 2024 20:19:07 +0800 Subject: [PATCH 08/26] simplify typeclass --- src/main/scala/minisql/util/Show.scala | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/main/scala/minisql/util/Show.scala b/src/main/scala/minisql/util/Show.scala index b4acc97..3496caf 100644 --- a/src/main/scala/minisql/util/Show.scala +++ b/src/main/scala/minisql/util/Show.scala @@ -1,21 +1,20 @@ package minisql.util -object Show { - trait Show[T] { - def show(v: T): String +trait Show[T] { + extension (v: T) { + def show: String } +} - object Show { - def apply[T](f: T => String) = new Show[T] { - def show(v: T) = f(v) +object Show { + + def apply[T](f: T => String) = new Show[T] { + extension (v: T) { + def show: String = f(v) } } - implicit class Shower[T](v: T)(implicit shower: Show[T]) { - def show = shower.show(v) - } - - implicit def listShow[T](implicit shower: Show[T]): Show[List[T]] = + given listShow[T](using shower: Show[T]): Show[List[T]] = Show[List[T]] { case list => list.map(_.show).mkString(", ") } From cb0c6082d0b65788058b5e315204ddad562f493e Mon Sep 17 00:00:00 2001 From: jilen Date: Tue, 17 Jun 2025 17:31:36 +0800 Subject: [PATCH 09/26] Add statement --- .gitignore | 3 +- build.sbt | 4 +- project/build.properties | 2 +- src/main/scala/minisql/ParamEncoder.scala | 3 +- src/main/scala/minisql/context/Context.scala | 63 ++++++++++--------- .../scala/minisql/context/MirrorContext.scala | 3 +- src/main/scala/minisql/context/mirror.scala | 50 ++++++++++++--- .../scala/minisql/parsing/InfixParsing.scala | 1 - .../minisql/parsing/PropertyParsing.scala | 1 - .../scala/minisql/parsing/QuotedSuite.scala | 39 ++++-------- 10 files changed, 97 insertions(+), 72 deletions(-) diff --git a/.gitignore b/.gitignore index 816c54d..a56226a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ target/ .bsp/ .metals/ .bloop/ -project/metals.sbt \ No newline at end of file +project/metals.sbt +.aider* diff --git a/build.sbt b/build.sbt index 2dc3b9c..67a69a6 100644 --- a/build.sbt +++ b/build.sbt @@ -1,9 +1,7 @@ name := "minisql" -scalaVersion := "3.5.2" +scalaVersion := "3.7.0" libraryDependencies ++= Seq( "org.scalameta" %% "munit" % "1.0.3" % Test ) - -scalacOptions ++= Seq("-experimental", "-language:experimental.namedTuples") diff --git a/project/build.properties b/project/build.properties index db1723b..e97b272 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.10.5 +sbt.version=1.10.10 diff --git a/src/main/scala/minisql/ParamEncoder.scala b/src/main/scala/minisql/ParamEncoder.scala index 05ef348..4d2abe4 100644 --- a/src/main/scala/minisql/ParamEncoder.scala +++ b/src/main/scala/minisql/ParamEncoder.scala @@ -3,10 +3,9 @@ package minisql import scala.util.Try trait ParamEncoder[E] { - type Stmt - def setParam(s: Stmt, idx: Int, v: E): Unit + def setParam(s: Stmt, idx: Int, v: E): Stmt } trait ColumnDecoder[X] { diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index af33d80..de05202 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -9,6 +9,40 @@ import minisql.{NamingStrategy, ParamEncoder} import minisql.ColumnDecoder import minisql.ast.{Ast, ScalarValueLift, CollectAst} +trait RowExtract[A, Row] { + def extract(row: Row): Try[A] +} + +object RowExtract { + + private def extractorImpl[A, Row]( + decoders: IArray[Any], + m: Mirror.ProductOf[A] + ): RowExtract[A, Row] = new RowExtract[A, Row] { + def extract(row: Row): Try[A] = { + val decodedFields = decoders.zipWithIndex.traverse { + case (d, i) => + d.asInstanceOf[ColumnDecoder.Aux[Row, ?]].decode(row, i) + } + decodedFields.map { vs => + m.fromProduct(Tuple.fromIArray(vs)) + } + } + } + + inline given [P <: Product, Row, Decoder[_]](using + m: Mirror.ProductOf[P] + ): RowExtract[P, Row] = { + val decoders = + summonAll[ + Tuple.Map[m.MirroredElemTypes, [X] =>> ColumnDecoder[ + X + ] { type DBRow = Row }] + ] + extractorImpl(decoders.toIArray.asInstanceOf, m) + } +} + trait Context[I <: Idiom, N <: NamingStrategy] { selft => val idiom: I @@ -18,33 +52,6 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => type DBRow type DBResultSet - trait RowExtract[A] { - def extract(row: DBRow): Try[A] - } - - object RowExtract { - - private class ExtractorImpl[A]( - decoders: IArray[Any], - m: Mirror.ProductOf[A] - ) extends RowExtract[A] { - def extract(row: DBRow): Try[A] = { - val decodedFields = decoders.zipWithIndex.traverse { - case (d, i) => - d.asInstanceOf[Decoder[?]].decode(row, i) - } - decodedFields.map { vs => - m.fromProduct(Tuple.fromIArray(vs)) - } - } - } - - inline given [P <: Product](using m: Mirror.ProductOf[P]): RowExtract[P] = { - val decoders = summonAll[Tuple.Map[m.MirroredElemTypes, Decoder]] - ExtractorImpl(decoders.toIArray.asInstanceOf, m) - } - } - type Encoder[X] = ParamEncoder[X] { type Stmt = DBStatement } @@ -76,7 +83,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => inline def io[E]( inline q: minisql.Query[E] - )(using r: RowExtract[E]): DBIO[IArray[E]] = { + )(using r: RowExtract[E, DBRow]): DBIO[IArray[E]] = { val lifts = q.liftMap val stmt = minisql.compile(q, idiom, naming) val (sql, params) = stmt.expand(lifts) diff --git a/src/main/scala/minisql/context/MirrorContext.scala b/src/main/scala/minisql/context/MirrorContext.scala index c1be053..ba00db1 100644 --- a/src/main/scala/minisql/context/MirrorContext.scala +++ b/src/main/scala/minisql/context/MirrorContext.scala @@ -9,7 +9,6 @@ class MirrorContext[Idiom <: idiom.Idiom, Naming <: NamingStrategy]( type DBRow = Row - type DBResultSet = Iterable[DBRow] + type DBResultSet = ResultSet - type DBStatement = IArray[Any] } diff --git a/src/main/scala/minisql/context/mirror.scala b/src/main/scala/minisql/context/mirror.scala index cd3b725..a67fdd2 100644 --- a/src/main/scala/minisql/context/mirror.scala +++ b/src/main/scala/minisql/context/mirror.scala @@ -1,14 +1,17 @@ package minisql.context.mirror -import minisql.{MirrorContext, NamingStrategy} +import minisql.{MirrorContext, NamingStrategy, ParamEncoder, ColumnDecoder} import minisql.idiom.Idiom import minisql.util.Messages.fail +import scala.util.Try import scala.reflect.ClassTag /** * No extra class defined */ -opaque type Row = IArray[Any] *: EmptyTuple +opaque type Row = IArray[Any] *: EmptyTuple +opaque type ResultSet = Iterable[Row] +opaque type Statement = Map[Int, Any] extension (r: Row) { @@ -27,9 +30,42 @@ extension (r: Row) { } } -trait MirrorCodecs[I <: Idiom, N <: NamingStrategy] { - this: MirrorContext[I, N] => - - given byteEncoder: Encoder[Byte] - +type Encoder[E] = ParamEncoder[E] { + type Stmt = Statement } + +private def encoder[V]: Encoder[V] = new ParamEncoder[V] { + + type Stmt = Map[Int, Any] + + def setParam(s: Stmt, idx: Int, v: V): Stmt = { + s + (idx -> v) + } +} + +given Encoder[Long] = encoder[Long] + +type Decoder[A] = ColumnDecoder[A] { + type DBRow = Row +} + +private def apply[X](conv: Any => Option[X]): Decoder[X] = + new ColumnDecoder[X] { + type DBRow = Row + def decode(row: Row, idx: Int): Try[X] = { + row._1 + .lift(idx) + .flatMap { x => + conv(x) + } + .toRight(new Exception(s"Cannot convert value at ${idx}")) + .toTry + } + } + +given Decoder[Long] = apply(x => + x match { + case l: Long => Some(l) + case _ => None + } +) diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala index 7b173db..3bcdfc5 100644 --- a/src/main/scala/minisql/parsing/InfixParsing.scala +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -1,7 +1,6 @@ package minisql.parsing import minisql.ast -import minisql.dsl.* import scala.quoted.* private[parsing] def infixParsing( diff --git a/src/main/scala/minisql/parsing/PropertyParsing.scala b/src/main/scala/minisql/parsing/PropertyParsing.scala index 292f281..0e6946e 100644 --- a/src/main/scala/minisql/parsing/PropertyParsing.scala +++ b/src/main/scala/minisql/parsing/PropertyParsing.scala @@ -1,7 +1,6 @@ package minisql.parsing import minisql.ast -import minisql.dsl.* import scala.quoted._ private[parsing] def propertyParsing( diff --git a/src/test/scala/minisql/parsing/QuotedSuite.scala b/src/test/scala/minisql/parsing/QuotedSuite.scala index 2bd02eb..d2f8981 100644 --- a/src/test/scala/minisql/parsing/QuotedSuite.scala +++ b/src/test/scala/minisql/parsing/QuotedSuite.scala @@ -1,35 +1,22 @@ -package minisql +package minisql.parsing +import minisql.* import minisql.ast.* +import minisql.idiom.* +import minisql.NamingStrategy +import minisql.MirrorContext +import minisql.MirrorIdiom +import minisql.context.mirror.{*, given} class QuotedSuite extends munit.FunSuite { - private inline def testQuoted(label: String)( - inline x: Quoted, - expect: Ast - ) = test(label) { - assertEquals(compileTimeAst(x), Some(expect.toString())) - } + val ctx = new MirrorContext(MirrorIdiom, SnakeCase) case class Foo(id: Long) - inline def Foos = query[Foo]("foo") - val entityFoo = Entity("foo", Nil) - val idx = Ident("x") - - testQuoted("EntityQuery")(Foos, entityFoo) - - testQuoted("Query/filter")( - Foos.filter(x => x.id > 0), - Filter( - entityFoo, - idx, - BinaryOperation(Property(idx, "id"), NumericOperator.>, Constant(0)) - ) - ) - - testQuoted("Query/map")( - Foos.map(x => x.id), - Map(entityFoo, idx, Property(idx, "id")) - ) + test("SimpleQuery") { + val o = ctx.io(query[Foo]("foo").filter(_.id > 0)) + println("============" + o) + o + } } From 63a9a0cad38700656bcd1164ae5e18a9b3652738 Mon Sep 17 00:00:00 2001 From: jilen Date: Tue, 17 Jun 2025 19:52:30 +0800 Subject: [PATCH 10/26] Allow both decoder and encoder --- src/main/scala/minisql/Quoted.scala | 3 +++ src/main/scala/minisql/context/Context.scala | 19 +++++++++++++++++-- .../scala/minisql/idiom/MirrorIdiom.scala | 2 +- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index da7008e..3ffb886 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -18,6 +18,7 @@ opaque type EntityQuery[E] <: Query[E] = Query[E] object EntityQuery { extension [E](inline e: EntityQuery[E]) { + inline def map[E1](inline f: E => E1): EntityQuery[E1] = { transform(e)(f)(Map.apply) } @@ -25,6 +26,7 @@ object EntityQuery { inline def filter(inline f: E => Boolean): EntityQuery[E] = { transform(e)(f)(Filter.apply) } + } } @@ -90,6 +92,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( .getOrElse(report.errorAndAbort(s"NamingStrategy not known at compile")) val stmt = idiom.translate(ast)(using naming) + report.info(s"Static Query: ${stmt}") Expr(stmt._2) case None => report.info("Dynamic Query") diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index de05202..6f6bea5 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -15,6 +15,14 @@ trait RowExtract[A, Row] { object RowExtract { + private[context] def single[Row, E]( + decoder: ColumnDecoder.Aux[Row, E] + ): RowExtract[E, Row] = new RowExtract[E, Row] { + def extract(row: Row): Try[E] = { + decoder.decode(row, 0) + } + } + private def extractorImpl[A, Row]( decoders: IArray[Any], m: Mirror.ProductOf[A] @@ -83,14 +91,21 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => inline def io[E]( inline q: minisql.Query[E] - )(using r: RowExtract[E, DBRow]): DBIO[IArray[E]] = { + ): DBIO[IArray[E]] = { + + val extractor = summonFrom { + case e: RowExtract[E, DBRow] => e + case e: ColumnDecoder.Aux[DBRow, E] => + RowExtract.single(e) + } + val lifts = q.liftMap val stmt = minisql.compile(q, idiom, naming) val (sql, params) = stmt.expand(lifts) ( sql = sql, params = params.map(_.value.get.asInstanceOf), - mapper = (rows) => rows.traverse(r.extract) + mapper = (rows) => rows.traverse(extractor.extract) ) } diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala index 88aab8c..b325b7d 100644 --- a/src/main/scala/minisql/idiom/MirrorIdiom.scala +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -304,7 +304,7 @@ trait MirrorIdiomBase extends Idiom { Tokenizer[OnConflict.Target] { case OnConflict.NoTarget => stmt"" case OnConflict.Properties(props) => - val listTokens = listTokenizer(astTokenizer).token(props) + val listTokens = listTokenizer(using astTokenizer).token(props) stmt"(${listTokens})" } From 1bc6baad688a1308b0a647ea3b73391d79df1d75 Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 18 Jun 2025 16:59:06 +0800 Subject: [PATCH 11/26] add sql idiom --- src/main/scala/minisql/MirrorSqlDialect.scala | 52 ++ .../minisql/context/sql/ConcatSupport.scala | 17 + .../context/sql/OnConflictSupport.scala | 70 ++ .../context/sql/PositionalBindVariables.scala | 6 + .../sql/QuestionMarkBindVariables.scala | 6 + .../scala/minisql/context/sql/SqlIdiom.scala | 700 ++++++++++++++++++ .../scala/minisql/context/sql/SqlQuery.scala | 326 ++++++++ .../minisql/context/sql/VerifySqlQuery.scala | 122 +++ .../sql/norm/AddDropToNestedOrderBy.scala | 47 ++ .../context/sql/norm/ExpandDistinct.scala | 68 ++ .../minisql/context/sql/norm/ExpandJoin.scala | 49 ++ .../context/sql/norm/ExpandMappedInfix.scala | 12 + .../sql/norm/ExpandNestedQueries.scala | 147 ++++ .../sql/norm/FlattenGroupByAggregation.scala | 58 ++ .../context/sql/norm/SqlNormalize.scala | 53 ++ .../context/sql/norm/nested/Elements.scala | 29 + .../sql/norm/nested/ExpandSelect.scala | 262 +++++++ .../norm/nested/FindUnexpressedInfixes.scala | 83 +++ .../scala/minisql/norm/FreeVariables.scala | 120 +++ 19 files changed, 2227 insertions(+) create mode 100644 src/main/scala/minisql/MirrorSqlDialect.scala create mode 100644 src/main/scala/minisql/context/sql/ConcatSupport.scala create mode 100644 src/main/scala/minisql/context/sql/OnConflictSupport.scala create mode 100644 src/main/scala/minisql/context/sql/PositionalBindVariables.scala create mode 100644 src/main/scala/minisql/context/sql/QuestionMarkBindVariables.scala create mode 100644 src/main/scala/minisql/context/sql/SqlIdiom.scala create mode 100644 src/main/scala/minisql/context/sql/SqlQuery.scala create mode 100644 src/main/scala/minisql/context/sql/VerifySqlQuery.scala create mode 100644 src/main/scala/minisql/context/sql/norm/AddDropToNestedOrderBy.scala create mode 100644 src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala create mode 100644 src/main/scala/minisql/context/sql/norm/ExpandJoin.scala create mode 100644 src/main/scala/minisql/context/sql/norm/ExpandMappedInfix.scala create mode 100644 src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala create mode 100644 src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala create mode 100644 src/main/scala/minisql/context/sql/norm/SqlNormalize.scala create mode 100644 src/main/scala/minisql/context/sql/norm/nested/Elements.scala create mode 100644 src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala create mode 100644 src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala create mode 100644 src/main/scala/minisql/norm/FreeVariables.scala diff --git a/src/main/scala/minisql/MirrorSqlDialect.scala b/src/main/scala/minisql/MirrorSqlDialect.scala new file mode 100644 index 0000000..563f770 --- /dev/null +++ b/src/main/scala/minisql/MirrorSqlDialect.scala @@ -0,0 +1,52 @@ +package minisql + +import minisql.context.{ + CanReturnClause, + CanReturnField, + CanReturnMultiField, + CannotReturn +} +import minisql.context.sql.idiom.SqlIdiom +import minisql.context.sql.idiom.QuestionMarkBindVariables +import minisql.context.sql.idiom.ConcatSupport + +trait MirrorSqlDialect + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnField + +trait MirrorSqlDialectWithReturnMulti + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnMultiField + +trait MirrorSqlDialectWithReturnClause + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnClause + +trait MirrorSqlDialectWithNoReturn + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CannotReturn + +object MirrorSqlDialect extends MirrorSqlDialect { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithReturnMulti extends MirrorSqlDialectWithReturnMulti { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithReturnClause + extends MirrorSqlDialectWithReturnClause { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithNoReturn extends MirrorSqlDialectWithNoReturn { + override def prepareForProbing(string: String) = string +} diff --git a/src/main/scala/minisql/context/sql/ConcatSupport.scala b/src/main/scala/minisql/context/sql/ConcatSupport.scala new file mode 100644 index 0000000..39a2e64 --- /dev/null +++ b/src/main/scala/minisql/context/sql/ConcatSupport.scala @@ -0,0 +1,17 @@ +package minisql.context.sql.idiom + +import minisql.util.Messages + +trait ConcatSupport { + this: SqlIdiom => + + override def concatFunction = "UNNEST" +} + +trait NoConcatSupport { + this: SqlIdiom => + + override def concatFunction = Messages.fail( + s"`concatMap` not supported by ${this.getClass.getSimpleName}" + ) +} diff --git a/src/main/scala/minisql/context/sql/OnConflictSupport.scala b/src/main/scala/minisql/context/sql/OnConflictSupport.scala new file mode 100644 index 0000000..940d5bf --- /dev/null +++ b/src/main/scala/minisql/context/sql/OnConflictSupport.scala @@ -0,0 +1,70 @@ +package minisql.context.sql.idiom + +import minisql.ast._ +import minisql.idiom.StatementInterpolator._ +import minisql.idiom.Token +import minisql.NamingStrategy +import minisql.util.Messages.fail + +trait OnConflictSupport { + self: SqlIdiom => + + implicit def conflictTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[OnConflict] = { + + val customEntityTokenizer = Tokenizer[Entity] { + case Entity.Opinionated(name, _, renameable) => + stmt"INTO ${renameable.fixedOr(name.token)(strategy.table(name).token)} AS t" + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](self.astTokenizer(_, strategy)) { + case _: OnConflict.Excluded => stmt"EXCLUDED" + case OnConflict.Existing(a) => stmt"${a.token}" + case a: Action => + self + .actionTokenizer(customEntityTokenizer)( + actionAstTokenizer, + strategy + ) + .token(a) + } + + import OnConflict._ + + def doUpdateStmt(i: Token, t: Token, u: Update) = { + val assignments = u.assignments + .map(a => + stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(customAstTokenizer)}" + ) + .mkStmt() + + stmt"$i ON CONFLICT $t DO UPDATE SET $assignments" + } + + def doNothingStmt(i: Ast, t: Token) = + stmt"${i.token} ON CONFLICT $t DO NOTHING" + + implicit val conflictTargetPropsTokenizer: Tokenizer[Properties] = + Tokenizer[Properties] { + case OnConflict.Properties(props) => + stmt"(${props.map(n => n.renameable.fixedOr(n.name)(strategy.column(n.name))).mkStmt(",")})" + } + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[OnConflict] { + case OnConflict(_, NoTarget, _: Update) => + fail("'DO UPDATE' statement requires explicit conflict target") + case OnConflict(i, p: Properties, u: Update) => + doUpdateStmt(i.token, p.token, u) + + case OnConflict(i, NoTarget, Ignore) => + stmt"${astTokenizer.token(i)} ON CONFLICT DO NOTHING" + case OnConflict(i, p: Properties, Ignore) => doNothingStmt(i, p.token) + } + + tokenizer(customAstTokenizer) + } +} diff --git a/src/main/scala/minisql/context/sql/PositionalBindVariables.scala b/src/main/scala/minisql/context/sql/PositionalBindVariables.scala new file mode 100644 index 0000000..fcd42ea --- /dev/null +++ b/src/main/scala/minisql/context/sql/PositionalBindVariables.scala @@ -0,0 +1,6 @@ +package minisql.context.sql.idiom + +trait PositionalBindVariables { self: SqlIdiom => + + override def liftingPlaceholder(index: Int): String = s"$$${index + 1}" +} diff --git a/src/main/scala/minisql/context/sql/QuestionMarkBindVariables.scala b/src/main/scala/minisql/context/sql/QuestionMarkBindVariables.scala new file mode 100644 index 0000000..f7ccf27 --- /dev/null +++ b/src/main/scala/minisql/context/sql/QuestionMarkBindVariables.scala @@ -0,0 +1,6 @@ +package minisql.context.sql.idiom + +trait QuestionMarkBindVariables { self: SqlIdiom => + + override def liftingPlaceholder(index: Int): String = s"?" +} diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala new file mode 100644 index 0000000..c593153 --- /dev/null +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -0,0 +1,700 @@ +package minisql.context.sql.idiom + +import minisql.ast._ +import minisql.ast.BooleanOperator._ +import minisql.ast.Lift +import minisql.context.sql._ +import minisql.context.sql.norm._ +import minisql.idiom._ +import minisql.idiom.StatementInterpolator._ +import minisql.NamingStrategy +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.Hidden +import minisql.context.{ReturningCapability, ReturningClauseSupported} +import minisql.util.Interleave +import minisql.util.Messages.{fail, trace} +import minisql.idiom.Token +import minisql.norm.EqualityBehavior +import minisql.norm.ConcatBehavior +import minisql.norm.ConcatBehavior.AnsiConcat +import minisql.norm.EqualityBehavior.AnsiEquality +import minisql.norm.ExpandReturning + +trait SqlIdiom extends Idiom { + + override def prepareForProbing(string: String): String + + protected def concatBehavior: ConcatBehavior = AnsiConcat + protected def equalityBehavior: EqualityBehavior = AnsiEquality + + protected def actionAlias: Option[Ident] = None + + override def format(queryString: String): String = queryString + + def querifyAst(ast: Ast) = SqlQuery(ast) + + private def doTranslate(ast: Ast, cached: Boolean)(implicit + naming: NamingStrategy + ): (Ast, Statement) = { + val normalizedAst = + SqlNormalize(ast, concatBehavior, equalityBehavior) + + implicit val tokernizer: Tokenizer[Ast] = defaultTokenizer + + val token = + normalizedAst match { + case q: Query => + val sql = querifyAst(q) + trace("sql")(sql) + VerifySqlQuery(sql).map(fail) + val expanded = new ExpandNestedQueries(naming)(sql, List()) + trace("expanded sql")(expanded) + val tokenized = expanded.token + trace("tokenized sql")(tokenized) + tokenized + case other => + other.token + } + + (normalizedAst, stmt"$token") + } + + override def translate( + ast: Ast + )(implicit naming: NamingStrategy): (Ast, Statement) = { + doTranslate(ast, false) + } + + def defaultTokenizer(implicit naming: NamingStrategy): Tokenizer[Ast] = + new Tokenizer[Ast] { + private val stableTokenizer = astTokenizer(this, naming) + + extension (v: Ast) { + def token = stableTokenizer.token(v) + } + + } + + def astTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Ast] = + Tokenizer[Ast] { + case a: Query => SqlQuery(a).token + case a: Operation => a.token + case a: Infix => a.token + case a: Action => a.token + case a: Ident => a.token + case a: ExternalIdent => a.token + case a: Property => a.token + case a: Value => a.token + case a: If => a.token + case a: Lift => a.token + case a: Assignment => a.token + case a: OptionOperation => a.token + case a @ ( + _: Function | _: FunctionApply | _: Dynamic | _: OptionOperation | + _: Block | _: Val | _: Ordering | _: IterableOperation | + _: OnConflict.Excluded | _: OnConflict.Existing + ) => + fail(s"Malformed or unsupported construct: $a.") + } + + implicit def ifTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[If] = Tokenizer[If] { + case ast: If => + def flatten(ast: Ast): (List[(Ast, Ast)], Ast) = + ast match { + case If(cond, a, b) => + val (l, e) = flatten(b) + ((cond, a) +: l, e) + case other => + (List(), other) + } + + val (l, e) = flatten(ast) + val conditions = + for ((cond, body) <- l) yield { + stmt"WHEN ${cond.token} THEN ${body.token}" + } + stmt"CASE ${conditions.mkStmt(" ")} ELSE ${e.token} END" + } + + def concatFunction: String + + protected def tokenizeGroupBy( + values: Ast + )(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Token = + values.token + + protected class FlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ) { + import q._ + + def selectTokenizer = + select match { + case Nil => stmt"*" + case _ => select.token + } + + def distinctTokenizer = ( + distinct match { + case DistinctKind.Distinct => stmt"DISTINCT " + case DistinctKind.DistinctOn(props) => + stmt"DISTINCT ON (${props.token}) " + case DistinctKind.None => stmt"" + } + ) + + def withDistinct = stmt"$distinctTokenizer${selectTokenizer}" + + def withFrom = + from match { + case Nil => withDistinct + case head :: tail => + val t = tail.foldLeft(stmt"${head.token}") { + case (a, b: FlatJoinContext) => + stmt"$a ${(b: FromContext).token}" + case (a, b) => + stmt"$a, ${b.token}" + } + + stmt"$withDistinct FROM $t" + } + + def withWhere = + where match { + case None => withFrom + case Some(where) => stmt"$withFrom WHERE ${where.token}" + } + def withGroupBy = + groupBy match { + case None => withWhere + case Some(groupBy) => + stmt"$withWhere GROUP BY ${tokenizeGroupBy(groupBy)}" + } + def withOrderBy = + orderBy match { + case Nil => withGroupBy + case orderBy => stmt"$withGroupBy ${tokenOrderBy(orderBy)}" + } + def withLimitOffset = limitOffsetToken(withOrderBy).token((limit, offset)) + + def apply = stmt"SELECT $withLimitOffset" + } + + implicit def sqlQueryTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { + case q: FlattenSqlQuery => + new FlattenSqlQueryTokenizerHelper(q).apply + case SetOperationSqlQuery(a, op, b) => + stmt"(${a.token}) ${op.token} (${b.token})" + case UnaryOperationSqlQuery(op, q) => + stmt"SELECT ${op.token} (${q.token})" + } + + protected def tokenizeColumn( + strategy: NamingStrategy, + column: String, + renameable: Renameable + ) = + renameable match { + case Fixed => column + case _ => strategy.column(column) + } + + protected def tokenizeTable( + strategy: NamingStrategy, + table: String, + renameable: Renameable + ) = + renameable match { + case Fixed => table + case _ => strategy.table(table) + } + + protected def tokenizeAlias(strategy: NamingStrategy, table: String) = + strategy.default(table) + + implicit def selectValueTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[SelectValue] = { + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[SelectValue] { + case SelectValue(ast, Some(alias), false) => { + stmt"${ast.token} AS ${alias.token}" + } + case SelectValue(ast, Some(alias), true) => + stmt"${concatFunction.token}(${ast.token}) AS ${alias.token}" + case selectValue => + val value = + selectValue match { + case SelectValue(Ident("?"), _, _) => "?".token + case SelectValue(Ident(name), _, _) => + stmt"${strategy.default(name).token}.*" + case SelectValue(ast, _, _) => ast.token + } + selectValue.concat match { + case true => stmt"${concatFunction.token}(${value.token})" + case false => value + } + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case Aggregation(op, Ident(_) | Tuple(_)) => stmt"${op.token}(*)" + case Aggregation(op, Distinct(ast)) => + stmt"${op.token}(DISTINCT ${ast.token})" + case ast @ Aggregation(op, _: Query) => scopedTokenizer(ast) + case Aggregation(op, ast) => stmt"${op.token}(${ast.token})" + } + + tokenizer(customAstTokenizer) + } + + implicit def operationTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Operation] = Tokenizer[Operation] { + case UnaryOperation(op, ast) => stmt"${op.token} (${ast.token})" + case BinaryOperation(a, EqualityOperator.`==`, NullValue) => + stmt"${scopedTokenizer(a)} IS NULL" + case BinaryOperation(NullValue, EqualityOperator.`==`, b) => + stmt"${scopedTokenizer(b)} IS NULL" + case BinaryOperation(a, EqualityOperator.`!=`, NullValue) => + stmt"${scopedTokenizer(a)} IS NOT NULL" + case BinaryOperation(NullValue, EqualityOperator.`!=`, b) => + stmt"${scopedTokenizer(b)} IS NOT NULL" + case BinaryOperation(a, StringOperator.`startsWith`, b) => + stmt"${scopedTokenizer(a)} LIKE (${(BinaryOperation(b, StringOperator.`concat`, Constant("%")): Ast).token})" + case BinaryOperation(a, op @ StringOperator.`split`, b) => + stmt"${op.token}(${scopedTokenizer(a)}, ${scopedTokenizer(b)})" + case BinaryOperation(a, op @ SetOperator.`contains`, b) => + SetContainsToken(scopedTokenizer(b), op.token, a.token) + case BinaryOperation(a, op @ `&&`, b) => + (a, b) match { + case (BinaryOperation(_, `||`, _), BinaryOperation(_, `||`, _)) => + stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}" + case (BinaryOperation(_, `||`, _), _) => + stmt"${scopedTokenizer(a)} ${op.token} ${b.token}" + case (_, BinaryOperation(_, `||`, _)) => + stmt"${a.token} ${op.token} ${scopedTokenizer(b)}" + case _ => stmt"${a.token} ${op.token} ${b.token}" + } + case BinaryOperation(a, op @ `||`, b) => + stmt"${a.token} ${op.token} ${b.token}" + case BinaryOperation(a, op, b) => + stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}" + case e: FunctionApply => fail(s"Can't translate the ast to sql: '$e'") + } + + implicit def optionOperationTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] { + case OptionIsEmpty(ast) => stmt"${ast.token} IS NULL" + case OptionNonEmpty(ast) => stmt"${ast.token} IS NOT NULL" + case OptionIsDefined(ast) => stmt"${ast.token} IS NOT NULL" + case other => fail(s"Malformed or unsupported construct: $other.") + } + + implicit val setOperationTokenizer: Tokenizer[SetOperation] = + Tokenizer[SetOperation] { + case UnionOperation => stmt"UNION" + case UnionAllOperation => stmt"UNION ALL" + } + + protected def limitOffsetToken( + query: Statement + )(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = + Tokenizer[(Option[Ast], Option[Ast])] { + case (None, None) => query + case (Some(limit), None) => stmt"$query LIMIT ${limit.token}" + case (Some(limit), Some(offset)) => + stmt"$query LIMIT ${limit.token} OFFSET ${offset.token}" + case (None, Some(offset)) => stmt"$query OFFSET ${offset.token}" + } + + protected def tokenOrderBy( + criterias: List[OrderByCriteria] + )(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = + stmt"ORDER BY ${criterias.token}" + + implicit def sourceTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[FromContext] = Tokenizer[FromContext] { + case TableContext(name, alias) => + stmt"${name.token} ${tokenizeAlias(strategy, alias).token}" + case QueryContext(query, alias) => + stmt"(${query.token}) AS ${tokenizeAlias(strategy, alias).token}" + case InfixContext(infix, alias) if infix.noParen => + stmt"${(infix: Ast).token} AS ${strategy.default(alias).token}" + case InfixContext(infix, alias) => + stmt"(${(infix: Ast).token}) AS ${strategy.default(alias).token}" + case JoinContext(t, a, b, on) => + stmt"${a.token} ${t.token} ${b.token} ON ${on.token}" + case FlatJoinContext(t, a, on) => stmt"${t.token} ${a.token} ON ${on.token}" + } + + implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { + case InnerJoin => stmt"INNER JOIN" + case LeftJoin => stmt"LEFT JOIN" + case RightJoin => stmt"RIGHT JOIN" + case FullJoin => stmt"FULL JOIN" + } + + implicit def orderByCriteriaTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] { + case OrderByCriteria(ast, Asc) => stmt"${scopedTokenizer(ast)} ASC" + case OrderByCriteria(ast, Desc) => stmt"${scopedTokenizer(ast)} DESC" + case OrderByCriteria(ast, AscNullsFirst) => + stmt"${scopedTokenizer(ast)} ASC NULLS FIRST" + case OrderByCriteria(ast, DescNullsFirst) => + stmt"${scopedTokenizer(ast)} DESC NULLS FIRST" + case OrderByCriteria(ast, AscNullsLast) => + stmt"${scopedTokenizer(ast)} ASC NULLS LAST" + case OrderByCriteria(ast, DescNullsLast) => + stmt"${scopedTokenizer(ast)} DESC NULLS LAST" + } + + implicit val unaryOperatorTokenizer: Tokenizer[UnaryOperator] = + Tokenizer[UnaryOperator] { + case NumericOperator.`-` => stmt"-" + case BooleanOperator.`!` => stmt"NOT" + case StringOperator.`toUpperCase` => stmt"UPPER" + case StringOperator.`toLowerCase` => stmt"LOWER" + case StringOperator.`toLong` => stmt"" // cast is implicit + case StringOperator.`toInt` => stmt"" // cast is implicit + case SetOperator.`isEmpty` => stmt"NOT EXISTS" + case SetOperator.`nonEmpty` => stmt"EXISTS" + } + + implicit val aggregationOperatorTokenizer: Tokenizer[AggregationOperator] = + Tokenizer[AggregationOperator] { + case AggregationOperator.`min` => stmt"MIN" + case AggregationOperator.`max` => stmt"MAX" + case AggregationOperator.`avg` => stmt"AVG" + case AggregationOperator.`sum` => stmt"SUM" + case AggregationOperator.`size` => stmt"COUNT" + } + + implicit val binaryOperatorTokenizer: Tokenizer[BinaryOperator] = + Tokenizer[BinaryOperator] { + case EqualityOperator.`==` => stmt"=" + case EqualityOperator.`!=` => stmt"<>" + case BooleanOperator.`&&` => stmt"AND" + case BooleanOperator.`||` => stmt"OR" + case StringOperator.`concat` => stmt"||" + case StringOperator.`startsWith` => + fail("bug: this code should be unreachable") + case StringOperator.`split` => stmt"SPLIT" + case NumericOperator.`-` => stmt"-" + case NumericOperator.`+` => stmt"+" + case NumericOperator.`*` => stmt"*" + case NumericOperator.`>` => stmt">" + case NumericOperator.`>=` => stmt">=" + case NumericOperator.`<` => stmt"<" + case NumericOperator.`<=` => stmt"<=" + case NumericOperator.`/` => stmt"/" + case NumericOperator.`%` => stmt"%" + case SetOperator.`contains` => stmt"IN" + } + + implicit def propertyTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Property] = { + + def unnest(ast: Ast): (Ast, List[String]) = + ast match { + case Property.Opinionated(a, _, _, Hidden) => + unnest(a) match { + case (a, nestedName) => (a, nestedName) + } + // Append the property name. This includes tuple indexes. + case Property(a, name) => + unnest(a) match { + case (ast, nestedName) => + (ast, nestedName :+ name) + } + case a => (a, Nil) + } + + def tokenizePrefixedProperty( + name: String, + prefix: List[String], + strategy: NamingStrategy, + renameable: Renameable + ) = + renameable.fixedOr( + (prefix.mkString + name).token + )(tokenizeColumn(strategy, prefix.mkString + name, renameable).token) + + Tokenizer[Property] { + case Property.Opinionated( + ast, + name, + renameable, + _ /* Top level property cannot be invisible */ + ) => + // When we have things like Embedded tables, properties inside of one another needs to be un-nested. + // E.g. in `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` the inner + // property needs to be unwrapped and the result of this should only be `realTable.realPropertyAlias` + // as opposed to `realTable.embeddedTableAlias.realPropertyAlias`. + unnest(ast) match { + // When using ExternalIdent such as .returning(eid => eid.idColumn) clauses drop the 'eid' since SQL + // returning clauses have no alias for the original table. I.e. INSERT [...] RETURNING idColumn there's no + // alias you can assign to the INSERT [...] clause that can be used as a prefix to 'idColumn'. + // In this case, `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` + // should just be `realPropertyAlias` as opposed to `realTable.realPropertyAlias`. + // The exception to this is when a Query inside of a RETURNING clause is used. In that case, assume + // that there is an alias for the inserted table (i.e. `INSERT ... as theAlias values ... RETURNING`) + // and the instances of ExternalIdent use it. + case (ExternalIdent(_), prefix) => + stmt"${actionAlias + .map(alias => stmt"${scopedTokenizer(alias)}.") + .getOrElse(stmt"")}${tokenizePrefixedProperty(name, prefix, strategy, renameable)}" + + // In the rare case that the Ident is invisible, do not show it. See the Ident documentation for more info. + case (Ident.Opinionated(_, Hidden), prefix) => + stmt"${tokenizePrefixedProperty(name, prefix, strategy, renameable)}" + + // The normal case where `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` + // becomes `realTable.realPropertyAlias`. + case (ast, prefix) => + stmt"${scopedTokenizer(ast)}.${tokenizePrefixedProperty(name, prefix, strategy, renameable)}" + } + } + } + + implicit def valueTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Value] = Tokenizer[Value] { + case Constant(v: String) => stmt"'${v.token}'" + case Constant(()) => stmt"1" + case Constant(v) => stmt"${v.toString.token}" + case NullValue => stmt"null" + case Tuple(values) => stmt"${values.token}" + case CaseClass(values) => stmt"${values.map(_._2).token}" + } + + implicit def infixTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Infix] = Tokenizer[Infix] { + case Infix(parts, params, _, _) => + val pt = parts.map(_.token) + val pr = params.map(_.token) + Statement(Interleave(pt, pr)) + } + + implicit def identTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Ident] = + Tokenizer[Ident](e => strategy.default(e.name).token) + + implicit def externalIdentTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[ExternalIdent] = + Tokenizer[ExternalIdent](e => strategy.default(e.name).token) + + implicit def assignmentTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Assignment] = Tokenizer[Assignment] { + case Assignment(alias, prop, value) => + stmt"${prop.token} = ${scopedTokenizer(value)}" + } + + implicit def defaultAstTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Action] = { + val insertEntityTokenizer = Tokenizer[Entity] { + case Entity.Opinionated(name, _, renameable) => + stmt"INTO ${tokenizeTable(strategy, name, renameable).token}" + } + actionTokenizer(insertEntityTokenizer)(actionAstTokenizer, strategy) + } + + protected def actionAstTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ) = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case q: Query => astTokenizer.token(q) + case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") => + stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NULL" + case Property( + Property.Opinionated(_, name, renameable, _), + "isDefined" + ) => + stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL" + case Property(Property.Opinionated(_, name, renameable, _), "nonEmpty") => + stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL" + case Property.Opinionated(_, name, renameable, _) => + renameable.fixedOr(name.token)( + tokenizeColumn(strategy, name, renameable).token + ) + } + + def returnListTokenizer(implicit + tokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[List[Ast]] = { + val customAstTokenizer = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case sq: Query => + stmt"(${tokenizer.token(sq)})" + } + + Tokenizer[List[Ast]] { + case list => + list.mkStmt(", ")(customAstTokenizer) + } + } + + protected def actionTokenizer( + insertEntityTokenizer: Tokenizer[Entity] + )(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Action] = + Tokenizer[Action] { + + case Insert(entity: Entity, assignments) => + val table = insertEntityTokenizer.token(entity) + val columns = assignments.map(_.property.token) + val values = assignments.map(_.value) + stmt"INSERT $table${actionAlias.map(alias => stmt" AS ${alias.token}").getOrElse(stmt"")} (${columns + .mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})" + + case Update(table: Entity, assignments) => + stmt"UPDATE ${table.token}${actionAlias + .map(alias => stmt" AS ${alias.token}") + .getOrElse(stmt"")} SET ${assignments.token}" + + case Update(Filter(table: Entity, x, where), assignments) => + stmt"UPDATE ${table.token}${actionAlias + .map(alias => stmt" AS ${alias.token}") + .getOrElse(stmt"")} SET ${assignments.token} WHERE ${where.token}" + + case Delete(Filter(table: Entity, x, where)) => + stmt"DELETE FROM ${table.token} WHERE ${where.token}" + + case Delete(table: Entity) => + stmt"DELETE FROM ${table.token}" + + case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) => + idiomReturningCapability match { + // If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as + // a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5) + case ReturningClauseSupported + if (CollectAst.byType[Entity](prop).nonEmpty) => + SqlIdiom.withActionAlias(this, r) + case ReturningClauseSupported => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer + .token(ExpandReturning(r)(this, strategy).map(_._1))}" + case other => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" + } + + case r @ ReturningAction(action, alias, prop) => + idiomReturningCapability match { + // If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as + // a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5) + case ReturningClauseSupported + if (CollectAst.byType[Entity](prop).nonEmpty) => + SqlIdiom.withActionAlias(this, r) + case ReturningClauseSupported => + stmt"${action.token} RETURNING ${returnListTokenizer.token( + ExpandReturning(r)(this, strategy).map(_._1) + )}" + case other => + stmt"${action.token}" + } + + case other => + fail(s"Action ast can't be translated to sql: '$other'") + } + + implicit def entityTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Entity] = Tokenizer[Entity] { + case Entity.Opinionated(name, _, renameable) => + tokenizeTable(strategy, name, renameable).token + } + + protected def scopedTokenizer(ast: Ast)(implicit tokenizer: Tokenizer[Ast]) = + ast match { + case _: Query => stmt"(${ast.token})" + case _: BinaryOperation => stmt"(${ast.token})" + case _: Tuple => stmt"(${ast.token})" + case _ => ast.token + } +} + +object SqlIdiom { + private[minisql] def copyIdiom( + parent: SqlIdiom, + newActionAlias: Option[Ident] + ) = + new SqlIdiom { + override protected def actionAlias: Option[Ident] = newActionAlias + override def prepareForProbing(string: String): String = + parent.prepareForProbing(string) + override def concatFunction: String = parent.concatFunction + override def liftingPlaceholder(index: Int): String = + parent.liftingPlaceholder(index) + override def idiomReturningCapability: ReturningCapability = + parent.idiomReturningCapability + } + + /** + * Construct a new instance of the specified idiom with `newActionAlias` + * variable specified so that actions (i.e. insert, and update) will be + * rendered with the specified alias. This is needed for RETURNING clauses + * that have queries inside. See #1509 for details. + */ + private[minisql] def withActionAlias( + parentIdiom: SqlIdiom, + query: ReturningAction + )(implicit strategy: NamingStrategy) = { + val idiom = copyIdiom(parentIdiom, Some(query.alias)) + import idiom._ + + implicit val stableTokenizer: Tokenizer[Ast] = idiom.astTokenizer( + new Tokenizer[Ast] { self => + extension (v: Ast) { + def token = astTokenizer(self, strategy).token(v) + } + }, + strategy + ) + + query match { + case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) => + stmt"INSERT INTO ${table.token} AS ${alias.name.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer + .token(ExpandReturning(r)(idiom, strategy).map(_._1))}" + case r @ ReturningAction(action, alias, prop) => + stmt"${action.token} RETURNING ${returnListTokenizer.token( + ExpandReturning(r)(idiom, strategy).map(_._1) + )}" + } + } +} diff --git a/src/main/scala/minisql/context/sql/SqlQuery.scala b/src/main/scala/minisql/context/sql/SqlQuery.scala new file mode 100644 index 0000000..06ec412 --- /dev/null +++ b/src/main/scala/minisql/context/sql/SqlQuery.scala @@ -0,0 +1,326 @@ +package minisql.context.sql + +import minisql.ast._ +import minisql.context.sql.norm.FlattenGroupByAggregation +import minisql.norm.BetaReduction +import minisql.util.Messages.fail +import minisql.{Literal, PseudoAst, NamingStrategy} + +case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering) + +sealed trait FromContext +case class TableContext(entity: Entity, alias: String) extends FromContext +case class QueryContext(query: SqlQuery, alias: String) extends FromContext +case class InfixContext(infix: Infix, alias: String) extends FromContext +case class JoinContext(t: JoinType, a: FromContext, b: FromContext, on: Ast) + extends FromContext +case class FlatJoinContext(t: JoinType, a: FromContext, on: Ast) + extends FromContext + +sealed trait SqlQuery { + override def toString = { + import minisql.MirrorSqlDialect._ + import minisql.idiom.StatementInterpolator.* + given Tokenizer[SqlQuery] = sqlQueryTokenizer(using + defaultTokenizer(using Literal), + Literal + ) + summon[Tokenizer[SqlQuery]].token(this).toString() + } +} + +sealed trait SetOperation +case object UnionOperation extends SetOperation +case object UnionAllOperation extends SetOperation + +sealed trait DistinctKind { def isDistinct: Boolean } +case object DistinctKind { + case object Distinct extends DistinctKind { val isDistinct: Boolean = true } + case class DistinctOn(props: List[Ast]) extends DistinctKind { + val isDistinct: Boolean = true + } + case object None extends DistinctKind { val isDistinct: Boolean = false } +} + +case class SetOperationSqlQuery( + a: SqlQuery, + op: SetOperation, + b: SqlQuery +) extends SqlQuery + +case class UnaryOperationSqlQuery( + op: UnaryOperator, + q: SqlQuery +) extends SqlQuery + +case class SelectValue( + ast: Ast, + alias: Option[String] = None, + concat: Boolean = false +) extends PseudoAst { + override def toString: String = + s"${ast.toString}${alias.map("->" + _).getOrElse("")}" +} + +case class FlattenSqlQuery( + from: List[FromContext] = List(), + where: Option[Ast] = None, + groupBy: Option[Ast] = None, + orderBy: List[OrderByCriteria] = Nil, + limit: Option[Ast] = None, + offset: Option[Ast] = None, + select: List[SelectValue], + distinct: DistinctKind = DistinctKind.None +) extends SqlQuery + +object TakeDropFlatten { + def unapply(q: Query): Option[(Query, Option[Ast], Option[Ast])] = q match { + case Take(q: FlatMap, n) => Some((q, Some(n), None)) + case Drop(q: FlatMap, n) => Some((q, None, Some(n))) + case _ => None + } +} + +object SqlQuery { + + def apply(query: Ast): SqlQuery = + query match { + case Union(a, b) => + SetOperationSqlQuery(apply(a), UnionOperation, apply(b)) + case UnionAll(a, b) => + SetOperationSqlQuery(apply(a), UnionAllOperation, apply(b)) + case UnaryOperation(op, q: Query) => UnaryOperationSqlQuery(op, apply(q)) + case _: Operation | _: Value => + FlattenSqlQuery(select = List(SelectValue(query))) + case Map(q, a, b) if a == b => apply(q) + case TakeDropFlatten(q, limit, offset) => + flatten(q, "x").copy(limit = limit, offset = offset) + case q: Query => flatten(q, "x") + case infix: Infix => flatten(infix, "x") + case other => + fail( + s"Query not properly normalized. Please open a bug report. Ast: '$other'" + ) + } + + private def flatten(query: Ast, alias: String): FlattenSqlQuery = { + val (sources, finalFlatMapBody) = flattenContexts(query) + flatten(sources, finalFlatMapBody, alias) + } + + private def flattenContexts(query: Ast): (List[FromContext], Ast) = + query match { + case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Query) => + val source = this.source(q, alias) + val (nestedContexts, finalFlatMapBody) = flattenContexts(p) + (source +: nestedContexts, finalFlatMapBody) + case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Infix) => + fail(s"Infix can't be use as a `flatMap` body. $query") + case other => + (List.empty, other) + } + + object NestedNest { + def unapply(q: Ast): Option[Ast] = + q match { + case _: Nested => recurse(q) + case _ => None + } + + private def recurse(q: Ast): Option[Ast] = + q match { + case Nested(qn) => recurse(qn) + case other => Some(other) + } + } + + private def flatten( + sources: List[FromContext], + finalFlatMapBody: Ast, + alias: String + ): FlattenSqlQuery = { + + def select(alias: String) = SelectValue(Ident(alias), None) :: Nil + + def base(q: Ast, alias: String) = { + def nest(ctx: FromContext) = + FlattenSqlQuery(from = sources :+ ctx, select = select(alias)) + q match { + case Map(_: GroupBy, _, _) => nest(source(q, alias)) + case NestedNest(q) => nest(QueryContext(apply(q), alias)) + case q: ConcatMap => nest(QueryContext(apply(q), alias)) + case Join(tpe, a, b, iA, iB, on) => + val ctx = source(q, alias) + def aliases(ctx: FromContext): List[String] = + ctx match { + case TableContext(_, alias) => alias :: Nil + case QueryContext(_, alias) => alias :: Nil + case InfixContext(_, alias) => alias :: Nil + case JoinContext(_, a, b, _) => aliases(a) ::: aliases(b) + case FlatJoinContext(_, a, _) => aliases(a) + } + FlattenSqlQuery( + from = ctx :: Nil, + select = aliases(ctx).map(a => SelectValue(Ident(a), None)) + ) + case q @ (_: Map | _: Filter | _: Entity) => flatten(sources, q, alias) + case q if (sources == Nil) => flatten(sources, q, alias) + case other => nest(source(q, alias)) + } + } + + finalFlatMapBody match { + + case ConcatMap(q, Ident(alias), p) => + FlattenSqlQuery( + from = source(q, alias) :: Nil, + select = selectValues(p).map(_.copy(concat = true)) + ) + + case Map(GroupBy(q, x @ Ident(alias), g), a, p) => + val b = base(q, alias) + val select = BetaReduction(p, a -> Tuple(List(g, x))) + val flattenSelect = FlattenGroupByAggregation(x)(select) + b.copy(groupBy = Some(g), select = this.selectValues(flattenSelect)) + + case GroupBy(q, Ident(alias), p) => + fail("A `groupBy` clause must be followed by `map`.") + + case Map(q, Ident(alias), p) => + val b = base(q, alias) + val agg = b.select.collect { + case s @ SelectValue(_: Aggregation, _, _) => s + } + if (!b.distinct.isDistinct && agg.isEmpty) + b.copy(select = selectValues(p)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + select = selectValues(p) + ) + + case Filter(q, Ident(alias), p) => + val b = base(q, alias) + if (b.where.isEmpty) + b.copy(where = Some(p)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + where = Some(p), + select = select(alias) + ) + + case SortBy(q, Ident(alias), p, o) => + val b = base(q, alias) + val criterias = orderByCriterias(p, o) + if (b.orderBy.isEmpty) + b.copy(orderBy = criterias) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + orderBy = criterias, + select = select(alias) + ) + + case Aggregation(op, q: Query) => + val b = flatten(q, alias) + b.select match { + case head :: Nil if !b.distinct.isDistinct => + b.copy(select = List(head.copy(ast = Aggregation(op, head.ast)))) + case other => + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + select = List(SelectValue(Aggregation(op, Ident("*")))) + ) + } + + case Take(q, n) => + val b = base(q, alias) + if (b.limit.isEmpty) + b.copy(limit = Some(n)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + limit = Some(n), + select = select(alias) + ) + + case Drop(q, n) => + val b = base(q, alias) + if (b.offset.isEmpty && b.limit.isEmpty) + b.copy(offset = Some(n)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + offset = Some(n), + select = select(alias) + ) + + case Distinct(q: Query) => + val b = base(q, alias) + b.copy(distinct = DistinctKind.Distinct) + + case DistinctOn(q, Ident(alias), fields) => + val distinctList = + fields match { + case Tuple(values) => values + case other => List(other) + } + + q match { + // Ideally we don't need to make an extra sub-query for every single case of + // distinct-on but it only works when the parent AST is an entity. That's because DistinctOn + // selects from an alias of an outer clause. For example, query[Person].map(p => Name(p.firstName, p.lastName)).distinctOn(_.name) + // (Let's say Person(firstName, lastName, age), Name(first, last)) will turn into + // SELECT DISTINCT ON (p.name), p.firstName AS first, p.lastName AS last, p.age FROM Person + // This doesn't work beause `name` in `p.name` doesn't exist yet. Therefore we have to nest this in a subquery: + // SELECT DISTINCT ON (p.name) FROM (SELECT p.firstName AS first, p.lastName AS last, p.age FROM Person p) AS p + // The only exception to this is if we are directly selecting from an entity: + // query[Person].distinctOn(_.firstName) which should be fine: SELECT (x.firstName), x.firstName, x.lastName, a.age FROM Person x + // since all the fields inside the (...) of the DISTINCT ON must be contained in the entity. + case _: Entity => + val b = base(q, alias) + b.copy(distinct = DistinctKind.DistinctOn(distinctList)) + case _ => + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + select = select(alias), + distinct = DistinctKind.DistinctOn(distinctList) + ) + } + + case other => + FlattenSqlQuery( + from = sources :+ source(other, alias), + select = select(alias) + ) + } + } + + private def selectValues(ast: Ast) = + ast match { + case Tuple(values) => values.map(SelectValue(_)) + case other => SelectValue(ast) :: Nil + } + + private def source(ast: Ast, alias: String): FromContext = + ast match { + case entity: Entity => TableContext(entity, alias) + case infix: Infix => InfixContext(infix, alias) + case Join(t, a, b, ia, ib, on) => + JoinContext(t, source(a, ia.name), source(b, ib.name), on) + case FlatJoin(t, a, ia, on) => FlatJoinContext(t, source(a, ia.name), on) + case Nested(q) => QueryContext(apply(q), alias) + case other => QueryContext(apply(other), alias) + } + + private def orderByCriterias(ast: Ast, ordering: Ast): List[OrderByCriteria] = + (ast, ordering) match { + case (Tuple(properties), ord: PropertyOrdering) => + properties.flatMap(orderByCriterias(_, ord)) + case (Tuple(properties), TupleOrdering(ord)) => + properties.zip(ord).flatMap { case (a, o) => orderByCriterias(a, o) } + case (a, o: PropertyOrdering) => List(OrderByCriteria(a, o)) + case other => fail(s"Invalid order by criteria $ast") + } +} diff --git a/src/main/scala/minisql/context/sql/VerifySqlQuery.scala b/src/main/scala/minisql/context/sql/VerifySqlQuery.scala new file mode 100644 index 0000000..82a3d59 --- /dev/null +++ b/src/main/scala/minisql/context/sql/VerifySqlQuery.scala @@ -0,0 +1,122 @@ +package minisql.context.sql.idiom + +import minisql.ast._ +import minisql.context.sql._ +import minisql.norm.FreeVariables + +case class Error(free: List[Ident], ast: Ast) +case class InvalidSqlQuery(errors: List[Error]) { + override def toString = + s"The monad composition can't be expressed using applicative joins. " + + errors + .map(error => + s"Faulty expression: '${error.ast}'. Free variables: '${error.free}'." + ) + .mkString(", ") +} + +object VerifySqlQuery { + + def apply(query: SqlQuery): Option[String] = + verify(query).map(_.toString) + + private def verify(query: SqlQuery): Option[InvalidSqlQuery] = + query match { + case q: FlattenSqlQuery => verify(q) + case SetOperationSqlQuery(a, op, b) => verify(a).orElse(verify(b)) + case UnaryOperationSqlQuery(op, q) => verify(q) + } + + private def verifyFlatJoins(q: FlattenSqlQuery) = { + + def loop(l: List[FromContext], available: Set[String]): Set[String] = + l.foldLeft(available) { + case (av, TableContext(_, alias)) => Set(alias) + case (av, InfixContext(_, alias)) => Set(alias) + case (av, QueryContext(_, alias)) => Set(alias) + case (av, JoinContext(_, a, b, on)) => + av ++ loop(a :: Nil, av) ++ loop(b :: Nil, av) + case (av, FlatJoinContext(_, a, on)) => + val nav = av ++ loop(a :: Nil, av) + val free = FreeVariables(on).map(_.name) + val invalid = free -- nav + require( + invalid.isEmpty, + s"Found an `ON` table reference of a table that is not available: $invalid. " + + "The `ON` condition can only use tables defined through explicit joins." + ) + nav + } + loop(q.from, Set()) + } + + private def verify(query: FlattenSqlQuery): Option[InvalidSqlQuery] = { + + verifyFlatJoins(query) + + val aliases = + query.from.flatMap(this.aliases).map(Ident(_)) :+ Ident("*") :+ Ident("?") + + def verifyAst(ast: Ast) = { + val freeVariables = + (FreeVariables(ast) -- aliases).toList + val freeIdents = + (CollectAst(ast) { + case ast: Property => None + case Aggregation(_, _: Ident) => None + case ast: Ident => Some(ast) + }).flatten + (freeVariables ++ freeIdents) match { + case Nil => None + case free => Some(Error(free, ast)) + } + } + + // Recursively expand children until values are fully flattened. Identities in all these should + // be skipped during verification. + def expandSelect(sv: SelectValue): List[SelectValue] = + sv.ast match { + case Tuple(values) => + values.map(v => SelectValue(v)).flatMap(expandSelect(_)) + case CaseClass(values) => + values.map(v => SelectValue(v._2)).flatMap(expandSelect(_)) + case _ => List(sv) + } + + val freeVariableErrors: List[Error] = + query.where.flatMap(verifyAst).toList ++ + query.orderBy.map(_.ast).flatMap(verifyAst) ++ + query.limit.flatMap(verifyAst) ++ + query.select + .flatMap( + expandSelect(_) + ) // Expand tuple select clauses so their top-level identities are skipped + .map(_.ast) + .filterNot(_.isInstanceOf[Ident]) + .flatMap(verifyAst) ++ + query.from.flatMap { + case j: JoinContext => verifyAst(j.on) + case j: FlatJoinContext => verifyAst(j.on) + case _ => Nil + } + + val nestedErrors = + query.from.collect { + case QueryContext(query, alias) => verify(query).map(_.errors) + }.flatten.flatten + + (freeVariableErrors ++ nestedErrors) match { + case Nil => None + case errors => Some(InvalidSqlQuery(errors)) + } + } + + private def aliases(s: FromContext): List[String] = + s match { + case s: TableContext => List(s.alias) + case s: QueryContext => List(s.alias) + case s: InfixContext => List(s.alias) + case s: JoinContext => aliases(s.a) ++ aliases(s.b) + case s: FlatJoinContext => aliases(s.a) + } +} diff --git a/src/main/scala/minisql/context/sql/norm/AddDropToNestedOrderBy.scala b/src/main/scala/minisql/context/sql/norm/AddDropToNestedOrderBy.scala new file mode 100644 index 0000000..8fb0d20 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/AddDropToNestedOrderBy.scala @@ -0,0 +1,47 @@ +package minisql.context.sql.norm + +import minisql.ast.Constant +import minisql.context.sql.{FlattenSqlQuery, SqlQuery, _} + +/** + * In SQL Server, `Order By` clauses are only allowed in sub-queries if the + * sub-query has a `TOP` or `OFFSET` modifier. Otherwise an exception will be + * thrown. This transformation adds a 'dummy' `OFFSET 0` in this scenario (if an + * `Offset` clause does not exist already). + */ +object AddDropToNestedOrderBy { + + def applyInner(q: SqlQuery): SqlQuery = + q match { + case q: FlattenSqlQuery => + q.copy( + offset = + if (q.orderBy.nonEmpty) q.offset.orElse(Some(Constant(0))) + else q.offset, + from = q.from.map(applyInner(_)) + ) + + case SetOperationSqlQuery(a, op, b) => + SetOperationSqlQuery(applyInner(a), op, applyInner(b)) + case UnaryOperationSqlQuery(op, a) => + UnaryOperationSqlQuery(op, applyInner(a)) + } + + private def applyInner(f: FromContext): FromContext = + f match { + case QueryContext(a, alias) => QueryContext(applyInner(a), alias) + case JoinContext(t, a, b, on) => + JoinContext(t, applyInner(a), applyInner(b), on) + case FlatJoinContext(t, a, on) => FlatJoinContext(t, applyInner(a), on) + case other => other + } + + def apply(q: SqlQuery): SqlQuery = + q match { + case q: FlattenSqlQuery => q.copy(from = q.from.map(applyInner(_))) + case SetOperationSqlQuery(a, op, b) => + SetOperationSqlQuery(applyInner(a), op, applyInner(b)) + case UnaryOperationSqlQuery(op, a) => + UnaryOperationSqlQuery(op, applyInner(a)) + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala b/src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala new file mode 100644 index 0000000..9d03c1f --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala @@ -0,0 +1,68 @@ +package minisql.context.sql.norm + +import minisql.ast.Visibility.Hidden +import minisql.ast._ + +object ExpandDistinct { + + @annotation.tailrec + def hasJoin(q: Ast): Boolean = { + q match { + case _: Join => true + case Map(q, _, _) => hasJoin(q) + case Filter(q, _, _) => hasJoin(q) + case _ => false + } + } + + def apply(q: Ast): Ast = + q match { + case Distinct(q) => + Distinct(apply(q)) + case q => + Transform(q) { + case Aggregation(op, Distinct(q)) => + Aggregation(op, Distinct(apply(q))) + case Distinct(Map(q, x, cc @ Tuple(values))) => + Map( + Distinct(Map(q, x, cc)), + x, + Tuple(values.zipWithIndex.map { + case (_, i) => Property(x, s"_${i + 1}") + }) + ) + + // Situations like this: + // case class AdHocCaseClass(id: Int, name: String) + // val q = quote { + // query[SomeTable].map(st => AdHocCaseClass(st.id, st.name)).distinct + // } + // ... need some special treatment. Otherwise their values will not be correctly expanded. + case Distinct(Map(q, x, cc @ CaseClass(values))) => + Map( + Distinct(Map(q, x, cc)), + x, + CaseClass(values.map { + case (name, _) => (name, Property(x, name)) + }) + ) + + // Need some special handling to address issues with distinct returning a single embedded entity i.e: + // query[Parent].map(p => p.emb).distinct.map(e => (e.name, e.id)) + // cannot treat such a case normally or "confused" queries will result e.g: + // SELECT p.embname, p.embid FROM (SELECT DISTINCT emb.name /* Where the heck is 'emb' coming from? */ AS embname, emb.id AS embid FROM Parent p) AS p + case d @ Distinct( + Map(q, x, p @ Property.Opinionated(_, _, _, Hidden)) + ) => + d + + // Problems with distinct were first discovered in #1032. Basically, unless + // the distinct is "expanded" adding an outer map, Ident's representing a Table will end up in invalid places + // such as "ORDER BY tableIdent" etc... + case Distinct(Map(q, x, p)) => + val newMap = Map(q, x, Tuple(List(p))) + val newIdent = Ident(x.name) + Map(Distinct(newMap), newIdent, Property(newIdent, "_1")) + } + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandJoin.scala b/src/main/scala/minisql/context/sql/norm/ExpandJoin.scala new file mode 100644 index 0000000..1677cf7 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandJoin.scala @@ -0,0 +1,49 @@ +package minisql.context.sql.norm + +import minisql.ast._ +import minisql.norm.BetaReduction +import minisql.norm.Normalize + +object ExpandJoin { + + def apply(q: Ast) = expand(q, None) + + def expand(q: Ast, id: Option[Ident]) = + Transform(q) { + case q @ Join(_, _, _, Ident(a), Ident(b), _) => + val (qr, tuple) = expandedTuple(q) + Map(qr, id.getOrElse(Ident(s"$a$b")), tuple) + } + + private def expandedTuple(q: Join): (Join, Tuple) = + q match { + + case Join(t, a: Join, b: Join, tA, tB, o) => + val (ar, at) = expandedTuple(a) + val (br, bt) = expandedTuple(b) + val or = BetaReduction(o, tA -> at, tB -> bt) + (Join(t, ar, br, tA, tB, or), Tuple(List(at, bt))) + + case Join(t, a: Join, b, tA, tB, o) => + val (ar, at) = expandedTuple(a) + val or = BetaReduction(o, tA -> at) + (Join(t, ar, b, tA, tB, or), Tuple(List(at, tB))) + + case Join(t, a, b: Join, tA, tB, o) => + val (br, bt) = expandedTuple(b) + val or = BetaReduction(o, tB -> bt) + (Join(t, a, br, tA, tB, or), Tuple(List(tA, bt))) + + case q @ Join(t, a, b, tA, tB, on) => + ( + Join(t, nestedExpand(a, tA), nestedExpand(b, tB), tA, tB, on), + Tuple(List(tA, tB)) + ) + } + + private def nestedExpand(q: Ast, id: Ident) = + Normalize(expand(q, Some(id))) match { + case Map(q, _, _) => q + case q => q + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandMappedInfix.scala b/src/main/scala/minisql/context/sql/norm/ExpandMappedInfix.scala new file mode 100644 index 0000000..b1cc186 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandMappedInfix.scala @@ -0,0 +1,12 @@ +package minisql.context.sql.norm + +import minisql.ast._ + +object ExpandMappedInfix { + def apply(q: Ast): Ast = { + Transform(q) { + case Map(Infix("" :: parts, (q: Query) :: params, pure, noParen), x, p) => + Infix("" :: parts, Map(q, x, p) :: params, pure, noParen) + } + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala b/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala new file mode 100644 index 0000000..56095e9 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala @@ -0,0 +1,147 @@ +package minisql.context.sql.norm + +import minisql.NamingStrategy +import minisql.ast.Ast +import minisql.ast.Ident +import minisql.ast._ +import minisql.ast.StatefulTransformer +import minisql.ast.Visibility.Visible +import minisql.context.sql._ + +import scala.collection.mutable.LinkedHashSet +import minisql.util.Interpolator +import minisql.util.Messages.TraceType.NestedQueryExpansion +import minisql.context.sql.norm.nested.ExpandSelect +import minisql.norm.BetaReduction + +import scala.collection.mutable + +class ExpandNestedQueries(strategy: NamingStrategy) { + + val interp = new Interpolator(3) + import interp._ + + def apply(q: SqlQuery, references: List[Property]): SqlQuery = + apply(q, LinkedHashSet.empty ++ references) + + // Using LinkedHashSet despite the fact that it is mutable because it has better characteristics then ListSet. + // Also this collection is strictly internal to ExpandNestedQueries and exposed anywhere else. + private def apply( + q: SqlQuery, + references: LinkedHashSet[Property] + ): SqlQuery = + q match { + case q: FlattenSqlQuery => + val expand = expandNested( + q.copy(select = ExpandSelect(q.select, references, strategy)) + ) + trace"Expanded Nested Query $q into $expand".andLog() + expand + case SetOperationSqlQuery(a, op, b) => + SetOperationSqlQuery(apply(a, references), op, apply(b, references)) + case UnaryOperationSqlQuery(op, q) => + UnaryOperationSqlQuery(op, apply(q, references)) + } + + private def expandNested(q: FlattenSqlQuery): SqlQuery = + q match { + case FlattenSqlQuery( + from, + where, + groupBy, + orderBy, + limit, + offset, + select, + distinct + ) => + val asts = Nil ++ select.map(_.ast) ++ where ++ groupBy ++ orderBy.map( + _.ast + ) ++ limit ++ offset + val expansions = q.from.map(expandContext(_, asts)) + val from = expansions.map(_._1) + val references = expansions.flatMap(_._2) + + val replacedRefs = references.map(ref => (ref, unhideAst(ref))) + + // Need to unhide properties that were used during the query + def replaceProps(ast: Ast) = + BetaReduction(ast, replacedRefs: _*) + def replacePropsOption(ast: Option[Ast]) = + ast.map(replaceProps(_)) + + val distinctKind = + q.distinct match { + case DistinctKind.DistinctOn(props) => + DistinctKind.DistinctOn(props.map(p => replaceProps(p))) + case other => other + } + + q.copy( + select = select.map(sv => sv.copy(ast = replaceProps(sv.ast))), + from = from, + where = replacePropsOption(where), + groupBy = replacePropsOption(groupBy), + orderBy = orderBy.map(ob => ob.copy(ast = replaceProps(ob.ast))), + limit = replacePropsOption(limit), + offset = replacePropsOption(offset), + distinct = distinctKind + ) + + } + + def unhideAst(ast: Ast): Ast = + Transform(ast) { + case Property.Opinionated(a, n, r, v) => + Property.Opinionated(unhideAst(a), n, r, Visible) + } + + private def unhideProperties(sv: SelectValue) = + sv.copy(ast = unhideAst(sv.ast)) + + private def expandContext( + s: FromContext, + asts: List[Ast] + ): (FromContext, LinkedHashSet[Property]) = + s match { + case QueryContext(q, alias) => + val refs = references(alias, asts) + (QueryContext(apply(q, refs), alias), refs) + case JoinContext(t, a, b, on) => + val (left, leftRefs) = expandContext(a, asts :+ on) + val (right, rightRefs) = expandContext(b, asts :+ on) + (JoinContext(t, left, right, on), leftRefs ++ rightRefs) + case FlatJoinContext(t, a, on) => + val (next, refs) = expandContext(a, asts :+ on) + (FlatJoinContext(t, next, on), refs) + case _: TableContext | _: InfixContext => + (s, new mutable.LinkedHashSet[Property]()) + } + + private def references(alias: String, asts: List[Ast]) = + LinkedHashSet.empty ++ (References(State(Ident(alias), Nil))(asts)( + _.apply + )._2.state.references) +} + +case class State(ident: Ident, references: List[Property]) + +case class References(val state: State) extends StatefulTransformer[State] { + + import state._ + + override def apply(a: Ast) = + a match { + case `reference`(p) => (p, References(State(ident, references :+ p))) + case other => super.apply(a) + } + + object reference { + def unapply(p: Property): Option[Property] = + p match { + case Property(`ident`, name) => Some(p) + case Property(reference(_), name) => Some(p) + case other => None + } + } +} diff --git a/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala b/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala new file mode 100644 index 0000000..30abb53 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala @@ -0,0 +1,58 @@ +package minisql.context.sql.norm + +import minisql.ast.Aggregation +import minisql.ast.Ast +import minisql.ast.Drop +import minisql.ast.Filter +import minisql.ast.FlatMap +import minisql.ast.Ident +import minisql.ast.Join +import minisql.ast.Map +import minisql.ast.Query +import minisql.ast.SortBy +import minisql.ast.StatelessTransformer +import minisql.ast.Take +import minisql.ast.Union +import minisql.ast.UnionAll +import minisql.norm.BetaReduction +import minisql.util.Messages.fail +import minisql.ast.ConcatMap + +case class FlattenGroupByAggregation(agg: Ident) extends StatelessTransformer { + + override def apply(ast: Ast) = + ast match { + case q: Query if (isGroupByAggregation(q)) => + q match { + case Aggregation(op, Map(`agg`, ident, body)) => + Aggregation(op, BetaReduction(body, ident -> agg)) + case Map(`agg`, ident, body) => + BetaReduction(body, ident -> agg) + case q @ Aggregation(op, `agg`) => + q + case other => + fail(s"Invalid group by aggregation: '$other'") + } + case other => + super.apply(other) + } + + private[this] def isGroupByAggregation(ast: Ast): Boolean = + ast match { + case Aggregation(a, b) => isGroupByAggregation(b) + case Map(a, b, c) => isGroupByAggregation(a) + case FlatMap(a, b, c) => isGroupByAggregation(a) + case ConcatMap(a, b, c) => isGroupByAggregation(a) + case Filter(a, b, c) => isGroupByAggregation(a) + case SortBy(a, b, c, d) => isGroupByAggregation(a) + case Take(a, b) => isGroupByAggregation(a) + case Drop(a, b) => isGroupByAggregation(a) + case Union(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b) + case UnionAll(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b) + case Join(t, a, b, ta, tb, on) => + isGroupByAggregation(a) || isGroupByAggregation(b) + case `agg` => true + case other => false + } + +} diff --git a/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala b/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala new file mode 100644 index 0000000..c239b63 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala @@ -0,0 +1,53 @@ +package minisql.context.sql.norm + +import minisql.norm._ +import minisql.ast.Ast +import minisql.norm.ConcatBehavior.AnsiConcat +import minisql.norm.EqualityBehavior.AnsiEquality +import minisql.norm.capture.DemarcateExternalAliases +import minisql.util.Messages.trace + +object SqlNormalize { + def apply( + ast: Ast, + concatBehavior: ConcatBehavior = AnsiConcat, + equalityBehavior: EqualityBehavior = AnsiEquality + ) = + new SqlNormalize(concatBehavior, equalityBehavior)(ast) +} + +class SqlNormalize( + concatBehavior: ConcatBehavior, + equalityBehavior: EqualityBehavior +) { + + private val normalize = + (identity[Ast] _) + .andThen(trace("original")) + .andThen(DemarcateExternalAliases.apply _) + .andThen(trace("DemarcateReturningAliases")) + .andThen(new FlattenOptionOperation(concatBehavior).apply _) + .andThen(trace("FlattenOptionOperation")) + .andThen(new SimplifyNullChecks(equalityBehavior).apply _) + .andThen(trace("SimplifyNullChecks")) + .andThen(Normalize.apply _) + .andThen(trace("Normalize")) + // Need to do RenameProperties before ExpandJoin which normalizes-out all the tuple indexes + // on which RenameProperties relies + .andThen(RenameProperties.apply _) + .andThen(trace("RenameProperties")) + .andThen(ExpandDistinct.apply _) + .andThen(trace("ExpandDistinct")) + .andThen(NestImpureMappedInfix.apply _) + .andThen(trace("NestMappedInfix")) + .andThen(Normalize.apply _) + .andThen(trace("Normalize")) + .andThen(ExpandJoin.apply _) + .andThen(trace("ExpandJoin")) + .andThen(ExpandMappedInfix.apply _) + .andThen(trace("ExpandMappedInfix")) + .andThen(Normalize.apply _) + .andThen(trace("Normalize")) + + def apply(ast: Ast) = normalize(ast) +} diff --git a/src/main/scala/minisql/context/sql/norm/nested/Elements.scala b/src/main/scala/minisql/context/sql/norm/nested/Elements.scala new file mode 100644 index 0000000..1cdd629 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/nested/Elements.scala @@ -0,0 +1,29 @@ +package minisql.context.sql.norm.nested + +import minisql.PseudoAst +import minisql.context.sql.SelectValue + +object Elements { + + /** + * In order to be able to reconstruct the original ordering of elements inside + * of a select clause, we need to keep track of their order, not only within + * the top-level select but also it's order within any possible + * tuples/case-classes that in which it is embedded. For example, in the + * query:
 query[Person].map(p => (p.id, (p.name, p.age))).nested
+   * // SELECT p.id, p.name, p.age FROM (SELECT x.id, x.name, x.age FROM person
+   * x) AS p 
Since the `p.name` and `p.age` elements are selected + * inside of a sub-tuple, their "order" is `List(2,1)` and `List(2,2)` + * respectively as opposed to `p.id` whose "order" is just `List(1)`. + * + * This class keeps track of the values needed in order to perform do this. + */ + case class OrderedSelect(order: List[Int], selectValue: SelectValue) + extends PseudoAst { + override def toString: String = s"[${order.mkString(",")}]${selectValue}" + } + object OrderedSelect { + def apply(order: Int, selectValue: SelectValue) = + new OrderedSelect(List(order), selectValue) + } +} diff --git a/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala b/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala new file mode 100644 index 0000000..a8fd5d6 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala @@ -0,0 +1,262 @@ +package minisql.context.sql.norm.nested + +import minisql.NamingStrategy +import minisql.ast.Property +import minisql.context.sql.SelectValue +import minisql.util.Interpolator +import minisql.util.Messages.TraceType.NestedQueryExpansion + +import scala.collection.mutable.LinkedHashSet +import minisql.context.sql.norm.nested.Elements._ +import minisql.ast._ +import minisql.norm.BetaReduction + +/** + * Takes the `SelectValue` elements inside of a sub-query (if a super/sub-query + * constrct exists) and flattens them from a nested-hiearchical structure (i.e. + * tuples inside case classes inside tuples etc..) into into a single series of + * top-level select elements where needed. In cases where a user wants to select + * an element that contains an entire tuple (i.e. a sub-tuple of the outer + * select clause) we pull out the entire tuple that is being selected and leave + * it to the tokenizer to flatten later. + * + * The part about this operation that is tricky is if there are situations where + * there are infix clauses in a sub-query representing an element that has not + * been selected by the query-query but in order to ensure the SQL operation has + * the same meaning, we need to keep track for it. For example:
 val
+ * q = quote { query[Person].map(p => (infix"DISTINCT ON (${p.other})".as[Int],
+ * p.name, p.id)).map(t => (t._2, t._3)) } run(q) // SELECT p._2, p._3 FROM
+ * (SELECT DISTINCT ON (p.other), p.name AS _2, p.id AS _3 FROM Person p) AS p
+ * 
Since `DISTINCT ON` significantly changes the behavior of the + * outer query, we need to keep track of it inside of the inner query. In order + * to do this, we need to keep track of the location of the infix in the inner + * query so that we can reconstruct it. This is why the `OrderedSelect` and + * `DoubleOrderedSelect` objects are used. See the notes on these classes for + * more detail. + * + * See issue #1597 for more details and another example. + */ +private class ExpandSelect( + selectValues: List[SelectValue], + references: LinkedHashSet[Property], + strategy: NamingStrategy +) { + val interp = new Interpolator(3) + import interp._ + + object TupleIndex { + def unapply(s: String): Option[Int] = + if (s.matches("_[0-9]*")) + Some(s.drop(1).toInt - 1) + else + None + } + + object MultiTupleIndex { + def unapply(s: String): Boolean = + if (s.matches("(_[0-9]+)+")) + true + else + false + } + + val select = + selectValues.zipWithIndex.map { + case (value, index) => OrderedSelect(index, value) + } + + def expandColumn(name: String, renameable: Renameable): String = + renameable.fixedOr(name)(strategy.column(name)) + + def apply: List[SelectValue] = + trace"Expanding Select values: $selectValues into references: $references" andReturn { + + def expandReference(ref: Property): OrderedSelect = + trace"Expanding: $ref from $select" andReturn { + + def expressIfTupleIndex(str: String) = + str match { + case MultiTupleIndex() => Some(str) + case _ => None + } + + def concat(alias: Option[String], idx: Int) = + Some(s"${alias.getOrElse("")}_${idx + 1}") + + val orderedSelect = ref match { + case pp @ Property(ast: Property, TupleIndex(idx)) => + trace"Reference is a sub-property of a tuple index: $idx. Walking inside." andReturn + expandReference(ast) match { + case OrderedSelect(o, SelectValue(Tuple(elems), alias, c)) => + trace"Expressing Element $idx of $elems " andReturn + OrderedSelect( + o :+ idx, + SelectValue(elems(idx), concat(alias, idx), c) + ) + case OrderedSelect(o, SelectValue(ast, alias, c)) => + trace"Appending $idx to $alias " andReturn + OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) + } + case pp @ Property.Opinionated( + ast: Property, + name, + renameable, + visible + ) => + trace"Reference is a sub-property. Walking inside." andReturn + expandReference(ast) match { + case OrderedSelect(o, SelectValue(ast, nested, c)) => + // Alias is the name of the column after the naming strategy + // The clauses in `SqlIdiom` that use `Tokenizer[SelectValue]` select the + // alias field when it's value is Some(T). + // Technically the aliases of a column should not be using naming strategies + // but this is an issue to fix at a later date. + + // In the current implementation, aliases we add nested tuple names to queries e.g. + // SELECT foo from + // SELECT x, y FROM (SELECT foo, bar, red, orange FROM baz JOIN colors) + // Typically becomes SELECT foo _1foo, _1bar, _2red, _2orange when + // this kind of query is the result of an applicative join that looks like this: + // query[baz].join(query[colors]).nested + // this may need to change based on how distinct appends table names instead of just tuple indexes + // into the property path. + + trace"...inside walk completed, continuing to return: " andReturn + OrderedSelect( + o, + SelectValue( + // Note: Pass invisible properties to be tokenized by the idiom, they should be excluded there + Property.Opinionated(ast, name, renameable, visible), + // Skip concatonation of invisible properties into the alias e.g. so it will be + Some( + s"${nested.getOrElse("")}${expandColumn(name, renameable)}" + ) + ) + ) + } + case pp @ Property(_, TupleIndex(idx)) => + trace"Reference is a tuple index: $idx from $select." andReturn + select(idx) match { + case OrderedSelect(o, SelectValue(ast, alias, c)) => + OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) + } + case pp @ Property.Opinionated(_, name, renameable, visible) => + select match { + case List( + OrderedSelect(o, SelectValue(cc: CaseClass, alias, c)) + ) => + // Currently case class element name is not being appended. Need to change that in order to ensure + // path name uniqueness in future. + val ((_, ast), index) = + cc.values.zipWithIndex.find(_._1._1 == name) match { + case Some(v) => v + case None => + throw new IllegalArgumentException( + s"Cannot find element $name in $cc" + ) + } + trace"Reference is a case class member: " andReturn + OrderedSelect( + o :+ index, + SelectValue(ast, Some(expandColumn(name, renameable)), c) + ) + case List(OrderedSelect(o, SelectValue(i: Ident, _, c))) => + trace"Reference is an identifier: " andReturn + OrderedSelect( + o, + SelectValue( + Property.Opinionated(i, name, renameable, visible), + Some(name), + c + ) + ) + case other => + trace"Reference is unidentified: $other returning:" andReturn + OrderedSelect( + Integer.MAX_VALUE, + SelectValue( + Ident.Opinionated(name, visible), + Some(expandColumn(name, renameable)), + false + ) + ) + } + } + + // For certain very large queries where entities are unwrapped and then re-wrapped into CaseClass/Tuple constructs, + // the actual row-types can contain Tuple/CaseClass values. For this reason. They need to be beta-reduced again. + val normalizedOrderedSelect = orderedSelect.copy(selectValue = + orderedSelect.selectValue.copy(ast = + BetaReduction(orderedSelect.selectValue.ast) + ) + ) + + trace"Expanded $ref into $orderedSelect then Normalized to $normalizedOrderedSelect" andReturn + normalizedOrderedSelect + } + + def deAliasWhenUneeded(os: OrderedSelect) = + os match { + case OrderedSelect( + _, + sv @ SelectValue(Property(Ident(_), propName), Some(alias), _) + ) if (propName == alias) => + trace"Detected select value with un-needed alias: $os removing it:" andReturn + os.copy(selectValue = sv.copy(alias = None)) + case _ => os + } + + references.toList match { + case Nil => select.map(_.selectValue) + case refs => { + // elements first need to be sorted by their order in the select clause. Since some may map to multiple + // properties when expanded, we want to maintain this order of properties as a secondary value. + val mappedRefs = + refs + // Expand all the references to properties that we have selected in the super query + .map(expandReference) + // Once all the recursive calls of expandReference are done, remove the alias if it is not needed. + // We cannot do this because during recursive calls, the aliases of outer clauses are used for inner ones. + .map(deAliasWhenUneeded(_)) + + trace"Mapped Refs: $mappedRefs".andLog() + + // are there any selects that have infix values which we have not already selected? We need to include + // them because they could be doing essential things e.g. RANK ... ORDER BY + val remainingSelectsWithInfixes = + trace"Searching Selects with Infix:" andReturn + new FindUnexpressedInfixes(select)(mappedRefs) + + implicit val ordering: scala.math.Ordering[List[Int]] = + new scala.math.Ordering[List[Int]] { + override def compare(x: List[Int], y: List[Int]): Int = + (x, y) match { + case (head1 :: tail1, head2 :: tail2) => + val diff = head1 - head2 + if (diff != 0) diff + else compare(tail1, tail2) + case (Nil, Nil) => 0 // List(1,2,3) == List(1,2,3) + case (head1, Nil) => -1 // List(1,2,3) < List(1,2) + case (Nil, head2) => 1 // List(1,2) > List(1,2,3) + } + } + + val sortedRefs = + (mappedRefs ++ remainingSelectsWithInfixes).sortBy(ref => + ref.order + ) // (ref.order, ref.secondaryOrder) + + sortedRefs.map(_.selectValue) + } + } + } +} + +object ExpandSelect { + def apply( + selectValues: List[SelectValue], + references: LinkedHashSet[Property], + strategy: NamingStrategy + ): List[SelectValue] = + new ExpandSelect(selectValues, references, strategy).apply +} diff --git a/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala b/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala new file mode 100644 index 0000000..2ea1320 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala @@ -0,0 +1,83 @@ +package minisql.context.sql.norm.nested + +import minisql.context.sql.norm.nested.Elements._ +import minisql.util.Interpolator +import minisql.util.Messages.TraceType.NestedQueryExpansion +import minisql.ast._ +import minisql.context.sql.SelectValue + +/** + * The challenge with appeneding infixes (that have not been used but are still + * needed) back into the query, is that they could be inside of + * tuples/case-classes that have already been selected, or inside of sibling + * elements which have been selected. Take for instance a query that looks like + * this:
 query[Person].map(p => (p.name, (p.id,
+ * infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2._1)) 
In + * this situation, `p.id` which is the sibling of the non-selected infix has + * been selected via `p._2._1` (whose select-order is List(1,0) to represent 1st + * element in 2nd tuple. We need to add it's sibling infix. + * + * Or take the following situation:
 query[Person].map(p => (p.name,
+ * (p.id, infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2))
+ * 
In this case, we have selected the entire 2nd element including + * the infix. We need to know that `P._2._2` does not need to be selected since + * `p._2` was. + * + * In order to do these things, we use the `order` property from `OrderedSelect` + * in order to see which sub-sub-...-element has been selected. If `p._2` (that + * has order `List(1)`) has been selected, we know that any infixes inside of it + * e.g. `p._2._1` (ordering `List(1,0)`) does not need to be. + */ +class FindUnexpressedInfixes(select: List[OrderedSelect]) { + val interp = new Interpolator(3) + import interp._ + + def apply(refs: List[OrderedSelect]) = { + + def pathExists(path: List[Int]) = + refs.map(_.order).contains(path) + + def containsInfix(ast: Ast) = + CollectAst.byType[Infix](ast).length > 0 + + // build paths to every infix and see these paths were not selected already + def findMissingInfixes( + ast: Ast, + parentOrder: List[Int] + ): List[(Ast, List[Int])] = { + trace"Searching for infix: $ast in the sub-path $parentOrder".andLog() + if (pathExists(parentOrder)) + trace"No infixes found" andContinue + List() + else + ast match { + case Tuple(values) => + values.zipWithIndex + .filter(v => containsInfix(v._1)) + .flatMap { + case (ast, index) => + findMissingInfixes(ast, parentOrder :+ index) + } + case CaseClass(values) => + values.zipWithIndex + .filter(v => containsInfix(v._1._2)) + .flatMap { + case ((_, ast), index) => + findMissingInfixes(ast, parentOrder :+ index) + } + case other if (containsInfix(other)) => + trace"Found unexpressed infix inside $other in $parentOrder" + .andLog() + List((other, parentOrder)) + case _ => + List() + } + } + + select.flatMap { + case OrderedSelect(o, sv) => findMissingInfixes(sv.ast, o) + }.map { + case (ast, order) => OrderedSelect(order, SelectValue(ast)) + } + } +} diff --git a/src/main/scala/minisql/norm/FreeVariables.scala b/src/main/scala/minisql/norm/FreeVariables.scala new file mode 100644 index 0000000..9c63437 --- /dev/null +++ b/src/main/scala/minisql/norm/FreeVariables.scala @@ -0,0 +1,120 @@ +package minisql.norm + +import minisql.ast.* +import collection.immutable.Set + +case class State(seen: Set[Ident], free: Set[Ident]) + +case class FreeVariables(state: State) extends StatefulTransformer[State] { + + override def apply(ast: Ast): (Ast, StatefulTransformer[State]) = + ast match { + case ident: Ident if (!state.seen.contains(ident)) => + (ident, FreeVariables(State(state.seen, state.free + ident))) + case f @ Function(params, body) => + val (_, t) = + FreeVariables(State(state.seen ++ params, state.free))(body) + (f, FreeVariables(State(state.seen, state.free ++ t.state.free))) + case q @ Foreach(a, b, c) => + (q, free(a, b, c)) + case other => + super.apply(other) + } + + override def apply( + o: OptionOperation + ): (OptionOperation, StatefulTransformer[State]) = + o match { + case q @ OptionTableFlatMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionTableMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionTableExists(a, b, c) => + (q, free(a, b, c)) + case q @ OptionTableForall(a, b, c) => + (q, free(a, b, c)) + case q @ OptionFlatMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionForall(a, b, c) => + (q, free(a, b, c)) + case q @ OptionExists(a, b, c) => + (q, free(a, b, c)) + case other => + super.apply(other) + } + + override def apply(e: Assignment): (Assignment, StatefulTransformer[State]) = + e match { + case Assignment(a, b, c) => + val t = FreeVariables(State(state.seen + a, state.free)) + val (bt, btt) = t(b) + val (ct, ctt) = t(c) + ( + Assignment(a, bt, ct), + FreeVariables( + State(state.seen, state.free ++ btt.state.free ++ ctt.state.free) + ) + ) + } + + override def apply(action: Action): (Action, StatefulTransformer[State]) = + action match { + case q @ Returning(a, b, c) => + (q, free(a, b, c)) + case q @ ReturningGenerated(a, b, c) => + (q, free(a, b, c)) + case other => + super.apply(other) + } + + override def apply( + e: OnConflict.Target + ): (OnConflict.Target, StatefulTransformer[State]) = (e, this) + + override def apply(query: Query): (Query, StatefulTransformer[State]) = + query match { + case q @ Filter(a, b, c) => (q, free(a, b, c)) + case q @ Map(a, b, c) => (q, free(a, b, c)) + case q @ DistinctOn(a, b, c) => (q, free(a, b, c)) + case q @ FlatMap(a, b, c) => (q, free(a, b, c)) + case q @ ConcatMap(a, b, c) => (q, free(a, b, c)) + case q @ SortBy(a, b, c, d) => (q, free(a, b, c)) + case q @ GroupBy(a, b, c) => (q, free(a, b, c)) + case q @ FlatJoin(t, a, b, c) => (q, free(a, b, c)) + case q @ Join(t, a, b, iA, iB, on) => + val (_, freeA) = apply(a) + val (_, freeB) = apply(b) + val (_, freeOn) = + FreeVariables(State(state.seen + iA + iB, Set.empty))(on) + ( + q, + FreeVariables( + State( + state.seen, + state.free ++ freeA.state.free ++ freeB.state.free ++ freeOn.state.free + ) + ) + ) + case _: Entity | _: Take | _: Drop | _: Union | _: UnionAll | + _: Aggregation | _: Distinct | _: Nested => + super.apply(query) + } + + private def free(a: Ast, ident: Ident, c: Ast) = { + val (_, ta) = apply(a) + val (_, tc) = FreeVariables(State(state.seen + ident, state.free))(c) + FreeVariables( + State(state.seen, state.free ++ ta.state.free ++ tc.state.free) + ) + } +} + +object FreeVariables { + def apply(ast: Ast): Set[Ident] = + new FreeVariables(State(Set.empty, Set.empty))(ast) match { + case (_, transformer) => + transformer.state.free + } +} From 17e97495b7728fed633438e1787d48c867359c85 Mon Sep 17 00:00:00 2001 From: jilen Date: Thu, 19 Jun 2025 18:49:14 +0800 Subject: [PATCH 12/26] Simplify Mirror Codec --- build.sbt | 9 +- .../scala/minisql/context/MirrorContext.scala | 15 +- src/main/scala/minisql/context/mirror.scala | 180 ++++++++++++------ .../context/sql/MirrorSqlContext.scala | 24 +++ .../{ => context/sql}/MirrorSqlDialect.scala | 2 +- .../context/sql/OnConflictSupport.scala | 8 +- .../minisql/context/sql/SqlContext.scala | 44 +++++ .../scala/minisql/context/sql/SqlIdiom.scala | 30 +-- .../scala/minisql/context/sql/SqlQuery.scala | 2 +- .../sql/norm/ExpandNestedQueries.scala | 2 +- .../sql/norm/FlattenGroupByAggregation.scala | 2 +- .../context/sql/norm/SqlNormalize.scala | 24 +-- .../sql/norm/nested/ExpandSelect.scala | 24 +-- .../norm/nested/FindUnexpressedInfixes.scala | 2 +- .../scala/minisql/idiom/MirrorIdiom.scala | 5 +- .../{parsing => mirror}/QuotedSuite.scala | 8 +- src/test/scala/minisql/mirror/context.scala | 6 + 17 files changed, 275 insertions(+), 112 deletions(-) create mode 100644 src/main/scala/minisql/context/sql/MirrorSqlContext.scala rename src/main/scala/minisql/{ => context/sql}/MirrorSqlDialect.scala (97%) create mode 100644 src/main/scala/minisql/context/sql/SqlContext.scala rename src/test/scala/minisql/{parsing => mirror}/QuotedSuite.scala (65%) create mode 100644 src/test/scala/minisql/mirror/context.scala diff --git a/build.sbt b/build.sbt index 67a69a6..86502a5 100644 --- a/build.sbt +++ b/build.sbt @@ -1,7 +1,14 @@ name := "minisql" -scalaVersion := "3.7.0" +scalaVersion := "3.7.1" libraryDependencies ++= Seq( "org.scalameta" %% "munit" % "1.0.3" % Test ) + +scalacOptions ++= Seq( + "-deprecation", + "-feature", + "-source:3.7-migration", + "-rewrite" +) diff --git a/src/main/scala/minisql/context/MirrorContext.scala b/src/main/scala/minisql/context/MirrorContext.scala index ba00db1..7501274 100644 --- a/src/main/scala/minisql/context/MirrorContext.scala +++ b/src/main/scala/minisql/context/MirrorContext.scala @@ -1,14 +1,23 @@ package minisql import minisql.context.mirror.* +import minisql.util.Messages.fail +import scala.reflect.ClassTag class MirrorContext[Idiom <: idiom.Idiom, Naming <: NamingStrategy]( val idiom: Idiom, val naming: Naming -) extends context.Context[Idiom, Naming] { +) extends context.Context[Idiom, Naming] + with MirrorCodecs { - type DBRow = Row + type DBRow = IArray[Any] *: EmptyTuple + type DBResultSet = Iterable[DBRow] + type DBStatement = Map[Int, Any] - type DBResultSet = ResultSet + extension (r: DBRow) { + + def data: IArray[Any] = r._1 + def add(value: Any): DBRow = (r.data :+ value) *: EmptyTuple + } } diff --git a/src/main/scala/minisql/context/mirror.scala b/src/main/scala/minisql/context/mirror.scala index a67fdd2..81b8f98 100644 --- a/src/main/scala/minisql/context/mirror.scala +++ b/src/main/scala/minisql/context/mirror.scala @@ -1,71 +1,139 @@ package minisql.context.mirror -import minisql.{MirrorContext, NamingStrategy, ParamEncoder, ColumnDecoder} -import minisql.idiom.Idiom +import minisql.MirrorContext +import java.time.LocalDate +import java.util.{Date, UUID} +import minisql.{ParamEncoder, ColumnDecoder} import minisql.util.Messages.fail +import scala.util.{Failure, Success, Try} import scala.util.Try import scala.reflect.ClassTag -/** -* No extra class defined -*/ -opaque type Row = IArray[Any] *: EmptyTuple -opaque type ResultSet = Iterable[Row] -opaque type Statement = Map[Int, Any] +trait MirrorCodecs { + ctx: MirrorContext[?, ?] => -extension (r: Row) { - - def data: IArray[Any] = r._1 - - def add(value: Any): Row = (r.data :+ value) *: EmptyTuple - - def apply[T](idx: Int)(using t: ClassTag[T]): T = { - r.data(idx) match { - case v: T => v - case other => - fail( - s"Invalid column type. Expected '${t.runtimeClass}', but got '$other'" - ) + final protected def mirrorEncoder[V]: Encoder[V] = new ParamEncoder[V] { + type Stmt = ctx.DBStatement + def setParam(s: Stmt, idx: Int, v: V): Stmt = { + s + (idx -> v) } } -} -type Encoder[E] = ParamEncoder[E] { - type Stmt = Statement -} + final protected def mirrorColumnDecoder[X]( + conv: Any => Option[X] + ): Decoder[X] = + new ColumnDecoder[X] { + type DBRow = ctx.DBRow + def decode(row: DBRow, idx: Int): Try[X] = { + row.data + .lift(idx) + .flatMap { x => + conv(x) + } + .toRight(new Exception(s"Cannot convert value at ${idx}")) + .toTry + } + } -private def encoder[V]: Encoder[V] = new ParamEncoder[V] { - - type Stmt = Map[Int, Any] - - def setParam(s: Stmt, idx: Int, v: V): Stmt = { - s + (idx -> v) - } -} - -given Encoder[Long] = encoder[Long] - -type Decoder[A] = ColumnDecoder[A] { - type DBRow = Row -} - -private def apply[X](conv: Any => Option[X]): Decoder[X] = - new ColumnDecoder[X] { - type DBRow = Row - def decode(row: Row, idx: Int): Try[X] = { - row._1 - .lift(idx) - .flatMap { x => - conv(x) + given optionDecoder[T](using d: Decoder[T]): Decoder[Option[T]] = { + new ColumnDecoder[Option[T]] { + type DBRow = ctx.DBRow + override def decode(row: DBRow, idx: Int): Try[Option[T]] = + row.data.lift(idx) match { + case Some(null) => Success(None) + case Some(value) => d.decode(row, idx).map(Some(_)) + case None => Success(None) } - .toRight(new Exception(s"Cannot convert value at ${idx}")) - .toTry } } -given Decoder[Long] = apply(x => - x match { - case l: Long => Some(l) - case _ => None - } -) + given optionEncoder[T](using e: Encoder[T]): Encoder[Option[T]] = + new ParamEncoder[Option[T]] { + type Stmt = ctx.DBStatement + override def setParam( + s: Stmt, + idx: Int, + v: Option[T] + ): Stmt = + v match { + case Some(value) => e.setParam(s, idx, value) + case None => + s + (idx -> null) + } + } + + // Implement all required decoders using mirrorColumnDecoder from MirrorCodecs + given stringDecoder: Decoder[String] = mirrorColumnDecoder[String](x => + x match { case s: String => Some(s); case _ => None } + ) + given bigDecimalDecoder: Decoder[BigDecimal] = + mirrorColumnDecoder[BigDecimal](x => + x match { + case bd: BigDecimal => Some(bd); case i: Int => Some(BigDecimal(i)); + case l: Long => Some(BigDecimal(l)); + case d: Double => Some(BigDecimal(d)); case _ => None + } + ) + given booleanDecoder: Decoder[Boolean] = mirrorColumnDecoder[Boolean](x => + x match { case b: Boolean => Some(b); case _ => None } + ) + given byteDecoder: Decoder[Byte] = mirrorColumnDecoder[Byte](x => + x match { + case b: Byte => Some(b); case i: Int => Some(i.toByte); case _ => None + } + ) + given shortDecoder: Decoder[Short] = mirrorColumnDecoder[Short](x => + x match { + case s: Short => Some(s); case i: Int => Some(i.toShort); case _ => None + } + ) + given intDecoder: Decoder[Int] = mirrorColumnDecoder[Int](x => + x match { case i: Int => Some(i); case _ => None } + ) + given longDecoder: Decoder[Long] = mirrorColumnDecoder[Long](x => + x match { + case l: Long => Some(l); case i: Int => Some(i.toLong); case _ => None + } + ) + given floatDecoder: Decoder[Float] = mirrorColumnDecoder[Float](x => + x match { + case f: Float => Some(f); case d: Double => Some(d.toFloat); + case _ => None + } + ) + given doubleDecoder: Decoder[Double] = mirrorColumnDecoder[Double](x => + x match { + case d: Double => Some(d); case f: Float => Some(f.toDouble); + case _ => None + } + ) + given byteArrayDecoder: Decoder[Array[Byte]] = + mirrorColumnDecoder[Array[Byte]](x => + x match { case ba: Array[Byte] => Some(ba); case _ => None } + ) + given dateDecoder: Decoder[Date] = mirrorColumnDecoder[Date](x => + x match { case d: Date => Some(d); case _ => None } + ) + given localDateDecoder: Decoder[LocalDate] = + mirrorColumnDecoder[LocalDate](x => + x match { case ld: LocalDate => Some(ld); case _ => None } + ) + given uuidDecoder: Decoder[UUID] = mirrorColumnDecoder[UUID](x => + x match { case uuid: UUID => Some(uuid); case _ => None } + ) + + // Implement all required encoders using mirrorEncoder from MirrorCodecs + given stringEncoder: Encoder[String] = mirrorEncoder[String] + given bigDecimalEncoder: Encoder[BigDecimal] = mirrorEncoder[BigDecimal] + given booleanEncoder: Encoder[Boolean] = mirrorEncoder[Boolean] + given byteEncoder: Encoder[Byte] = mirrorEncoder[Byte] + given shortEncoder: Encoder[Short] = mirrorEncoder[Short] + given intEncoder: Encoder[Int] = mirrorEncoder[Int] + given longEncoder: Encoder[Long] = mirrorEncoder[Long] + given floatEncoder: Encoder[Float] = mirrorEncoder[Float] + given doubleEncoder: Encoder[Double] = mirrorEncoder[Double] + given byteArrayEncoder: Encoder[Array[Byte]] = mirrorEncoder[Array[Byte]] + given dateEncoder: Encoder[Date] = mirrorEncoder[Date] + given localDateEncoder: Encoder[LocalDate] = mirrorEncoder[LocalDate] + given uuidEncoder: Encoder[UUID] = mirrorEncoder[UUID] +} diff --git a/src/main/scala/minisql/context/sql/MirrorSqlContext.scala b/src/main/scala/minisql/context/sql/MirrorSqlContext.scala new file mode 100644 index 0000000..0a035f9 --- /dev/null +++ b/src/main/scala/minisql/context/sql/MirrorSqlContext.scala @@ -0,0 +1,24 @@ +package minisql.context.sql + +import minisql.{NamingStrategy, MirrorContext} +import minisql.context.Context +import minisql.idiom.Idiom // Changed from minisql.idiom.* to avoid ambiguity with Statement +import minisql.context.mirror.MirrorCodecs +import minisql.context.ReturningClauseSupported +import minisql.context.ReturningCapability + +class MirrorSqlIdiom extends idiom.SqlIdiom { + override def concatFunction: String = "CONCAT" + override def idiomReturningCapability: ReturningCapability = + ReturningClauseSupported + + // Implementations previously provided by MirrorIdiomBase + override def prepareForProbing(string: String): String = string + override def liftingPlaceholder(index: Int): String = "?" +} +object MirrorSqlIdiom extends MirrorSqlIdiom + +class MirrorSqlContext[N <: NamingStrategy](naming: N) + extends MirrorContext[MirrorSqlIdiom, N](MirrorSqlIdiom, naming) + with SqlContext[MirrorSqlIdiom, N] + with MirrorCodecs {} diff --git a/src/main/scala/minisql/MirrorSqlDialect.scala b/src/main/scala/minisql/context/sql/MirrorSqlDialect.scala similarity index 97% rename from src/main/scala/minisql/MirrorSqlDialect.scala rename to src/main/scala/minisql/context/sql/MirrorSqlDialect.scala index 563f770..7ee69ba 100644 --- a/src/main/scala/minisql/MirrorSqlDialect.scala +++ b/src/main/scala/minisql/context/sql/MirrorSqlDialect.scala @@ -1,4 +1,4 @@ -package minisql +package minisql.context.sql import minisql.context.{ CanReturnClause, diff --git a/src/main/scala/minisql/context/sql/OnConflictSupport.scala b/src/main/scala/minisql/context/sql/OnConflictSupport.scala index 940d5bf..3d0d6ce 100644 --- a/src/main/scala/minisql/context/sql/OnConflictSupport.scala +++ b/src/main/scala/minisql/context/sql/OnConflictSupport.scala @@ -20,13 +20,13 @@ trait OnConflictSupport { } val customAstTokenizer = - Tokenizer.withFallback[Ast](self.astTokenizer(_, strategy)) { + Tokenizer.withFallback[Ast](self.astTokenizer(using _, strategy)) { case _: OnConflict.Excluded => stmt"EXCLUDED" case OnConflict.Existing(a) => stmt"${a.token}" case a: Action => self .actionTokenizer(customEntityTokenizer)( - actionAstTokenizer, + using actionAstTokenizer, strategy ) .token(a) @@ -37,7 +37,7 @@ trait OnConflictSupport { def doUpdateStmt(i: Token, t: Token, u: Update) = { val assignments = u.assignments .map(a => - stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(customAstTokenizer)}" + stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(using customAstTokenizer)}" ) .mkStmt() @@ -65,6 +65,6 @@ trait OnConflictSupport { case OnConflict(i, p: Properties, Ignore) => doNothingStmt(i, p.token) } - tokenizer(customAstTokenizer) + tokenizer(using customAstTokenizer) } } diff --git a/src/main/scala/minisql/context/sql/SqlContext.scala b/src/main/scala/minisql/context/sql/SqlContext.scala new file mode 100644 index 0000000..7c185d5 --- /dev/null +++ b/src/main/scala/minisql/context/sql/SqlContext.scala @@ -0,0 +1,44 @@ +package minisql.context.sql + +import java.time.LocalDate + +import minisql.idiom.{Idiom => BaseIdiom} +import java.util.{Date, UUID} + +import minisql.context.Context +import minisql.NamingStrategy + +trait SqlContext[Idiom <: BaseIdiom, Naming <: NamingStrategy] + extends Context[Idiom, Naming] { + + given optionDecoder[T](using d: Decoder[T]): Decoder[Option[T]] + given optionEncoder[T](using d: Encoder[T]): Encoder[Option[T]] + + given stringDecoder: Decoder[String] + given bigDecimalDecoder: Decoder[BigDecimal] + given booleanDecoder: Decoder[Boolean] + given byteDecoder: Decoder[Byte] + given shortDecoder: Decoder[Short] + given intDecoder: Decoder[Int] + given longDecoder: Decoder[Long] + given floatDecoder: Decoder[Float] + given doubleDecoder: Decoder[Double] + given byteArrayDecoder: Decoder[Array[Byte]] + given dateDecoder: Decoder[Date] + given localDateDecoder: Decoder[LocalDate] + given uuidDecoder: Decoder[UUID] + + given stringEncoder: Encoder[String] + given bigDecimalEncoder: Encoder[BigDecimal] + given booleanEncoder: Encoder[Boolean] + given byteEncoder: Encoder[Byte] + given shortEncoder: Encoder[Short] + given intEncoder: Encoder[Int] + given longEncoder: Encoder[Long] + given floatEncoder: Encoder[Float] + given doubleEncoder: Encoder[Double] + given byteArrayEncoder: Encoder[Array[Byte]] + given dateEncoder: Encoder[Date] + given localDateEncoder: Encoder[LocalDate] + given uuidEncoder: Encoder[UUID] +} diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala index c593153..ba099b0 100644 --- a/src/main/scala/minisql/context/sql/SqlIdiom.scala +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -26,9 +26,7 @@ trait SqlIdiom extends Idiom { protected def concatBehavior: ConcatBehavior = AnsiConcat protected def equalityBehavior: EqualityBehavior = AnsiEquality - - protected def actionAlias: Option[Ident] = None - + protected def actionAlias: Option[Ident] = None override def format(queryString: String): String = queryString def querifyAst(ast: Ast) = SqlQuery(ast) @@ -67,7 +65,7 @@ trait SqlIdiom extends Idiom { def defaultTokenizer(implicit naming: NamingStrategy): Tokenizer[Ast] = new Tokenizer[Ast] { - private val stableTokenizer = astTokenizer(this, naming) + private val stableTokenizer = astTokenizer(using this, naming) extension (v: Ast) { def token = stableTokenizer.token(v) @@ -249,7 +247,9 @@ trait SqlIdiom extends Idiom { } val customAstTokenizer = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + Tokenizer.withFallback[Ast]( + SqlIdiom.this.astTokenizer(using _, strategy) + ) { case Aggregation(op, Ident(_) | Tuple(_)) => stmt"${op.token}(*)" case Aggregation(op, Distinct(ast)) => stmt"${op.token}(DISTINCT ${ast.token})" @@ -257,7 +257,7 @@ trait SqlIdiom extends Idiom { case Aggregation(op, ast) => stmt"${op.token}(${ast.token})" } - tokenizer(customAstTokenizer) + tokenizer(using customAstTokenizer) } implicit def operationTokenizer(implicit @@ -528,14 +528,14 @@ trait SqlIdiom extends Idiom { case Entity.Opinionated(name, _, renameable) => stmt"INTO ${tokenizeTable(strategy, name, renameable).token}" } - actionTokenizer(insertEntityTokenizer)(actionAstTokenizer, strategy) + actionTokenizer(insertEntityTokenizer)(using actionAstTokenizer, strategy) } protected def actionAstTokenizer(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy ) = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(using _, strategy)) { case q: Query => astTokenizer.token(q) case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") => stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NULL" @@ -557,14 +557,16 @@ trait SqlIdiom extends Idiom { strategy: NamingStrategy ): Tokenizer[List[Ast]] = { val customAstTokenizer = - Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + Tokenizer.withFallback[Ast]( + SqlIdiom.this.astTokenizer(using _, strategy) + ) { case sq: Query => stmt"(${tokenizer.token(sq)})" } Tokenizer[List[Ast]] { case list => - list.mkStmt(", ")(customAstTokenizer) + list.mkStmt(", ")(using customAstTokenizer) } } @@ -653,7 +655,7 @@ object SqlIdiom { private[minisql] def copyIdiom( parent: SqlIdiom, newActionAlias: Option[Ident] - ) = + ): SqlIdiom = new SqlIdiom { override protected def actionAlias: Option[Ident] = newActionAlias override def prepareForProbing(string: String): String = @@ -678,10 +680,10 @@ object SqlIdiom { val idiom = copyIdiom(parentIdiom, Some(query.alias)) import idiom._ - implicit val stableTokenizer: Tokenizer[Ast] = idiom.astTokenizer( + implicit val stableTokenizer: Tokenizer[Ast] = idiom.astTokenizer(using new Tokenizer[Ast] { self => extension (v: Ast) { - def token = astTokenizer(self, strategy).token(v) + def token = astTokenizer(using self, strategy).token(v) } }, strategy @@ -695,6 +697,8 @@ object SqlIdiom { stmt"${action.token} RETURNING ${returnListTokenizer.token( ExpandReturning(r)(idiom, strategy).map(_._1) )}" + case r => + fail(s"Unsupported Returning construct: $r") } } } diff --git a/src/main/scala/minisql/context/sql/SqlQuery.scala b/src/main/scala/minisql/context/sql/SqlQuery.scala index 06ec412..2884662 100644 --- a/src/main/scala/minisql/context/sql/SqlQuery.scala +++ b/src/main/scala/minisql/context/sql/SqlQuery.scala @@ -19,7 +19,7 @@ case class FlatJoinContext(t: JoinType, a: FromContext, on: Ast) sealed trait SqlQuery { override def toString = { - import minisql.MirrorSqlDialect._ + import MirrorSqlDialect.* import minisql.idiom.StatementInterpolator.* given Tokenizer[SqlQuery] = sqlQueryTokenizer(using defaultTokenizer(using Literal), diff --git a/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala b/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala index 56095e9..8063c7f 100644 --- a/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala +++ b/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala @@ -66,7 +66,7 @@ class ExpandNestedQueries(strategy: NamingStrategy) { // Need to unhide properties that were used during the query def replaceProps(ast: Ast) = - BetaReduction(ast, replacedRefs: _*) + BetaReduction(ast, replacedRefs*) def replacePropsOption(ast: Option[Ast]) = ast.map(replaceProps(_)) diff --git a/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala b/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala index 30abb53..9b9b354 100644 --- a/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala +++ b/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala @@ -37,7 +37,7 @@ case class FlattenGroupByAggregation(agg: Ident) extends StatelessTransformer { super.apply(other) } - private[this] def isGroupByAggregation(ast: Ast): Boolean = + private def isGroupByAggregation(ast: Ast): Boolean = ast match { case Aggregation(a, b) => isGroupByAggregation(b) case Map(a, b, c) => isGroupByAggregation(a) diff --git a/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala b/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala index c239b63..bf80bf5 100644 --- a/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala +++ b/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala @@ -22,31 +22,31 @@ class SqlNormalize( ) { private val normalize = - (identity[Ast] _) + (identity[Ast]) .andThen(trace("original")) - .andThen(DemarcateExternalAliases.apply _) + .andThen(DemarcateExternalAliases.apply) .andThen(trace("DemarcateReturningAliases")) - .andThen(new FlattenOptionOperation(concatBehavior).apply _) + .andThen(new FlattenOptionOperation(concatBehavior).apply) .andThen(trace("FlattenOptionOperation")) - .andThen(new SimplifyNullChecks(equalityBehavior).apply _) + .andThen(new SimplifyNullChecks(equalityBehavior).apply) .andThen(trace("SimplifyNullChecks")) - .andThen(Normalize.apply _) + .andThen(Normalize.apply) .andThen(trace("Normalize")) // Need to do RenameProperties before ExpandJoin which normalizes-out all the tuple indexes // on which RenameProperties relies - .andThen(RenameProperties.apply _) + .andThen(RenameProperties.apply) .andThen(trace("RenameProperties")) - .andThen(ExpandDistinct.apply _) + .andThen(ExpandDistinct.apply) .andThen(trace("ExpandDistinct")) - .andThen(NestImpureMappedInfix.apply _) + .andThen(NestImpureMappedInfix.apply) .andThen(trace("NestMappedInfix")) - .andThen(Normalize.apply _) + .andThen(Normalize.apply) .andThen(trace("Normalize")) - .andThen(ExpandJoin.apply _) + .andThen(ExpandJoin.apply) .andThen(trace("ExpandJoin")) - .andThen(ExpandMappedInfix.apply _) + .andThen(ExpandMappedInfix.apply) .andThen(trace("ExpandMappedInfix")) - .andThen(Normalize.apply _) + .andThen(Normalize.apply) .andThen(trace("Normalize")) def apply(ast: Ast) = normalize(ast) diff --git a/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala b/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala index a8fd5d6..1b8818e 100644 --- a/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala +++ b/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala @@ -85,16 +85,16 @@ private class ExpandSelect( val orderedSelect = ref match { case pp @ Property(ast: Property, TupleIndex(idx)) => - trace"Reference is a sub-property of a tuple index: $idx. Walking inside." andReturn + trace"Reference is a sub-property of a tuple index: $idx. Walking inside." `andReturn` expandReference(ast) match { case OrderedSelect(o, SelectValue(Tuple(elems), alias, c)) => - trace"Expressing Element $idx of $elems " andReturn + trace"Expressing Element $idx of $elems " `andReturn` OrderedSelect( o :+ idx, SelectValue(elems(idx), concat(alias, idx), c) ) case OrderedSelect(o, SelectValue(ast, alias, c)) => - trace"Appending $idx to $alias " andReturn + trace"Appending $idx to $alias " `andReturn` OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) } case pp @ Property.Opinionated( @@ -103,7 +103,7 @@ private class ExpandSelect( renameable, visible ) => - trace"Reference is a sub-property. Walking inside." andReturn + trace"Reference is a sub-property. Walking inside." `andReturn` expandReference(ast) match { case OrderedSelect(o, SelectValue(ast, nested, c)) => // Alias is the name of the column after the naming strategy @@ -121,7 +121,7 @@ private class ExpandSelect( // this may need to change based on how distinct appends table names instead of just tuple indexes // into the property path. - trace"...inside walk completed, continuing to return: " andReturn + trace"...inside walk completed, continuing to return: " `andReturn` OrderedSelect( o, SelectValue( @@ -135,7 +135,7 @@ private class ExpandSelect( ) } case pp @ Property(_, TupleIndex(idx)) => - trace"Reference is a tuple index: $idx from $select." andReturn + trace"Reference is a tuple index: $idx from $select." `andReturn` select(idx) match { case OrderedSelect(o, SelectValue(ast, alias, c)) => OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) @@ -155,13 +155,13 @@ private class ExpandSelect( s"Cannot find element $name in $cc" ) } - trace"Reference is a case class member: " andReturn + trace"Reference is a case class member: " `andReturn` OrderedSelect( o :+ index, SelectValue(ast, Some(expandColumn(name, renameable)), c) ) case List(OrderedSelect(o, SelectValue(i: Ident, _, c))) => - trace"Reference is an identifier: " andReturn + trace"Reference is an identifier: " `andReturn` OrderedSelect( o, SelectValue( @@ -171,7 +171,7 @@ private class ExpandSelect( ) ) case other => - trace"Reference is unidentified: $other returning:" andReturn + trace"Reference is unidentified: $other returning:" `andReturn` OrderedSelect( Integer.MAX_VALUE, SelectValue( @@ -191,7 +191,7 @@ private class ExpandSelect( ) ) - trace"Expanded $ref into $orderedSelect then Normalized to $normalizedOrderedSelect" andReturn + trace"Expanded $ref into $orderedSelect then Normalized to $normalizedOrderedSelect" `andReturn` normalizedOrderedSelect } @@ -201,7 +201,7 @@ private class ExpandSelect( _, sv @ SelectValue(Property(Ident(_), propName), Some(alias), _) ) if (propName == alias) => - trace"Detected select value with un-needed alias: $os removing it:" andReturn + trace"Detected select value with un-needed alias: $os removing it:" `andReturn` os.copy(selectValue = sv.copy(alias = None)) case _ => os } @@ -224,7 +224,7 @@ private class ExpandSelect( // are there any selects that have infix values which we have not already selected? We need to include // them because they could be doing essential things e.g. RANK ... ORDER BY val remainingSelectsWithInfixes = - trace"Searching Selects with Infix:" andReturn + trace"Searching Selects with Infix:" `andReturn` new FindUnexpressedInfixes(select)(mappedRefs) implicit val ordering: scala.math.Ordering[List[Int]] = diff --git a/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala b/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala index 2ea1320..7149ea2 100644 --- a/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala +++ b/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala @@ -47,7 +47,7 @@ class FindUnexpressedInfixes(select: List[OrderedSelect]) { ): List[(Ast, List[Int])] = { trace"Searching for infix: $ast in the sub-path $parentOrder".andLog() if (pathExists(parentOrder)) - trace"No infixes found" andContinue + trace"No infixes found" `andContinue` List() else ast match { diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala index b325b7d..fd18549 100644 --- a/src/main/scala/minisql/idiom/MirrorIdiom.scala +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -1,8 +1,9 @@ -package minisql +package minisql.idiom +import minisql.NamingStrategy import minisql.ast.Renameable.{ByStrategy, Fixed} import minisql.ast.Visibility.Hidden -import minisql.ast._ +import minisql.ast.* import minisql.context.CanReturnClause import minisql.idiom.{Idiom, SetContainsToken, Statement} import minisql.idiom.StatementInterpolator.* diff --git a/src/test/scala/minisql/parsing/QuotedSuite.scala b/src/test/scala/minisql/mirror/QuotedSuite.scala similarity index 65% rename from src/test/scala/minisql/parsing/QuotedSuite.scala rename to src/test/scala/minisql/mirror/QuotedSuite.scala index d2f8981..65b9cf3 100644 --- a/src/test/scala/minisql/parsing/QuotedSuite.scala +++ b/src/test/scala/minisql/mirror/QuotedSuite.scala @@ -1,20 +1,20 @@ -package minisql.parsing +package minisql.context.mirror import minisql.* import minisql.ast.* import minisql.idiom.* import minisql.NamingStrategy import minisql.MirrorContext -import minisql.MirrorIdiom import minisql.context.mirror.{*, given} class QuotedSuite extends munit.FunSuite { - val ctx = new MirrorContext(MirrorIdiom, SnakeCase) case class Foo(id: Long) + import mirrorContext.given + test("SimpleQuery") { - val o = ctx.io(query[Foo]("foo").filter(_.id > 0)) + val o = mirrorContext.io(query[Foo]("foo").filter(_.id > 0)) println("============" + o) o } diff --git a/src/test/scala/minisql/mirror/context.scala b/src/test/scala/minisql/mirror/context.scala new file mode 100644 index 0000000..240a475 --- /dev/null +++ b/src/test/scala/minisql/mirror/context.scala @@ -0,0 +1,6 @@ +package minisql.context.mirror + +import minisql.* +import minisql.idiom.MirrorIdiom + +val mirrorContext = new MirrorContext(MirrorIdiom, Literal) From 2b52ef32036937116306f8db335d57c7758bf064 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 22 Jun 2025 14:27:15 +0800 Subject: [PATCH 13/26] Add property alist --- src/main/scala/minisql/Quoted.scala | 25 ++++++++++++++++--- .../minisql/idiom/StatementInterpolator.scala | 21 ---------------- .../scala/minisql/mirror/QuotedSuite.scala | 7 +++++- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 3ffb886..a01ed3c 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -4,8 +4,18 @@ import minisql.* import minisql.idiom.* import minisql.parsing.* import minisql.util.* -import minisql.ast.{Ast, Entity, Map, Property, Ident, Filter, given} +import minisql.ast.{ + Ast, + Entity, + Map, + Property, + Ident, + Filter, + PropertyAlias, + given +} import scala.quoted.* +import scala.deriving.* import scala.compiletime.* import scala.compiletime.ops.string.* import scala.collection.immutable.{Map => IMap} @@ -36,8 +46,17 @@ private inline def transform[A, B](inline q1: Quoted)( fast(q1, f.param0, f.body) } -inline def query[E](inline table: String): EntityQuery[E] = - Entity(table, Nil) +inline def alias(inline from: String, inline to: String): PropertyAlias = + PropertyAlias(List(from), to) + +inline def query[E]( + inline table: String, + inline alias: PropertyAlias* +): EntityQuery[E] = + Entity( + table, + List(alias*) + ) extension [A, B](inline f1: A => B) { private inline def param0 = parsing.parseParamAt(f1, 0) diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala index 3aa4d26..056d58f 100644 --- a/src/main/scala/minisql/idiom/StatementInterpolator.scala +++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala @@ -43,27 +43,6 @@ object StatementInterpolator { implicit def liftTokenizer: Tokenizer[Lift] = Tokenizer[Lift] { case lift: ScalarLift => ScalarLiftToken(lift) - case lift => - fail( - s"Can't tokenize a non-scalar lifting. ${lift.name}\n" + - s"\n" + - s"This might happen because:\n" + - s"* You are trying to insert or update an `Option[A]` field, but Scala infers the type\n" + - s" to `Some[A]` or `None.type`. For example:\n" + - s" run(query[Users].update(_.optionalField -> lift(Some(value))))" + - s" In that case, make sure the type is `Option`:\n" + - s" run(query[Users].update(_.optionalField -> lift(Some(value): Option[Int])))\n" + - s" or\n" + - s" run(query[Users].update(_.optionalField -> lift(Option(value))))\n" + - s"\n" + - s"* You are trying to insert or update whole Embedded case class. For example:\n" + - s" run(query[Users].update(_.embeddedCaseClass -> lift(someInstance)))\n" + - s" In that case, make sure you are updating individual columns, for example:\n" + - s" run(query[Users].update(\n" + - s" _.embeddedCaseClass.a -> lift(someInstance.a),\n" + - s" _.embeddedCaseClass.b -> lift(someInstance.b)\n" + - s" ))" - ) } implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity) diff --git a/src/test/scala/minisql/mirror/QuotedSuite.scala b/src/test/scala/minisql/mirror/QuotedSuite.scala index 65b9cf3..1d89348 100644 --- a/src/test/scala/minisql/mirror/QuotedSuite.scala +++ b/src/test/scala/minisql/mirror/QuotedSuite.scala @@ -14,7 +14,12 @@ class QuotedSuite extends munit.FunSuite { import mirrorContext.given test("SimpleQuery") { - val o = mirrorContext.io(query[Foo]("foo").filter(_.id > 0)) + val o = mirrorContext.io( + query[Foo]( + "foo", + alias("id", "id1") + ).filter(_.id > 0) + ) println("============" + o) o } From 184ab0b884556b91de7ebc710009162dbb16338e Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 22 Jun 2025 20:45:26 +0800 Subject: [PATCH 14/26] Add insert placeholder --- src/main/scala/minisql/Quoted.scala | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index a01ed3c..b1ddb0c 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -24,9 +24,33 @@ opaque type Quoted <: Ast = Ast opaque type Query[E] <: Quoted = Quoted +opaque type Action[E] <: Quoted = Quoted + +opaque type Insert <: Action[Long] = Quoted + +object Query { + + extension [E](inline e: Query[E]) { + + inline def map[E1](inline f: E => E1): Query[E1] = { + transform(e)(f)(Map.apply) + } + + inline def filter(inline f: E => Boolean): Query[E] = { + transform(e)(f)(Filter.apply) + } + + inline def withFilter(inline f: E => Boolean): Query[E] = { + transform(e)(f)(Filter.apply) + } + + } +} + opaque type EntityQuery[E] <: Query[E] = Query[E] object EntityQuery { + extension [E](inline e: EntityQuery[E]) { inline def map[E1](inline f: E => E1): EntityQuery[E1] = { @@ -37,6 +61,10 @@ object EntityQuery { transform(e)(f)(Filter.apply) } + inline def insert(v: E)(using m: Mirror.ProductOf[E]): Insert = { + ??? + } + } } From 3a9d15f015eb77ada76d70592d080d521c866d79 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 22 Jun 2025 21:20:23 +0800 Subject: [PATCH 15/26] Try add insert support --- src/main/scala/minisql/Quoted.scala | 68 +++++++++++++++++-- .../scala/minisql/mirror/QuotedSuite.scala | 4 ++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index b1ddb0c..af3b4c6 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -28,6 +28,28 @@ opaque type Action[E] <: Quoted = Quoted opaque type Insert <: Action[Long] = Quoted +private inline def quotedLift[X](x: X)(using + e: ParamEncoder[X] +): ast.ScalarValueLift = ${ + quotedLiftImpl[X]('x, 'e) +} + +private def quotedLiftImpl[X: Type]( + x: Expr[X], + e: Expr[ParamEncoder[X]] +)(using Quotes): Expr[ast.ScalarValueLift] = { + import quotes.reflect.* + val name = x.asTerm.symbol.fullName + val liftId = x.asTerm.symbol.owner.fullName + "@" + name + '{ + ast.ScalarValueLift( + ${ Expr(name) }, + ${ Expr(liftId) }, + Some(($x, $e)) + ) + } +} + object Query { extension [E](inline e: Query[E]) { @@ -62,12 +84,50 @@ object EntityQuery { } inline def insert(v: E)(using m: Mirror.ProductOf[E]): Insert = { - ??? + val entity = e.asInstanceOf[ast.Entity] + val assignments = transformCaseClassToAssignments[E](v, entity.name) + ast.Insert(entity, assignments) } - } } +private inline def transformCaseClassToAssignments[E]( + v: E, + entityName: String +)(using m: Mirror.ProductOf[E]): List[ast.Assignment] = ${ + transformCaseClassToAssignmentsImpl[E]('v, 'entityName) +} + +private def transformCaseClassToAssignmentsImpl[E: Type]( + v: Expr[E], + entityName: Expr[String] +)(using Quotes): Expr[List[ast.Assignment]] = { + import quotes.reflect.* + + val fields = TypeRepr.of[E].typeSymbol.caseFields + val assignments = fields.map { field => + val fieldName = field.name + val fieldType = field.tree match { + case v: ValDef => v.tpt.tpe + case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName") + } + fieldType.asType match { + case '[t] => + '{ + ast.Assignment( + ast.Ident($entityName), + ast.Property(ast.Ident($entityName), ${ Expr(fieldName) }), + quotedLift[t](${ Select.unique(v.asTerm, fieldName).asExprOf[t] })( + using summonInline[ParamEncoder[t]] + ) + ) + } + } + } + + Expr.ofList(assignments) +} + private inline def transform[A, B](inline q1: Quoted)( inline f: A => B )(inline fast: (Ast, Ident, Ast) => Ast): Quoted = { @@ -102,9 +162,7 @@ def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() class NonQuotedException extends Exception("Cannot be used at runtime") private[minisql] inline def compileTimeAst(inline q: Quoted): Option[String] = - ${ - compileTimeAstImpl('q) - } + ${ compileTimeAstImpl('q) } private def compileTimeAstImpl(e: Expr[Quoted])(using Quotes diff --git a/src/test/scala/minisql/mirror/QuotedSuite.scala b/src/test/scala/minisql/mirror/QuotedSuite.scala index 1d89348..e8f4b39 100644 --- a/src/test/scala/minisql/mirror/QuotedSuite.scala +++ b/src/test/scala/minisql/mirror/QuotedSuite.scala @@ -24,4 +24,8 @@ class QuotedSuite extends munit.FunSuite { o } + test("Insert") { + ??? + } + } From 24f7f6aec04296e6715b9dd7209216a6bfa819a5 Mon Sep 17 00:00:00 2001 From: jilen Date: Fri, 27 Jun 2025 19:50:11 +0800 Subject: [PATCH 16/26] Convert to using --- build.sbt | 2 ++ src/main/scala/minisql/Quoted.scala | 26 +++++++------- src/main/scala/minisql/ast/FromExprs.scala | 3 +- src/main/scala/minisql/context/Context.scala | 31 +++++++++++++++-- .../scala/minisql/context/sql/SqlIdiom.scala | 8 ++--- .../scala/minisql/idiom/MirrorIdiom.scala | 2 +- .../scala/minisql/idiom/ReifyStatement.scala | 5 +-- .../minisql/idiom/StatementInterpolator.scala | 34 +++++++++---------- .../scala/minisql/parsing/LiftParsing.scala | 5 +-- src/main/scala/minisql/util/CollectTry.scala | 2 +- .../scala/minisql/util/QuotesHelper.scala | 24 +++++++++++++ .../{mirror => context/sql}/QuotedSuite.scala | 18 ++++++---- .../scala/minisql/context/sql/context.scala | 5 +++ src/test/scala/minisql/mirror/context.scala | 6 ---- 14 files changed, 112 insertions(+), 59 deletions(-) create mode 100644 src/main/scala/minisql/util/QuotesHelper.scala rename src/test/scala/minisql/{mirror => context/sql}/QuotedSuite.scala (57%) create mode 100644 src/test/scala/minisql/context/sql/context.scala delete mode 100644 src/test/scala/minisql/mirror/context.scala diff --git a/build.sbt b/build.sbt index 86502a5..d869492 100644 --- a/build.sbt +++ b/build.sbt @@ -6,6 +6,8 @@ libraryDependencies ++= Seq( "org.scalameta" %% "munit" % "1.0.3" % Test ) +javaOptions ++= Seq("-Xss16m") + scalacOptions ++= Seq( "-deprecation", "-feature", diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index af3b4c6..3256d00 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -39,8 +39,8 @@ private def quotedLiftImpl[X: Type]( e: Expr[ParamEncoder[X]] )(using Quotes): Expr[ast.ScalarValueLift] = { import quotes.reflect.* - val name = x.asTerm.symbol.fullName - val liftId = x.asTerm.symbol.owner.fullName + "@" + name + val name = x.asTerm.show + val liftId = liftIdOfExpr(x) '{ ast.ScalarValueLift( ${ Expr(name) }, @@ -84,23 +84,19 @@ object EntityQuery { } inline def insert(v: E)(using m: Mirror.ProductOf[E]): Insert = { - val entity = e.asInstanceOf[ast.Entity] - val assignments = transformCaseClassToAssignments[E](v, entity.name) - ast.Insert(entity, assignments) + ast.Insert(e, transformCaseClassToAssignments[E](v)) } } } private inline def transformCaseClassToAssignments[E]( - v: E, - entityName: String + v: E )(using m: Mirror.ProductOf[E]): List[ast.Assignment] = ${ - transformCaseClassToAssignmentsImpl[E]('v, 'entityName) + transformCaseClassToAssignmentsImpl[E]('v) } private def transformCaseClassToAssignmentsImpl[E: Type]( - v: Expr[E], - entityName: Expr[String] + v: Expr[E] )(using Quotes): Expr[List[ast.Assignment]] = { import quotes.reflect.* @@ -115,10 +111,10 @@ private def transformCaseClassToAssignmentsImpl[E: Type]( case '[t] => '{ ast.Assignment( - ast.Ident($entityName), - ast.Property(ast.Ident($entityName), ${ Expr(fieldName) }), - quotedLift[t](${ Select.unique(v.asTerm, fieldName).asExprOf[t] })( - using summonInline[ParamEncoder[t]] + ast.Ident("v"), + ast.Property(ast.Ident("v"), ${ Expr(fieldName) }), + quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using + summonInline[ParamEncoder[t]] ) ) } @@ -186,8 +182,10 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( n: Expr[N] )(using Quotes, Type[I], Type[N]): Expr[Statement] = { import quotes.reflect.* + println(s"Start q.value") q.value match { case Some(ast) => + println(s"Finish q.value: ${ast}") val idiom = LoadObject[I].getOrElse( report.errorAndAbort(s"Idiom not known at compile") ) diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 9f70b0d..e527a6f 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -46,8 +46,7 @@ private given FromExpr[ScalarValueLift] with { def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] = x match { case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } => - // don't cared about value here, a little tricky - Some(ScalarValueLift(n, id, null)) + Some(ScalarValueLift(n, id, None)) } } diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 6f6bea5..c469064 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -1,13 +1,14 @@ package minisql.context -import scala.deriving.* -import scala.compiletime.* -import scala.util.Try import minisql.util.* import minisql.idiom.{Idiom, Statement, ReifyStatement} import minisql.{NamingStrategy, ParamEncoder} import minisql.ColumnDecoder import minisql.ast.{Ast, ScalarValueLift, CollectAst} +import scala.deriving.* +import scala.compiletime.* +import scala.util.Try +import scala.annotation.targetName trait RowExtract[A, Row] { def extract(row: Row): Try[A] @@ -89,6 +90,30 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => ) } + @targetName("ioAction") + inline def io[E](inline q: minisql.Action[E]): DBIO[E] = { + val extractor = summonFrom { + case e: RowExtract[E, DBRow] => e + case e: ColumnDecoder.Aux[DBRow, E] => + RowExtract.single(e) + } + + val lifts = q.liftMap + val stmt = minisql.compile(q, idiom, naming) + val (sql, params) = stmt.expand(lifts) + ( + sql = sql, + params = params.map(_.value.get.asInstanceOf[(Any, Encoder[?])]), + mapper = (rows) => + rows + .traverse(extractor.extract) + .flatMap( + _.headOption.toRight(new Exception(s"No value return")).toTry + ) + ) + } + + @targetName("ioQuery") inline def io[E]( inline q: minisql.Query[E] ): DBIO[IArray[E]] = { diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala index ba099b0..dffd56b 100644 --- a/src/main/scala/minisql/context/sql/SqlIdiom.scala +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -31,13 +31,13 @@ trait SqlIdiom extends Idiom { def querifyAst(ast: Ast) = SqlQuery(ast) - private def doTranslate(ast: Ast, cached: Boolean)(implicit + private def doTranslate(ast: Ast, cached: Boolean)(using naming: NamingStrategy ): (Ast, Statement) = { val normalizedAst = SqlNormalize(ast, concatBehavior, equalityBehavior) - implicit val tokernizer: Tokenizer[Ast] = defaultTokenizer + given Tokenizer[Ast] = defaultTokenizer val token = normalizedAst match { @@ -63,7 +63,7 @@ trait SqlIdiom extends Idiom { doTranslate(ast, false) } - def defaultTokenizer(implicit naming: NamingStrategy): Tokenizer[Ast] = + def defaultTokenizer(using naming: NamingStrategy): Tokenizer[Ast] = new Tokenizer[Ast] { private val stableTokenizer = astTokenizer(using this, naming) @@ -73,7 +73,7 @@ trait SqlIdiom extends Idiom { } - def astTokenizer(implicit + def astTokenizer(using astTokenizer: Tokenizer[Ast], strategy: NamingStrategy ): Tokenizer[Ast] = diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala index fd18549..1507919 100644 --- a/src/main/scala/minisql/idiom/MirrorIdiom.scala +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -305,7 +305,7 @@ trait MirrorIdiomBase extends Idiom { Tokenizer[OnConflict.Target] { case OnConflict.NoTarget => stmt"" case OnConflict.Properties(props) => - val listTokens = listTokenizer(using astTokenizer).token(props) + val listTokens = props.token stmt"(${listTokens})" } diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala index 7a4a07a..4206238 100644 --- a/src/main/scala/minisql/idiom/ReifyStatement.scala +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -16,11 +16,12 @@ object ReifyStatement { liftMap: SMap[String, (Any, ParamEncoder[?])] ): (String, List[ScalarValueLift]) = { val expanded = expandLiftings(statement, emptySetContainsToken, liftMap) - token2string(expanded, liftingPlaceholder) + token2string(expanded, liftMap, liftingPlaceholder) } private def token2string( token: Token, + liftMap: SMap[String, (Any, ParamEncoder[?])], liftingPlaceholder: Int => String ): (String, List[ScalarValueLift]) = { @@ -44,7 +45,7 @@ object ReifyStatement { ) case ScalarLiftToken(lift: ScalarValueLift) => sqlBuilder ++= liftingPlaceholder(liftingSize) - liftBuilder += lift + liftBuilder += lift.copy(value = liftMap.get(lift.liftId)) loop(tail, liftingSize + 1) case ScalarLiftToken(o) => throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}") diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala index 056d58f..2893c8d 100644 --- a/src/main/scala/minisql/idiom/StatementInterpolator.scala +++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala @@ -8,12 +8,20 @@ import scala.collection.mutable.ListBuffer object StatementInterpolator { + extension [T](list: List[T]) { + private[minisql] def mkStmt( + sep: String = ", " + )(using tokenize: Tokenizer[T]) = { + val l1 = list.map(_.token) + val l2 = List.fill(l1.size - 1)(StringToken(sep)) + Statement(Interleave(l1, l2)) + } + } trait Tokenizer[T] { extension (v: T) { def token: Token } } - object Tokenizer { def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] { extension (v: T) { @@ -31,37 +39,29 @@ object StatementInterpolator { } } - implicit class TokenImplicit[T](v: T)(implicit tokenizer: Tokenizer[T]) { + extension [T](v: T)(using tokenizer: Tokenizer[T]) { def token = tokenizer.token(v) } - implicit def stringTokenizer: Tokenizer[String] = + given stringTokenizer: Tokenizer[String] = Tokenizer[String] { case string => StringToken(string) } - implicit def liftTokenizer: Tokenizer[Lift] = + given liftTokenizer: Tokenizer[Lift] = Tokenizer[Lift] { case lift: ScalarLift => ScalarLiftToken(lift) } - implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity) - implicit def statementTokenizer: Tokenizer[Statement] = + given tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity) + given statementTokenizer: Tokenizer[Statement] = Tokenizer[Statement](identity) - implicit def stringTokenTokenizer: Tokenizer[StringToken] = + given stringTokenTokenizer: Tokenizer[StringToken] = Tokenizer[StringToken](identity) - implicit def liftingTokenTokenizer: Tokenizer[ScalarLiftToken] = + given liftingTokenTokenizer: Tokenizer[ScalarLiftToken] = Tokenizer[ScalarLiftToken](identity) - extension [T](list: List[T]) { - def mkStmt(sep: String = ", ")(implicit tokenize: Tokenizer[T]) = { - val l1 = list.map(_.token) - val l2 = List.fill(l1.size - 1)(StringToken(sep)) - Statement(Interleave(l1, l2)) - } - } - - implicit def listTokenizer[T](implicit + given listTokenizer[T](using tokenize: Tokenizer[T] ): Tokenizer[List[T]] = Tokenizer[List[T]] { diff --git a/src/main/scala/minisql/parsing/LiftParsing.scala b/src/main/scala/minisql/parsing/LiftParsing.scala index 9a0f32b..f0aba8a 100644 --- a/src/main/scala/minisql/parsing/LiftParsing.scala +++ b/src/main/scala/minisql/parsing/LiftParsing.scala @@ -4,13 +4,14 @@ import scala.quoted.* import minisql.ParamEncoder import minisql.ast import minisql.* +import minisql.util.* private[parsing] def liftParsing( astParser: => Parser[ast.Ast] )(using Quotes): Parser[ast.Lift] = { case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } => import quotes.reflect.* - val name = x.asTerm.symbol.fullName - val liftId = x.asTerm.symbol.owner.fullName + "@" + name + val name = x.show + val liftId = liftIdOfExpr(x) '{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) } } diff --git a/src/main/scala/minisql/util/CollectTry.scala b/src/main/scala/minisql/util/CollectTry.scala index 6b466be..f4572b8 100644 --- a/src/main/scala/minisql/util/CollectTry.scala +++ b/src/main/scala/minisql/util/CollectTry.scala @@ -20,7 +20,7 @@ extension [A](xs: Iterable[A]) { } } -object CollectTry { +private[minisql] object CollectTry { def apply[T](list: List[Try[T]]): Try[List[T]] = list.foldLeft(Try(List.empty[T])) { case (list, t) => diff --git a/src/main/scala/minisql/util/QuotesHelper.scala b/src/main/scala/minisql/util/QuotesHelper.scala new file mode 100644 index 0000000..6ecbc76 --- /dev/null +++ b/src/main/scala/minisql/util/QuotesHelper.scala @@ -0,0 +1,24 @@ +package minisql.util + +import scala.quoted.* + +private[minisql] def splicePkgPath(using Quotes) = { + import quotes.reflect.* + def recurse(sym: Symbol): String = + sym match { + case s if s.isPackageDef => s.fullName + case s if s.isNoSymbol => "" + case _ => + recurse(sym.maybeOwner) + } + recurse(Symbol.spliceOwner) +} + +private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = { + import quotes.reflect.* + val name = x.asTerm.show + val packageName = splicePkgPath + val pos = x.asTerm.pos + val fileName = pos.sourceFile.name + s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}" +} diff --git a/src/test/scala/minisql/mirror/QuotedSuite.scala b/src/test/scala/minisql/context/sql/QuotedSuite.scala similarity index 57% rename from src/test/scala/minisql/mirror/QuotedSuite.scala rename to src/test/scala/minisql/context/sql/QuotedSuite.scala index e8f4b39..c5fda24 100644 --- a/src/test/scala/minisql/mirror/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/QuotedSuite.scala @@ -1,4 +1,4 @@ -package minisql.context.mirror +package minisql.context.sql import minisql.* import minisql.ast.* @@ -9,23 +9,27 @@ import minisql.context.mirror.{*, given} class QuotedSuite extends munit.FunSuite { - case class Foo(id: Long) + case class Foo(id: Long, name: String) - import mirrorContext.given + inline def Foos = query[Foo]("foo") + + import testContext.given test("SimpleQuery") { - val o = mirrorContext.io( + val o = testContext.io( query[Foo]( "foo", alias("id", "id1") ).filter(_.id > 0) ) - println("============" + o) - o + println(o) } test("Insert") { + val v: Foo = Foo(0L, "foo") + + val o = testContext.io(Foos.insert(v)) + println(o) ??? } - } diff --git a/src/test/scala/minisql/context/sql/context.scala b/src/test/scala/minisql/context/sql/context.scala new file mode 100644 index 0000000..d3d36fa --- /dev/null +++ b/src/test/scala/minisql/context/sql/context.scala @@ -0,0 +1,5 @@ +package minisql.context.sql + +import minisql.* + +val testContext = new MirrorSqlContext(Literal) diff --git a/src/test/scala/minisql/mirror/context.scala b/src/test/scala/minisql/mirror/context.scala deleted file mode 100644 index 240a475..0000000 --- a/src/test/scala/minisql/mirror/context.scala +++ /dev/null @@ -1,6 +0,0 @@ -package minisql.context.mirror - -import minisql.* -import minisql.idiom.MirrorIdiom - -val mirrorContext = new MirrorContext(MirrorIdiom, Literal) From 2753f01001e9c157cc0612afd3866decae7c9e04 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 29 Jun 2025 10:25:38 +0800 Subject: [PATCH 17/26] fix naming --- src/main/scala/minisql/Quoted.scala | 2 -- src/main/scala/minisql/context/Context.scala | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 3256d00..7abc115 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -182,10 +182,8 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( n: Expr[N] )(using Quotes, Type[I], Type[N]): Expr[Statement] = { import quotes.reflect.* - println(s"Start q.value") q.value match { case Some(ast) => - println(s"Finish q.value: ${ast}") val idiom = LoadObject[I].getOrElse( report.errorAndAbort(s"Idiom not known at compile") ) diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index c469064..875d11e 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -55,7 +55,7 @@ object RowExtract { trait Context[I <: Idiom, N <: NamingStrategy] { selft => val idiom: I - val naming: NamingStrategy + val naming: N type DBStatement type DBRow @@ -99,7 +99,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => } val lifts = q.liftMap - val stmt = minisql.compile(q, idiom, naming) + val stmt = minisql.compile[I, N](q, idiom, naming) val (sql, params) = stmt.expand(lifts) ( sql = sql, @@ -125,7 +125,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => } val lifts = q.liftMap - val stmt = minisql.compile(q, idiom, naming) + val stmt = minisql.compile[I, N](q, idiom, naming) val (sql, params) = stmt.expand(lifts) ( sql = sql, From a1201a67aa6dbd27ab075a76582ef255d5268432 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 29 Jun 2025 16:12:12 +0800 Subject: [PATCH 18/26] Add more test case. Expand query elements --- src/main/scala/minisql/Quoted.scala | 24 +++- src/main/scala/minisql/ast/Ast.scala | 15 ++- src/main/scala/minisql/ast/FromExprs.scala | 7 +- src/main/scala/minisql/context/Context.scala | 10 +- .../scala/minisql/ast/FromExprsSuite.scala | 127 ++++++++++++++++++ 5 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 src/test/scala/minisql/ast/FromExprsSuite.scala diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 7abc115..3d4164d 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -52,8 +52,14 @@ private def quotedLiftImpl[X: Type]( object Query { + private[minisql] inline def apply[E](inline ast: Ast): Query[E] = ast + extension [E](inline e: Query[E]) { + private[minisql] inline def expanded: Query[E] = { + Query(expandFields[E](e)) + } + inline def map[E1](inline f: E => E1): Query[E1] = { transform(e)(f)(Map.apply) } @@ -157,10 +163,10 @@ def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() class NonQuotedException extends Exception("Cannot be used at runtime") -private[minisql] inline def compileTimeAst(inline q: Quoted): Option[String] = +private[minisql] inline def compileTimeAst(inline q: Ast): Option[String] = ${ compileTimeAstImpl('q) } -private def compileTimeAstImpl(e: Expr[Quoted])(using +private def compileTimeAstImpl(e: Expr[Ast])(using Quotes ): Expr[Option[String]] = { import quotes.reflect.* @@ -203,3 +209,17 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( } } + +private inline def expandFields[E](inline base: Ast): Ast = + ${ expandFieldsImpl[E]('base) } + +private def expandFieldsImpl[E](baseExpr: Expr[Ast])(using + Quotes, + Type[E] +): Expr[Ast] = { + import quotes.reflect.* + val values = TypeRepr.of[E].typeSymbol.caseFields.map { f => + '{ Property(ast.Ident("x"), ${ Expr(f.name) }) } + } + '{ Map(${ baseExpr }, ast.Ident("x"), ast.Tuple(${ Expr.ofList(values) })) } +} diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 52446e3..86407e1 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -59,9 +59,9 @@ object Entity { object Opinionated { inline def apply( - name: String, - properties: List[PropertyAlias], - renameableNew: Renameable + inline name: String, + inline properties: List[PropertyAlias], + inline renameableNew: Renameable ): Entity = Entity(name, properties, renameableNew) def unapply(e: Entity) = @@ -154,11 +154,14 @@ case class Ident(name: String, visibility: Visibility) extends Ast { * ExpandNestedQueries phase, needs to be marked invisible. */ object Ident { - def apply(name: String): Ident = Ident(name, Visibility.neutral) - def unapply(p: Ident) = Some((p.name)) + inline def apply(inline name: String): Ident = Ident(name, Visibility.neutral) + def unapply(p: Ident) = Some((p.name)) object Opinionated { - def apply(name: String, visibilityNew: Visibility): Ident = + inline def apply( + inline name: String, + inline visibilityNew: Visibility + ): Ident = Ident(name, visibilityNew) def unapply(p: Ident) = Some((p.name, p.visibility)) diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index e527a6f..c6d48cc 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -52,7 +52,7 @@ private given FromExpr[ScalarValueLift] with { private given FromExpr[Ident] with { def unapply(x: Expr[Ident])(using Quotes): Option[Ident] = x match { - case '{ Ident(${ Expr(n) }) } => Some(Ident(n)) + case '{ Ident(${ Expr(n) }, ${ Expr(v) }) } => Some(Ident(n, v)) } } @@ -136,7 +136,7 @@ private given FromExpr[Query] with { case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } => Some(SortBy(b, p, s, o)) case o => - println(s"Cannot extract ${o}") + println(s"Cannot extract ${o.show}") None } } @@ -274,10 +274,11 @@ private def extractTerm(using Quotes)(x: quotes.reflect.Term) = { } extension (e: Expr[Any]) { - def toTerm(using Quotes) = { + private def toTerm(using Quotes) = { import quotes.reflect.* e.asTerm } + } private def fromBlock(using diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 875d11e..bf945c3 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -118,14 +118,14 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => inline q: minisql.Query[E] ): DBIO[IArray[E]] = { - val extractor = summonFrom { - case e: RowExtract[E, DBRow] => e + val (stmt, extractor) = summonFrom { + case e: RowExtract[E, DBRow] => + minisql.compile[I, N](q.expanded, idiom, naming) -> e case e: ColumnDecoder.Aux[DBRow, E] => - RowExtract.single(e) - } + minisql.compile[I, N](q, idiom, naming) -> RowExtract.single(e) + }: @unchecked val lifts = q.liftMap - val stmt = minisql.compile[I, N](q, idiom, naming) val (sql, params) = stmt.expand(lifts) ( sql = sql, diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala new file mode 100644 index 0000000..ea6d14b --- /dev/null +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -0,0 +1,127 @@ +package minisql.ast + +import munit.FunSuite +import minisql.ast.* +import scala.quoted.* + +class FromExprsSuite extends FunSuite { + + // Helper to test both compile-time and runtime extraction + inline def testFor[A <: Ast](label: String)(inline ast: A) = { + test(label) { + // Test compile-time extraction + val compileTimeResult = minisql.compileTimeAst(ast) + assert(compileTimeResult.contains(ast.toString)) + } + } + + testFor("Ident") { + Ident("test") + } + + testFor("Ident with visibility") { + Ident.Opinionated("test", Visibility.Hidden) + } + + testFor("Property") { + Property(Ident("a"), "b") + } + + testFor("Property with opinions") { + Property.Opinionated(Ident("a"), "b", Renameable.Fixed, Visibility.Visible) + } + + testFor("BinaryOperation") { + BinaryOperation(Ident("a"), EqualityOperator.==, Ident("b")) + } + + testFor("UnaryOperation") { + UnaryOperation(BooleanOperator.!, Ident("flag")) + } + + testFor("ScalarValueLift") { + ScalarValueLift("name", "id", None) + } + + testFor("Ordering") { + Asc + } + + testFor("TupleOrdering") { + TupleOrdering(List(Asc, Desc)) + } + + testFor("Entity") { + Entity("people", Nil) + } + + testFor("Entity with properties") { + Entity("people", List(PropertyAlias(List("name"), "full_name"))) + } + + testFor("Action - Insert") { + Insert( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ) + } + + testFor("Action - Update") { + Update( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ) + } + + testFor("If expression") { + If(Ident("cond"), Ident("then"), Ident("else")) + } + + testFor("Infix") { + Infix( + List("func(", ")"), + List(Ident("param")), + pure = true, + noParen = false + ) + } + + testFor("OptionOperation - OptionMap") { + OptionMap(Ident("opt"), Ident("x"), Ident("x")) + } + + testFor("OptionOperation - OptionFlatMap") { + OptionFlatMap(Ident("opt"), Ident("x"), Ident("x")) + } + + testFor("OptionOperation - OptionGetOrElse") { + OptionGetOrElse(Ident("opt"), Ident("default")) + } + + testFor("Join") { + Join( + InnerJoin, + Ident("a"), + Ident("b"), + Ident("a1"), + Ident("b1"), + BinaryOperation(Ident("a1.id"), EqualityOperator.==, Ident("b1.id")) + ) + } + + testFor("Distinct") { + Distinct(Ident("query")) + } + + testFor("GroupBy") { + GroupBy(Ident("query"), Ident("alias"), Ident("body")) + } + + testFor("Aggregation") { + Aggregation(AggregationOperator.avg, Ident("field")) + } + + testFor("CaseClass") { + CaseClass(List(("name", Ident("value")))) + } +} From 23c048460904b16c364c29d3e7cae386844dce90 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 29 Jun 2025 17:02:18 +0800 Subject: [PATCH 19/26] More instance --- src/main/scala/minisql/ast/FromExprs.scala | 173 ++++++++++++++++-- src/main/scala/minisql/ast/JoinType.scala | 24 ++- .../scala/minisql/context/sql/SqlIdiom.scala | 8 +- .../scala/minisql/idiom/MirrorIdiom.scala | 8 +- .../scala/minisql/ast/FromExprsSuite.scala | 2 +- 5 files changed, 190 insertions(+), 25 deletions(-) diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index c6d48cc..aca73db 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -70,7 +70,7 @@ private given FromExpr[Property] with { } => Some(Property(a, n, r, v)) case o => - println(s"Cannot extrat ${o.show}") + println(s"Cannot extract ${o.show}") None } } @@ -82,6 +82,8 @@ private given FromExpr[Ordering] with { case '{ Desc } => Some(Desc) case '{ AscNullsFirst } => Some(AscNullsFirst) case '{ AscNullsLast } => Some(AscNullsLast) + case '{ DescNullsFirst } => Some(DescNullsFirst) + case '{ DescNullsLast } => Some(DescNullsLast) case '{ TupleOrdering($xs) } => xs.value.map(TupleOrdering(_)) } } @@ -135,6 +137,35 @@ private given FromExpr[Query] with { Some(Take(b, n)) case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } => Some(SortBy(b, p, s, o)) + case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } => + Some(GroupBy(b, p, body)) + case '{ Distinct(${ Expr(a) }) } => + Some(Distinct(a)) + case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } => + Some(DistinctOn(q, a, body)) + case '{ Aggregation(${ Expr(op) }, ${ Expr(a) }) } => + Some(Aggregation(op, a)) + case '{ Union(${ Expr(a) }, ${ Expr(b) }) } => + Some(Union(a, b)) + case '{ UnionAll(${ Expr(a) }, ${ Expr(b) }) } => + Some(UnionAll(a, b)) + case '{ + Join( + ${ Expr(t) }, + ${ Expr(a) }, + ${ Expr(b) }, + ${ Expr(ia) }, + ${ Expr(ib) }, + ${ Expr(on) } + ) + } => + Some(Join(t, a, b, ia, ib, on)) + case '{ + FlatJoin(${ Expr(t) }, ${ Expr(a) }, ${ Expr(ia) }, ${ Expr(on) }) + } => + Some(FlatJoin(t, a, ia, on)) + case '{ Nested(${ Expr(a) }) } => + Some(Nested(a)) case o => println(s"Cannot extract ${o.show}") None @@ -153,17 +184,21 @@ private given FromExpr[BinaryOperator] with { case '{ NumericOperator.* } => Some(NumericOperator.*) case '{ NumericOperator./ } => Some(NumericOperator./) case '{ NumericOperator.> } => Some(NumericOperator.>) + case '{ NumericOperator.>= } => Some(NumericOperator.>=) + case '{ NumericOperator.< } => Some(NumericOperator.<) + case '{ NumericOperator.<= } => Some(NumericOperator.<=) + case '{ NumericOperator.% } => Some(NumericOperator.%) case '{ StringOperator.split } => Some(StringOperator.split) case '{ StringOperator.startsWith } => Some(StringOperator.startsWith) case '{ StringOperator.concat } => Some(StringOperator.concat) case '{ BooleanOperator.&& } => Some(BooleanOperator.&&) case '{ BooleanOperator.|| } => Some(BooleanOperator.||) + case '{ SetOperator.contains } => Some(SetOperator.contains) } } } private given FromExpr[UnaryOperator] with { - def unapply(x: Expr[UnaryOperator])(using Quotes): Option[UnaryOperator] = { x match { case '{ BooleanOperator.! } => Some(BooleanOperator.!) @@ -171,6 +206,33 @@ private given FromExpr[UnaryOperator] with { case '{ StringOperator.toLowerCase } => Some(StringOperator.toLowerCase) case '{ StringOperator.toLong } => Some(StringOperator.toLong) case '{ StringOperator.toInt } => Some(StringOperator.toInt) + case '{ NumericOperator.- } => Some(NumericOperator.-) + case '{ SetOperator.nonEmpty } => Some(SetOperator.nonEmpty) + case '{ SetOperator.isEmpty } => Some(SetOperator.isEmpty) + } + } +} + +private given FromExpr[AggregationOperator] with { + def unapply( + x: Expr[AggregationOperator] + )(using Quotes): Option[AggregationOperator] = { + x match { + case '{ AggregationOperator.min } => Some(AggregationOperator.min) + case '{ AggregationOperator.max } => Some(AggregationOperator.max) + case '{ AggregationOperator.avg } => Some(AggregationOperator.avg) + case '{ AggregationOperator.sum } => Some(AggregationOperator.sum) + case '{ AggregationOperator.size } => Some(AggregationOperator.size) + } + } +} + +private given FromExpr[Operator] with { + def unapply(x: Expr[Operator])(using Quotes): Option[Operator] = { + x match { + case '{ $x: BinaryOperator } => x.value + case '{ $x: UnaryOperator } => x.value + case '{ $x: AggregationOperator } => x.value } } } @@ -225,18 +287,21 @@ private given FromExpr[Action] with { extension [A](xs: Seq[Expr[A]]) { private def sequence(using FromExpr[A], Quotes): Option[List[A]] = { - val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) => - for { - _r <- r - _x <- x.value - } yield _r += _x + if (xs.isEmpty) Some(Nil) + else { + val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) => + for { + _r <- r + _x <- x.value + } yield _r += _x + } + acc.map(_.result()) } - acc.map(b => b.result()) } } -private given FromExpr[Constant] with { - def unapply(x: Expr[Constant])(using Quotes): Option[Constant] = { +private given FromExpr[Value] with { + def unapply(x: Expr[Value])(using Quotes): Option[Value] = { import quotes.reflect.{Constant => *, *} x match { case '{ Constant($ce) } => @@ -244,8 +309,92 @@ private given FromExpr[Constant] with { case Literal(v) => Some(Constant(v.value)) } + case '{ NullValue } => + Some(NullValue) + case '{ $x: CaseClass } => x.value } + } +} +private given FromExpr[OptionOperation] with { + def unapply( + x: Expr[OptionOperation] + )(using Quotes): Option[OptionOperation] = { + x match { + case '{ OptionFlatten(${ Expr(ast) }) } => + Some(OptionFlatten(ast)) + case '{ OptionGetOrElse(${ Expr(ast) }, ${ Expr(body) }) } => + Some(OptionGetOrElse(ast, body)) + case '{ + OptionFlatMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionFlatMap(ast, alias, body)) + case '{ OptionMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) } => + Some(OptionMap(ast, alias, body)) + case '{ + OptionForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionForall(ast, alias, body)) + case '{ + OptionExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionExists(ast, alias, body)) + case '{ OptionContains(${ Expr(ast) }, ${ Expr(body) }) } => + Some(OptionContains(ast, body)) + case '{ OptionIsEmpty(${ Expr(ast) }) } => + Some(OptionIsEmpty(ast)) + case '{ OptionNonEmpty(${ Expr(ast) }) } => + Some(OptionNonEmpty(ast)) + case '{ OptionIsDefined(${ Expr(ast) }) } => + Some(OptionIsDefined(ast)) + case '{ + OptionTableFlatMap( + ${ Expr(ast) }, + ${ Expr(alias) }, + ${ Expr(body) } + ) + } => + Some(OptionTableFlatMap(ast, alias, body)) + case '{ + OptionTableMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionTableMap(ast, alias, body)) + case '{ + OptionTableExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionTableExists(ast, alias, body)) + case '{ + OptionTableForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionTableForall(ast, alias, body)) + case '{ OptionNone } => Some(OptionNone) + case '{ OptionSome(${ Expr(ast) }) } => Some(OptionSome(ast)) + case '{ OptionApply(${ Expr(ast) }) } => Some(OptionApply(ast)) + case '{ OptionOrNull(${ Expr(ast) }) } => Some(OptionOrNull(ast)) + case '{ OptionGetOrNull(${ Expr(ast) }) } => Some(OptionGetOrNull(ast)) + case _ => None + } + } +} + +private given FromExpr[CaseClass] with { + def unapply(x: Expr[CaseClass])(using Quotes): Option[CaseClass] = { + import quotes.reflect.* + x match { + case '{ CaseClass(${ Expr(values) }) } => + // Verify the values are properly structured as List[(String, Ast)] + try { + Some(CaseClass(values)) + } catch { + case e: Exception => + report.warning( + s"Failed to extract CaseClass values: ${e.getMessage}", + x.asTerm.pos + ) + None + } + case _ => None + } } } @@ -316,12 +465,14 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] { case '{ $x: Property } => x.value case '{ $x: Ident } => x.value case '{ $x: Tuple } => x.value - case '{ $x: Constant } => x.value + case '{ $x: Value } => x.value case '{ $x: Operation } => x.value case '{ $x: Ordering } => x.value case '{ $x: Action } => x.value case '{ $x: If } => x.value case '{ $x: Infix } => x.value + case '{ $x: CaseClass } => x.value + case '{ $x: OptionOperation } => x.value case o => import quotes.reflect.* report.warning(s"Cannot get value from ${o.show}", o.asTerm.pos) diff --git a/src/main/scala/minisql/ast/JoinType.scala b/src/main/scala/minisql/ast/JoinType.scala index bcb623b..911b4f2 100644 --- a/src/main/scala/minisql/ast/JoinType.scala +++ b/src/main/scala/minisql/ast/JoinType.scala @@ -1,8 +1,22 @@ package minisql.ast -sealed trait JoinType +import scala.quoted.* -case object InnerJoin extends JoinType -case object LeftJoin extends JoinType -case object RightJoin extends JoinType -case object FullJoin extends JoinType +enum JoinType { + case InnerJoin + case LeftJoin + case RightJoin + case FullJoin +} + +object JoinType { + given FromExpr[JoinType] with { + + def unapply(x: Expr[JoinType])(using Quotes): Option[JoinType] = x match { + case '{ JoinType.InnerJoin } => Some(JoinType.InnerJoin) + case '{ JoinType.LeftJoin } => Some(JoinType.LeftJoin) + case '{ JoinType.RightJoin } => Some(JoinType.RightJoin) + case '{ JoinType.FullJoin } => Some(JoinType.FullJoin) + } + } +} diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala index dffd56b..daeee76 100644 --- a/src/main/scala/minisql/context/sql/SqlIdiom.scala +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -346,10 +346,10 @@ trait SqlIdiom extends Idiom { } implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { - case InnerJoin => stmt"INNER JOIN" - case LeftJoin => stmt"LEFT JOIN" - case RightJoin => stmt"RIGHT JOIN" - case FullJoin => stmt"FULL JOIN" + case JoinType.InnerJoin => stmt"INNER JOIN" + case JoinType.LeftJoin => stmt"LEFT JOIN" + case JoinType.RightJoin => stmt"RIGHT JOIN" + case JoinType.FullJoin => stmt"FULL JOIN" } implicit def orderByCriteriaTokenizer(implicit diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala index 1507919..2630288 100644 --- a/src/main/scala/minisql/idiom/MirrorIdiom.scala +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -192,10 +192,10 @@ trait MirrorIdiomBase extends Idiom { } implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { - case InnerJoin => stmt"join" - case LeftJoin => stmt"leftJoin" - case RightJoin => stmt"rightJoin" - case FullJoin => stmt"fullJoin" + case JoinType.InnerJoin => stmt"join" + case JoinType.LeftJoin => stmt"leftJoin" + case JoinType.RightJoin => stmt"rightJoin" + case JoinType.FullJoin => stmt"fullJoin" } implicit def functionTokenizer(implicit diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index ea6d14b..051da57 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -100,7 +100,7 @@ class FromExprsSuite extends FunSuite { testFor("Join") { Join( - InnerJoin, + JoinType.InnerJoin, Ident("a"), Ident("b"), Ident("a1"), From c1f26a0704dc2eda54f39e542f695ac22cfa4018 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 29 Jun 2025 19:15:27 +0800 Subject: [PATCH 20/26] Assert sql --- src/main/scala/minisql/Quoted.scala | 2 +- src/main/scala/minisql/ast/Ast.scala | 15 ++++++++------- src/main/scala/minisql/ast/FromExprs.scala | 1 + src/main/scala/minisql/context/sql/SqlIdiom.scala | 14 ++++++++------ src/main/scala/minisql/idiom/MirrorIdiom.scala | 14 +++++++------- src/test/scala/minisql/ast/FromExprsSuite.scala | 4 ++-- .../scala/minisql/context/sql/QuotedSuite.scala | 10 ++++++---- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 3d4164d..abb5e5f 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -199,7 +199,7 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( .getOrElse(report.errorAndAbort(s"NamingStrategy not known at compile")) val stmt = idiom.translate(ast)(using naming) - report.info(s"Static Query: ${stmt}") + report.info(s"Static Query: ${stmt._2}") Expr(stmt._2) case None => report.info("Dynamic Query") diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 86407e1..42386c7 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -85,13 +85,14 @@ case class SortBy(query: Ast, alias: Ident, criterias: Ast, ordering: Ordering) sealed trait Ordering extends Ast case class TupleOrdering(elems: List[Ordering]) extends Ordering -sealed trait PropertyOrdering extends Ordering -case object Asc extends PropertyOrdering -case object Desc extends PropertyOrdering -case object AscNullsFirst extends PropertyOrdering -case object DescNullsFirst extends PropertyOrdering -case object AscNullsLast extends PropertyOrdering -case object DescNullsLast extends PropertyOrdering +enum PropertyOrdering extends Ordering { + case Asc + case Desc + case AscNullsFirst + case DescNullsFirst + case AscNullsLast + case DescNullsLast +} case class GroupBy(query: Ast, alias: Ident, body: Ast) extends Query diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index aca73db..f6d0192 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -77,6 +77,7 @@ private given FromExpr[Property] with { private given FromExpr[Ordering] with { def unapply(x: Expr[Ordering])(using Quotes): Option[Ordering] = { + import PropertyOrdering.* x match { case '{ Asc } => Some(Asc) case '{ Desc } => Some(Desc) diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala index daeee76..bc10d3d 100644 --- a/src/main/scala/minisql/context/sql/SqlIdiom.scala +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -356,15 +356,17 @@ trait SqlIdiom extends Idiom { astTokenizer: Tokenizer[Ast], strategy: NamingStrategy ): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] { - case OrderByCriteria(ast, Asc) => stmt"${scopedTokenizer(ast)} ASC" - case OrderByCriteria(ast, Desc) => stmt"${scopedTokenizer(ast)} DESC" - case OrderByCriteria(ast, AscNullsFirst) => + case OrderByCriteria(ast, PropertyOrdering.Asc) => + stmt"${scopedTokenizer(ast)} ASC" + case OrderByCriteria(ast, PropertyOrdering.Desc) => + stmt"${scopedTokenizer(ast)} DESC" + case OrderByCriteria(ast, PropertyOrdering.AscNullsFirst) => stmt"${scopedTokenizer(ast)} ASC NULLS FIRST" - case OrderByCriteria(ast, DescNullsFirst) => + case OrderByCriteria(ast, PropertyOrdering.DescNullsFirst) => stmt"${scopedTokenizer(ast)} DESC NULLS FIRST" - case OrderByCriteria(ast, AscNullsLast) => + case OrderByCriteria(ast, PropertyOrdering.AscNullsLast) => stmt"${scopedTokenizer(ast)} ASC NULLS LAST" - case OrderByCriteria(ast, DescNullsLast) => + case OrderByCriteria(ast, PropertyOrdering.DescNullsLast) => stmt"${scopedTokenizer(ast)} DESC NULLS LAST" } diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala index 2630288..9e04134 100644 --- a/src/main/scala/minisql/idiom/MirrorIdiom.scala +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -141,13 +141,13 @@ trait MirrorIdiomBase extends Idiom { } implicit val orderingTokenizer: Tokenizer[Ordering] = Tokenizer[Ordering] { - case TupleOrdering(elems) => stmt"Ord(${elems.token})" - case Asc => stmt"Ord.asc" - case Desc => stmt"Ord.desc" - case AscNullsFirst => stmt"Ord.ascNullsFirst" - case DescNullsFirst => stmt"Ord.descNullsFirst" - case AscNullsLast => stmt"Ord.ascNullsLast" - case DescNullsLast => stmt"Ord.descNullsLast" + case TupleOrdering(elems) => stmt"Ord(${elems.token})" + case PropertyOrdering.Asc => stmt"Ord.asc" + case PropertyOrdering.Desc => stmt"Ord.desc" + case PropertyOrdering.AscNullsFirst => stmt"Ord.ascNullsFirst" + case PropertyOrdering.DescNullsFirst => stmt"Ord.descNullsFirst" + case PropertyOrdering.AscNullsLast => stmt"Ord.ascNullsLast" + case PropertyOrdering.DescNullsLast => stmt"Ord.descNullsLast" } implicit def optionOperationTokenizer(implicit diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index 051da57..bbc6973 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -44,11 +44,11 @@ class FromExprsSuite extends FunSuite { } testFor("Ordering") { - Asc + PropertyOrdering.Asc } testFor("TupleOrdering") { - TupleOrdering(List(Asc, Desc)) + TupleOrdering(List(PropertyOrdering.Asc, PropertyOrdering.Desc)) } testFor("Entity") { diff --git a/src/test/scala/minisql/context/sql/QuotedSuite.scala b/src/test/scala/minisql/context/sql/QuotedSuite.scala index c5fda24..4ef6d67 100644 --- a/src/test/scala/minisql/context/sql/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/QuotedSuite.scala @@ -20,16 +20,18 @@ class QuotedSuite extends munit.FunSuite { query[Foo]( "foo", alias("id", "id1") - ).filter(_.id > 0) + ).filter(x => x.id > 0) ) - println(o) + assertEquals(o.sql, "SELECT x.id1, x.name FROM foo x WHERE x.id1 > 0") } test("Insert") { val v: Foo = Foo(0L, "foo") val o = testContext.io(Foos.insert(v)) - println(o) - ??? + assertEquals( + o.sql, + "INSERT INTO foo (id,name) VALUES (?, ?)" + ) } } From adc60400a782346b09663b882edd8f7d1cc900fb Mon Sep 17 00:00:00 2001 From: jilen Date: Mon, 30 Jun 2025 19:33:39 +0800 Subject: [PATCH 21/26] unified extractTerm --- src/main/scala/minisql/Quoted.scala | 57 ++++++++++++++++++- src/main/scala/minisql/ast/FromExprs.scala | 23 +------- src/main/scala/minisql/parsing/Parser.scala | 1 + src/main/scala/minisql/parsing/Parsing.scala | 20 +------ .../minisql/parsing/PatMatchParsing.scala | 11 +--- .../scala/minisql/parsing/ValueParsing.scala | 1 + .../scala/minisql/util/QuotesHelper.scala | 17 ++++++ .../minisql/context/sql/QuotedSuite.scala | 9 +++ 8 files changed, 86 insertions(+), 53 deletions(-) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index abb5e5f..3440061 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -12,6 +12,8 @@ import minisql.ast.{ Ident, Filter, PropertyAlias, + JoinType, + Join, given } import scala.quoted.* @@ -28,6 +30,54 @@ opaque type Action[E] <: Quoted = Quoted opaque type Insert <: Action[Long] = Quoted +sealed trait Joined[E1, E2] + +opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted + +object Joined { + + def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] = + new Joined[E1, E2] {} + + extension [E1, E2](inline j: Joined[E1, E2]) { + inline def on(inline f: (E1, E2) => Boolean): JoinQuery[E1, E2] = + joinOn(j, f) + } +} + +private inline def joinOn[E1, E2]( + inline j: Joined[E1, E2], + inline f: (E1, E2) => Boolean +): JoinQuery[E1, E2] = j.toJoinQuery(f.param0, f.param1, f.body) + +extension [E1, E2](inline j: Joined[E1, E2]) { + private inline def toJoinQuery( + inline aliasA: Ident, + inline aliasB: Ident, + inline on: Ast + ): Ast = ${ joinQueryOf('j, 'aliasA, 'aliasB, 'on) } +} + +private def joinQueryOf[E1, E2]( + x: Expr[Joined[E1, E2]], + aliasA: Expr[Ident], + aliasB: Expr[Ident], + on: Expr[Ast] +)(using Quotes, Type[E1], Type[E2]): Expr[Join] = { + import quotes.reflect.* + extractTerm(x.asTerm).asExpr match { + case '{ + Joined[E1, E2]($jt, $a, $b) + } => + '{ + Join($jt, $a, $b, $aliasA, $aliasB, $on) + } + case o => + println("====================---" + o.show) + throw new Exception(s"Fail") + } +} + private inline def quotedLift[X](x: X)(using e: ParamEncoder[X] ): ast.ScalarValueLift = ${ @@ -52,14 +102,15 @@ private def quotedLiftImpl[X: Type]( object Query { - private[minisql] inline def apply[E](inline ast: Ast): Query[E] = ast - extension [E](inline e: Query[E]) { private[minisql] inline def expanded: Query[E] = { - Query(expandFields[E](e)) + expandFields[E](e) } + inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, E1] = + Joined[E, E1](JoinType.LeftJoin, e, e1) + inline def map[E1](inline f: E => E1): Query[E1] = { transform(e)(f)(Map.apply) } diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index f6d0192..8029af8 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -69,9 +69,6 @@ private given FromExpr[Property] with { ) } => Some(Property(a, n, r, v)) - case o => - println(s"Cannot extract ${o.show}") - None } } @@ -168,7 +165,7 @@ private given FromExpr[Query] with { case '{ Nested(${ Expr(a) }) } => Some(Nested(a)) case o => - println(s"Cannot extract ${o.show}") + // println(s"Cannot extract ${o.show}") None } } @@ -406,23 +403,6 @@ private given FromExpr[If] with { } } -private def extractTerm(using Quotes)(x: quotes.reflect.Term) = { - import quotes.reflect.* - def unwrapTerm(t: Term): Term = t match { - case Inlined(_, _, o) => unwrapTerm(o) - case Block(Nil, last) => last - case Typed(t, _) => - unwrapTerm(t) - case Select(t, "$asInstanceOf$") => - unwrapTerm(t) - case TypeApply(t, _) => - unwrapTerm(t) - case o => o - } - val o = unwrapTerm(x) - o -} - extension (e: Expr[Any]) { private def toTerm(using Quotes) = { import quotes.reflect.* @@ -434,7 +414,6 @@ extension (e: Expr[Any]) { private def fromBlock(using Quotes )(block: quotes.reflect.Block): Option[Ast] = { - println(s"Show block ${block.show}") import quotes.reflect.* val empty: Option[List[Ast]] = Some(Nil) val stmts = block.statements.foldLeft(empty) { (r, stmt) => diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala index 91bfbc0..4235398 100644 --- a/src/main/scala/minisql/parsing/Parser.scala +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -3,6 +3,7 @@ package minisql.parsing import minisql.ast import minisql.ast.Ast import scala.quoted.* +import minisql.util.* private[minisql] inline def parseParamAt[F]( inline f: F, diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala index 07da46a..370133b 100644 --- a/src/main/scala/minisql/parsing/Parsing.scala +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -9,7 +9,7 @@ import scala.annotation.tailrec import minisql.ast.Implicits._ import minisql.ast.Renameable.Fixed import minisql.ast.Visibility.{Hidden, Visible} -import minisql.util.Interleave +import minisql.util.{Interleave, extractTerm} import scala.quoted.* type Parser[A] = PartialFunction[Expr[Any], Expr[A]] @@ -29,22 +29,6 @@ private def parser[A]( case e if f.isDefinedAt(e) => f(e) } -private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { - import quotes.reflect.* - def unwrapTerm(t: Term): Term = t match { - case Inlined(_, _, o) => unwrapTerm(o) - case Block(Nil, last) => last - case Typed(t, _) => - unwrapTerm(t) - case Select(t, "$asInstanceOf$") => - unwrapTerm(t) - case TypeApply(t, _) => - unwrapTerm(t) - case o => o - } - unwrapTerm(x) -} - private[minisql] object Parsing { def parseExpr( @@ -103,7 +87,7 @@ private[minisql] object Parsing { } lazy val identParser: Parser[ast.Ident] = termParser { - case x @ Ident(n) if x.symbol.isValDef => + case x @ Ident(n) => '{ ast.Ident(${ Expr(n) }) } } diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala index 2db7652..b8bd924 100644 --- a/src/main/scala/minisql/parsing/PatMatchParsing.scala +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -16,16 +16,7 @@ private[parsing] def patMatchParsing( case (Bind(n, ident), idx) => n -> Select.unique(t, s"_${idx + 1}") }.toMap - val tm = new TreeMap { - override def transformTerm(tree: Term)(owner: Symbol): Term = { - tree match { - case Ident(n) => bm(n) - case o => super.transformTerm(o)(owner) - } - } - } - val newBody = tm.transformTree(body)(e.symbol) - astParser(newBody.asExpr) + blockParsing(astParser)(body.asExpr) } } diff --git a/src/main/scala/minisql/parsing/ValueParsing.scala b/src/main/scala/minisql/parsing/ValueParsing.scala index 6c2fb9e..e7d4ae9 100644 --- a/src/main/scala/minisql/parsing/ValueParsing.scala +++ b/src/main/scala/minisql/parsing/ValueParsing.scala @@ -2,6 +2,7 @@ package minisql package parsing import scala.quoted._ +import minisql.util.* private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using Quotes diff --git a/src/main/scala/minisql/util/QuotesHelper.scala b/src/main/scala/minisql/util/QuotesHelper.scala index 6ecbc76..fe93aa7 100644 --- a/src/main/scala/minisql/util/QuotesHelper.scala +++ b/src/main/scala/minisql/util/QuotesHelper.scala @@ -22,3 +22,20 @@ private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = { val fileName = pos.sourceFile.name s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}" } + +private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { + import quotes.reflect.* + def unwrapTerm(t: Term): Term = t match { + case Inlined(_, _, o) => unwrapTerm(o) + case Block(Nil, last) => last + case Typed(t, _) => + unwrapTerm(t) + case Select(t, "$asInstanceOf$") => + unwrapTerm(t) + case TypeApply(t, _) => + unwrapTerm(t) + case o => o + } + val o = unwrapTerm(x) + o +} diff --git a/src/test/scala/minisql/context/sql/QuotedSuite.scala b/src/test/scala/minisql/context/sql/QuotedSuite.scala index 4ef6d67..17ea26b 100644 --- a/src/test/scala/minisql/context/sql/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/QuotedSuite.scala @@ -34,4 +34,13 @@ class QuotedSuite extends munit.FunSuite { "INSERT INTO foo (id,name) VALUES (?, ?)" ) } + + test("LeftJoin") { + val o = testContext + .io(Foos.leftJoin(Foos).on((f1, f2) => f1.id == f2.id).map { + case (f1, f2) => (f1.id, f2.id) + }) + + println(o) + } } From f5e43657b39e248213e622bb52c536570810e2fe Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 2 Jul 2025 10:33:47 +0800 Subject: [PATCH 22/26] =?UTF-8?q?=20=20=E5=A2=9E=E5=8A=A0=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=20`case=20(x,=20y)=20=3D>`=20=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- build.sbt | 3 +- project/build.properties | 2 +- src/main/scala/minisql/Quoted.scala | 3 -- src/main/scala/minisql/ast/FromExprs.scala | 43 +++++++------------ .../scala/minisql/parsing/BlockParsing.scala | 19 ++++++++ .../scala/minisql/parsing/InfixParsing.scala | 11 ----- src/main/scala/minisql/parsing/Parsing.scala | 3 -- .../minisql/parsing/PatMatchParsing.scala | 22 +++++++--- .../scala/minisql/ast/FromExprsSuite.scala | 10 +++++ 9 files changed, 63 insertions(+), 53 deletions(-) diff --git a/build.sbt b/build.sbt index d869492..509877e 100644 --- a/build.sbt +++ b/build.sbt @@ -3,10 +3,9 @@ name := "minisql" scalaVersion := "3.7.1" libraryDependencies ++= Seq( - "org.scalameta" %% "munit" % "1.0.3" % Test + "org.scalameta" %% "munit" % "1.1.1" % Test ) -javaOptions ++= Seq("-Xss16m") scalacOptions ++= Seq( "-deprecation", diff --git a/project/build.properties b/project/build.properties index e97b272..bbb0b60 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.10.10 +sbt.version=1.11.2 diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 3440061..265927c 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -72,9 +72,6 @@ private def joinQueryOf[E1, E2]( '{ Join($jt, $a, $b, $aliasA, $aliasB, $on) } - case o => - println("====================---" + o.show) - throw new Exception(s"Fail") } } diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 8029af8..072cf21 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -164,9 +164,6 @@ private given FromExpr[Query] with { Some(FlatJoin(t, a, ia, on)) case '{ Nested(${ Expr(a) }) } => Some(Nested(a)) - case o => - // println(s"Cannot extract ${o.show}") - None } } @@ -403,6 +400,20 @@ private given FromExpr[If] with { } } +private given FromExpr[Block] with { + def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match { + case '{ Block(${ Expr(statements) }) } => + Some(Block(statements)) + } +} + +private given FromExpr[Val] with { + def unapply(x: Expr[Val])(using Quotes): Option[Val] = x match { + case '{ Val(${ Expr(n) }, ${ Expr(b) }) } => + Some(Val(n, b)) + } +} + extension (e: Expr[Any]) { private def toTerm(using Quotes) = { import quotes.reflect.* @@ -411,30 +422,6 @@ extension (e: Expr[Any]) { } -private def fromBlock(using - Quotes -)(block: quotes.reflect.Block): Option[Ast] = { - import quotes.reflect.* - val empty: Option[List[Ast]] = Some(Nil) - val stmts = block.statements.foldLeft(empty) { (r, stmt) => - stmt match { - case ValDef(n, _, Some(body)) => - r.flatMap { astList => - body.asExprOf[Ast].value.map { v => - astList :+ v - } - } - case o => - None - } - } - stmts.flatMap { stmts => - block.expr.asExprOf[Ast].value.map { last => - minisql.ast.Block(stmts :+ last) - } - } -} - given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] { def unapply(e: Expr[Ast])(using Quotes): Option[Ast] = { @@ -444,6 +431,7 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] { case '{ $x: ScalarValueLift } => x.value case '{ $x: Property } => x.value case '{ $x: Ident } => x.value + case '{ $x: Val } => x.value case '{ $x: Tuple } => x.value case '{ $x: Value } => x.value case '{ $x: Operation } => x.value @@ -453,6 +441,7 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] { case '{ $x: Infix } => x.value case '{ $x: CaseClass } => x.value case '{ $x: OptionOperation } => x.value + case '{ $x: Block } => x.value case o => import quotes.reflect.* report.warning(s"Cannot get value from ${o.show}", o.asTerm.pos) diff --git a/src/main/scala/minisql/parsing/BlockParsing.scala b/src/main/scala/minisql/parsing/BlockParsing.scala index ae8722c..02475e4 100644 --- a/src/main/scala/minisql/parsing/BlockParsing.scala +++ b/src/main/scala/minisql/parsing/BlockParsing.scala @@ -22,6 +22,25 @@ private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using valDefParser } +private[parsing] def parseBlockList( + astParser: => Parser[ast.Ast], + e: Expr[Any] +)(using Quotes): List[Expr[ast.Ast]] = { + import quotes.reflect.* + + lazy val statementParser = statementParsing(astParser) + + e.asTerm match { + case Block(st, t) => + (st :+ t).map { + case e if e.isExpr => astParser(e.asExpr) + case `statementParser`(x) => x + case o => + report.errorAndAbort(s"Cannot parse statement: ${o.show}") + } + } +} + private[parsing] def blockParsing( astParser: => Parser[ast.Ast] )(using Quotes): Parser[ast.Ast] = { diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala index 3bcdfc5..8b13789 100644 --- a/src/main/scala/minisql/parsing/InfixParsing.scala +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -1,12 +1 @@ -package minisql.parsing -import minisql.ast -import scala.quoted.* - -private[parsing] def infixParsing( - astParser: => Parser[ast.Ast] -)(using Quotes): Parser[ast.Infix] = { - - import quotes.reflect.* - ??? -} diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala index 370133b..f044292 100644 --- a/src/main/scala/minisql/parsing/Parsing.scala +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -59,7 +59,6 @@ private[minisql] object Parsing { .orElse(ifParser) .orElse(traversableOperationParser) .orElse(patMatchParser) - // .orElse(infixParser) .orElse { case o => val str = scala.util.Try(o.show).getOrElse("") @@ -106,8 +105,6 @@ private[minisql] object Parsing { } lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) - // lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser) - lazy val traversableOperationParser: Parser[ast.IterableOperation] = traversableOperationParsing(astParser) diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala index b8bd924..9a7686c 100644 --- a/src/main/scala/minisql/parsing/PatMatchParsing.scala +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -11,12 +11,22 @@ private[parsing] def patMatchParsing( termParser { // Val defs that showd pattern variables will cause error - case e @ Match(t, List(CaseDef(IsTupleUnapply(binds), None, body))) => - val bm = binds.zipWithIndex.map { - case (Bind(n, ident), idx) => - n -> Select.unique(t, s"_${idx + 1}") - }.toMap - blockParsing(astParser)(body.asExpr) + case e @ Match( + Ident(t), + List(CaseDef(IsTupleUnapply(binds), None, body)) + ) => + val bindStmts = binds.map { + case Bind(bn, _) => + '{ + ast.Val( + ast.Ident(${ Expr(bn) }), + ast.Property(ast.Ident(${ Expr(t) }), "_1") + ) + } + } + + val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr) + '{ ast.Block(${ Expr.ofList(allStmts.toList) }) } } } diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index bbc6973..4e6d8c9 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -124,4 +124,14 @@ class FromExprsSuite extends FunSuite { testFor("CaseClass") { CaseClass(List(("name", Ident("value")))) } + + testFor("Block") { // Also tested Val + Block( + List( + Val(Ident("x"), Constant(1)), + Val(Ident("y"), Constant(2)), + BinaryOperation(Ident("x"), NumericOperator.+, Ident("y")) + ) + ) + } } From 48cb1003bbcbcd7af7266bf9d9cfa0ac710ddc7f Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 2 Jul 2025 12:07:54 +0800 Subject: [PATCH 23/26] =?UTF-8?q?=E6=94=AF=E6=8C=81Infix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/scala/minisql/SqlInfix.scala | 12 ++++++++ .../scala/minisql/parsing/InfixParsing.scala | 28 +++++++++++++++++++ src/main/scala/minisql/parsing/Parsing.scala | 11 ++++++-- .../scala/minisql/ast/FromExprsSuite.scala | 9 ++++++ .../minisql/context/sql/QuotedSuite.scala | 7 +++++ 5 files changed, 65 insertions(+), 2 deletions(-) create mode 100644 src/main/scala/minisql/SqlInfix.scala diff --git a/src/main/scala/minisql/SqlInfix.scala b/src/main/scala/minisql/SqlInfix.scala new file mode 100644 index 0000000..ea5fe6a --- /dev/null +++ b/src/main/scala/minisql/SqlInfix.scala @@ -0,0 +1,12 @@ +package minisql + +import minisql.ast.Ast +import scala.quoted.* + +sealed trait InfixValue { + def as[T]: T +} + +extension (sc: StringContext) { + def infix(args: Any*): InfixValue = throw NonQuotedException() +} diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala index 8b13789..1b19cad 100644 --- a/src/main/scala/minisql/parsing/InfixParsing.scala +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -1 +1,29 @@ +package minisql.parsing +import minisql.ast +import scala.quoted.* + +private[parsing] def infixParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Infix] = { + import quotes.reflect.* + { + case '{ ($x: minisql.InfixValue).as[t] } => infixParsing(astParser)(x) + case '{ + minisql.infix(StringContext(${ Varargs(partsExprs) }*))(${ + Varargs(argsExprs) + }*) + } => + val parts = partsExprs.map { p => + p.value.getOrElse( + report.errorAndAbort( + s"Expected a string literal in StringContext parts, but got: ${p.show}" + ) + ) + }.toList + + val params = argsExprs.map(arg => astParser(arg)).toList + + '{ ast.Infix(${ Expr(parts) }, ${ Expr.ofList(params) }, true, false) } + } +} diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala index f044292..8794e7f 100644 --- a/src/main/scala/minisql/parsing/Parsing.scala +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -41,8 +41,10 @@ private[minisql] object Parsing { f: Parser[ast.Ast] ): Parser[ast.Ast] = { case expr => - val t = expr.asTerm - f(extractTerm(t).asExpr) + val t = extractTerm(expr.asTerm) + if (t.isExpr) + f(t.asExpr) + else f(expr) } lazy val astParser: Parser[ast.Ast] = @@ -50,6 +52,7 @@ private[minisql] object Parsing { typedParser .orElse(propertyParser) .orElse(liftParser) + .orElse(infixParser) .orElse(identParser) .orElse(valueParser) .orElse(operationParser) @@ -108,6 +111,10 @@ private[minisql] object Parsing { lazy val traversableOperationParser: Parser[ast.IterableOperation] = traversableOperationParsing(astParser) + lazy val infixParser: Parser[ast.Infix] = infixParsing( + astParser + ) + astParser(expr) } diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index 4e6d8c9..8820c12 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -86,6 +86,15 @@ class FromExprsSuite extends FunSuite { ) } + testFor("Infix with different parameters") { + Infix( + List("?", " + ", "?"), + List(Constant(1), Constant(2)), + pure = true, + noParen = true + ) + } + testFor("OptionOperation - OptionMap") { OptionMap(Ident("opt"), Ident("x"), Ident("x")) } diff --git a/src/test/scala/minisql/context/sql/QuotedSuite.scala b/src/test/scala/minisql/context/sql/QuotedSuite.scala index 17ea26b..467a10c 100644 --- a/src/test/scala/minisql/context/sql/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/QuotedSuite.scala @@ -43,4 +43,11 @@ class QuotedSuite extends munit.FunSuite { println(o) } + + test("Infix string interpolation") { + val o = testContext.io( + Foos.map(f => infix"CONCAT(${f.name}, ' ', ${f.id})".as[String]) + ) + assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f") + } } From 06850823d7efdb547ddabb272cad0ee6c856821e Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 2 Jul 2025 13:54:53 +0800 Subject: [PATCH 24/26] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9B=B4=E5=A4=9Ajoin?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/scala/minisql/Quoted.scala | 15 +++++++++++++-- ...tedSuite.scala => MirrorSqlContextSuite.scala} | 4 ++-- 2 files changed, 15 insertions(+), 4 deletions(-) rename src/test/scala/minisql/context/sql/{QuotedSuite.scala => MirrorSqlContextSuite.scala} (90%) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 265927c..1efc593 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -105,8 +105,19 @@ object Query { expandFields[E](e) } - inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, E1] = - Joined[E, E1](JoinType.LeftJoin, e, e1) + inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, Option[E1]] = + Joined[E, Option[E1]](JoinType.LeftJoin, e, e1) + + inline def rightJoin[E1](inline e1: Query[E1]): Joined[Option[E], E1] = + Joined[Option[E], E1](JoinType.RightJoin, e, e1) + + inline def join[E1](inline e1: Query[E1]): Joined[E, E1] = + Joined[E, E1](JoinType.InnerJoin, e, e1) + + inline def fullJoin[E1]( + inline e1: Query[E1] + ): Joined[Option[E], Option[E1]] = + Joined[Option[E], Option[E1]](JoinType.FullJoin, e, e1) inline def map[E1](inline f: E => E1): Query[E1] = { transform(e)(f)(Map.apply) diff --git a/src/test/scala/minisql/context/sql/QuotedSuite.scala b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala similarity index 90% rename from src/test/scala/minisql/context/sql/QuotedSuite.scala rename to src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala index 467a10c..ff26942 100644 --- a/src/test/scala/minisql/context/sql/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala @@ -7,7 +7,7 @@ import minisql.NamingStrategy import minisql.MirrorContext import minisql.context.mirror.{*, given} -class QuotedSuite extends munit.FunSuite { +class MirrorSqlContextSuite extends munit.FunSuite { case class Foo(id: Long, name: String) @@ -37,7 +37,7 @@ class QuotedSuite extends munit.FunSuite { test("LeftJoin") { val o = testContext - .io(Foos.leftJoin(Foos).on((f1, f2) => f1.id == f2.id).map { + .io(Foos.join(Foos).on((f1, f2) => f1.id == f2.id).map { case (f1, f2) => (f1.id, f2.id) }) From ed1952b91522de1e9225cd306d71b5fc0ddd19b2 Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 2 Jul 2025 15:39:42 +0800 Subject: [PATCH 25/26] =?UTF-8?q?=E5=A2=9E=E5=8A=A0returningGenerated?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/scala/minisql/Quoted.scala | 26 ++++++++++++++----- src/main/scala/minisql/ast/FromExprs.scala | 19 +++++++++++++- .../scala/minisql/ast/FromExprsSuite.scala | 16 ++++++++++++ .../context/sql/MirrorSqlContextSuite.scala | 10 +++++++ 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 1efc593..45376f1 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -28,7 +28,23 @@ opaque type Query[E] <: Quoted = Quoted opaque type Action[E] <: Quoted = Quoted -opaque type Insert <: Action[Long] = Quoted +opaque type Insert[E] <: Action[Long] = Quoted + +object Insert { + extension [E](inline insert: Insert[E]) { + inline def returning[E1](inline f: E => E1): InsertReturning[E1] = { + transform(insert)(f)(ast.Returning.apply) + } + + inline def returningGenerated[E1]( + inline f: E => E1 + ): InsertReturning[E1] = { + transform(insert)(f)(ast.ReturningGenerated.apply) + } + } +} + +opaque type InsertReturning[E] <: Action[E] = Quoted sealed trait Joined[E1, E2] @@ -66,9 +82,7 @@ private def joinQueryOf[E1, E2]( )(using Quotes, Type[E1], Type[E2]): Expr[Join] = { import quotes.reflect.* extractTerm(x.asTerm).asExpr match { - case '{ - Joined[E1, E2]($jt, $a, $b) - } => + case '{ Joined[E1, E2]($jt, $a, $b) } => '{ Join($jt, $a, $b, $aliasA, $aliasB, $on) } @@ -148,7 +162,7 @@ object EntityQuery { transform(e)(f)(Filter.apply) } - inline def insert(v: E)(using m: Mirror.ProductOf[E]): Insert = { + inline def insert(v: E): Insert[E] = { ast.Insert(e, transformCaseClassToAssignments[E](v)) } } @@ -156,7 +170,7 @@ object EntityQuery { private inline def transformCaseClassToAssignments[E]( v: E -)(using m: Mirror.ProductOf[E]): List[ast.Assignment] = ${ +): List[ast.Assignment] = ${ transformCaseClassToAssignmentsImpl[E]('v) } diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 072cf21..08686b3 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -270,12 +270,29 @@ private given FromExpr[Action] with { ass.sequence.map { ass1 => Update(a, ass1) } - case '{ Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) } => + case '{ + Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) + } => + Some(Returning(act, id, body)) + case '{ + val x: Ast = ${ Expr(act) } + val y: Ident = ${ Expr(id) } + val z: Ast = ${ Expr(body) } + Returning(x, y, z) + } => Some(Returning(act, id, body)) case '{ ReturningGenerated(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) } => Some(ReturningGenerated(act, id, body)) + case '{ + val x: Ast = ${ Expr(act) } + val y: Ident = ${ Expr(id) } + val z: Ast = ${ Expr(body) } + ReturningGenerated(x, y, z) + } => + Some(ReturningGenerated(act, id, body)) + } } } diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index 8820c12..9cab3b1 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -73,6 +73,22 @@ class FromExprsSuite extends FunSuite { ) } + testFor("Action - Returning") { + Returning( + Insert(Ident("table"), List(Assignment(Ident("x"), Ident("col"), Ident("val")))), + Ident("x"), + Property(Ident("x"), "id") + ) + } + + testFor("Action - ReturningGenerated") { + ReturningGenerated( + Insert(Ident("table"), List(Assignment(Ident("x"), Ident("col"), Ident("val")))), + Ident("x"), + Property(Ident("x"), "generatedId") + ) + } + testFor("If expression") { If(Ident("cond"), Ident("then"), Ident("else")) } diff --git a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala index ff26942..0b79e14 100644 --- a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala +++ b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala @@ -35,6 +35,16 @@ class MirrorSqlContextSuite extends munit.FunSuite { ) } + test("InsertReturningGenerated") { + val v: Foo = Foo(0L, "foo") + + val o = testContext.io(Foos.insert(v).returningGenerated(_.id)) + assertEquals( + o.sql, + "INSERT INTO foo (name) VALUES (?) RETURNING id" + ) + } + test("LeftJoin") { val o = testContext .io(Foos.join(Foos).on((f1, f2) => f1.id == f2.id).map { From 071b27abcf9bc2afce2ce07ba163f87e6f059477 Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 2 Jul 2025 15:46:40 +0800 Subject: [PATCH 26/26] Better test case --- .../scala/minisql/ast/FromExprsSuite.scala | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index 9cab3b1..70a0d6d 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -75,7 +75,10 @@ class FromExprsSuite extends FunSuite { testFor("Action - Returning") { Returning( - Insert(Ident("table"), List(Assignment(Ident("x"), Ident("col"), Ident("val")))), + Insert( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ), Ident("x"), Property(Ident("x"), "id") ) @@ -83,12 +86,35 @@ class FromExprsSuite extends FunSuite { testFor("Action - ReturningGenerated") { ReturningGenerated( - Insert(Ident("table"), List(Assignment(Ident("x"), Ident("col"), Ident("val")))), + Insert( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ), Ident("x"), Property(Ident("x"), "generatedId") ) } + testFor("Action - Val outside") { + val p1 = Update( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ) + val p2 = Ident("x") + val p3 = Property(Ident("x"), "id") + Returning(p1, p2, p3) + } + + testFor("Action - ReturningGenerated with Update") { + val p1 = Update( + Ident("table"), + List(Assignment(Ident("x"), Ident("col"), Ident("val"))) + ) + val p2 = Ident("x") + val p3 = Property(Ident("x"), "id") + ReturningGenerated(p1, p2, p3) + } + testFor("If expression") { If(Ident("cond"), Ident("then"), Ident("else")) }