From 8103d4517890be42196eac54d012377251e39c00 Mon Sep 17 00:00:00 2001 From: jilen Date: Sun, 15 Dec 2024 20:51:38 +0800 Subject: [PATCH] try parsing function body --- src/main/scala/minisql/Parser.scala | 127 ----- src/main/scala/minisql/ReturnAction.scala | 7 + src/main/scala/minisql/ast/Ast.scala | 13 +- src/main/scala/minisql/ast/AstOps.scala | 94 ++++ src/main/scala/minisql/ast/FromExprs.scala | 5 +- .../context/ReturnFieldCapability.scala | 64 +++ src/main/scala/minisql/dsl.scala | 18 +- src/main/scala/minisql/idiom/Idiom.scala | 23 + src/main/scala/minisql/idiom/LoadNaming.scala | 28 + .../scala/minisql/idiom/ReifyStatement.scala | 68 +++ src/main/scala/minisql/idiom/Statement.scala | 47 ++ .../minisql/idiom/StatementInterpolator.scala | 146 ++++++ .../scala/minisql/norm/AdHocReduction.scala | 52 ++ src/main/scala/minisql/norm/ApplyMap.scala | 160 ++++++ .../scala/minisql/norm/AttachToEntity.scala | 48 ++ .../{util => norm}/BetaReduction.scala | 4 +- .../scala/minisql/norm/ConcatBehavior.scala | 7 + .../scala/minisql/norm/EqualityBehavior.scala | 7 + .../scala/minisql/norm/ExpandReturning.scala | 74 +++ .../minisql/norm/FlattenOptionOperation.scala | 108 ++++ .../minisql/norm/NestImpureMappedInfix.scala | 76 +++ src/main/scala/minisql/norm/Normalize.scala | 51 ++ .../norm/NormalizeAggregationIdent.scala | 29 ++ .../norm/NormalizeNestedStructures.scala | 47 ++ .../minisql/norm/NormalizeReturning.scala | 154 ++++++ src/main/scala/minisql/norm/OrderTerms.scala | 29 ++ .../scala/minisql/norm/RenameProperties.scala | 491 ++++++++++++++++++ .../minisql/{util => norm}/Replacements.scala | 2 +- .../minisql/norm/SimplifyNullChecks.scala | 124 +++++ .../minisql/norm/SymbolicReduction.scala | 38 ++ .../norm/capture/AvoidAliasConflict.scala | 174 +++++++ .../minisql/norm/capture/AvoidCapture.scala | 9 + .../scala/minisql/norm/capture/Dealias.scala | 72 +++ .../capture/DemarcateExternalAliases.scala | 98 ++++ .../scala/minisql/parsing/BlockParsing.scala | 47 ++ .../scala/minisql/parsing/BoxingParsing.scala | 31 ++ .../scala/minisql/parsing/InfixParsing.scala | 13 + .../scala/minisql/parsing/LiftParsing.scala | 16 + .../minisql/parsing/OperationParsing.scala | 113 ++++ src/main/scala/minisql/parsing/Parser.scala | 47 ++ src/main/scala/minisql/parsing/Parsing.scala | 139 +++++ .../minisql/parsing/PatMatchParsing.scala | 49 ++ .../minisql/parsing/PropertyParsing.scala | 30 ++ .../parsing/TraversableOperationParsing.scala | 16 + .../scala/minisql/parsing/ValueParsing.scala | 71 +++ .../scala/minisql/util/Interpolator.scala | 15 +- src/main/scala/minisql/util/Message.scala | 75 +++ 47 files changed, 3000 insertions(+), 156 deletions(-) delete mode 100644 src/main/scala/minisql/Parser.scala create mode 100644 src/main/scala/minisql/ReturnAction.scala create mode 100644 src/main/scala/minisql/ast/AstOps.scala create mode 100644 src/main/scala/minisql/context/ReturnFieldCapability.scala create mode 100644 src/main/scala/minisql/idiom/Idiom.scala create mode 100644 src/main/scala/minisql/idiom/LoadNaming.scala create mode 100644 src/main/scala/minisql/idiom/ReifyStatement.scala create mode 100644 src/main/scala/minisql/idiom/Statement.scala create mode 100644 src/main/scala/minisql/idiom/StatementInterpolator.scala create mode 100644 src/main/scala/minisql/norm/AdHocReduction.scala create mode 100644 src/main/scala/minisql/norm/ApplyMap.scala create mode 100644 src/main/scala/minisql/norm/AttachToEntity.scala rename src/main/scala/minisql/{util => norm}/BetaReduction.scala (99%) create mode 100644 src/main/scala/minisql/norm/ConcatBehavior.scala create mode 100644 src/main/scala/minisql/norm/EqualityBehavior.scala create mode 100644 src/main/scala/minisql/norm/ExpandReturning.scala create mode 100644 src/main/scala/minisql/norm/FlattenOptionOperation.scala create mode 100644 src/main/scala/minisql/norm/NestImpureMappedInfix.scala create mode 100644 src/main/scala/minisql/norm/Normalize.scala create mode 100644 src/main/scala/minisql/norm/NormalizeAggregationIdent.scala create mode 100644 src/main/scala/minisql/norm/NormalizeNestedStructures.scala create mode 100644 src/main/scala/minisql/norm/NormalizeReturning.scala create mode 100644 src/main/scala/minisql/norm/OrderTerms.scala create mode 100644 src/main/scala/minisql/norm/RenameProperties.scala rename src/main/scala/minisql/{util => norm}/Replacements.scala (98%) create mode 100644 src/main/scala/minisql/norm/SimplifyNullChecks.scala create mode 100644 src/main/scala/minisql/norm/SymbolicReduction.scala create mode 100644 src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala create mode 100644 src/main/scala/minisql/norm/capture/AvoidCapture.scala create mode 100644 src/main/scala/minisql/norm/capture/Dealias.scala create mode 100644 src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala create mode 100644 src/main/scala/minisql/parsing/BlockParsing.scala create mode 100644 src/main/scala/minisql/parsing/BoxingParsing.scala create mode 100644 src/main/scala/minisql/parsing/InfixParsing.scala create mode 100644 src/main/scala/minisql/parsing/LiftParsing.scala create mode 100644 src/main/scala/minisql/parsing/OperationParsing.scala create mode 100644 src/main/scala/minisql/parsing/Parser.scala create mode 100644 src/main/scala/minisql/parsing/Parsing.scala create mode 100644 src/main/scala/minisql/parsing/PatMatchParsing.scala create mode 100644 src/main/scala/minisql/parsing/PropertyParsing.scala create mode 100644 src/main/scala/minisql/parsing/TraversableOperationParsing.scala create mode 100644 src/main/scala/minisql/parsing/ValueParsing.scala create mode 100644 src/main/scala/minisql/util/Message.scala diff --git a/src/main/scala/minisql/Parser.scala b/src/main/scala/minisql/Parser.scala deleted file mode 100644 index 3b21d90..0000000 --- a/src/main/scala/minisql/Parser.scala +++ /dev/null @@ -1,127 +0,0 @@ -package minisql.parsing - -import minisql.ast -import minisql.ast.Ast -import scala.quoted.* - -type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]] - -private[minisql] inline def parseParamAt[F]( - inline f: F, - inline n: Int -): ast.Ident = ${ - parseParamAt('f, 'n) -} - -private[minisql] inline def parseBody[X]( - inline f: X -): ast.Ast = ${ - parseBody('f) -} - -private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using - Quotes -): Expr[ast.Ident] = { - - import quotes.reflect.* - - val pIdx = n.value.getOrElse( - report.errorAndAbort(s"Param index ${n.show} is not know") - ) - extractTerm(f.asTerm) match { - case Lambda(vals, _) => - vals(pIdx) match { - case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) } - } - } -} - -private[minisql] def parseBody[X]( - x: Expr[X] -)(using Quotes): Expr[Ast] = { - import quotes.reflect.* - extractTerm(x.asTerm) match { - case Lambda(vals, body) => - astParser(body.asExpr) - case o => - report.errorAndAbort(s"Can only parse function") - } -} -private def isNumeric(x: Expr[?])(using Quotes): Boolean = { - import quotes.reflect.* - x.asTerm.tpe match { - case t if t <:< TypeRepr.of[Int] => true - case t if t <:< TypeRepr.of[Long] => true - case t if t <:< TypeRepr.of[Float] => true - case t if t <:< TypeRepr.of[Double] => true - case t if t <:< TypeRepr.of[BigDecimal] => true - case t if t <:< TypeRepr.of[BigInt] => true - case t if t <:< TypeRepr.of[java.lang.Integer] => true - case t if t <:< TypeRepr.of[java.lang.Long] => true - case t if t <:< TypeRepr.of[java.lang.Float] => true - case t if t <:< TypeRepr.of[java.lang.Double] => true - case t if t <:< TypeRepr.of[java.math.BigDecimal] => true - case _ => false - } -} - -private def identParser(using Quotes): Parser[ast.Ident] = { - import quotes.reflect.* - { (x: Expr[?]) => - extractTerm(x.asTerm) match { - case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) }) - case _ => None - } - }.unlift - -} - -private lazy val astParser: Quotes ?=> Parser[Ast] = { - identParser.orElse(propertyParser(astParser)) -} - -private object IsPropertySelect { - def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = { - import quotes.reflect.* - x.asTerm match { - case Select(x, n) => Some(x.asExpr, n) - case _ => None - } - } -} - -def propertyParser( - astParser: => Parser[Ast] -)(using Quotes): Parser[ast.Property] = { - case IsPropertySelect(expr, n) => - '{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) } -} - -def optionOperationParser( - astParser: => Parser[Ast] -)(using Quotes): Parser[ast.OptionOperation] = { - case '{ ($x: Option[t]).isEmpty } => - '{ ast.OptionIsEmpty(${ astParser(x) }) } -} - -def binaryOperationParser( - astParser: => Parser[Ast] -)(using Quotes): Parser[ast.BinaryOperation] = { - ??? -} - -private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { - import quotes.reflect.* - def unwrapTerm(t: Term): Term = t match { - case Inlined(_, _, o) => unwrapTerm(o) - case Block(Nil, last) => last - case Typed(t, _) => - unwrapTerm(t) - case Select(t, "$asInstanceOf$") => - unwrapTerm(t) - case TypeApply(t, _) => - unwrapTerm(t) - case o => o - } - unwrapTerm(x) -} diff --git a/src/main/scala/minisql/ReturnAction.scala b/src/main/scala/minisql/ReturnAction.scala new file mode 100644 index 0000000..63df7a0 --- /dev/null +++ b/src/main/scala/minisql/ReturnAction.scala @@ -0,0 +1,7 @@ +package minisql + +enum ReturnAction { + case ReturnNothing + case ReturnColumns(columns: List[String]) + case ReturnRecord +} diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 35feb70..48e3fae 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -1,6 +1,7 @@ package minisql.ast import minisql.NamingStrategy +import minisql.ParamEncoder import scala.quoted.* @@ -378,14 +379,20 @@ sealed trait ScalarLift extends Lift case class ScalarValueLift( name: String, - liftId: String + liftId: String, + value: Option[(Any, ParamEncoder[?])] ) extends ScalarLift +case class ScalarQueryLift( + name: String, + liftId: String +) extends ScalarLift {} + object ScalarLift { given ToExpr[ScalarLift] with { def apply(l: ScalarLift)(using Quotes) = l match { - case ScalarValueLift(n, id) => - '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } + case ScalarValueLift(n, id, v) => + '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, None) } } } } diff --git a/src/main/scala/minisql/ast/AstOps.scala b/src/main/scala/minisql/ast/AstOps.scala new file mode 100644 index 0000000..e0069ae --- /dev/null +++ b/src/main/scala/minisql/ast/AstOps.scala @@ -0,0 +1,94 @@ +package minisql.ast +object Implicits { + implicit class AstOps(body: Ast) { + private[minisql] def +||+(other: Ast) = + BinaryOperation(body, BooleanOperator.`||`, other) + private[minisql] def +&&+(other: Ast) = + BinaryOperation(body, BooleanOperator.`&&`, other) + private[minisql] def +==+(other: Ast) = + BinaryOperation(body, EqualityOperator.`==`, other) + private[minisql] def +!=+(other: Ast) = + BinaryOperation(body, EqualityOperator.`!=`, other) + } +} + +object +||+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, BooleanOperator.`||`, two) => Some((one, two)) + case _ => None + } + } +} + +object +&&+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, BooleanOperator.`&&`, two) => Some((one, two)) + case _ => None + } + } +} + +val EqOp = EqualityOperator.`==` +val NeqOp = EqualityOperator.`!=` + +object +==+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, EqOp, two) => Some((one, two)) + case _ => None + } + } +} + +object +!=+ { + def unapply(a: Ast): Option[(Ast, Ast)] = { + a match { + case BinaryOperation(one, NeqOp, two) => Some((one, two)) + case _ => None + } + } +} + +object IsNotNullCheck { + def apply(ast: Ast) = BinaryOperation(ast, EqualityOperator.`!=`, NullValue) + + def unapply(ast: Ast): Option[Ast] = { + ast match { + case BinaryOperation(cond, NeqOp, NullValue) => Some(cond) + case _ => None + } + } +} + +object IsNullCheck { + def apply(ast: Ast) = BinaryOperation(ast, EqOp, NullValue) + + def unapply(ast: Ast): Option[Ast] = { + ast match { + case BinaryOperation(cond, EqOp, NullValue) => Some(cond) + case _ => None + } + } +} + +object IfExistElseNull { + def apply(exists: Ast, `then`: Ast) = + If(IsNotNullCheck(exists), `then`, NullValue) + + def unapply(ast: Ast) = ast match { + case If(IsNotNullCheck(exists), t, NullValue) => Some((exists, t)) + case _ => None + } +} + +object IfExist { + def apply(exists: Ast, `then`: Ast, otherwise: Ast) = + If(IsNotNullCheck(exists), `then`, otherwise) + + def unapply(ast: Ast) = ast match { + case If(IsNotNullCheck(exists), t, e) => Some((exists, t, e)) + case _ => None + } +} diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index b66e4b9..26694fb 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -45,8 +45,9 @@ private given FromExpr[Infix] with { private given FromExpr[ScalarValueLift] with { def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] = x match { - case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } => - Some(ScalarValueLift(n, id)) + case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } => + // don't cared about value here, a little tricky + Some(ScalarValueLift(n, id, null)) } } diff --git a/src/main/scala/minisql/context/ReturnFieldCapability.scala b/src/main/scala/minisql/context/ReturnFieldCapability.scala new file mode 100644 index 0000000..cadbb78 --- /dev/null +++ b/src/main/scala/minisql/context/ReturnFieldCapability.scala @@ -0,0 +1,64 @@ +package minisql.context + +sealed trait ReturningCapability + +/** + * Data cannot be returned Insert/Update/etc... clauses in the target database. + */ +sealed trait ReturningNotSupported extends ReturningCapability + +/** + * Returning a single field from Insert/Update/etc... clauses is supported. This + * is the most common databases e.g. MySQL, Sqlite, and H2 (although as of + * h2database/h2database#1972 this may change. See #1496 regarding this. + * Typically this needs to be setup in the JDBC + * `connection.prepareStatement(sql, Array("returnColumn"))`. + */ +sealed trait ReturningSingleFieldSupported extends ReturningCapability + +/** + * Returning multiple columns from Insert/Update/etc... clauses is supported. + * This generally means that columns besides auto-incrementing ones can be + * returned. This is supported by Oracle. In JDBC, the following is done: + * `connection.prepareStatement(sql, Array("column1, column2, ..."))`. + */ +sealed trait ReturningMultipleFieldSupported extends ReturningCapability + +/** + * An actual `RETURNING` clause is supported in the SQL dialect of the specified + * database e.g. Postgres. this typically means that columns returned from + * Insert/Update/etc... clauses can have other database operations done on them + * such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. In + * JDBC, the following is done: `connection.prepareStatement(sql, + * Statement.RETURN_GENERATED_KEYS))`. + */ +sealed trait ReturningClauseSupported extends ReturningCapability + +object ReturningNotSupported extends ReturningNotSupported +object ReturningSingleFieldSupported extends ReturningSingleFieldSupported +object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported +object ReturningClauseSupported extends ReturningClauseSupported + +trait Capabilities { + def idiomReturningCapability: ReturningCapability +} + +trait CanReturnClause extends Capabilities { + override def idiomReturningCapability: ReturningClauseSupported = + ReturningClauseSupported +} + +trait CanReturnField extends Capabilities { + override def idiomReturningCapability: ReturningSingleFieldSupported = + ReturningSingleFieldSupported +} + +trait CanReturnMultiField extends Capabilities { + override def idiomReturningCapability: ReturningMultipleFieldSupported = + ReturningMultipleFieldSupported +} + +trait CannotReturn extends Capabilities { + override def idiomReturningCapability: ReturningNotSupported = + ReturningNotSupported +} diff --git a/src/main/scala/minisql/dsl.scala b/src/main/scala/minisql/dsl.scala index 43cc9c0..ace3d8f 100644 --- a/src/main/scala/minisql/dsl.scala +++ b/src/main/scala/minisql/dsl.scala @@ -29,20 +29,6 @@ private inline def transform[A, B](inline q1: Quoted)( inline def query[E](inline table: String): EntityQuery[E] = Entity(table, Nil) -inline def compile(inline x: Ast): Option[String] = ${ - compileImpl('{ x }) -} - -private def compileImpl( - x: Expr[Ast] -)(using Quotes): Expr[Option[String]] = { - import quotes.reflect.* - x.value match { - case Some(xv) => '{ Some(${ Expr(xv.toString()) }) } - case None => '{ None } - } -} - extension [A, B](inline f1: A => B) { private inline def param0 = parsing.parseParamAt(f1, 0) private inline def body = parsing.parseBody(f1) @@ -54,6 +40,6 @@ extension [A1, A2, B](inline f1: (A1, A2) => B) { private inline def body = parsing.parseBody(f1) } -case class Foo(id: Int) +def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException() -inline def queryFooId = query[Foo]("foo").map(_.id) +class NonQuotedException extends Exception("Cannot be used at runtime") diff --git a/src/main/scala/minisql/idiom/Idiom.scala b/src/main/scala/minisql/idiom/Idiom.scala new file mode 100644 index 0000000..7e0cd01 --- /dev/null +++ b/src/main/scala/minisql/idiom/Idiom.scala @@ -0,0 +1,23 @@ +package minisql.idiom + +import minisql.NamingStrategy +import minisql.ast._ +import minisql.context.Capabilities + +trait Idiom extends Capabilities { + + def emptySetContainsToken(field: Token): Token = StringToken("FALSE") + + def defaultAutoGeneratedToken(field: Token): Token = StringToken( + "DEFAULT VALUES" + ) + + def liftingPlaceholder(index: Int): String + + def translate(ast: Ast)(implicit naming: NamingStrategy): (Ast, Statement) + + def format(queryString: String): String = queryString + + def prepareForProbing(string: String): String + +} diff --git a/src/main/scala/minisql/idiom/LoadNaming.scala b/src/main/scala/minisql/idiom/LoadNaming.scala new file mode 100644 index 0000000..2405080 --- /dev/null +++ b/src/main/scala/minisql/idiom/LoadNaming.scala @@ -0,0 +1,28 @@ +package minisql.idiom + +import scala.util.Try +import scala.quoted._ +import minisql.NamingStrategy +import minisql.util.CollectTry +import minisql.util.LoadObject +import minisql.CompositeNamingStrategy + +object LoadNaming { + + def static[C](using Quotes, Type[C]): Try[NamingStrategy] = CollectTry { + strategies[C].map(LoadObject[NamingStrategy](_)) + }.map(NamingStrategy(_)) + + private def strategies[C](using Quotes, Type[C]) = { + import quotes.reflect.* + val isComposite = TypeRepr.of[C] <:< TypeRepr.of[CompositeNamingStrategy] + val ct = TypeRepr.of[C] + if (isComposite) { + ct.typeArgs.filterNot { t => + t =:= TypeRepr.of[NamingStrategy] && t =:= TypeRepr.of[Nothing] + } + } else { + List(ct) + } + } +} diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala new file mode 100644 index 0000000..a3e8902 --- /dev/null +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -0,0 +1,68 @@ +package minisql.idiom + +import minisql.ast._ +import minisql.util.Interleave +import minisql.idiom.StatementInterpolator._ +import scala.annotation.tailrec +import scala.collection.immutable.{Map => SMap} + +object ReifyStatement { + + def apply( + liftingPlaceholder: Int => String, + emptySetContainsToken: Token => Token, + statement: Statement, + liftMap: SMap[String, (Any, Any)] + ): (String, List[ScalarValueLift]) = { + val expanded = expandLiftings(statement, emptySetContainsToken, liftMap) + token2string(expanded, liftingPlaceholder) + } + + private def token2string( + token: Token, + liftingPlaceholder: Int => String + ): (String, List[ScalarValueLift]) = { + + val liftBuilder = List.newBuilder[ScalarValueLift] + val sqlBuilder = StringBuilder() + @tailrec + def loop( + workList: Seq[Token], + liftingSize: Int + ): Unit = workList match { + case Seq() => () + case head +: tail => + head match { + case StringToken(s2) => + sqlBuilder ++= s2 + loop(tail, liftingSize) + case SetContainsToken(a, op, b) => + loop( + stmt"$a $op ($b)" +: tail, + liftingSize + ) + case ScalarLiftToken(lift: ScalarValueLift) => + sqlBuilder ++= liftingPlaceholder(liftingSize) + liftBuilder += lift + loop(tail, liftingSize + 1) + case ScalarLiftToken(o) => + throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}") + case Statement(tokens) => + loop( + tokens.foldRight(tail)(_ +: _), + liftingSize + ) + } + } + loop(Vector(token), 0) + sqlBuilder.toString() -> liftBuilder.result() + } + + private def expandLiftings( + statement: Statement, + emptySetContainsToken: Token => Token, + liftMap: SMap[String, (Any, Any)] + ): (Token) = { + statement + } +} diff --git a/src/main/scala/minisql/idiom/Statement.scala b/src/main/scala/minisql/idiom/Statement.scala new file mode 100644 index 0000000..987ee2f --- /dev/null +++ b/src/main/scala/minisql/idiom/Statement.scala @@ -0,0 +1,47 @@ +package minisql.idiom + +import scala.quoted._ +import minisql.ast._ + +sealed trait Token + +case class StringToken(string: String) extends Token { + override def toString = string +} + +case class ScalarLiftToken(lift: ScalarLift) extends Token { + override def toString = s"lift(${lift.name})" +} + +case class Statement(tokens: List[Token]) extends Token { + override def toString = tokens.mkString +} + +case class SetContainsToken(a: Token, op: Token, b: Token) extends Token { + override def toString = s"${a.toString} ${op.toString} (${b.toString})" +} + +object Statement { + given ToExpr[Statement] with { + def apply(t: Statement)(using Quotes): Expr[Statement] = { + '{ Statement(${ Expr(t.tokens) }) } + } + } +} + +object Token { + + given ToExpr[Token] with { + def apply(t: Token)(using Quotes): Expr[Token] = { + t match { + case StringToken(s) => + '{ StringToken(${ Expr(s) }) } + case ScalarLiftToken(l) => + '{ ScalarLiftToken(${ Expr(l) }) } + case SetContainsToken(a, op, b) => + '{ SetContainsToken(${ Expr(a) }, ${ Expr(op) }, ${ Expr(b) }) } + case s: Statement => Expr(s) + } + } + } +} diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala new file mode 100644 index 0000000..b732da1 --- /dev/null +++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala @@ -0,0 +1,146 @@ +package minisql.idiom + +import minisql.ast._ +import minisql.util.Interleave +import minisql.util.Messages._ + +import scala.collection.mutable.ListBuffer + +object StatementInterpolator { + + trait Tokenizer[T] { + def token(v: T): Token + } + + object Tokenizer { + def apply[T](f: T => Token) = new Tokenizer[T] { + def token(v: T) = f(v) + } + def withFallback[T]( + fallback: Tokenizer[T] => Tokenizer[T] + )(pf: PartialFunction[T, Token]) = + new Tokenizer[T] { + private val stable = fallback(this) + override def token(v: T) = pf.applyOrElse(v, stable.token) + } + } + + implicit class TokenImplicit[T](v: T)(implicit tokenizer: Tokenizer[T]) { + def token = tokenizer.token(v) + } + + implicit def stringTokenizer: Tokenizer[String] = + Tokenizer[String] { + case string => StringToken(string) + } + + implicit def liftTokenizer: Tokenizer[Lift] = + Tokenizer[Lift] { + case lift: ScalarLift => ScalarLiftToken(lift) + case lift => + fail( + s"Can't tokenize a non-scalar lifting. ${lift.name}\n" + + s"\n" + + s"This might happen because:\n" + + s"* You are trying to insert or update an `Option[A]` field, but Scala infers the type\n" + + s" to `Some[A]` or `None.type`. For example:\n" + + s" run(query[Users].update(_.optionalField -> lift(Some(value))))" + + s" In that case, make sure the type is `Option`:\n" + + s" run(query[Users].update(_.optionalField -> lift(Some(value): Option[Int])))\n" + + s" or\n" + + s" run(query[Users].update(_.optionalField -> lift(Option(value))))\n" + + s"\n" + + s"* You are trying to insert or update whole Embedded case class. For example:\n" + + s" run(query[Users].update(_.embeddedCaseClass -> lift(someInstance)))\n" + + s" In that case, make sure you are updating individual columns, for example:\n" + + s" run(query[Users].update(\n" + + s" _.embeddedCaseClass.a -> lift(someInstance.a),\n" + + s" _.embeddedCaseClass.b -> lift(someInstance.b)\n" + + s" ))" + ) + } + + implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity) + implicit def statementTokenizer: Tokenizer[Statement] = + Tokenizer[Statement](identity) + implicit def stringTokenTokenizer: Tokenizer[StringToken] = + Tokenizer[StringToken](identity) + implicit def liftingTokenTokenizer: Tokenizer[ScalarLiftToken] = + Tokenizer[ScalarLiftToken](identity) + + extension [T](list: List[T]) { + def mkStmt(sep: String = ", ")(implicit tokenize: Tokenizer[T]) = { + val l1 = list.map(_.token) + val l2 = List.fill(l1.size - 1)(StringToken(sep)) + Statement(Interleave(l1, l2)) + } + } + + implicit def listTokenizer[T](implicit + tokenize: Tokenizer[T] + ): Tokenizer[List[T]] = + Tokenizer[List[T]] { + case list => list.mkStmt() + } + + extension (sc: StringContext) { + + def flatten(tokens: List[Token]): List[Token] = { + + def unestStatements(tokens: List[Token]): List[Token] = { + tokens.flatMap { + case Statement(innerTokens) => unestStatements(innerTokens) + case token => token :: Nil + } + } + + def mergeStringTokens(tokens: List[Token]): List[Token] = { + val (resultBuilder, leftTokens) = + tokens.foldLeft((new ListBuffer[Token], new ListBuffer[String])) { + case ((builder, acc), stringToken: StringToken) => + val str = stringToken.string + if (str.nonEmpty) + acc += stringToken.string + (builder, acc) + case ((builder, prev), b) if prev.isEmpty => + (builder += b.token, prev) + case ((builder, prev), b) /* if prev.nonEmpty */ => + builder += StringToken(prev.result().mkString) + builder += b.token + (builder, new ListBuffer[String]) + } + if (leftTokens.nonEmpty) + resultBuilder += StringToken(leftTokens.result().mkString) + resultBuilder.result() + } + + (unestStatements) + .andThen(mergeStringTokens) + .apply(tokens) + } + + def checkLengths( + args: scala.collection.Seq[Any], + parts: Seq[String] + ): Unit = + if (parts.length != args.length + 1) + throw new IllegalArgumentException( + "wrong number of arguments (" + args.length + + ") for interpolated string with " + parts.length + " parts" + ) + + def stmt(args: Token*): Statement = { + checkLengths(args, sc.parts) + val partsIterator = sc.parts.iterator + val argsIterator = args.iterator + val bldr = List.newBuilder[Token] + bldr += StringToken(partsIterator.next()) + while (argsIterator.hasNext) { + bldr += argsIterator.next() + bldr += StringToken(partsIterator.next()) + } + val tokens = flatten(bldr.result()) + Statement(tokens) + } + } +} diff --git a/src/main/scala/minisql/norm/AdHocReduction.scala b/src/main/scala/minisql/norm/AdHocReduction.scala new file mode 100644 index 0000000..10ad9a8 --- /dev/null +++ b/src/main/scala/minisql/norm/AdHocReduction.scala @@ -0,0 +1,52 @@ +package minisql.norm + +import minisql.ast.BinaryOperation +import minisql.ast.BooleanOperator +import minisql.ast.Filter +import minisql.ast.FlatMap +import minisql.ast.Map +import minisql.ast.Query +import minisql.ast.Union +import minisql.ast.UnionAll + +object AdHocReduction { + + def unapply(q: Query) = + q match { + + // --------------------------- + // *.filter + + // a.filter(b => c).filter(d => e) => + // a.filter(b => c && e[d := b]) + case Filter(Filter(a, b, c), d, e) => + val er = BetaReduction(e, d -> b) + Some(Filter(a, b, BinaryOperation(c, BooleanOperator.`&&`, er))) + + // --------------------------- + // flatMap.* + + // a.flatMap(b => c).map(d => e) => + // a.flatMap(b => c.map(d => e)) + case Map(FlatMap(a, b, c), d, e) => + Some(FlatMap(a, b, Map(c, d, e))) + + // a.flatMap(b => c).filter(d => e) => + // a.flatMap(b => c.filter(d => e)) + case Filter(FlatMap(a, b, c), d, e) => + Some(FlatMap(a, b, Filter(c, d, e))) + + // a.flatMap(b => c.union(d)) + // a.flatMap(b => c).union(a.flatMap(b => d)) + case FlatMap(a, b, Union(c, d)) => + Some(Union(FlatMap(a, b, c), FlatMap(a, b, d))) + + // a.flatMap(b => c.unionAll(d)) + // a.flatMap(b => c).unionAll(a.flatMap(b => d)) + case FlatMap(a, b, UnionAll(c, d)) => + Some(UnionAll(FlatMap(a, b, c), FlatMap(a, b, d))) + + case other => None + } + +} diff --git a/src/main/scala/minisql/norm/ApplyMap.scala b/src/main/scala/minisql/norm/ApplyMap.scala new file mode 100644 index 0000000..e5ddb0c --- /dev/null +++ b/src/main/scala/minisql/norm/ApplyMap.scala @@ -0,0 +1,160 @@ +package minisql.norm + +import minisql.ast._ + +object ApplyMap { + + private def isomorphic(e: Ast, c: Ast, alias: Ident) = + BetaReduction(e, alias -> c) == c + + object InfixedTailOperation { + + def hasImpureInfix(ast: Ast) = + CollectAst(ast) { + case i @ Infix(_, _, false, _) => i + }.nonEmpty + + def unapply(ast: Ast): Option[Ast] = + ast match { + case cc: CaseClass if hasImpureInfix(cc) => Some(cc) + case tup: Tuple if hasImpureInfix(tup) => Some(tup) + case p: Property if hasImpureInfix(p) => Some(p) + case b: BinaryOperation if hasImpureInfix(b) => Some(b) + case u: UnaryOperation if hasImpureInfix(u) => Some(u) + case i @ Infix(_, _, false, _) => Some(i) + case _ => None + } + } + + object MapWithoutInfixes { + def unapply(ast: Ast): Option[(Ast, Ident, Ast)] = + ast match { + case Map(a, b, InfixedTailOperation(c)) => None + case Map(a, b, c) => Some((a, b, c)) + case _ => None + } + } + + object DetachableMap { + def unapply(ast: Ast): Option[(Ast, Ident, Ast)] = + ast match { + case Map(a: GroupBy, b, c) => None + case Map(a, b, InfixedTailOperation(c)) => None + case Map(a, b, c) => Some((a, b, c)) + case _ => None + } + } + + def unapply(q: Query): Option[Query] = + q match { + + case Map(a: GroupBy, b, c) if (b == c) => None + case Map(a: DistinctOn, b, c) => None + case Map(a: Nested, b, c) if (b == c) => None + case Nested(DetachableMap(a: Join, b, c)) => None + + // map(i => (i.i, i.l)).distinct.map(x => (x._1, x._2)) => + // map(i => (i.i, i.l)).distinct + case Map(Distinct(DetachableMap(a, b, c)), d, e) if isomorphic(e, c, d) => + Some(Distinct(Map(a, b, c))) + + // a.map(b => c).map(d => e) => + // a.map(b => e[d := c]) + case before @ Map(MapWithoutInfixes(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + Some(Map(a, b, er)) + + // a.map(b => b) => + // a + case Map(a: Query, b, c) if (b == c) => + Some(a) + + // a.map(b => c).flatMap(d => e) => + // a.flatMap(b => e[d := c]) + case FlatMap(DetachableMap(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + Some(FlatMap(a, b, er)) + + // a.map(b => c).filter(d => e) => + // a.filter(b => e[d := c]).map(b => c) + case Filter(DetachableMap(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + Some(Map(Filter(a, b, er), b, c)) + + // a.map(b => c).sortBy(d => e) => + // a.sortBy(b => e[d := c]).map(b => c) + case SortBy(DetachableMap(a, b, c), d, e, f) => + val er = BetaReduction(e, d -> c) + Some(Map(SortBy(a, b, er, f), b, c)) + + // a.map(b => c).sortBy(d => e).distinct => + // a.sortBy(b => e[d := c]).map(b => c).distinct + case SortBy(Distinct(DetachableMap(a, b, c)), d, e, f) => + val er = BetaReduction(e, d -> c) + Some(Distinct(Map(SortBy(a, b, er, f), b, c))) + + // a.map(b => c).groupBy(d => e) => + // a.groupBy(b => e[d := c]).map(x => (x._1, x._2.map(b => c))) + case GroupBy(DetachableMap(a, b, c), d, e) => + val er = BetaReduction(e, d -> c) + val x = Ident("x") + val x1 = Property( + Ident("x"), + "_1" + ) // These introduced property should not be renamed + val x2 = Property(Ident("x"), "_2") // due to any naming convention. + val body = Tuple(List(x1, Map(x2, b, c))) + Some(Map(GroupBy(a, b, er), x, body)) + + // a.map(b => c).drop(d) => + // a.drop(d).map(b => c) + case Drop(DetachableMap(a, b, c), d) => + Some(Map(Drop(a, d), b, c)) + + // a.map(b => c).take(d) => + // a.drop(d).map(b => c) + case Take(DetachableMap(a, b, c), d) => + Some(Map(Take(a, d), b, c)) + + // a.map(b => c).nested => + // a.nested.map(b => c) + case Nested(DetachableMap(a, b, c)) => + Some(Map(Nested(a), b, c)) + + // a.map(b => c).*join(d.map(e => f)).on((iA, iB) => on) + // a.*join(d).on((b, e) => on[iA := c, iB := f]).map(t => (c[b := t._1], f[e := t._2])) + case Join( + tpe, + DetachableMap(a, b, c), + DetachableMap(d, e, f), + iA, + iB, + on + ) => + val onr = BetaReduction(on, iA -> c, iB -> f) + val t = Ident("t") + val t1 = BetaReduction(c, b -> Property(t, "_1")) + val t2 = BetaReduction(f, e -> Property(t, "_2")) + Some(Map(Join(tpe, a, d, b, e, onr), t, Tuple(List(t1, t2)))) + + // a.*join(b.map(c => d)).on((iA, iB) => on) + // a.*join(b).on((iA, c) => on[iB := d]).map(t => (t._1, d[c := t._2])) + case Join(tpe, a, DetachableMap(b, c, d), iA, iB, on) => + val onr = BetaReduction(on, iB -> d) + val t = Ident("t") + val t1 = Property(t, "_1") + val t2 = BetaReduction(d, c -> Property(t, "_2")) + Some(Map(Join(tpe, a, b, iA, c, onr), t, Tuple(List(t1, t2)))) + + // a.map(b => c).*join(d).on((iA, iB) => on) + // a.*join(d).on((b, iB) => on[iA := c]).map(t => (c[b := t._1], t._2)) + case Join(tpe, DetachableMap(a, b, c), d, iA, iB, on) => + val onr = BetaReduction(on, iA -> c) + val t = Ident("t") + val t1 = BetaReduction(c, b -> Property(t, "_1")) + val t2 = Property(t, "_2") + Some(Map(Join(tpe, a, d, b, iB, onr), t, Tuple(List(t1, t2)))) + + case other => None + } +} diff --git a/src/main/scala/minisql/norm/AttachToEntity.scala b/src/main/scala/minisql/norm/AttachToEntity.scala new file mode 100644 index 0000000..e5608e9 --- /dev/null +++ b/src/main/scala/minisql/norm/AttachToEntity.scala @@ -0,0 +1,48 @@ +package minisql.norm + +import minisql.util.Messages.fail +import minisql.ast._ + +object AttachToEntity { + + private object IsEntity { + def unapply(q: Ast): Option[Ast] = + q match { + case q: Entity => Some(q) + case q: Infix => Some(q) + case _ => None + } + } + + def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)( + q: Ast + ): Ast = + q match { + + case Map(IsEntity(a), b, c) => Map(f(a, b), b, c) + case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c) + case ConcatMap(IsEntity(a), b, c) => ConcatMap(f(a, b), b, c) + case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c) + case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d) + case DistinctOn(IsEntity(a), b, c) => DistinctOn(f(a, b), b, c) + + case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join | + _: FlatJoin => + f(q, alias.getOrElse(Ident("x"))) + + case Map(a: Query, b, c) => Map(apply(f, Some(b))(a), b, c) + case FlatMap(a: Query, b, c) => FlatMap(apply(f, Some(b))(a), b, c) + case ConcatMap(a: Query, b, c) => ConcatMap(apply(f, Some(b))(a), b, c) + case Filter(a: Query, b, c) => Filter(apply(f, Some(b))(a), b, c) + case SortBy(a: Query, b, c, d) => SortBy(apply(f, Some(b))(a), b, c, d) + case Take(a: Query, b) => Take(apply(f, alias)(a), b) + case Drop(a: Query, b) => Drop(apply(f, alias)(a), b) + case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a)) + case Distinct(a: Query) => Distinct(apply(f, alias)(a)) + case DistinctOn(a: Query, b, c) => DistinctOn(apply(f, Some(b))(a), b, c) + + case IsEntity(q) => f(q, alias.getOrElse(Ident("x"))) + + case other => fail(s"Can't find an 'Entity' in '$q'") + } +} diff --git a/src/main/scala/minisql/util/BetaReduction.scala b/src/main/scala/minisql/norm/BetaReduction.scala similarity index 99% rename from src/main/scala/minisql/util/BetaReduction.scala rename to src/main/scala/minisql/norm/BetaReduction.scala index 0940564..868b021 100644 --- a/src/main/scala/minisql/util/BetaReduction.scala +++ b/src/main/scala/minisql/norm/BetaReduction.scala @@ -1,6 +1,6 @@ -package minisql.util +package minisql.norm -import minisql.ast.* +import minisql.ast._ import scala.collection.immutable.{Map => IMap} case class BetaReduction(replacements: Replacements) diff --git a/src/main/scala/minisql/norm/ConcatBehavior.scala b/src/main/scala/minisql/norm/ConcatBehavior.scala new file mode 100644 index 0000000..3547b1d --- /dev/null +++ b/src/main/scala/minisql/norm/ConcatBehavior.scala @@ -0,0 +1,7 @@ +package minisql.norm + +trait ConcatBehavior +object ConcatBehavior { + case object AnsiConcat extends ConcatBehavior + case object NonAnsiConcat extends ConcatBehavior +} diff --git a/src/main/scala/minisql/norm/EqualityBehavior.scala b/src/main/scala/minisql/norm/EqualityBehavior.scala new file mode 100644 index 0000000..ee1ff38 --- /dev/null +++ b/src/main/scala/minisql/norm/EqualityBehavior.scala @@ -0,0 +1,7 @@ +package minisql.norm + +trait EqualityBehavior +object EqualityBehavior { + case object AnsiEquality extends EqualityBehavior + case object NonAnsiEquality extends EqualityBehavior +} diff --git a/src/main/scala/minisql/norm/ExpandReturning.scala b/src/main/scala/minisql/norm/ExpandReturning.scala new file mode 100644 index 0000000..32a886f --- /dev/null +++ b/src/main/scala/minisql/norm/ExpandReturning.scala @@ -0,0 +1,74 @@ +package minisql.norm + +import minisql.ReturnAction.ReturnColumns +import minisql.{NamingStrategy, ReturnAction} +import minisql.ast._ +import minisql.context.{ + ReturningClauseSupported, + ReturningMultipleFieldSupported, + ReturningNotSupported, + ReturningSingleFieldSupported +} +import minisql.idiom.{Idiom, Statement} + +/** + * Take the `.returning` part in a query that contains it and return the array + * of columns representing of the returning seccovtion with any other operations + * etc... that they might contain. + */ +object ExpandReturning { + + def applyMap( + returning: ReturningAction + )(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = { + val initialExpand = ExpandReturning.apply(returning)(idiom, naming) + + idiom.idiomReturningCapability match { + case ReturningClauseSupported => + ReturnAction.ReturnRecord + case ReturningMultipleFieldSupported => + ReturnColumns(initialExpand.map { + case (ast, statement) => f(ast, statement) + }) + case ReturningSingleFieldSupported => + if (initialExpand.length == 1) + ReturnColumns(initialExpand.map { + case (ast, statement) => f(ast, statement) + }) + else + throw new IllegalArgumentException( + s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified." + ) + case ReturningNotSupported => + throw new IllegalArgumentException( + s"RETURNING columns are not allowed in the ${idiom} dialect." + ) + } + } + + def apply( + returning: ReturningAction + )(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = { + val ReturningAction(_, alias, properties) = returning: @unchecked + + // Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1)))) + // => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))) + val dePropertized = + Transform(properties) { + case `alias` => ExternalIdent(alias.name) + } + + val aliasName = alias.name + + // Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))) + // => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))) + val deTuplified = dePropertized match { + case Tuple(values) => values + case CaseClass(values) => values.map(_._2) + case other => List(other) + } + + implicit val namingStrategy: NamingStrategy = naming + deTuplified.map(v => idiom.translate(v)) + } +} diff --git a/src/main/scala/minisql/norm/FlattenOptionOperation.scala b/src/main/scala/minisql/norm/FlattenOptionOperation.scala new file mode 100644 index 0000000..50ba857 --- /dev/null +++ b/src/main/scala/minisql/norm/FlattenOptionOperation.scala @@ -0,0 +1,108 @@ +package minisql.norm + +import minisql.ast.* +import minisql.ast.Implicits.* +import minisql.norm.ConcatBehavior.NonAnsiConcat + +class FlattenOptionOperation(concatBehavior: ConcatBehavior) + extends StatelessTransformer { + + private def emptyOrNot(b: Boolean, ast: Ast) = + if (b) OptionIsEmpty(ast) else OptionNonEmpty(ast) + + def uncheckedReduction(ast: Ast, alias: Ident, body: Ast) = + apply(BetaReduction(body, alias -> ast)) + + def uncheckedForall(ast: Ast, alias: Ident, body: Ast) = { + val reduced = BetaReduction(body, alias -> ast) + apply((IsNullCheck(ast) +||+ reduced): Ast) + } + + def containsNonFallthroughElement(ast: Ast) = + CollectAst(ast) { + case If(_, _, _) => true + case Infix(_, _, _, _) => true + case BinaryOperation(_, StringOperator.`concat`, _) + if (concatBehavior == NonAnsiConcat) => + true + }.nonEmpty + + override def apply(ast: Ast): Ast = + ast match { + + case OptionTableFlatMap(ast, alias, body) => + uncheckedReduction(ast, alias, body) + + case OptionTableMap(ast, alias, body) => + uncheckedReduction(ast, alias, body) + + case OptionTableExists(ast, alias, body) => + uncheckedReduction(ast, alias, body) + + case OptionTableForall(ast, alias, body) => + uncheckedForall(ast, alias, body) + + case OptionFlatten(ast) => + apply(ast) + + case OptionSome(ast) => + apply(ast) + + case OptionApply(ast) => + apply(ast) + + case OptionOrNull(ast) => + apply(ast) + + case OptionGetOrNull(ast) => + apply(ast) + + case OptionNone => NullValue + + case OptionGetOrElse(OptionMap(ast, alias, body), Constant(b: Boolean)) => + apply((BetaReduction(body, alias -> ast) +||+ emptyOrNot(b, ast)): Ast) + + case OptionGetOrElse(ast, body) => + apply(If(IsNotNullCheck(ast), ast, body)) + + case OptionFlatMap(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduced = BetaReduction(body, alias -> ast) + apply(IfExistElseNull(ast, reduced)) + } else { + uncheckedReduction(ast, alias, body) + } + + case OptionMap(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduced = BetaReduction(body, alias -> ast) + apply(IfExistElseNull(ast, reduced)) + } else { + uncheckedReduction(ast, alias, body) + } + + case OptionForall(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduction = BetaReduction(body, alias -> ast) + apply( + (IsNullCheck(ast) +||+ (IsNotNullCheck(ast) +&&+ reduction)): Ast + ) + } else { + uncheckedForall(ast, alias, body) + } + + case OptionExists(ast, alias, body) => + if (containsNonFallthroughElement(body)) { + val reduction = BetaReduction(body, alias -> ast) + apply((IsNotNullCheck(ast) +&&+ reduction): Ast) + } else { + uncheckedReduction(ast, alias, body) + } + + case OptionContains(ast, body) => + apply((ast +==+ body): Ast) + + case other => + super.apply(other) + } +} diff --git a/src/main/scala/minisql/norm/NestImpureMappedInfix.scala b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala new file mode 100644 index 0000000..789efba --- /dev/null +++ b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala @@ -0,0 +1,76 @@ +package minisql.norm + +import minisql.ast._ + +/** + * A problem occurred in the original way infixes were done in that it was + * assumed that infix clauses represented pure functions. While this is true of + * many UDFs (e.g. `CONCAT`, `GETDATE`) it is certainly not true of many others + * e.g. `RAND()`, and most importantly `RANK()`. For this reason, the operations + * that are done in `ApplyMap` on standard AST `Map` clauses cannot be done + * therefore additional safety checks were introduced there in order to assure + * this does not happen. In addition to this however, it is necessary to add + * this normalization step which inserts `Nested` AST elements in every map that + * contains impure infix. See more information and examples in #1534. + */ +object NestImpureMappedInfix extends StatelessTransformer { + + // Are there any impure infixes that exist inside the specified ASTs + def hasInfix(asts: Ast*): Boolean = + asts.exists(ast => + CollectAst(ast) { + case i @ Infix(_, _, false, _) => i + }.nonEmpty + ) + + // Continue exploring into the Map to see if there are additional impure infix clauses inside. + private def applyInside(m: Map) = + Map(apply(m.query), m.alias, m.body) + + override def apply(ast: Ast): Ast = + ast match { + // If there is already a nested clause inside the map, there is no reason to insert another one + case Nested(Map(inner, a, b)) => + Nested(Map(apply(inner), a, b)) + + case m @ Map(_, x, cc @ CaseClass(values)) if hasInfix(cc) => // Nested(m) + Map( + Nested(applyInside(m)), + x, + CaseClass(values.map { + case (name, _) => + ( + name, + Property(x, name) + ) // mappings of nested-query case class properties should not be renamed + }) + ) + + case m @ Map(_, x, tup @ Tuple(values)) if hasInfix(tup) => + Map( + Nested(applyInside(m)), + x, + Tuple(values.zipWithIndex.map { + case (_, i) => + Property( + x, + s"_${i + 1}" + ) // mappings of nested-query tuple properties should not be renamed + }) + ) + + case m @ Map(_, x, i @ Infix(_, _, false, _)) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case m @ Map(_, x, Property(prop, _)) if hasInfix(prop) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case m @ Map(_, x, BinaryOperation(a, _, b)) if hasInfix(a, b) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case m @ Map(_, x, UnaryOperation(_, a)) if hasInfix(a) => + Map(Nested(applyInside(m)), x, Property(x, "_1")) + + case other => super.apply(other) + } +} diff --git a/src/main/scala/minisql/norm/Normalize.scala b/src/main/scala/minisql/norm/Normalize.scala new file mode 100644 index 0000000..d3d8d67 --- /dev/null +++ b/src/main/scala/minisql/norm/Normalize.scala @@ -0,0 +1,51 @@ +package minisql.norm + +import minisql.ast.Ast +import minisql.ast.Query +import minisql.ast.StatelessTransformer +import minisql.norm.capture.AvoidCapture +import minisql.ast.Action +import minisql.util.Messages.trace +import minisql.util.Messages.TraceType.Normalizations + +import scala.annotation.tailrec + +object Normalize extends StatelessTransformer { + + override def apply(q: Ast): Ast = + super.apply(BetaReduction(q)) + + override def apply(q: Action): Action = + NormalizeReturning(super.apply(q)) + + override def apply(q: Query): Query = + norm(AvoidCapture(q)) + + private def traceNorm[T](label: String) = + trace[T](s"${label} (Normalize)", 1, Normalizations) + + @tailrec + private def norm(q: Query): Query = + q match { + case NormalizeNestedStructures(query) => + traceNorm("NormalizeNestedStructures")(query) + norm(query) + case ApplyMap(query) => + traceNorm("ApplyMap")(query) + norm(query) + case SymbolicReduction(query) => + traceNorm("SymbolicReduction")(query) + norm(query) + case AdHocReduction(query) => + traceNorm("AdHocReduction")(query) + norm(query) + case OrderTerms(query) => + traceNorm("OrderTerms")(query) + norm(query) + case NormalizeAggregationIdent(query) => + traceNorm("NormalizeAggregationIdent")(query) + norm(query) + case other => + other + } +} diff --git a/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala new file mode 100644 index 0000000..c451ee4 --- /dev/null +++ b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala @@ -0,0 +1,29 @@ +package minisql.norm + +import minisql.ast._ + +object NormalizeAggregationIdent { + + def unapply(q: Query) = + q match { + + // a => a.b.map(x => x.c).agg => + // a => a.b.map(a => a.c).agg + case Aggregation( + op, + Map( + p @ Property(i: Ident, _), + mi, + Property.Opinionated(_: Ident, n, renameable, visibility) + ) + ) if i != mi => + Some( + Aggregation( + op, + Map(p, i, Property.Opinionated(i, n, renameable, visibility)) + ) + ) // in example aove, if c in x.c is fixed c in a.c should also be + + case _ => None + } +} diff --git a/src/main/scala/minisql/norm/NormalizeNestedStructures.scala b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala new file mode 100644 index 0000000..603c411 --- /dev/null +++ b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala @@ -0,0 +1,47 @@ +package minisql.norm + +import minisql.ast._ + +object NormalizeNestedStructures { + + def unapply(q: Query): Option[Query] = + q match { + case e: Entity => None + case Map(a, b, c) => apply(a, c)(Map(_, b, _)) + case FlatMap(a, b, c) => apply(a, c)(FlatMap(_, b, _)) + case ConcatMap(a, b, c) => apply(a, c)(ConcatMap(_, b, _)) + case Filter(a, b, c) => apply(a, c)(Filter(_, b, _)) + case SortBy(a, b, c, d) => apply(a, c)(SortBy(_, b, _, d)) + case GroupBy(a, b, c) => apply(a, c)(GroupBy(_, b, _)) + case Aggregation(a, b) => apply(b)(Aggregation(a, _)) + case Take(a, b) => apply(a, b)(Take.apply) + case Drop(a, b) => apply(a, b)(Drop.apply) + case Union(a, b) => apply(a, b)(Union.apply) + case UnionAll(a, b) => apply(a, b)(UnionAll.apply) + case Distinct(a) => apply(a)(Distinct.apply) + case DistinctOn(a, b, c) => apply(a, c)(DistinctOn(_, b, _)) + case Nested(a) => apply(a)(Nested.apply) + case FlatJoin(t, a, iA, on) => + (Normalize(a), Normalize(on)) match { + case (`a`, `on`) => None + case (a, on) => Some(FlatJoin(t, a, iA, on)) + } + case Join(t, a, b, iA, iB, on) => + (Normalize(a), Normalize(b), Normalize(on)) match { + case (`a`, `b`, `on`) => None + case (a, b, on) => Some(Join(t, a, b, iA, iB, on)) + } + } + + private def apply(a: Ast)(f: Ast => Query) = + (Normalize(a)) match { + case (`a`) => None + case (a) => Some(f(a)) + } + + private def apply(a: Ast, b: Ast)(f: (Ast, Ast) => Query) = + (Normalize(a), Normalize(b)) match { + case (`a`, `b`) => None + case (a, b) => Some(f(a, b)) + } +} diff --git a/src/main/scala/minisql/norm/NormalizeReturning.scala b/src/main/scala/minisql/norm/NormalizeReturning.scala new file mode 100644 index 0000000..43fc241 --- /dev/null +++ b/src/main/scala/minisql/norm/NormalizeReturning.scala @@ -0,0 +1,154 @@ +package minisql.norm + +import minisql.ast._ +import minisql.norm.capture.AvoidAliasConflict + +/** + * When actions are used with a `.returning` clause, remove the columns used in + * the returning clause from the action. E.g. for `insert(Person(id, + * name)).returning(_.id)` remove the `id` column from the original insert. + */ +object NormalizeReturning { + + def apply(e: Action): Action = { + e match { + case ReturningGenerated(a: Action, alias, body) => + // De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action. + // This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres. + // For example: + // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max)) + // Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally + // be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case. + // In order to fix this, we need to change `t` into a different alias. + val newBody = dealiasBody(body, alias) + ReturningGenerated(apply(a, newBody, alias), alias, newBody) + + // For a regular return clause, do not need to exclude assignments from insertion however, we still + // need to de-alias the Action body in case conflicts result. For example the following query: + // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max)) + // would incorrectly be interpreted as: + // INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t + // whereas it should be: + // INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1) + case Returning(a: Action, alias, body) => + val newBody = dealiasBody(body, alias) + Returning(a, alias, newBody) + + case _ => e + } + } + + /** + * In some situations, a query can exist inside of a `returning` clause. In + * this case, we need to rename if the aliases used in that query override the + * alias used in the `returning` clause otherwise they will be treated as + * returning-clause aliases ExpandReturning (i.e. they will become + * ExternalAlias instances) and later be tokenized incorrectly. + */ + private def dealiasBody(body: Ast, alias: Ident): Ast = + Transform(body) { + case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias)) + } + + private def apply(e: Action, body: Ast, returningIdent: Ident): Action = + e match { + case Insert(query, assignments) => + Insert(query, filterReturnedColumn(assignments, body, returningIdent)) + case Update(query, assignments) => + Update(query, filterReturnedColumn(assignments, body, returningIdent)) + case OnConflict(a: Action, target, act) => + OnConflict(apply(a, body, returningIdent), target, act) + case _ => e + } + + private def filterReturnedColumn( + assignments: List[Assignment], + column: Ast, + returningIdent: Ident + ): List[Assignment] = + assignments.flatMap(filterReturnedColumn(_, column, returningIdent)) + + /** + * In situations like Property(Property(ident, foo), bar) pull out the + * inner-most ident + */ + object NestedProperty { + def unapply(ast: Property): Option[Ast] = { + ast match { + case p @ Property(subAst, _) => Some(innerMost(subAst)) + } + } + + private def innerMost(ast: Ast): Ast = ast match { + case Property(inner, _) => innerMost(inner) + case other => other + } + } + + /** + * Remove the specified column from the assignment. For example, in a query + * like `insert(Person(id, name)).returning(r => r.id)` we need to remove the + * `id` column from the insertion. The value of the `column:Ast` in this case + * will be `Property(Ident(r), id)` and the values fo the assignment `p1` + * property will typically be `v.id` and `v.name` (the `v` variable is a + * default used for `insert` queries). + */ + private def filterReturnedColumn( + assignment: Assignment, + body: Ast, + returningIdent: Ident + ): Option[Assignment] = + assignment match { + case Assignment(_, p1: Property, _) => { + // Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but + // if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2)) + // or it can even be a query since queries are allowed to be in return sections e.g: + // query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max)) + // In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment + // in order to know what to exclude. + val matchedProps = + CollectAst(body) { + // case prop @ NestedProperty(`returningIdent`) => prop + case prop @ NestedProperty(Ident(name)) + if (name == returningIdent.name) => + prop + case prop @ NestedProperty(ExternalIdent(name)) + if (name == returningIdent.name) => + prop + } + + if ( + matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp)) + ) + None + else + Some(assignment) + } + case assignment => Some(assignment) + } + + object SomeIdent { + def unapply(ast: Ast): Option[Ast] = + ast match { + case id: Ident => Some(id) + case id: ExternalIdent => Some(id) + case _ => None + } + } + + /** + * Is it the same property (but possibly of a different identity). E.g. + * `p.foo.bar` and `v.foo.bar` + */ + private def isSameProperties(p1: Property, p2: Property): Boolean = + (p1.ast, p2.ast) match { + case (SomeIdent(_), SomeIdent(_)) => + p1.name == p2.name + // If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the + // outer properties are the same before moving on to the inner ones. + case (pp1: Property, pp2: Property) if (p1.name == p2.name) => + isSameProperties(pp1, pp2) + case _ => + false + } +} diff --git a/src/main/scala/minisql/norm/OrderTerms.scala b/src/main/scala/minisql/norm/OrderTerms.scala new file mode 100644 index 0000000..22422fa --- /dev/null +++ b/src/main/scala/minisql/norm/OrderTerms.scala @@ -0,0 +1,29 @@ +package minisql.norm + +import minisql.ast._ + +object OrderTerms { + + def unapply(q: Query) = + q match { + + case Take(Map(a: GroupBy, b, c), d) => None + + // a.sortBy(b => c).filter(d => e) => + // a.filter(d => e).sortBy(b => c) + case Filter(SortBy(a, b, c, d), e, f) => + Some(SortBy(Filter(a, e, f), b, c, d)) + + // a.flatMap(b => c).take(n).map(d => e) => + // a.flatMap(b => c).map(d => e).take(n) + case Map(Take(fm: FlatMap, n), ma, mb) => + Some(Take(Map(fm, ma, mb), n)) + + // a.flatMap(b => c).drop(n).map(d => e) => + // a.flatMap(b => c).map(d => e).drop(n) + case Map(Drop(fm: FlatMap, n), ma, mb) => + Some(Drop(Map(fm, ma, mb), n)) + + case other => None + } +} diff --git a/src/main/scala/minisql/norm/RenameProperties.scala b/src/main/scala/minisql/norm/RenameProperties.scala new file mode 100644 index 0000000..53d135c --- /dev/null +++ b/src/main/scala/minisql/norm/RenameProperties.scala @@ -0,0 +1,491 @@ +package minisql.norm + +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.Visible +import minisql.ast._ +import minisql.util.Interpolator + +object RenameProperties extends StatelessTransformer { + val interp = new Interpolator(3) + import interp._ + def traceDifferent(one: Any, two: Any) = + if (one != two) + trace"Replaced $one with $two".andLog() + else + trace"Replacements did not match".andLog() + + override def apply(q: Query): Query = + applySchemaOnly(q) + + override def apply(q: Action): Action = + applySchema(q) match { + case (q, schema) => q + } + + override def apply(e: Operation): Operation = + e match { + case UnaryOperation(o, c: Query) => + UnaryOperation(o, applySchemaOnly(apply(c))) + case _ => super.apply(e) + } + + private def applySchema(q: Action): (Action, Schema) = + q match { + case Insert(q: Query, assignments) => + applySchema(q, assignments, Insert.apply) + case Update(q: Query, assignments) => + applySchema(q, assignments, Update.apply) + case Delete(q: Query) => + applySchema(q) match { + case (q, schema) => (Delete(q), schema) + } + case Returning(action: Action, alias, body) => + applySchema(action) match { + case (action, schema) => + val replace = + trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn` + replacements(alias, schema) + val bodyr = BetaReduction(body, replace*) + traceDifferent(body, bodyr) + (Returning(action, alias, bodyr), schema) + } + case ReturningGenerated(action: Action, alias, body) => + applySchema(action) match { + case (action, schema) => + val replace = + trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn` + replacements(alias, schema) + val bodyr = BetaReduction(body, replace*) + traceDifferent(body, bodyr) + (ReturningGenerated(action, alias, bodyr), schema) + } + case OnConflict(a: Action, target, act) => + applySchema(a) match { + case (action, schema) => + val targetR = target match { + case OnConflict.Properties(props) => + val propsR = props.map { prop => + val replace = + trace"Finding Replacements for $props inside ${prop.ast} using schema $schema:" `andReturn` + replacements( + prop.ast, + schema + ) // A BetaReduction on a Property will always give back a Property + BetaReduction(prop, replace*).asInstanceOf[Property] + } + traceDifferent(props, propsR) + OnConflict.Properties(propsR) + case OnConflict.NoTarget => target + } + val actR = act match { + case OnConflict.Update(assignments) => + OnConflict.Update(replaceAssignments(assignments, schema)) + case _ => act + } + (OnConflict(action, targetR, actR), schema) + } + case q => (q, TupleSchema.empty) + } + + private def replaceAssignments( + a: List[Assignment], + schema: Schema + ): List[Assignment] = + a.map { + case Assignment(alias, prop, value) => + val replace = + trace"Finding Replacements for $prop inside $alias using schema $schema:" `andReturn` + replacements(alias, schema) + val propR = BetaReduction(prop, replace*) + traceDifferent(prop, propR) + val valueR = BetaReduction(value, replace*) + traceDifferent(value, valueR) + Assignment(alias, propR, valueR) + } + + private def applySchema( + q: Query, + a: List[Assignment], + f: (Query, List[Assignment]) => Action + ): (Action, Schema) = + applySchema(q) match { + case (q, schema) => + (f(q, replaceAssignments(a, schema)), schema) + } + + private def applySchemaOnly(q: Query): Query = + applySchema(q) match { + case (q, _) => q + } + + object TupleIndex { + def unapply(s: String): Option[Int] = + if (s.matches("_[0-9]*")) + Some(s.drop(1).toInt - 1) + else + None + } + + sealed trait Schema { + def lookup(property: List[String]): Option[Schema] = + (property, this) match { + case (Nil, schema) => + trace"Nil at $property returning " `andReturn` + Some(schema) + case (path, e @ EntitySchema(_)) => + trace"Entity at $path returning " `andReturn` + Some(e.subSchemaOrEmpty(path)) + case (head :: tail, CaseClassSchema(props)) if (props.contains(head)) => + trace"Case class at $property returning " `andReturn` + props(head).lookup(tail) + case (TupleIndex(idx) :: tail, TupleSchema(values)) + if values.contains(idx) => + trace"Tuple at at $property returning " `andReturn` + values(idx).lookup(tail) + case _ => + trace"Nothing found at $property returning " `andReturn` + None + } + } + + // Represents a nested property path to an identity i.e. Property(Property(... Ident(), ...)) + object PropertyMatroshka { + + def traverse(initial: Property): Option[(Ident, List[String])] = + initial match { + // If it's a nested-property walk inside and append the name to the result (if something is returned) + case Property(inner: Property, name) => + traverse(inner).map { case (id, list) => (id, list :+ name) } + // If it's a property with ident in the core, return that + case Property(id: Ident, name) => + Some((id, List(name))) + // Otherwise an ident property is not inside so don't return anything + case _ => + None + } + + def unapply(ast: Ast): Option[(Ident, List[String])] = + ast match { + case p: Property => traverse(p) + case _ => None + } + + } + + def protractSchema( + body: Ast, + ident: Ident, + schema: Schema + ): Option[Schema] = { + + def protractSchemaRecurse(body: Ast, schema: Schema): Option[Schema] = + body match { + // if any values yield a sub-schema which is not an entity, recurse into that + case cc @ CaseClass(values) => + trace"Protracting CaseClass $cc into new schema:" `andReturn` + CaseClassSchema( + values.collect { + case (name, innerBody @ HierarchicalAstEntity()) => + (name, protractSchemaRecurse(innerBody, schema)) + // pass the schema into a recursive call an extract from it when we non tuple/caseclass element + case (name, innerBody @ PropertyMatroshka(`ident`, path)) => + (name, protractSchemaRecurse(innerBody, schema)) + // we have reached an ident i.e. recurse to pass the current schema into the case class + case (name, `ident`) => + (name, protractSchemaRecurse(ident, schema)) + }.collect { + case (name, Some(subSchema)) => (name, subSchema) + } + ).notEmpty + case tup @ Tuple(values) => + trace"Protracting Tuple $tup into new schema:" `andReturn` + TupleSchema + .fromIndexes( + values.zipWithIndex.collect { + case (innerBody @ HierarchicalAstEntity(), index) => + (index, protractSchemaRecurse(innerBody, schema)) + // pass the schema into a recursive call an extract from it when we non tuple/caseclass element + case (innerBody @ PropertyMatroshka(`ident`, path), index) => + (index, protractSchemaRecurse(innerBody, schema)) + // we have reached an ident i.e. recurse to pass the current schema into the tuple + case (`ident`, index) => + (index, protractSchemaRecurse(ident, schema)) + }.collect { + case (index, Some(subSchema)) => (index, subSchema) + } + ) + .notEmpty + + case prop @ PropertyMatroshka(`ident`, path) => + trace"Protraction completed schema path $prop at the schema $schema pointing to:" `andReturn` + schema match { + // case e: EntitySchema => Some(e) + case _ => schema.lookup(path) + } + case `ident` => + trace"Protraction completed with the mapping identity $ident at the schema:" `andReturn` + Some(schema) + case other => + trace"Protraction DID NOT find a sub schema, it completed with $other at the schema:" `andReturn` + Some(schema) + } + + protractSchemaRecurse(body, schema) + } + + case object EmptySchema extends Schema + case class EntitySchema(e: Entity) extends Schema { + def noAliases = e.properties.isEmpty + + private def subSchema(path: List[String]) = + EntitySchema( + Entity( + s"sub-${e.name}", + e.properties.flatMap { + case PropertyAlias(aliasPath, alias) => + if (aliasPath == path) + List(PropertyAlias(aliasPath, alias)) + else if (aliasPath.startsWith(path)) + List(PropertyAlias(aliasPath.diff(path), alias)) + else + List() + } + ) + ) + + def subSchemaOrEmpty(path: List[String]): Schema = + trace"Creating sub-schema for entity $e at path $path will be" andReturn { + val sub = subSchema(path) + if (sub.noAliases) EmptySchema else sub + } + + } + case class TupleSchema(m: collection.Map[Int, Schema] /* Zero Indexed */ ) + extends Schema { + def list = m.toList.sortBy(_._1) + def notEmpty = + if (this.m.nonEmpty) Some(this) else None + } + case class CaseClassSchema(m: collection.Map[String, Schema]) extends Schema { + def list = m.toList + def notEmpty = + if (this.m.nonEmpty) Some(this) else None + } + object CaseClassSchema { + def apply(property: String, value: Schema): CaseClassSchema = + CaseClassSchema(collection.Map(property -> value)) + def apply(list: List[(String, Schema)]): CaseClassSchema = + CaseClassSchema(list.toMap) + } + + object TupleSchema { + def fromIndexes(schemas: List[(Int, Schema)]): TupleSchema = + TupleSchema(schemas.toMap) + + def apply(schemas: List[Schema]): TupleSchema = + TupleSchema(schemas.zipWithIndex.map(_.swap).toMap) + + def apply(index: Int, schema: Schema): TupleSchema = + TupleSchema(collection.Map(index -> schema)) + + def empty: TupleSchema = TupleSchema(List.empty) + } + + object HierarchicalAstEntity { + def unapply(ast: Ast): Boolean = + ast match { + case cc: CaseClass => true + case tup: Tuple => true + case _ => false + } + } + + private def applySchema(q: Query): (Query, Schema) = { + q match { + + // Don't understand why this is needed.... + case Map(q: Query, x, p) => + applySchema(q) match { + case (q, subSchema) => + val replace = + trace"Looking for possible replacements for $p inside $x using schema $subSchema:" `andReturn` + replacements(x, subSchema) + val pr = BetaReduction(p, replace*) + traceDifferent(p, pr) + val prr = apply(pr) + traceDifferent(pr, prr) + + val schema = + trace"Protracting Hierarchical Entity $prr into sub-schema: $subSchema" `andReturn` { + protractSchema(prr, x, subSchema) + }.getOrElse(EmptySchema) + + (Map(q, x, prr), schema) + } + + case e: Entity => (e, EntitySchema(e)) + case Filter(q: Query, x, p) => applySchema(q, x, p, Filter.apply) + case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o)) + case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy.apply) + case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _)) + case Take(q: Query, n) => applySchema(q, Take(_, n)) + case Drop(q: Query, n) => applySchema(q, Drop(_, n)) + case Nested(q: Query) => applySchema(q, Nested.apply) + case Distinct(q: Query) => applySchema(q, Distinct.apply) + case DistinctOn(q: Query, iA, on) => applySchema(q, DistinctOn(_, iA, on)) + + case FlatMap(q: Query, x, p) => + applySchema(q, x, p, FlatMap.apply) match { + case (FlatMap(q, x, p: Query), oldSchema) => + val (pr, newSchema) = applySchema(p) + (FlatMap(q, x, pr), newSchema) + case (flatMap, oldSchema) => + (flatMap, TupleSchema.empty) + } + + case ConcatMap(q: Query, x, p) => + applySchema(q, x, p, ConcatMap.apply) match { + case (ConcatMap(q, x, p: Query), oldSchema) => + val (pr, newSchema) = applySchema(p) + (ConcatMap(q, x, pr), newSchema) + case (concatMap, oldSchema) => + (concatMap, TupleSchema.empty) + } + + case Join(typ, a: Query, b: Query, iA, iB, on) => + (applySchema(a), applySchema(b)) match { + case ((a, schemaA), (b, schemaB)) => + val combinedReplacements = + trace"Finding Replacements for $on inside ${(iA, iB)} using schemas ${(schemaA, schemaB)}:" andReturn { + val replaceA = replacements(iA, schemaA) + val replaceB = replacements(iB, schemaB) + replaceA ++ replaceB + } + val onr = BetaReduction(on, combinedReplacements*) + traceDifferent(on, onr) + (Join(typ, a, b, iA, iB, onr), TupleSchema(List(schemaA, schemaB))) + } + + case FlatJoin(typ, a: Query, iA, on) => + applySchema(a) match { + case (a, schemaA) => + val replaceA = + trace"Finding Replacements for $on inside $iA using schema $schemaA:" `andReturn` + replacements(iA, schemaA) + val onr = BetaReduction(on, replaceA*) + traceDifferent(on, onr) + (FlatJoin(typ, a, iA, onr), schemaA) + } + + case Map(q: Operation, x, p) if x == p => + (Map(apply(q), x, p), TupleSchema.empty) + + case Map(Infix(parts, params, pure, paren), x, p) => + val transformed = + params.map { + case q: Query => + val (qr, schema) = applySchema(q) + traceDifferent(q, qr) + (qr, Some(schema)) + case q => + (q, None) + } + + val schema = + transformed.collect { + case (_, Some(schema)) => schema + } match { + case e :: Nil => e + case ls => TupleSchema(ls) + } + val replace = + trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn` + replacements(x, schema) + val pr = BetaReduction(p, replace*) + traceDifferent(p, pr) + val prr = apply(pr) + traceDifferent(pr, prr) + + (Map(Infix(parts, transformed.map(_._1), pure, paren), x, prr), schema) + + case q => + (q, TupleSchema.empty) + } + } + + private def applySchema(ast: Query, f: Ast => Query): (Query, Schema) = + applySchema(ast) match { + case (ast, schema) => + (f(ast), schema) + } + + private def applySchema[T]( + q: Query, + x: Ident, + p: Ast, + f: (Ast, Ident, Ast) => T + ): (T, Schema) = + applySchema(q) match { + case (q, schema) => + val replace = + trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn` + replacements(x, schema) + val pr = BetaReduction(p, replace*) + traceDifferent(p, pr) + val prr = apply(pr) + traceDifferent(pr, prr) + (f(q, x, prr), schema) + } + + private def replacements(base: Ast, schema: Schema): Seq[(Ast, Ast)] = + schema match { + // The entity renameable property should already have been marked as Fixed + case EntitySchema(Entity(entity, properties)) => + // trace"%4 Entity Schema: " andReturn + properties.flatMap { + // A property alias means that there was either a querySchema(tableName, _.propertyName -> PropertyAlias) + // or a schemaMeta (which ultimately gets turned into a querySchema) which is the same thing but implicit. + // In this case, we want to rename the properties based on the property aliases as well as mark + // them Fixed since they should not be renamed based on + // the naming strategy wherever they are tokenized (e.g. in SqlIdiom) + case PropertyAlias(path, alias) => + def apply(base: Ast, path: List[String]): Ast = + path match { + case Nil => base + case head :: tail => apply(Property(base, head), tail) + } + List( + apply(base, path) -> Property.Opinionated( + base, + alias, + Fixed, + Visible + ) // Hidden properties cannot be renamed + ) + } + case tup: TupleSchema => + // trace"%4 Tuple Schema: " andReturn + tup.list.flatMap { + case (idx, value) => + replacements( + // Should not matter whether property is fixed or variable here + // since beta reduction ignores that + Property(base, s"_${idx + 1}"), + value + ) + } + case cc: CaseClassSchema => + // trace"%4 CaseClass Schema: " andReturn + cc.list.flatMap { + case (property, value) => + replacements( + // Should not matter whether property is fixed or variable here + // since beta reduction ignores that + Property(base, property), + value + ) + } + // Do nothing if it is an empty schema + case EmptySchema => List() + } +} diff --git a/src/main/scala/minisql/util/Replacements.scala b/src/main/scala/minisql/norm/Replacements.scala similarity index 98% rename from src/main/scala/minisql/util/Replacements.scala rename to src/main/scala/minisql/norm/Replacements.scala index f0982e2..4b4a955 100644 --- a/src/main/scala/minisql/util/Replacements.scala +++ b/src/main/scala/minisql/norm/Replacements.scala @@ -1,4 +1,4 @@ -package minisql.util +package minisql.norm import minisql.ast.Ast import scala.collection.immutable.Map diff --git a/src/main/scala/minisql/norm/SimplifyNullChecks.scala b/src/main/scala/minisql/norm/SimplifyNullChecks.scala new file mode 100644 index 0000000..a49b949 --- /dev/null +++ b/src/main/scala/minisql/norm/SimplifyNullChecks.scala @@ -0,0 +1,124 @@ +package minisql.norm + +import minisql.ast.* +import minisql.norm.EqualityBehavior.AnsiEquality + +/** + * Due to the introduction of null checks in `map`, `flatMap`, and `exists`, in + * `FlattenOptionOperation` in order to resolve #1053, as well as to support + * non-ansi compliant string concatenation as outlined in #1295, large + * conditional composites became common. For example:
 case class
+ * Holder(value:Option[String])
+ *
+ * // The following statement query[Holder].map(h => h.value.map(_ + "foo")) //
+ * Will yield the following result SELECT CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END FROM Holder h 
Now, let's add a + * getOrElse statement to the clause that requires an additional + * wrapped null check. We cannot rely on there being a map call + * beforehand since we could be reading value as a nullable field + * directly from the database).
 // The following statement
+ * query[Holder].map(h => h.value.map(_ + "foo").getOrElse("bar")) // Yields the
+ * following result: SELECT CASE WHEN CASE WHEN h.value IS NOT NULL THEN h.value
+ * \|| 'foo' ELSE null END IS NOT NULL THEN CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END ELSE 'bar' END FROM Holder h 
+ * This of course is highly redundant and can be reduced to simply:
+ * SELECT CASE WHEN h.value IS NOT NULL AND (h.value || 'foo') IS NOT NULL THEN
+ * h.value || 'foo' ELSE 'bar' END FROM Holder h 
This reduction is + * done by the "Center Rule." There are some other simplification rules as well. + * Note how we are force to null-check both `h.value` as well as `(h.value || + * 'foo')` because a user may use `Option[T].flatMap` and explicitly transform a + * particular value to `null`. + */ +class SimplifyNullChecks(equalityBehavior: EqualityBehavior) + extends StatelessTransformer { + + override def apply(ast: Ast): Ast = { + import minisql.ast.Implicits.* + ast match { + // Center rule + case IfExist( + IfExistElseNull(condA, thenA), + IfExistElseNull(condB, thenB), + otherwise + ) if (condA == condB && thenA == thenB) => + apply( + If(IsNotNullCheck(condA) +&&+ IsNotNullCheck(thenA), thenA, otherwise) + ) + + // Left hand rule + case IfExist(IfExistElseNull(check, affirm), value, otherwise) => + apply( + If( + IsNotNullCheck(check) +&&+ IsNotNullCheck(affirm), + value, + otherwise + ) + ) + + // Right hand rule + case IfExistElseNull(cond, IfExistElseNull(innerCond, innerThen)) => + apply( + If( + IsNotNullCheck(cond) +&&+ IsNotNullCheck(innerCond), + innerThen, + NullValue + ) + ) + + case OptionIsDefined(Optional(a)) +&&+ OptionIsDefined( + Optional(b) + ) +&&+ (exp @ (Optional(a1) `== or !=` Optional(b1))) + if (a == a1 && b == b1 && equalityBehavior == AnsiEquality) => + apply(exp) + + case OptionIsDefined(Optional(a)) +&&+ (exp @ (Optional( + a1 + ) `== or !=` Optional(_))) + if (a == a1 && equalityBehavior == AnsiEquality) => + apply(exp) + case OptionIsDefined(Optional(b)) +&&+ (exp @ (Optional( + _ + ) `== or !=` Optional(b1))) + if (b == b1 && equalityBehavior == AnsiEquality) => + apply(exp) + + case (left +&&+ OptionIsEmpty(Optional(Constant(_)))) +||+ other => + apply(other) + case (OptionIsEmpty(Optional(Constant(_))) +&&+ right) +||+ other => + apply(other) + case other +||+ (left +&&+ OptionIsEmpty(Optional(Constant(_)))) => + apply(other) + case other +||+ (OptionIsEmpty(Optional(Constant(_))) +&&+ right) => + apply(other) + + case (left +&&+ OptionIsDefined(Optional(Constant(_)))) => apply(left) + case (OptionIsDefined(Optional(Constant(_))) +&&+ right) => apply(right) + case (left +||+ OptionIsEmpty(Optional(Constant(_)))) => apply(left) + case (OptionIsEmpty(OptionSome(Optional(_))) +||+ right) => apply(right) + + case other => + super.apply(other) + } + } + + object `== or !=` { + def unapply(ast: Ast): Option[(Ast, Ast)] = ast match { + case a +==+ b => Some((a, b)) + case a +!=+ b => Some((a, b)) + case _ => None + } + } + + /** + * Simple extractor that looks inside of an optional values to see if the + * thing inside can be pulled out. If not, it just returns whatever element it + * can find. + */ + object Optional { + def unapply(a: Ast): Option[Ast] = a match { + case OptionApply(value) => Some(value) + case OptionSome(value) => Some(value) + case value => Some(value) + } + } +} diff --git a/src/main/scala/minisql/norm/SymbolicReduction.scala b/src/main/scala/minisql/norm/SymbolicReduction.scala new file mode 100644 index 0000000..d7e8965 --- /dev/null +++ b/src/main/scala/minisql/norm/SymbolicReduction.scala @@ -0,0 +1,38 @@ +package minisql.norm + +import minisql.ast.Filter +import minisql.ast.FlatMap +import minisql.ast.Query +import minisql.ast.Union +import minisql.ast.UnionAll + +object SymbolicReduction { + + def unapply(q: Query) = + q match { + + // a.filter(b => c).flatMap(d => e.$) => + // a.flatMap(d => e.filter(_ => c[b := d]).$) + case FlatMap(Filter(a, b, c), d, e: Query) => + val cr = BetaReduction(c, b -> d) + val er = AttachToEntity(Filter(_, _, cr))(e) + Some(FlatMap(a, d, er)) + + // a.flatMap(b => c).flatMap(d => e) => + // a.flatMap(b => c.flatMap(d => e)) + case FlatMap(FlatMap(a, b, c), d, e) => + Some(FlatMap(a, b, FlatMap(c, d, e))) + + // a.union(b).flatMap(c => d) + // a.flatMap(c => d).union(b.flatMap(c => d)) + case FlatMap(Union(a, b), c, d) => + Some(Union(FlatMap(a, c, d), FlatMap(b, c, d))) + + // a.unionAll(b).flatMap(c => d) + // a.flatMap(c => d).unionAll(b.flatMap(c => d)) + case FlatMap(UnionAll(a, b), c, d) => + Some(UnionAll(FlatMap(a, c, d), FlatMap(b, c, d))) + + case other => None + } +} diff --git a/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala new file mode 100644 index 0000000..ed45c62 --- /dev/null +++ b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala @@ -0,0 +1,174 @@ +package minisql.norm.capture + +import minisql.ast.{ + Entity, + Filter, + FlatJoin, + FlatMap, + GroupBy, + Ident, + Join, + Map, + Query, + SortBy, + StatefulTransformer, + _ +} +import minisql.norm.{BetaReduction, Normalize} +import scala.collection.immutable.Set + +private[minisql] case class AvoidAliasConflict(state: Set[Ident]) + extends StatefulTransformer[Set[Ident]] { + + object Unaliased { + + private def isUnaliased(q: Ast): Boolean = + q match { + case Nested(q: Query) => isUnaliased(q) + case Take(q: Query, _) => isUnaliased(q) + case Drop(q: Query, _) => isUnaliased(q) + case Aggregation(_, q: Query) => isUnaliased(q) + case Distinct(q: Query) => isUnaliased(q) + case _: Entity | _: Infix => true + case _ => false + } + + def unapply(q: Ast): Option[Ast] = + q match { + case q if (isUnaliased(q)) => Some(q) + case _ => None + } + } + + override def apply(q: Query): (Query, StatefulTransformer[Set[Ident]]) = + q match { + + case FlatMap(Unaliased(q), x, p) => + apply(x, p)(FlatMap(q, _, _)) + + case ConcatMap(Unaliased(q), x, p) => + apply(x, p)(ConcatMap(q, _, _)) + + case Map(Unaliased(q), x, p) => + apply(x, p)(Map(q, _, _)) + + case Filter(Unaliased(q), x, p) => + apply(x, p)(Filter(q, _, _)) + + case SortBy(Unaliased(q), x, p, o) => + apply(x, p)(SortBy(q, _, _, o)) + + case GroupBy(Unaliased(q), x, p) => + apply(x, p)(GroupBy(q, _, _)) + + case DistinctOn(Unaliased(q), x, p) => + apply(x, p)(DistinctOn(q, _, _)) + + case Join(t, a, b, iA, iB, o) => + val (ar, art) = apply(a) + val (br, brt) = art.apply(b) + val freshA = freshIdent(iA, brt.state) + val freshB = freshIdent(iB, brt.state + freshA) + val or = BetaReduction(o, iA -> freshA, iB -> freshB) + val (orr, orrt) = AvoidAliasConflict(brt.state + freshA + freshB)(or) + (Join(t, ar, br, freshA, freshB, orr), orrt) + + case FlatJoin(t, a, iA, o) => + val (ar, art) = apply(a) + val freshA = freshIdent(iA) + val or = BetaReduction(o, iA -> freshA) + val (orr, orrt) = AvoidAliasConflict(art.state + freshA)(or) + (FlatJoin(t, ar, freshA, orr), orrt) + + case _: Entity | _: FlatMap | _: ConcatMap | _: Map | _: Filter | + _: SortBy | _: GroupBy | _: Aggregation | _: Take | _: Drop | + _: Union | _: UnionAll | _: Distinct | _: DistinctOn | _: Nested => + super.apply(q) + } + + private def apply(x: Ident, p: Ast)( + f: (Ident, Ast) => Query + ): (Query, StatefulTransformer[Set[Ident]]) = { + val fresh = freshIdent(x) + val pr = BetaReduction(p, x -> fresh) + val (prr, t) = AvoidAliasConflict(state + fresh)(pr) + (f(fresh, prr), t) + } + + private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = { + def loop(x: Ident, n: Int): Ident = { + val fresh = Ident(s"${x.name}$n") + if (!state.contains(fresh)) + fresh + else + loop(x, n + 1) + } + if (!state.contains(x)) + x + else + loop(x, 1) + } + + /** + * Sometimes we need to change the variables in a function because they will + * might conflict with some variable further up in the macro. Right now, this + * only happens when you do something like this: val q = quote { (v: + * Foo) => query[Foo].insert(v) } run(q(lift(v))) Since 'v' is used by + * actionMeta in order to map keys to values for insertion, using it as a + * function argument messes up the output SQL like so: INSERT INTO + * MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?) + * Therefore, we need to have a method to remove such conflicting variables + * from Function ASTs + */ + private def applyFunction(f: Function): Function = { + val (newBody, _, newParams) = + f.params.foldLeft((f.body, state, List[Ident]())) { + case ((body, state, newParams), param) => { + val fresh = freshIdent(param) + val pr = BetaReduction(body, param -> fresh) + val (prr, t) = AvoidAliasConflict(state + fresh)(pr) + (prr, t.state, newParams :+ fresh) + } + } + Function(newParams, newBody) + } + + private def applyForeach(f: Foreach): Foreach = { + val fresh = freshIdent(f.alias) + val pr = BetaReduction(f.body, f.alias -> fresh) + val (prr, _) = AvoidAliasConflict(state + fresh)(pr) + Foreach(f.query, fresh, prr) + } +} + +private[minisql] object AvoidAliasConflict { + + def apply(q: Query): Query = + AvoidAliasConflict(Set[Ident]())(q) match { + case (q, _) => q + } + + /** + * Make sure query parameters do not collide with paramters of a AST function. + * Do this by walkning through the function's subtree and transforming and + * queries encountered. + */ + def sanitizeVariables( + f: Function, + dangerousVariables: Set[Ident] + ): Function = { + AvoidAliasConflict(dangerousVariables).applyFunction(f) + } + + /** Same is `sanitizeVariables` but for Foreach * */ + def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = { + AvoidAliasConflict(dangerousVariables).applyForeach(f) + } + + def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = { + AvoidAliasConflict(dangerousVariables).apply(q) match { + // Propagate aliasing changes to the rest of the query + case (q, _) => Normalize(q) + } + } +} diff --git a/src/main/scala/minisql/norm/capture/AvoidCapture.scala b/src/main/scala/minisql/norm/capture/AvoidCapture.scala new file mode 100644 index 0000000..788b551 --- /dev/null +++ b/src/main/scala/minisql/norm/capture/AvoidCapture.scala @@ -0,0 +1,9 @@ +package minisql.norm.capture + +import minisql.ast.Query + +object AvoidCapture { + + def apply(q: Query): Query = + Dealias(AvoidAliasConflict(q)) +} diff --git a/src/main/scala/minisql/norm/capture/Dealias.scala b/src/main/scala/minisql/norm/capture/Dealias.scala new file mode 100644 index 0000000..64a56ce --- /dev/null +++ b/src/main/scala/minisql/norm/capture/Dealias.scala @@ -0,0 +1,72 @@ +package minisql.norm.capture + +import minisql.ast._ +import minisql.norm.BetaReduction + +case class Dealias(state: Option[Ident]) + extends StatefulTransformer[Option[Ident]] { + + override def apply(q: Query): (Query, StatefulTransformer[Option[Ident]]) = + q match { + case FlatMap(a, b, c) => + dealias(a, b, c)(FlatMap.apply) match { + case (FlatMap(a, b, c), _) => + val (cn, cnt) = apply(c) + (FlatMap(a, b, cn), cnt) + } + case ConcatMap(a, b, c) => + dealias(a, b, c)(ConcatMap.apply) match { + case (ConcatMap(a, b, c), _) => + val (cn, cnt) = apply(c) + (ConcatMap(a, b, cn), cnt) + } + case Map(a, b, c) => + dealias(a, b, c)(Map.apply) + case Filter(a, b, c) => + dealias(a, b, c)(Filter.apply) + case SortBy(a, b, c, d) => + dealias(a, b, c)(SortBy(_, _, _, d)) + case GroupBy(a, b, c) => + dealias(a, b, c)(GroupBy.apply) + case DistinctOn(a, b, c) => + dealias(a, b, c)(DistinctOn.apply) + case Take(a, b) => + val (an, ant) = apply(a) + (Take(an, b), ant) + case Drop(a, b) => + val (an, ant) = apply(a) + (Drop(an, b), ant) + case Union(a, b) => + val (an, _) = apply(a) + val (bn, _) = apply(b) + (Union(an, bn), Dealias(None)) + case UnionAll(a, b) => + val (an, _) = apply(a) + val (bn, _) = apply(b) + (UnionAll(an, bn), Dealias(None)) + case Join(t, a, b, iA, iB, o) => + val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _)) + val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _)) + (Join(t, an, bn, iAn, iBn, onn), Dealias(None)) + case FlatJoin(t, a, iA, o) => + val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _)) + (FlatJoin(t, an, iAn, on), Dealias(Some(iA))) + case _: Entity | _: Distinct | _: Aggregation | _: Nested => + (q, Dealias(None)) + } + + private def dealias[T](a: Ast, b: Ident, c: Ast)(f: (Ast, Ident, Ast) => T) = + apply(a) match { + case (an, t @ Dealias(Some(alias))) => + (f(an, alias, BetaReduction(c, b -> alias)), t) + case other => + (f(a, b, c), Dealias(Some(b))) + } +} + +object Dealias { + def apply(query: Query) = + new Dealias(None)(query) match { + case (q, _) => q + } +} diff --git a/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala new file mode 100644 index 0000000..7c95f7c --- /dev/null +++ b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala @@ -0,0 +1,98 @@ +package minisql.norm.capture + +import minisql.ast.* + +/** + * Walk through any Queries that a returning clause has and replace Ident of the + * returning variable with ExternalIdent so that in later steps involving filter + * simplification, it will not be mistakenly dealiased with a potential shadow. + * Take this query for instance:
 query[TestEntity]
+ * .insert(lift(TestEntity("s", 0, 1L, None))) .returningGenerated( r =>
+ * (query[Dummy].filter(r => r.i == r.i).filter(d => d.i == r.i).max) ) 
+ * The returning clause has an alias `Ident("r")` as well as the first filter + * clause. These two filters will be combined into one at which point the + * meaning of `r.i` in the 2nd filter will be confused for the first filter's + * alias (i.e. the `r` in `filter(r => ...)`. Therefore, we need to change this + * vunerable `r.i` in the second filter clause to an `ExternalIdent` before any + * of the simplifications are done. + * + * Note that we only want to do this for Queries inside of a `Returning` clause + * body. Other places where this needs to be done (e.g. in a Tuple that + * `Returning` returns) are done in `ExpandReturning`. + */ +private[minisql] case class DemarcateExternalAliases(externalIdent: Ident) + extends StatelessTransformer { + + def applyNonOverride(idents: Ident*)(ast: Ast) = + if (idents.forall(_ != externalIdent)) apply(ast) + else ast + + override def apply(ast: Ast): Ast = ast match { + + case FlatMap(q, i, b) => + FlatMap(apply(q), i, applyNonOverride(i)(b)) + + case ConcatMap(q, i, b) => + ConcatMap(apply(q), i, applyNonOverride(i)(b)) + + case Map(q, i, b) => + Map(apply(q), i, applyNonOverride(i)(b)) + + case Filter(q, i, b) => + Filter(apply(q), i, applyNonOverride(i)(b)) + + case SortBy(q, i, p, o) => + SortBy(apply(q), i, applyNonOverride(i)(p), o) + + case GroupBy(q, i, b) => + GroupBy(apply(q), i, applyNonOverride(i)(b)) + + case DistinctOn(q, i, b) => + DistinctOn(apply(q), i, applyNonOverride(i)(b)) + + case Join(t, a, b, iA, iB, o) => + Join(t, a, b, iA, iB, applyNonOverride(iA, iB)(o)) + + case FlatJoin(t, a, iA, o) => + FlatJoin(t, a, iA, applyNonOverride(iA)(o)) + + case p @ Property.Opinionated( + id @ Ident(_), + value, + renameable, + visibility + ) => + if (id == externalIdent) + Property.Opinionated( + ExternalIdent(externalIdent.name), + value, + renameable, + visibility + ) + else + p + + case other => + super.apply(other) + } +} + +object DemarcateExternalAliases { + + private def demarcateQueriesInBody(id: Ident, body: Ast) = + Transform(body) { + // Apply to the AST defined apply method about, not to the superclass method that takes Query + case q: Query => + new DemarcateExternalAliases(id).apply(q.asInstanceOf[Ast]) + } + + def apply(ast: Ast): Ast = ast match { + case Returning(a, id, body) => + Returning(a, id, demarcateQueriesInBody(id, body)) + case ReturningGenerated(a, id, body) => + val d = demarcateQueriesInBody(id, body) + ReturningGenerated(a, id, demarcateQueriesInBody(id, body)) + case other => + other + } +} diff --git a/src/main/scala/minisql/parsing/BlockParsing.scala b/src/main/scala/minisql/parsing/BlockParsing.scala new file mode 100644 index 0000000..ae8722c --- /dev/null +++ b/src/main/scala/minisql/parsing/BlockParsing.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +type SParser[X] = + (q: Quotes) ?=> PartialFunction[q.reflect.Statement, Expr[X]] + +private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using + Quotes +): SParser[ast.Ast] = { + + import quotes.reflect.* + + @annotation.nowarn + lazy val valDefParser: SParser[ast.Val] = { + case ValDef(n, _, Some(b)) => + val body = astParser(b.asExpr) + '{ ast.Val(ast.Ident(${ Expr(n) }), $body) } + + } + valDefParser +} + +private[parsing] def blockParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + lazy val statementParser = statementParsing(astParser) + + termParser { + case Block(Nil, t) => astParser(t.asExpr) + case b @ Block(st, t) => + val asts = (st :+ t).map { + case e if e.isExpr => astParser(e.asExpr) + case `statementParser`(x) => x + case o => + report.errorAndAbort(s"Cannot parse statement: ${o.show}") + } + if (asts.size > 1) { + '{ ast.Block(${ Expr.ofList(asts) }) } + } else asts(0) + + } +} diff --git a/src/main/scala/minisql/parsing/BoxingParsing.scala b/src/main/scala/minisql/parsing/BoxingParsing.scala new file mode 100644 index 0000000..d7b7f31 --- /dev/null +++ b/src/main/scala/minisql/parsing/BoxingParsing.scala @@ -0,0 +1,31 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def boxingParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + case '{ BigDecimal.int2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.long2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.double2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.javaBigDecimal2bigDecimal($v) } => astParser(v) + case '{ Predef.byte2Byte($v) } => astParser(v) + case '{ Predef.short2Short($v) } => astParser(v) + case '{ Predef.char2Character($v) } => astParser(v) + case '{ Predef.int2Integer($v) } => astParser(v) + case '{ Predef.long2Long($v) } => astParser(v) + case '{ Predef.float2Float($v) } => astParser(v) + case '{ Predef.double2Double($v) } => astParser(v) + case '{ Predef.boolean2Boolean($v) } => astParser(v) + case '{ Predef.augmentString($v) } => astParser(v) + case '{ Predef.Byte2byte($v) } => astParser(v) + case '{ Predef.Short2short($v) } => astParser(v) + case '{ Predef.Character2char($v) } => astParser(v) + case '{ Predef.Integer2int($v) } => astParser(v) + case '{ Predef.Long2long($v) } => astParser(v) + case '{ Predef.Float2float($v) } => astParser(v) + case '{ Predef.Double2double($v) } => astParser(v) + case '{ Predef.Boolean2boolean($v) } => astParser(v) + +} diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala new file mode 100644 index 0000000..7b173db --- /dev/null +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -0,0 +1,13 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted.* + +private[parsing] def infixParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Infix] = { + + import quotes.reflect.* + ??? +} diff --git a/src/main/scala/minisql/parsing/LiftParsing.scala b/src/main/scala/minisql/parsing/LiftParsing.scala new file mode 100644 index 0000000..c01df89 --- /dev/null +++ b/src/main/scala/minisql/parsing/LiftParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import scala.quoted.* +import minisql.ParamEncoder +import minisql.ast +import minisql.dsl.* + +private[parsing] def liftParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Lift] = { + case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } => + import quotes.reflect.* + val name = x.asTerm.symbol.fullName + val liftId = x.asTerm.symbol.owner.fullName + "@" + name + '{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) } +} diff --git a/src/main/scala/minisql/parsing/OperationParsing.scala b/src/main/scala/minisql/parsing/OperationParsing.scala new file mode 100644 index 0000000..df93010 --- /dev/null +++ b/src/main/scala/minisql/parsing/OperationParsing.scala @@ -0,0 +1,113 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.{ + EqualityOperator, + StringOperator, + NumericOperator, + BooleanOperator +} +import minisql.dsl.* +import scala.quoted._ + +private[parsing] def operationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Operation] = { + import quotes.reflect.* + + def isNumeric(t: TypeRepr) = { + t <:< TypeRepr.of[Int] + || t <:< TypeRepr.of[Long] + || t <:< TypeRepr.of[Byte] + || t <:< TypeRepr.of[Float] + || t <:< TypeRepr.of[Double] + || t <:< TypeRepr.of[java.math.BigDecimal] + || t <:< TypeRepr.of[scala.math.BigDecimal] + } + + def parseBinary( + left: Expr[Any], + right: Expr[Any], + op: Expr[ast.BinaryOperator] + ) = { + val leftE = astParser(left) + val rightE = astParser(right) + '{ ast.BinaryOperation(${ leftE }, ${ op }, ${ rightE }) } + } + + def parseUnary(expr: Expr[Any], op: Expr[ast.UnaryOperator]) = { + val base = astParser(expr) + '{ ast.UnaryOperation($op, $base) } + + } + + val universalOpParser: Parser[ast.BinaryOperation] = termParser { + case Apply(Select(leftT, UniversalOp(op)), List(rightT)) => + parseBinary(leftT.asExpr, rightT.asExpr, op) + } + + val stringOpParser: Parser[ast.Operation] = { + case '{ ($x: String) + ($y: String) } => + parseBinary(x, y, '{ StringOperator.concat }) + case '{ ($x: String).startsWith($y) } => + parseBinary(x, y, '{ StringOperator.startsWith }) + case '{ ($x: String).split($y) } => + parseBinary(x, y, '{ StringOperator.split }) + case '{ ($x: String).toUpperCase } => + parseUnary(x, '{ StringOperator.toUpperCase }) + case '{ ($x: String).toLowerCase } => + parseUnary(x, '{ StringOperator.toLowerCase }) + case '{ ($x: String).toLong } => + parseUnary(x, '{ StringOperator.toLong }) + case '{ ($x: String).toInt } => + parseUnary(x, '{ StringOperator.toInt }) + } + + val numericOpParser = termParser { + case (Apply(Select(lt, NumericOp(op)), List(rt))) if isNumeric(lt.tpe) => + parseBinary(lt.asExpr, rt.asExpr, op) + case Select(leftTerm, "unary_-") if isNumeric(leftTerm.tpe) => + val leftExpr = astParser(leftTerm.asExpr) + '{ ast.UnaryOperation(NumericOperator.-, ${ leftExpr }) } + + } + + val booleanOpParser: Parser[ast.Operation] = { + case '{ ($x: Boolean) && $y } => + parseBinary(x, y, '{ BooleanOperator.&& }) + case '{ ($x: Boolean) || $y } => + parseBinary(x, y, '{ BooleanOperator.|| }) + case '{ !($x: Boolean) } => + parseUnary(x, '{ BooleanOperator.! }) + } + + universalOpParser + .orElse(stringOpParser) + .orElse(numericOpParser) + .orElse(booleanOpParser) +} + +private object UniversalOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "==" | "equals" => Some('{ EqualityOperator.== }) + case "!=" => Some('{ EqualityOperator.!= }) + case _ => None + } +} + +private object NumericOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "+" => Some('{ NumericOperator.+ }) + case "-" => Some('{ NumericOperator.- }) + case "*" => Some('{ NumericOperator.* }) + case "/" => Some('{ NumericOperator./ }) + case ">" => Some('{ NumericOperator.> }) + case ">=" => Some('{ NumericOperator.>= }) + case "<" => Some('{ NumericOperator.< }) + case "<=" => Some('{ NumericOperator.<= }) + case "%" => Some('{ NumericOperator.% }) + case _ => None + } +} diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala new file mode 100644 index 0000000..4bc7265 --- /dev/null +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.Ast +import scala.quoted.* + +private[minisql] inline def parseParamAt[F]( + inline f: F, + inline n: Int +): ast.Ident = ${ + parseParamAt('f, 'n) +} + +private[minisql] inline def parseBody[X]( + inline f: X +): ast.Ast = ${ + parseBody('f) +} + +private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using + Quotes +): Expr[ast.Ident] = { + + import quotes.reflect.* + + val pIdx = n.value.getOrElse( + report.errorAndAbort(s"Param index ${n.show} is not know") + ) + extractTerm(f.asTerm) match { + case Lambda(vals, _) => + vals(pIdx) match { + case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) } + } + } +} + +private[minisql] def parseBody[X]( + x: Expr[X] +)(using Quotes): Expr[Ast] = { + import quotes.reflect.* + x.asTerm match { + case Lambda(vals, body) => + Parsing.parseExpr(body.asExpr) + case o => + report.errorAndAbort(s"Can only parse function") + } +} diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala new file mode 100644 index 0000000..463112f --- /dev/null +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -0,0 +1,139 @@ +package minisql.parsing + +import minisql.ast +import minisql.context.{ReturningMultipleFieldSupported, _} +import minisql.norm.BetaReduction +import minisql.norm.capture.AvoidAliasConflict +import minisql.idiom.Idiom +import scala.annotation.tailrec +import minisql.ast.Implicits._ +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.{Hidden, Visible} +import minisql.util.Interleave +import scala.quoted.* + +type Parser[A] = PartialFunction[Expr[Any], Expr[A]] + +private def termParser[A](using q: Quotes)( + pf: PartialFunction[q.reflect.Term, Expr[A]] +): Parser[A] = { + import quotes.reflect._ + { + case e if pf.isDefinedAt(e.asTerm) => pf(e.asTerm) + } +} + +private def parser[A]( + f: PartialFunction[Expr[Any], Expr[A]] +)(using Quotes): Parser[A] = { + case e if f.isDefinedAt(e) => f(e) +} + +private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { + import quotes.reflect.* + def unwrapTerm(t: Term): Term = t match { + case Inlined(_, _, o) => unwrapTerm(o) + case Block(Nil, last) => last + case Typed(t, _) => + unwrapTerm(t) + case Select(t, "$asInstanceOf$") => + unwrapTerm(t) + case TypeApply(t, _) => + unwrapTerm(t) + case o => o + } + unwrapTerm(x) +} + +private[minisql] object Parsing { + + def parseExpr( + expr: Expr[?] + )(using q: Quotes): Expr[ast.Ast] = { + + import q.reflect._ + + def unwrapped( + f: Parser[ast.Ast] + ): Parser[ast.Ast] = { + case expr => + val t = expr.asTerm + f(extractTerm(t).asExpr) + } + + lazy val astParser: Parser[ast.Ast] = + unwrapped { + typedParser + .orElse(propertyParser) + .orElse(liftParser) + .orElse(identParser) + .orElse(valueParser) + .orElse(operationParser) + .orElse(constantParser) + .orElse(blockParser) + .orElse(boxingParser) + .orElse(ifParser) + .orElse(traversableOperationParser) + .orElse(patMatchParser) + .orElse(infixParser) + .orElse { + case o => + val str = scala.util.Try(o.show).getOrElse("") + report.errorAndAbort( + s"cannot parse ${str}", + o.asTerm.pos + ) + } + } + + lazy val typedParser: Parser[ast.Ast] = termParser { + case (Typed(t, _)) => + astParser(t.asExpr) + } + + lazy val blockParser: Parser[ast.Ast] = blockParsing(astParser) + + lazy val valueParser: Parser[ast.Value] = valueParsing(astParser) + + lazy val liftParser: Parser[ast.Lift] = liftParsing(astParser) + + lazy val constantParser: Parser[ast.Constant] = termParser { + case Literal(x) => + '{ ast.Constant(${ Literal(x).asExpr }) } + } + + lazy val identParser: Parser[ast.Ident] = termParser { + case x @ Ident(n) if x.symbol.isValDef => + '{ ast.Ident(${ Expr(n) }) } + } + + lazy val propertyParser: Parser[ast.Property] = propertyParsing(astParser) + + lazy val operationParser: Parser[ast.Operation] = operationParsing( + astParser + ) + + lazy val boxingParser: Parser[ast.Ast] = boxingParsing(astParser) + + lazy val ifParser: Parser[ast.If] = { + case '{ if ($a) $b else $c } => + '{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) } + + } + lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) + + lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser) + + lazy val traversableOperationParser: Parser[ast.IterableOperation] = + traversableOperationParsing(astParser) + + astParser(expr) + } + + private[minisql] inline def parse[A]( + inline a: A + ): ast.Ast = ${ + parseExpr('a) + } + +} diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala new file mode 100644 index 0000000..2db7652 --- /dev/null +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -0,0 +1,49 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def patMatchParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + termParser { + // Val defs that showd pattern variables will cause error + case e @ Match(t, List(CaseDef(IsTupleUnapply(binds), None, body))) => + val bm = binds.zipWithIndex.map { + case (Bind(n, ident), idx) => + n -> Select.unique(t, s"_${idx + 1}") + }.toMap + val tm = new TreeMap { + override def transformTerm(tree: Term)(owner: Symbol): Term = { + tree match { + case Ident(n) => bm(n) + case o => super.transformTerm(o)(owner) + } + } + } + val newBody = tm.transformTree(body)(e.symbol) + astParser(newBody.asExpr) + } + +} + +object IsTupleUnapply { + + def unapply(using + Quotes + )(t: quotes.reflect.Tree): Option[List[quotes.reflect.Tree]] = { + import quotes.reflect.* + def isTupleNUnapply(x: Term) = { + val fn = x.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.unapply") + } + t match { + case Unapply(m, _, binds) if isTupleNUnapply(m) => + Some(binds) + case _ => None + } + } +} diff --git a/src/main/scala/minisql/parsing/PropertyParsing.scala b/src/main/scala/minisql/parsing/PropertyParsing.scala new file mode 100644 index 0000000..292f281 --- /dev/null +++ b/src/main/scala/minisql/parsing/PropertyParsing.scala @@ -0,0 +1,30 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted._ + +private[parsing] def propertyParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Property] = { + import quotes.reflect.* + + def isAccessor(s: Select) = { + s.qualifier.tpe.typeSymbol.caseFields.exists(cf => cf.name == s.name) + } + + val parseApply: Parser[ast.Property] = termParser { + case m @ Select(base, n) if isAccessor(m) => + val obj = astParser(base.asExpr) + '{ ast.Property($obj, ${ Expr(n) }) } + } + + val parseOptionGet: Parser[ast.Property] = { + case '{ ($e: Option[t]).get } => + report.errorAndAbort( + "Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead." + ) + } + parseApply.orElse(parseOptionGet) + +} diff --git a/src/main/scala/minisql/parsing/TraversableOperationParsing.scala b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala new file mode 100644 index 0000000..8f50098 --- /dev/null +++ b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private def traversableOperationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.IterableOperation] = { + case '{ type k; type v; (${ m }: Map[`k`, `v`]).contains($key) } => + '{ ast.MapContains(${ astParser(m) }, ${ astParser(key) }) } + case '{ ($s: Set[e]).contains($i) } => + '{ ast.SetContains(${ astParser(s) }, ${ astParser(i) }) } + case '{ ($s: Seq[e]).contains($i) } => + '{ ast.ListContains(${ astParser(s) }, ${ astParser(i) }) } + +} diff --git a/src/main/scala/minisql/parsing/ValueParsing.scala b/src/main/scala/minisql/parsing/ValueParsing.scala new file mode 100644 index 0000000..6c2fb9e --- /dev/null +++ b/src/main/scala/minisql/parsing/ValueParsing.scala @@ -0,0 +1,71 @@ +package minisql +package parsing + +import scala.quoted._ + +private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using + Quotes +): Parser[ast.Value] = { + + import quotes.reflect.* + + val parseTupleApply: Parser[ast.Tuple] = { + case IsTupleApply(args) => + val t = args.map(astParser) + '{ ast.Tuple(${ Expr.ofList(t) }) } + } + + val parseAssocTuple: Parser[ast.Tuple] = { + case '{ ($x: tx) -> ($y: ty) } => + '{ ast.Tuple(List(${ astParser(x) }, ${ astParser(y) })) } + } + + parseTupleApply.orElse(parseAssocTuple) +} + +private[minisql] object IsTupleApply { + + def unapply(e: Expr[Any])(using Quotes): Option[Seq[Expr[Any]]] = { + + import quotes.reflect.* + + def isTupleNApply(t: Term) = { + val fn = t.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.apply") + } + + def isTupleXXLApply(t: Term) = { + t.symbol.fullName == "scala.runtime.TupleXXL$.apply" + } + + extractTerm(e.asTerm) match { + // TupleN(0-22).apply + case Apply(b, args) if isTupleNApply(b) => + Some(args.map(_.asExpr)) + case TypeApply(Select(t, "$asInstanceOf$"), tt) if isTupleXXLApply(t) => + t.asExpr match { + case '{ scala.runtime.TupleXXL.apply(${ Varargs(args) }*) } => + Some(args) + } + case o => + None + } + } +} + +private[parsing] object IsTuple2 { + + def unapply(using + Quotes + )(t: Expr[Any]): Option[(Expr[Any], Expr[Any])] = + t match { + case '{ scala.Tuple2.apply($x1, $x2) } => Some((x1, x2)) + case '{ ($x1: t1) -> ($x2: t2) } => Some((x1, x2)) + case _ => None + } + + def unapply(using + Quotes + )(t: quotes.reflect.Term): Option[(Expr[Any], Expr[Any])] = + unapply(t.asExpr) +} diff --git a/src/main/scala/minisql/util/Interpolator.scala b/src/main/scala/minisql/util/Interpolator.scala index 55e32cd..c63e984 100644 --- a/src/main/scala/minisql/util/Interpolator.scala +++ b/src/main/scala/minisql/util/Interpolator.scala @@ -11,20 +11,25 @@ import scala.util.matching.Regex class Interpolator( defaultIndent: Int = 0, qprint: AstPrinter = AstPrinter(), - out: PrintStream = System.out, + out: PrintStream = System.out ) { + + extension (sc: StringContext) { + def trace(elements: Any*) = new Traceable(sc, elements) + } + class Traceable(sc: StringContext, elementsSeq: Seq[Any]) { private val elementPrefix = "| " private sealed trait PrintElement private case class Str(str: String, first: Boolean) extends PrintElement - private case class Elem(value: String) extends PrintElement - private case object Separator extends PrintElement + private case class Elem(value: String) extends PrintElement + private case object Separator extends PrintElement private def generateStringForCommand(value: Any, indent: Int) = { val objectString = qprint(value) - val oneLine = objectString.fitsOnOneLine + val oneLine = objectString.fitsOnOneLine oneLine match { case true => s"${indent.prefix}> ${objectString}" case false => @@ -42,7 +47,7 @@ class Interpolator( private def readBuffers() = { def orZero(i: Int): Int = if (i < 0) 0 else i - val parts = sc.parts.iterator.toList + val parts = sc.parts.iterator.toList val elements = elementsSeq.toList.map(qprint(_)) val (firstStr, explicitIndent) = readFirst(parts.head) diff --git a/src/main/scala/minisql/util/Message.scala b/src/main/scala/minisql/util/Message.scala new file mode 100644 index 0000000..f748456 --- /dev/null +++ b/src/main/scala/minisql/util/Message.scala @@ -0,0 +1,75 @@ +package minisql.util + +import minisql.AstPrinter +import minisql.idiom.Idiom +import minisql.util.IndentUtil._ + +object Messages { + + private def variable(propName: String, envName: String, default: String) = + Option(System.getProperty(propName)) + .orElse(sys.env.get(envName)) + .getOrElse(default) + + private[util] val prettyPrint = + variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean + private[util] val debugEnabled = + variable("quill.macro.log", "quill_macro_log", "true").toBoolean + private[util] val traceEnabled = + variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean + private[util] val traceColors = + variable("quill.trace.color", "quill_trace_color,", "false").toBoolean + private[util] val traceOpinions = + variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean + private[util] val traceAstSimple = variable( + "quill.trace.ast.simple", + "quill_trace_ast_simple", + "false" + ).toBoolean + private[minisql] val cacheDynamicQueries = variable( + "quill.query.cacheDaynamic", + "query_query_cacheDaynamic", + "true" + ).toBoolean + private[util] val traces: List[TraceType] = + variable("quill.trace.types", "quill_trace_types", "standard") + .split(",") + .toList + .map(_.trim) + .flatMap(trace => + TraceType.values.filter(traceType => trace == traceType.value) + ) + + def tracesEnabled(tt: TraceType) = + traceEnabled && traces.contains(tt) + + sealed trait TraceType { def value: String } + object TraceType { + case object Normalizations extends TraceType { val value = "norm" } + case object Standard extends TraceType { val value = "standard" } + case object NestedQueryExpansion extends TraceType { val value = "nest" } + + def values: List[TraceType] = + List(Standard, Normalizations, NestedQueryExpansion) + } + + val qprint = AstPrinter() + + def fail(msg: String) = + throw new IllegalStateException(msg) + + def trace[T]( + label: String, + numIndent: Int = 0, + traceType: TraceType = TraceType.Standard + ) = + (v: T) => { + val indent = (0 to numIndent).map(_ => "").mkString(" ") + if (tracesEnabled(traceType)) + println(s"$indent$label\n${{ + if (traceColors) qprint.apply(v) + else qprint.apply(v) + }.split("\n").map(s"$indent " + _).mkString("\n")}") + v + } +}