diff --git a/src/main/scala/minisql/Parser.scala b/src/main/scala/minisql/Parser.scala
deleted file mode 100644
index 3b21d90..0000000
--- a/src/main/scala/minisql/Parser.scala
+++ /dev/null
@@ -1,127 +0,0 @@
-package minisql.parsing
-
-import minisql.ast
-import minisql.ast.Ast
-import scala.quoted.*
-
-type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]]
-
-private[minisql] inline def parseParamAt[F](
- inline f: F,
- inline n: Int
-): ast.Ident = ${
- parseParamAt('f, 'n)
-}
-
-private[minisql] inline def parseBody[X](
- inline f: X
-): ast.Ast = ${
- parseBody('f)
-}
-
-private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
- Quotes
-): Expr[ast.Ident] = {
-
- import quotes.reflect.*
-
- val pIdx = n.value.getOrElse(
- report.errorAndAbort(s"Param index ${n.show} is not know")
- )
- extractTerm(f.asTerm) match {
- case Lambda(vals, _) =>
- vals(pIdx) match {
- case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
- }
- }
-}
-
-private[minisql] def parseBody[X](
- x: Expr[X]
-)(using Quotes): Expr[Ast] = {
- import quotes.reflect.*
- extractTerm(x.asTerm) match {
- case Lambda(vals, body) =>
- astParser(body.asExpr)
- case o =>
- report.errorAndAbort(s"Can only parse function")
- }
-}
-private def isNumeric(x: Expr[?])(using Quotes): Boolean = {
- import quotes.reflect.*
- x.asTerm.tpe match {
- case t if t <:< TypeRepr.of[Int] => true
- case t if t <:< TypeRepr.of[Long] => true
- case t if t <:< TypeRepr.of[Float] => true
- case t if t <:< TypeRepr.of[Double] => true
- case t if t <:< TypeRepr.of[BigDecimal] => true
- case t if t <:< TypeRepr.of[BigInt] => true
- case t if t <:< TypeRepr.of[java.lang.Integer] => true
- case t if t <:< TypeRepr.of[java.lang.Long] => true
- case t if t <:< TypeRepr.of[java.lang.Float] => true
- case t if t <:< TypeRepr.of[java.lang.Double] => true
- case t if t <:< TypeRepr.of[java.math.BigDecimal] => true
- case _ => false
- }
-}
-
-private def identParser(using Quotes): Parser[ast.Ident] = {
- import quotes.reflect.*
- { (x: Expr[?]) =>
- extractTerm(x.asTerm) match {
- case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) })
- case _ => None
- }
- }.unlift
-
-}
-
-private lazy val astParser: Quotes ?=> Parser[Ast] = {
- identParser.orElse(propertyParser(astParser))
-}
-
-private object IsPropertySelect {
- def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = {
- import quotes.reflect.*
- x.asTerm match {
- case Select(x, n) => Some(x.asExpr, n)
- case _ => None
- }
- }
-}
-
-def propertyParser(
- astParser: => Parser[Ast]
-)(using Quotes): Parser[ast.Property] = {
- case IsPropertySelect(expr, n) =>
- '{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) }
-}
-
-def optionOperationParser(
- astParser: => Parser[Ast]
-)(using Quotes): Parser[ast.OptionOperation] = {
- case '{ ($x: Option[t]).isEmpty } =>
- '{ ast.OptionIsEmpty(${ astParser(x) }) }
-}
-
-def binaryOperationParser(
- astParser: => Parser[Ast]
-)(using Quotes): Parser[ast.BinaryOperation] = {
- ???
-}
-
-private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
- import quotes.reflect.*
- def unwrapTerm(t: Term): Term = t match {
- case Inlined(_, _, o) => unwrapTerm(o)
- case Block(Nil, last) => last
- case Typed(t, _) =>
- unwrapTerm(t)
- case Select(t, "$asInstanceOf$") =>
- unwrapTerm(t)
- case TypeApply(t, _) =>
- unwrapTerm(t)
- case o => o
- }
- unwrapTerm(x)
-}
diff --git a/src/main/scala/minisql/ReturnAction.scala b/src/main/scala/minisql/ReturnAction.scala
new file mode 100644
index 0000000..63df7a0
--- /dev/null
+++ b/src/main/scala/minisql/ReturnAction.scala
@@ -0,0 +1,7 @@
+package minisql
+
+enum ReturnAction {
+ case ReturnNothing
+ case ReturnColumns(columns: List[String])
+ case ReturnRecord
+}
diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala
index 35feb70..48e3fae 100644
--- a/src/main/scala/minisql/ast/Ast.scala
+++ b/src/main/scala/minisql/ast/Ast.scala
@@ -1,6 +1,7 @@
package minisql.ast
import minisql.NamingStrategy
+import minisql.ParamEncoder
import scala.quoted.*
@@ -378,14 +379,20 @@ sealed trait ScalarLift extends Lift
case class ScalarValueLift(
name: String,
- liftId: String
+ liftId: String,
+ value: Option[(Any, ParamEncoder[?])]
) extends ScalarLift
+case class ScalarQueryLift(
+ name: String,
+ liftId: String
+) extends ScalarLift {}
+
object ScalarLift {
given ToExpr[ScalarLift] with {
def apply(l: ScalarLift)(using Quotes) = l match {
- case ScalarValueLift(n, id) =>
- '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) }
+ case ScalarValueLift(n, id, v) =>
+ '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, None) }
}
}
}
diff --git a/src/main/scala/minisql/ast/AstOps.scala b/src/main/scala/minisql/ast/AstOps.scala
new file mode 100644
index 0000000..e0069ae
--- /dev/null
+++ b/src/main/scala/minisql/ast/AstOps.scala
@@ -0,0 +1,94 @@
+package minisql.ast
+object Implicits {
+ implicit class AstOps(body: Ast) {
+ private[minisql] def +||+(other: Ast) =
+ BinaryOperation(body, BooleanOperator.`||`, other)
+ private[minisql] def +&&+(other: Ast) =
+ BinaryOperation(body, BooleanOperator.`&&`, other)
+ private[minisql] def +==+(other: Ast) =
+ BinaryOperation(body, EqualityOperator.`==`, other)
+ private[minisql] def +!=+(other: Ast) =
+ BinaryOperation(body, EqualityOperator.`!=`, other)
+ }
+}
+
+object +||+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, BooleanOperator.`||`, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+object +&&+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, BooleanOperator.`&&`, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+val EqOp = EqualityOperator.`==`
+val NeqOp = EqualityOperator.`!=`
+
+object +==+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, EqOp, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+object +!=+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, NeqOp, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+object IsNotNullCheck {
+ def apply(ast: Ast) = BinaryOperation(ast, EqualityOperator.`!=`, NullValue)
+
+ def unapply(ast: Ast): Option[Ast] = {
+ ast match {
+ case BinaryOperation(cond, NeqOp, NullValue) => Some(cond)
+ case _ => None
+ }
+ }
+}
+
+object IsNullCheck {
+ def apply(ast: Ast) = BinaryOperation(ast, EqOp, NullValue)
+
+ def unapply(ast: Ast): Option[Ast] = {
+ ast match {
+ case BinaryOperation(cond, EqOp, NullValue) => Some(cond)
+ case _ => None
+ }
+ }
+}
+
+object IfExistElseNull {
+ def apply(exists: Ast, `then`: Ast) =
+ If(IsNotNullCheck(exists), `then`, NullValue)
+
+ def unapply(ast: Ast) = ast match {
+ case If(IsNotNullCheck(exists), t, NullValue) => Some((exists, t))
+ case _ => None
+ }
+}
+
+object IfExist {
+ def apply(exists: Ast, `then`: Ast, otherwise: Ast) =
+ If(IsNotNullCheck(exists), `then`, otherwise)
+
+ def unapply(ast: Ast) = ast match {
+ case If(IsNotNullCheck(exists), t, e) => Some((exists, t, e))
+ case _ => None
+ }
+}
diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala
index b66e4b9..26694fb 100644
--- a/src/main/scala/minisql/ast/FromExprs.scala
+++ b/src/main/scala/minisql/ast/FromExprs.scala
@@ -45,8 +45,9 @@ private given FromExpr[Infix] with {
private given FromExpr[ScalarValueLift] with {
def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] =
x match {
- case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } =>
- Some(ScalarValueLift(n, id))
+ case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } =>
+ // don't cared about value here, a little tricky
+ Some(ScalarValueLift(n, id, null))
}
}
diff --git a/src/main/scala/minisql/context/ReturnFieldCapability.scala b/src/main/scala/minisql/context/ReturnFieldCapability.scala
new file mode 100644
index 0000000..cadbb78
--- /dev/null
+++ b/src/main/scala/minisql/context/ReturnFieldCapability.scala
@@ -0,0 +1,64 @@
+package minisql.context
+
+sealed trait ReturningCapability
+
+/**
+ * Data cannot be returned Insert/Update/etc... clauses in the target database.
+ */
+sealed trait ReturningNotSupported extends ReturningCapability
+
+/**
+ * Returning a single field from Insert/Update/etc... clauses is supported. This
+ * is the most common databases e.g. MySQL, Sqlite, and H2 (although as of
+ * h2database/h2database#1972 this may change. See #1496 regarding this.
+ * Typically this needs to be setup in the JDBC
+ * `connection.prepareStatement(sql, Array("returnColumn"))`.
+ */
+sealed trait ReturningSingleFieldSupported extends ReturningCapability
+
+/**
+ * Returning multiple columns from Insert/Update/etc... clauses is supported.
+ * This generally means that columns besides auto-incrementing ones can be
+ * returned. This is supported by Oracle. In JDBC, the following is done:
+ * `connection.prepareStatement(sql, Array("column1, column2, ..."))`.
+ */
+sealed trait ReturningMultipleFieldSupported extends ReturningCapability
+
+/**
+ * An actual `RETURNING` clause is supported in the SQL dialect of the specified
+ * database e.g. Postgres. this typically means that columns returned from
+ * Insert/Update/etc... clauses can have other database operations done on them
+ * such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. In
+ * JDBC, the following is done: `connection.prepareStatement(sql,
+ * Statement.RETURN_GENERATED_KEYS))`.
+ */
+sealed trait ReturningClauseSupported extends ReturningCapability
+
+object ReturningNotSupported extends ReturningNotSupported
+object ReturningSingleFieldSupported extends ReturningSingleFieldSupported
+object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported
+object ReturningClauseSupported extends ReturningClauseSupported
+
+trait Capabilities {
+ def idiomReturningCapability: ReturningCapability
+}
+
+trait CanReturnClause extends Capabilities {
+ override def idiomReturningCapability: ReturningClauseSupported =
+ ReturningClauseSupported
+}
+
+trait CanReturnField extends Capabilities {
+ override def idiomReturningCapability: ReturningSingleFieldSupported =
+ ReturningSingleFieldSupported
+}
+
+trait CanReturnMultiField extends Capabilities {
+ override def idiomReturningCapability: ReturningMultipleFieldSupported =
+ ReturningMultipleFieldSupported
+}
+
+trait CannotReturn extends Capabilities {
+ override def idiomReturningCapability: ReturningNotSupported =
+ ReturningNotSupported
+}
diff --git a/src/main/scala/minisql/dsl.scala b/src/main/scala/minisql/dsl.scala
index 43cc9c0..ace3d8f 100644
--- a/src/main/scala/minisql/dsl.scala
+++ b/src/main/scala/minisql/dsl.scala
@@ -29,20 +29,6 @@ private inline def transform[A, B](inline q1: Quoted)(
inline def query[E](inline table: String): EntityQuery[E] =
Entity(table, Nil)
-inline def compile(inline x: Ast): Option[String] = ${
- compileImpl('{ x })
-}
-
-private def compileImpl(
- x: Expr[Ast]
-)(using Quotes): Expr[Option[String]] = {
- import quotes.reflect.*
- x.value match {
- case Some(xv) => '{ Some(${ Expr(xv.toString()) }) }
- case None => '{ None }
- }
-}
-
extension [A, B](inline f1: A => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def body = parsing.parseBody(f1)
@@ -54,6 +40,6 @@ extension [A1, A2, B](inline f1: (A1, A2) => B) {
private inline def body = parsing.parseBody(f1)
}
-case class Foo(id: Int)
+def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException()
-inline def queryFooId = query[Foo]("foo").map(_.id)
+class NonQuotedException extends Exception("Cannot be used at runtime")
diff --git a/src/main/scala/minisql/idiom/Idiom.scala b/src/main/scala/minisql/idiom/Idiom.scala
new file mode 100644
index 0000000..7e0cd01
--- /dev/null
+++ b/src/main/scala/minisql/idiom/Idiom.scala
@@ -0,0 +1,23 @@
+package minisql.idiom
+
+import minisql.NamingStrategy
+import minisql.ast._
+import minisql.context.Capabilities
+
+trait Idiom extends Capabilities {
+
+ def emptySetContainsToken(field: Token): Token = StringToken("FALSE")
+
+ def defaultAutoGeneratedToken(field: Token): Token = StringToken(
+ "DEFAULT VALUES"
+ )
+
+ def liftingPlaceholder(index: Int): String
+
+ def translate(ast: Ast)(implicit naming: NamingStrategy): (Ast, Statement)
+
+ def format(queryString: String): String = queryString
+
+ def prepareForProbing(string: String): String
+
+}
diff --git a/src/main/scala/minisql/idiom/LoadNaming.scala b/src/main/scala/minisql/idiom/LoadNaming.scala
new file mode 100644
index 0000000..2405080
--- /dev/null
+++ b/src/main/scala/minisql/idiom/LoadNaming.scala
@@ -0,0 +1,28 @@
+package minisql.idiom
+
+import scala.util.Try
+import scala.quoted._
+import minisql.NamingStrategy
+import minisql.util.CollectTry
+import minisql.util.LoadObject
+import minisql.CompositeNamingStrategy
+
+object LoadNaming {
+
+ def static[C](using Quotes, Type[C]): Try[NamingStrategy] = CollectTry {
+ strategies[C].map(LoadObject[NamingStrategy](_))
+ }.map(NamingStrategy(_))
+
+ private def strategies[C](using Quotes, Type[C]) = {
+ import quotes.reflect.*
+ val isComposite = TypeRepr.of[C] <:< TypeRepr.of[CompositeNamingStrategy]
+ val ct = TypeRepr.of[C]
+ if (isComposite) {
+ ct.typeArgs.filterNot { t =>
+ t =:= TypeRepr.of[NamingStrategy] && t =:= TypeRepr.of[Nothing]
+ }
+ } else {
+ List(ct)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala
new file mode 100644
index 0000000..a3e8902
--- /dev/null
+++ b/src/main/scala/minisql/idiom/ReifyStatement.scala
@@ -0,0 +1,68 @@
+package minisql.idiom
+
+import minisql.ast._
+import minisql.util.Interleave
+import minisql.idiom.StatementInterpolator._
+import scala.annotation.tailrec
+import scala.collection.immutable.{Map => SMap}
+
+object ReifyStatement {
+
+ def apply(
+ liftingPlaceholder: Int => String,
+ emptySetContainsToken: Token => Token,
+ statement: Statement,
+ liftMap: SMap[String, (Any, Any)]
+ ): (String, List[ScalarValueLift]) = {
+ val expanded = expandLiftings(statement, emptySetContainsToken, liftMap)
+ token2string(expanded, liftingPlaceholder)
+ }
+
+ private def token2string(
+ token: Token,
+ liftingPlaceholder: Int => String
+ ): (String, List[ScalarValueLift]) = {
+
+ val liftBuilder = List.newBuilder[ScalarValueLift]
+ val sqlBuilder = StringBuilder()
+ @tailrec
+ def loop(
+ workList: Seq[Token],
+ liftingSize: Int
+ ): Unit = workList match {
+ case Seq() => ()
+ case head +: tail =>
+ head match {
+ case StringToken(s2) =>
+ sqlBuilder ++= s2
+ loop(tail, liftingSize)
+ case SetContainsToken(a, op, b) =>
+ loop(
+ stmt"$a $op ($b)" +: tail,
+ liftingSize
+ )
+ case ScalarLiftToken(lift: ScalarValueLift) =>
+ sqlBuilder ++= liftingPlaceholder(liftingSize)
+ liftBuilder += lift
+ loop(tail, liftingSize + 1)
+ case ScalarLiftToken(o) =>
+ throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}")
+ case Statement(tokens) =>
+ loop(
+ tokens.foldRight(tail)(_ +: _),
+ liftingSize
+ )
+ }
+ }
+ loop(Vector(token), 0)
+ sqlBuilder.toString() -> liftBuilder.result()
+ }
+
+ private def expandLiftings(
+ statement: Statement,
+ emptySetContainsToken: Token => Token,
+ liftMap: SMap[String, (Any, Any)]
+ ): (Token) = {
+ statement
+ }
+}
diff --git a/src/main/scala/minisql/idiom/Statement.scala b/src/main/scala/minisql/idiom/Statement.scala
new file mode 100644
index 0000000..987ee2f
--- /dev/null
+++ b/src/main/scala/minisql/idiom/Statement.scala
@@ -0,0 +1,47 @@
+package minisql.idiom
+
+import scala.quoted._
+import minisql.ast._
+
+sealed trait Token
+
+case class StringToken(string: String) extends Token {
+ override def toString = string
+}
+
+case class ScalarLiftToken(lift: ScalarLift) extends Token {
+ override def toString = s"lift(${lift.name})"
+}
+
+case class Statement(tokens: List[Token]) extends Token {
+ override def toString = tokens.mkString
+}
+
+case class SetContainsToken(a: Token, op: Token, b: Token) extends Token {
+ override def toString = s"${a.toString} ${op.toString} (${b.toString})"
+}
+
+object Statement {
+ given ToExpr[Statement] with {
+ def apply(t: Statement)(using Quotes): Expr[Statement] = {
+ '{ Statement(${ Expr(t.tokens) }) }
+ }
+ }
+}
+
+object Token {
+
+ given ToExpr[Token] with {
+ def apply(t: Token)(using Quotes): Expr[Token] = {
+ t match {
+ case StringToken(s) =>
+ '{ StringToken(${ Expr(s) }) }
+ case ScalarLiftToken(l) =>
+ '{ ScalarLiftToken(${ Expr(l) }) }
+ case SetContainsToken(a, op, b) =>
+ '{ SetContainsToken(${ Expr(a) }, ${ Expr(op) }, ${ Expr(b) }) }
+ case s: Statement => Expr(s)
+ }
+ }
+ }
+}
diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala
new file mode 100644
index 0000000..b732da1
--- /dev/null
+++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala
@@ -0,0 +1,146 @@
+package minisql.idiom
+
+import minisql.ast._
+import minisql.util.Interleave
+import minisql.util.Messages._
+
+import scala.collection.mutable.ListBuffer
+
+object StatementInterpolator {
+
+ trait Tokenizer[T] {
+ def token(v: T): Token
+ }
+
+ object Tokenizer {
+ def apply[T](f: T => Token) = new Tokenizer[T] {
+ def token(v: T) = f(v)
+ }
+ def withFallback[T](
+ fallback: Tokenizer[T] => Tokenizer[T]
+ )(pf: PartialFunction[T, Token]) =
+ new Tokenizer[T] {
+ private val stable = fallback(this)
+ override def token(v: T) = pf.applyOrElse(v, stable.token)
+ }
+ }
+
+ implicit class TokenImplicit[T](v: T)(implicit tokenizer: Tokenizer[T]) {
+ def token = tokenizer.token(v)
+ }
+
+ implicit def stringTokenizer: Tokenizer[String] =
+ Tokenizer[String] {
+ case string => StringToken(string)
+ }
+
+ implicit def liftTokenizer: Tokenizer[Lift] =
+ Tokenizer[Lift] {
+ case lift: ScalarLift => ScalarLiftToken(lift)
+ case lift =>
+ fail(
+ s"Can't tokenize a non-scalar lifting. ${lift.name}\n" +
+ s"\n" +
+ s"This might happen because:\n" +
+ s"* You are trying to insert or update an `Option[A]` field, but Scala infers the type\n" +
+ s" to `Some[A]` or `None.type`. For example:\n" +
+ s" run(query[Users].update(_.optionalField -> lift(Some(value))))" +
+ s" In that case, make sure the type is `Option`:\n" +
+ s" run(query[Users].update(_.optionalField -> lift(Some(value): Option[Int])))\n" +
+ s" or\n" +
+ s" run(query[Users].update(_.optionalField -> lift(Option(value))))\n" +
+ s"\n" +
+ s"* You are trying to insert or update whole Embedded case class. For example:\n" +
+ s" run(query[Users].update(_.embeddedCaseClass -> lift(someInstance)))\n" +
+ s" In that case, make sure you are updating individual columns, for example:\n" +
+ s" run(query[Users].update(\n" +
+ s" _.embeddedCaseClass.a -> lift(someInstance.a),\n" +
+ s" _.embeddedCaseClass.b -> lift(someInstance.b)\n" +
+ s" ))"
+ )
+ }
+
+ implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity)
+ implicit def statementTokenizer: Tokenizer[Statement] =
+ Tokenizer[Statement](identity)
+ implicit def stringTokenTokenizer: Tokenizer[StringToken] =
+ Tokenizer[StringToken](identity)
+ implicit def liftingTokenTokenizer: Tokenizer[ScalarLiftToken] =
+ Tokenizer[ScalarLiftToken](identity)
+
+ extension [T](list: List[T]) {
+ def mkStmt(sep: String = ", ")(implicit tokenize: Tokenizer[T]) = {
+ val l1 = list.map(_.token)
+ val l2 = List.fill(l1.size - 1)(StringToken(sep))
+ Statement(Interleave(l1, l2))
+ }
+ }
+
+ implicit def listTokenizer[T](implicit
+ tokenize: Tokenizer[T]
+ ): Tokenizer[List[T]] =
+ Tokenizer[List[T]] {
+ case list => list.mkStmt()
+ }
+
+ extension (sc: StringContext) {
+
+ def flatten(tokens: List[Token]): List[Token] = {
+
+ def unestStatements(tokens: List[Token]): List[Token] = {
+ tokens.flatMap {
+ case Statement(innerTokens) => unestStatements(innerTokens)
+ case token => token :: Nil
+ }
+ }
+
+ def mergeStringTokens(tokens: List[Token]): List[Token] = {
+ val (resultBuilder, leftTokens) =
+ tokens.foldLeft((new ListBuffer[Token], new ListBuffer[String])) {
+ case ((builder, acc), stringToken: StringToken) =>
+ val str = stringToken.string
+ if (str.nonEmpty)
+ acc += stringToken.string
+ (builder, acc)
+ case ((builder, prev), b) if prev.isEmpty =>
+ (builder += b.token, prev)
+ case ((builder, prev), b) /* if prev.nonEmpty */ =>
+ builder += StringToken(prev.result().mkString)
+ builder += b.token
+ (builder, new ListBuffer[String])
+ }
+ if (leftTokens.nonEmpty)
+ resultBuilder += StringToken(leftTokens.result().mkString)
+ resultBuilder.result()
+ }
+
+ (unestStatements)
+ .andThen(mergeStringTokens)
+ .apply(tokens)
+ }
+
+ def checkLengths(
+ args: scala.collection.Seq[Any],
+ parts: Seq[String]
+ ): Unit =
+ if (parts.length != args.length + 1)
+ throw new IllegalArgumentException(
+ "wrong number of arguments (" + args.length
+ + ") for interpolated string with " + parts.length + " parts"
+ )
+
+ def stmt(args: Token*): Statement = {
+ checkLengths(args, sc.parts)
+ val partsIterator = sc.parts.iterator
+ val argsIterator = args.iterator
+ val bldr = List.newBuilder[Token]
+ bldr += StringToken(partsIterator.next())
+ while (argsIterator.hasNext) {
+ bldr += argsIterator.next()
+ bldr += StringToken(partsIterator.next())
+ }
+ val tokens = flatten(bldr.result())
+ Statement(tokens)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/norm/AdHocReduction.scala b/src/main/scala/minisql/norm/AdHocReduction.scala
new file mode 100644
index 0000000..10ad9a8
--- /dev/null
+++ b/src/main/scala/minisql/norm/AdHocReduction.scala
@@ -0,0 +1,52 @@
+package minisql.norm
+
+import minisql.ast.BinaryOperation
+import minisql.ast.BooleanOperator
+import minisql.ast.Filter
+import minisql.ast.FlatMap
+import minisql.ast.Map
+import minisql.ast.Query
+import minisql.ast.Union
+import minisql.ast.UnionAll
+
+object AdHocReduction {
+
+ def unapply(q: Query) =
+ q match {
+
+ // ---------------------------
+ // *.filter
+
+ // a.filter(b => c).filter(d => e) =>
+ // a.filter(b => c && e[d := b])
+ case Filter(Filter(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> b)
+ Some(Filter(a, b, BinaryOperation(c, BooleanOperator.`&&`, er)))
+
+ // ---------------------------
+ // flatMap.*
+
+ // a.flatMap(b => c).map(d => e) =>
+ // a.flatMap(b => c.map(d => e))
+ case Map(FlatMap(a, b, c), d, e) =>
+ Some(FlatMap(a, b, Map(c, d, e)))
+
+ // a.flatMap(b => c).filter(d => e) =>
+ // a.flatMap(b => c.filter(d => e))
+ case Filter(FlatMap(a, b, c), d, e) =>
+ Some(FlatMap(a, b, Filter(c, d, e)))
+
+ // a.flatMap(b => c.union(d))
+ // a.flatMap(b => c).union(a.flatMap(b => d))
+ case FlatMap(a, b, Union(c, d)) =>
+ Some(Union(FlatMap(a, b, c), FlatMap(a, b, d)))
+
+ // a.flatMap(b => c.unionAll(d))
+ // a.flatMap(b => c).unionAll(a.flatMap(b => d))
+ case FlatMap(a, b, UnionAll(c, d)) =>
+ Some(UnionAll(FlatMap(a, b, c), FlatMap(a, b, d)))
+
+ case other => None
+ }
+
+}
diff --git a/src/main/scala/minisql/norm/ApplyMap.scala b/src/main/scala/minisql/norm/ApplyMap.scala
new file mode 100644
index 0000000..e5ddb0c
--- /dev/null
+++ b/src/main/scala/minisql/norm/ApplyMap.scala
@@ -0,0 +1,160 @@
+package minisql.norm
+
+import minisql.ast._
+
+object ApplyMap {
+
+ private def isomorphic(e: Ast, c: Ast, alias: Ident) =
+ BetaReduction(e, alias -> c) == c
+
+ object InfixedTailOperation {
+
+ def hasImpureInfix(ast: Ast) =
+ CollectAst(ast) {
+ case i @ Infix(_, _, false, _) => i
+ }.nonEmpty
+
+ def unapply(ast: Ast): Option[Ast] =
+ ast match {
+ case cc: CaseClass if hasImpureInfix(cc) => Some(cc)
+ case tup: Tuple if hasImpureInfix(tup) => Some(tup)
+ case p: Property if hasImpureInfix(p) => Some(p)
+ case b: BinaryOperation if hasImpureInfix(b) => Some(b)
+ case u: UnaryOperation if hasImpureInfix(u) => Some(u)
+ case i @ Infix(_, _, false, _) => Some(i)
+ case _ => None
+ }
+ }
+
+ object MapWithoutInfixes {
+ def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
+ ast match {
+ case Map(a, b, InfixedTailOperation(c)) => None
+ case Map(a, b, c) => Some((a, b, c))
+ case _ => None
+ }
+ }
+
+ object DetachableMap {
+ def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
+ ast match {
+ case Map(a: GroupBy, b, c) => None
+ case Map(a, b, InfixedTailOperation(c)) => None
+ case Map(a, b, c) => Some((a, b, c))
+ case _ => None
+ }
+ }
+
+ def unapply(q: Query): Option[Query] =
+ q match {
+
+ case Map(a: GroupBy, b, c) if (b == c) => None
+ case Map(a: DistinctOn, b, c) => None
+ case Map(a: Nested, b, c) if (b == c) => None
+ case Nested(DetachableMap(a: Join, b, c)) => None
+
+ // map(i => (i.i, i.l)).distinct.map(x => (x._1, x._2)) =>
+ // map(i => (i.i, i.l)).distinct
+ case Map(Distinct(DetachableMap(a, b, c)), d, e) if isomorphic(e, c, d) =>
+ Some(Distinct(Map(a, b, c)))
+
+ // a.map(b => c).map(d => e) =>
+ // a.map(b => e[d := c])
+ case before @ Map(MapWithoutInfixes(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Map(a, b, er))
+
+ // a.map(b => b) =>
+ // a
+ case Map(a: Query, b, c) if (b == c) =>
+ Some(a)
+
+ // a.map(b => c).flatMap(d => e) =>
+ // a.flatMap(b => e[d := c])
+ case FlatMap(DetachableMap(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ Some(FlatMap(a, b, er))
+
+ // a.map(b => c).filter(d => e) =>
+ // a.filter(b => e[d := c]).map(b => c)
+ case Filter(DetachableMap(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Map(Filter(a, b, er), b, c))
+
+ // a.map(b => c).sortBy(d => e) =>
+ // a.sortBy(b => e[d := c]).map(b => c)
+ case SortBy(DetachableMap(a, b, c), d, e, f) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Map(SortBy(a, b, er, f), b, c))
+
+ // a.map(b => c).sortBy(d => e).distinct =>
+ // a.sortBy(b => e[d := c]).map(b => c).distinct
+ case SortBy(Distinct(DetachableMap(a, b, c)), d, e, f) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Distinct(Map(SortBy(a, b, er, f), b, c)))
+
+ // a.map(b => c).groupBy(d => e) =>
+ // a.groupBy(b => e[d := c]).map(x => (x._1, x._2.map(b => c)))
+ case GroupBy(DetachableMap(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ val x = Ident("x")
+ val x1 = Property(
+ Ident("x"),
+ "_1"
+ ) // These introduced property should not be renamed
+ val x2 = Property(Ident("x"), "_2") // due to any naming convention.
+ val body = Tuple(List(x1, Map(x2, b, c)))
+ Some(Map(GroupBy(a, b, er), x, body))
+
+ // a.map(b => c).drop(d) =>
+ // a.drop(d).map(b => c)
+ case Drop(DetachableMap(a, b, c), d) =>
+ Some(Map(Drop(a, d), b, c))
+
+ // a.map(b => c).take(d) =>
+ // a.drop(d).map(b => c)
+ case Take(DetachableMap(a, b, c), d) =>
+ Some(Map(Take(a, d), b, c))
+
+ // a.map(b => c).nested =>
+ // a.nested.map(b => c)
+ case Nested(DetachableMap(a, b, c)) =>
+ Some(Map(Nested(a), b, c))
+
+ // a.map(b => c).*join(d.map(e => f)).on((iA, iB) => on)
+ // a.*join(d).on((b, e) => on[iA := c, iB := f]).map(t => (c[b := t._1], f[e := t._2]))
+ case Join(
+ tpe,
+ DetachableMap(a, b, c),
+ DetachableMap(d, e, f),
+ iA,
+ iB,
+ on
+ ) =>
+ val onr = BetaReduction(on, iA -> c, iB -> f)
+ val t = Ident("t")
+ val t1 = BetaReduction(c, b -> Property(t, "_1"))
+ val t2 = BetaReduction(f, e -> Property(t, "_2"))
+ Some(Map(Join(tpe, a, d, b, e, onr), t, Tuple(List(t1, t2))))
+
+ // a.*join(b.map(c => d)).on((iA, iB) => on)
+ // a.*join(b).on((iA, c) => on[iB := d]).map(t => (t._1, d[c := t._2]))
+ case Join(tpe, a, DetachableMap(b, c, d), iA, iB, on) =>
+ val onr = BetaReduction(on, iB -> d)
+ val t = Ident("t")
+ val t1 = Property(t, "_1")
+ val t2 = BetaReduction(d, c -> Property(t, "_2"))
+ Some(Map(Join(tpe, a, b, iA, c, onr), t, Tuple(List(t1, t2))))
+
+ // a.map(b => c).*join(d).on((iA, iB) => on)
+ // a.*join(d).on((b, iB) => on[iA := c]).map(t => (c[b := t._1], t._2))
+ case Join(tpe, DetachableMap(a, b, c), d, iA, iB, on) =>
+ val onr = BetaReduction(on, iA -> c)
+ val t = Ident("t")
+ val t1 = BetaReduction(c, b -> Property(t, "_1"))
+ val t2 = Property(t, "_2")
+ Some(Map(Join(tpe, a, d, b, iB, onr), t, Tuple(List(t1, t2))))
+
+ case other => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/AttachToEntity.scala b/src/main/scala/minisql/norm/AttachToEntity.scala
new file mode 100644
index 0000000..e5608e9
--- /dev/null
+++ b/src/main/scala/minisql/norm/AttachToEntity.scala
@@ -0,0 +1,48 @@
+package minisql.norm
+
+import minisql.util.Messages.fail
+import minisql.ast._
+
+object AttachToEntity {
+
+ private object IsEntity {
+ def unapply(q: Ast): Option[Ast] =
+ q match {
+ case q: Entity => Some(q)
+ case q: Infix => Some(q)
+ case _ => None
+ }
+ }
+
+ def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)(
+ q: Ast
+ ): Ast =
+ q match {
+
+ case Map(IsEntity(a), b, c) => Map(f(a, b), b, c)
+ case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c)
+ case ConcatMap(IsEntity(a), b, c) => ConcatMap(f(a, b), b, c)
+ case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c)
+ case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d)
+ case DistinctOn(IsEntity(a), b, c) => DistinctOn(f(a, b), b, c)
+
+ case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join |
+ _: FlatJoin =>
+ f(q, alias.getOrElse(Ident("x")))
+
+ case Map(a: Query, b, c) => Map(apply(f, Some(b))(a), b, c)
+ case FlatMap(a: Query, b, c) => FlatMap(apply(f, Some(b))(a), b, c)
+ case ConcatMap(a: Query, b, c) => ConcatMap(apply(f, Some(b))(a), b, c)
+ case Filter(a: Query, b, c) => Filter(apply(f, Some(b))(a), b, c)
+ case SortBy(a: Query, b, c, d) => SortBy(apply(f, Some(b))(a), b, c, d)
+ case Take(a: Query, b) => Take(apply(f, alias)(a), b)
+ case Drop(a: Query, b) => Drop(apply(f, alias)(a), b)
+ case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a))
+ case Distinct(a: Query) => Distinct(apply(f, alias)(a))
+ case DistinctOn(a: Query, b, c) => DistinctOn(apply(f, Some(b))(a), b, c)
+
+ case IsEntity(q) => f(q, alias.getOrElse(Ident("x")))
+
+ case other => fail(s"Can't find an 'Entity' in '$q'")
+ }
+}
diff --git a/src/main/scala/minisql/util/BetaReduction.scala b/src/main/scala/minisql/norm/BetaReduction.scala
similarity index 99%
rename from src/main/scala/minisql/util/BetaReduction.scala
rename to src/main/scala/minisql/norm/BetaReduction.scala
index 0940564..868b021 100644
--- a/src/main/scala/minisql/util/BetaReduction.scala
+++ b/src/main/scala/minisql/norm/BetaReduction.scala
@@ -1,6 +1,6 @@
-package minisql.util
+package minisql.norm
-import minisql.ast.*
+import minisql.ast._
import scala.collection.immutable.{Map => IMap}
case class BetaReduction(replacements: Replacements)
diff --git a/src/main/scala/minisql/norm/ConcatBehavior.scala b/src/main/scala/minisql/norm/ConcatBehavior.scala
new file mode 100644
index 0000000..3547b1d
--- /dev/null
+++ b/src/main/scala/minisql/norm/ConcatBehavior.scala
@@ -0,0 +1,7 @@
+package minisql.norm
+
+trait ConcatBehavior
+object ConcatBehavior {
+ case object AnsiConcat extends ConcatBehavior
+ case object NonAnsiConcat extends ConcatBehavior
+}
diff --git a/src/main/scala/minisql/norm/EqualityBehavior.scala b/src/main/scala/minisql/norm/EqualityBehavior.scala
new file mode 100644
index 0000000..ee1ff38
--- /dev/null
+++ b/src/main/scala/minisql/norm/EqualityBehavior.scala
@@ -0,0 +1,7 @@
+package minisql.norm
+
+trait EqualityBehavior
+object EqualityBehavior {
+ case object AnsiEquality extends EqualityBehavior
+ case object NonAnsiEquality extends EqualityBehavior
+}
diff --git a/src/main/scala/minisql/norm/ExpandReturning.scala b/src/main/scala/minisql/norm/ExpandReturning.scala
new file mode 100644
index 0000000..32a886f
--- /dev/null
+++ b/src/main/scala/minisql/norm/ExpandReturning.scala
@@ -0,0 +1,74 @@
+package minisql.norm
+
+import minisql.ReturnAction.ReturnColumns
+import minisql.{NamingStrategy, ReturnAction}
+import minisql.ast._
+import minisql.context.{
+ ReturningClauseSupported,
+ ReturningMultipleFieldSupported,
+ ReturningNotSupported,
+ ReturningSingleFieldSupported
+}
+import minisql.idiom.{Idiom, Statement}
+
+/**
+ * Take the `.returning` part in a query that contains it and return the array
+ * of columns representing of the returning seccovtion with any other operations
+ * etc... that they might contain.
+ */
+object ExpandReturning {
+
+ def applyMap(
+ returning: ReturningAction
+ )(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = {
+ val initialExpand = ExpandReturning.apply(returning)(idiom, naming)
+
+ idiom.idiomReturningCapability match {
+ case ReturningClauseSupported =>
+ ReturnAction.ReturnRecord
+ case ReturningMultipleFieldSupported =>
+ ReturnColumns(initialExpand.map {
+ case (ast, statement) => f(ast, statement)
+ })
+ case ReturningSingleFieldSupported =>
+ if (initialExpand.length == 1)
+ ReturnColumns(initialExpand.map {
+ case (ast, statement) => f(ast, statement)
+ })
+ else
+ throw new IllegalArgumentException(
+ s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified."
+ )
+ case ReturningNotSupported =>
+ throw new IllegalArgumentException(
+ s"RETURNING columns are not allowed in the ${idiom} dialect."
+ )
+ }
+ }
+
+ def apply(
+ returning: ReturningAction
+ )(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = {
+ val ReturningAction(_, alias, properties) = returning: @unchecked
+
+ // Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1))))
+ // => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
+ val dePropertized =
+ Transform(properties) {
+ case `alias` => ExternalIdent(alias.name)
+ }
+
+ val aliasName = alias.name
+
+ // Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
+ // => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))
+ val deTuplified = dePropertized match {
+ case Tuple(values) => values
+ case CaseClass(values) => values.map(_._2)
+ case other => List(other)
+ }
+
+ implicit val namingStrategy: NamingStrategy = naming
+ deTuplified.map(v => idiom.translate(v))
+ }
+}
diff --git a/src/main/scala/minisql/norm/FlattenOptionOperation.scala b/src/main/scala/minisql/norm/FlattenOptionOperation.scala
new file mode 100644
index 0000000..50ba857
--- /dev/null
+++ b/src/main/scala/minisql/norm/FlattenOptionOperation.scala
@@ -0,0 +1,108 @@
+package minisql.norm
+
+import minisql.ast.*
+import minisql.ast.Implicits.*
+import minisql.norm.ConcatBehavior.NonAnsiConcat
+
+class FlattenOptionOperation(concatBehavior: ConcatBehavior)
+ extends StatelessTransformer {
+
+ private def emptyOrNot(b: Boolean, ast: Ast) =
+ if (b) OptionIsEmpty(ast) else OptionNonEmpty(ast)
+
+ def uncheckedReduction(ast: Ast, alias: Ident, body: Ast) =
+ apply(BetaReduction(body, alias -> ast))
+
+ def uncheckedForall(ast: Ast, alias: Ident, body: Ast) = {
+ val reduced = BetaReduction(body, alias -> ast)
+ apply((IsNullCheck(ast) +||+ reduced): Ast)
+ }
+
+ def containsNonFallthroughElement(ast: Ast) =
+ CollectAst(ast) {
+ case If(_, _, _) => true
+ case Infix(_, _, _, _) => true
+ case BinaryOperation(_, StringOperator.`concat`, _)
+ if (concatBehavior == NonAnsiConcat) =>
+ true
+ }.nonEmpty
+
+ override def apply(ast: Ast): Ast =
+ ast match {
+
+ case OptionTableFlatMap(ast, alias, body) =>
+ uncheckedReduction(ast, alias, body)
+
+ case OptionTableMap(ast, alias, body) =>
+ uncheckedReduction(ast, alias, body)
+
+ case OptionTableExists(ast, alias, body) =>
+ uncheckedReduction(ast, alias, body)
+
+ case OptionTableForall(ast, alias, body) =>
+ uncheckedForall(ast, alias, body)
+
+ case OptionFlatten(ast) =>
+ apply(ast)
+
+ case OptionSome(ast) =>
+ apply(ast)
+
+ case OptionApply(ast) =>
+ apply(ast)
+
+ case OptionOrNull(ast) =>
+ apply(ast)
+
+ case OptionGetOrNull(ast) =>
+ apply(ast)
+
+ case OptionNone => NullValue
+
+ case OptionGetOrElse(OptionMap(ast, alias, body), Constant(b: Boolean)) =>
+ apply((BetaReduction(body, alias -> ast) +||+ emptyOrNot(b, ast)): Ast)
+
+ case OptionGetOrElse(ast, body) =>
+ apply(If(IsNotNullCheck(ast), ast, body))
+
+ case OptionFlatMap(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduced = BetaReduction(body, alias -> ast)
+ apply(IfExistElseNull(ast, reduced))
+ } else {
+ uncheckedReduction(ast, alias, body)
+ }
+
+ case OptionMap(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduced = BetaReduction(body, alias -> ast)
+ apply(IfExistElseNull(ast, reduced))
+ } else {
+ uncheckedReduction(ast, alias, body)
+ }
+
+ case OptionForall(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduction = BetaReduction(body, alias -> ast)
+ apply(
+ (IsNullCheck(ast) +||+ (IsNotNullCheck(ast) +&&+ reduction)): Ast
+ )
+ } else {
+ uncheckedForall(ast, alias, body)
+ }
+
+ case OptionExists(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduction = BetaReduction(body, alias -> ast)
+ apply((IsNotNullCheck(ast) +&&+ reduction): Ast)
+ } else {
+ uncheckedReduction(ast, alias, body)
+ }
+
+ case OptionContains(ast, body) =>
+ apply((ast +==+ body): Ast)
+
+ case other =>
+ super.apply(other)
+ }
+}
diff --git a/src/main/scala/minisql/norm/NestImpureMappedInfix.scala b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala
new file mode 100644
index 0000000..789efba
--- /dev/null
+++ b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala
@@ -0,0 +1,76 @@
+package minisql.norm
+
+import minisql.ast._
+
+/**
+ * A problem occurred in the original way infixes were done in that it was
+ * assumed that infix clauses represented pure functions. While this is true of
+ * many UDFs (e.g. `CONCAT`, `GETDATE`) it is certainly not true of many others
+ * e.g. `RAND()`, and most importantly `RANK()`. For this reason, the operations
+ * that are done in `ApplyMap` on standard AST `Map` clauses cannot be done
+ * therefore additional safety checks were introduced there in order to assure
+ * this does not happen. In addition to this however, it is necessary to add
+ * this normalization step which inserts `Nested` AST elements in every map that
+ * contains impure infix. See more information and examples in #1534.
+ */
+object NestImpureMappedInfix extends StatelessTransformer {
+
+ // Are there any impure infixes that exist inside the specified ASTs
+ def hasInfix(asts: Ast*): Boolean =
+ asts.exists(ast =>
+ CollectAst(ast) {
+ case i @ Infix(_, _, false, _) => i
+ }.nonEmpty
+ )
+
+ // Continue exploring into the Map to see if there are additional impure infix clauses inside.
+ private def applyInside(m: Map) =
+ Map(apply(m.query), m.alias, m.body)
+
+ override def apply(ast: Ast): Ast =
+ ast match {
+ // If there is already a nested clause inside the map, there is no reason to insert another one
+ case Nested(Map(inner, a, b)) =>
+ Nested(Map(apply(inner), a, b))
+
+ case m @ Map(_, x, cc @ CaseClass(values)) if hasInfix(cc) => // Nested(m)
+ Map(
+ Nested(applyInside(m)),
+ x,
+ CaseClass(values.map {
+ case (name, _) =>
+ (
+ name,
+ Property(x, name)
+ ) // mappings of nested-query case class properties should not be renamed
+ })
+ )
+
+ case m @ Map(_, x, tup @ Tuple(values)) if hasInfix(tup) =>
+ Map(
+ Nested(applyInside(m)),
+ x,
+ Tuple(values.zipWithIndex.map {
+ case (_, i) =>
+ Property(
+ x,
+ s"_${i + 1}"
+ ) // mappings of nested-query tuple properties should not be renamed
+ })
+ )
+
+ case m @ Map(_, x, i @ Infix(_, _, false, _)) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case m @ Map(_, x, Property(prop, _)) if hasInfix(prop) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case m @ Map(_, x, BinaryOperation(a, _, b)) if hasInfix(a, b) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case m @ Map(_, x, UnaryOperation(_, a)) if hasInfix(a) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case other => super.apply(other)
+ }
+}
diff --git a/src/main/scala/minisql/norm/Normalize.scala b/src/main/scala/minisql/norm/Normalize.scala
new file mode 100644
index 0000000..d3d8d67
--- /dev/null
+++ b/src/main/scala/minisql/norm/Normalize.scala
@@ -0,0 +1,51 @@
+package minisql.norm
+
+import minisql.ast.Ast
+import minisql.ast.Query
+import minisql.ast.StatelessTransformer
+import minisql.norm.capture.AvoidCapture
+import minisql.ast.Action
+import minisql.util.Messages.trace
+import minisql.util.Messages.TraceType.Normalizations
+
+import scala.annotation.tailrec
+
+object Normalize extends StatelessTransformer {
+
+ override def apply(q: Ast): Ast =
+ super.apply(BetaReduction(q))
+
+ override def apply(q: Action): Action =
+ NormalizeReturning(super.apply(q))
+
+ override def apply(q: Query): Query =
+ norm(AvoidCapture(q))
+
+ private def traceNorm[T](label: String) =
+ trace[T](s"${label} (Normalize)", 1, Normalizations)
+
+ @tailrec
+ private def norm(q: Query): Query =
+ q match {
+ case NormalizeNestedStructures(query) =>
+ traceNorm("NormalizeNestedStructures")(query)
+ norm(query)
+ case ApplyMap(query) =>
+ traceNorm("ApplyMap")(query)
+ norm(query)
+ case SymbolicReduction(query) =>
+ traceNorm("SymbolicReduction")(query)
+ norm(query)
+ case AdHocReduction(query) =>
+ traceNorm("AdHocReduction")(query)
+ norm(query)
+ case OrderTerms(query) =>
+ traceNorm("OrderTerms")(query)
+ norm(query)
+ case NormalizeAggregationIdent(query) =>
+ traceNorm("NormalizeAggregationIdent")(query)
+ norm(query)
+ case other =>
+ other
+ }
+}
diff --git a/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala
new file mode 100644
index 0000000..c451ee4
--- /dev/null
+++ b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala
@@ -0,0 +1,29 @@
+package minisql.norm
+
+import minisql.ast._
+
+object NormalizeAggregationIdent {
+
+ def unapply(q: Query) =
+ q match {
+
+ // a => a.b.map(x => x.c).agg =>
+ // a => a.b.map(a => a.c).agg
+ case Aggregation(
+ op,
+ Map(
+ p @ Property(i: Ident, _),
+ mi,
+ Property.Opinionated(_: Ident, n, renameable, visibility)
+ )
+ ) if i != mi =>
+ Some(
+ Aggregation(
+ op,
+ Map(p, i, Property.Opinionated(i, n, renameable, visibility))
+ )
+ ) // in example aove, if c in x.c is fixed c in a.c should also be
+
+ case _ => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/NormalizeNestedStructures.scala b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala
new file mode 100644
index 0000000..603c411
--- /dev/null
+++ b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala
@@ -0,0 +1,47 @@
+package minisql.norm
+
+import minisql.ast._
+
+object NormalizeNestedStructures {
+
+ def unapply(q: Query): Option[Query] =
+ q match {
+ case e: Entity => None
+ case Map(a, b, c) => apply(a, c)(Map(_, b, _))
+ case FlatMap(a, b, c) => apply(a, c)(FlatMap(_, b, _))
+ case ConcatMap(a, b, c) => apply(a, c)(ConcatMap(_, b, _))
+ case Filter(a, b, c) => apply(a, c)(Filter(_, b, _))
+ case SortBy(a, b, c, d) => apply(a, c)(SortBy(_, b, _, d))
+ case GroupBy(a, b, c) => apply(a, c)(GroupBy(_, b, _))
+ case Aggregation(a, b) => apply(b)(Aggregation(a, _))
+ case Take(a, b) => apply(a, b)(Take.apply)
+ case Drop(a, b) => apply(a, b)(Drop.apply)
+ case Union(a, b) => apply(a, b)(Union.apply)
+ case UnionAll(a, b) => apply(a, b)(UnionAll.apply)
+ case Distinct(a) => apply(a)(Distinct.apply)
+ case DistinctOn(a, b, c) => apply(a, c)(DistinctOn(_, b, _))
+ case Nested(a) => apply(a)(Nested.apply)
+ case FlatJoin(t, a, iA, on) =>
+ (Normalize(a), Normalize(on)) match {
+ case (`a`, `on`) => None
+ case (a, on) => Some(FlatJoin(t, a, iA, on))
+ }
+ case Join(t, a, b, iA, iB, on) =>
+ (Normalize(a), Normalize(b), Normalize(on)) match {
+ case (`a`, `b`, `on`) => None
+ case (a, b, on) => Some(Join(t, a, b, iA, iB, on))
+ }
+ }
+
+ private def apply(a: Ast)(f: Ast => Query) =
+ (Normalize(a)) match {
+ case (`a`) => None
+ case (a) => Some(f(a))
+ }
+
+ private def apply(a: Ast, b: Ast)(f: (Ast, Ast) => Query) =
+ (Normalize(a), Normalize(b)) match {
+ case (`a`, `b`) => None
+ case (a, b) => Some(f(a, b))
+ }
+}
diff --git a/src/main/scala/minisql/norm/NormalizeReturning.scala b/src/main/scala/minisql/norm/NormalizeReturning.scala
new file mode 100644
index 0000000..43fc241
--- /dev/null
+++ b/src/main/scala/minisql/norm/NormalizeReturning.scala
@@ -0,0 +1,154 @@
+package minisql.norm
+
+import minisql.ast._
+import minisql.norm.capture.AvoidAliasConflict
+
+/**
+ * When actions are used with a `.returning` clause, remove the columns used in
+ * the returning clause from the action. E.g. for `insert(Person(id,
+ * name)).returning(_.id)` remove the `id` column from the original insert.
+ */
+object NormalizeReturning {
+
+ def apply(e: Action): Action = {
+ e match {
+ case ReturningGenerated(a: Action, alias, body) =>
+ // De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action.
+ // This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres.
+ // For example:
+ // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
+ // Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally
+ // be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case.
+ // In order to fix this, we need to change `t` into a different alias.
+ val newBody = dealiasBody(body, alias)
+ ReturningGenerated(apply(a, newBody, alias), alias, newBody)
+
+ // For a regular return clause, do not need to exclude assignments from insertion however, we still
+ // need to de-alias the Action body in case conflicts result. For example the following query:
+ // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
+ // would incorrectly be interpreted as:
+ // INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t
+ // whereas it should be:
+ // INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1)
+ case Returning(a: Action, alias, body) =>
+ val newBody = dealiasBody(body, alias)
+ Returning(a, alias, newBody)
+
+ case _ => e
+ }
+ }
+
+ /**
+ * In some situations, a query can exist inside of a `returning` clause. In
+ * this case, we need to rename if the aliases used in that query override the
+ * alias used in the `returning` clause otherwise they will be treated as
+ * returning-clause aliases ExpandReturning (i.e. they will become
+ * ExternalAlias instances) and later be tokenized incorrectly.
+ */
+ private def dealiasBody(body: Ast, alias: Ident): Ast =
+ Transform(body) {
+ case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias))
+ }
+
+ private def apply(e: Action, body: Ast, returningIdent: Ident): Action =
+ e match {
+ case Insert(query, assignments) =>
+ Insert(query, filterReturnedColumn(assignments, body, returningIdent))
+ case Update(query, assignments) =>
+ Update(query, filterReturnedColumn(assignments, body, returningIdent))
+ case OnConflict(a: Action, target, act) =>
+ OnConflict(apply(a, body, returningIdent), target, act)
+ case _ => e
+ }
+
+ private def filterReturnedColumn(
+ assignments: List[Assignment],
+ column: Ast,
+ returningIdent: Ident
+ ): List[Assignment] =
+ assignments.flatMap(filterReturnedColumn(_, column, returningIdent))
+
+ /**
+ * In situations like Property(Property(ident, foo), bar) pull out the
+ * inner-most ident
+ */
+ object NestedProperty {
+ def unapply(ast: Property): Option[Ast] = {
+ ast match {
+ case p @ Property(subAst, _) => Some(innerMost(subAst))
+ }
+ }
+
+ private def innerMost(ast: Ast): Ast = ast match {
+ case Property(inner, _) => innerMost(inner)
+ case other => other
+ }
+ }
+
+ /**
+ * Remove the specified column from the assignment. For example, in a query
+ * like `insert(Person(id, name)).returning(r => r.id)` we need to remove the
+ * `id` column from the insertion. The value of the `column:Ast` in this case
+ * will be `Property(Ident(r), id)` and the values fo the assignment `p1`
+ * property will typically be `v.id` and `v.name` (the `v` variable is a
+ * default used for `insert` queries).
+ */
+ private def filterReturnedColumn(
+ assignment: Assignment,
+ body: Ast,
+ returningIdent: Ident
+ ): Option[Assignment] =
+ assignment match {
+ case Assignment(_, p1: Property, _) => {
+ // Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but
+ // if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2))
+ // or it can even be a query since queries are allowed to be in return sections e.g:
+ // query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max))
+ // In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment
+ // in order to know what to exclude.
+ val matchedProps =
+ CollectAst(body) {
+ // case prop @ NestedProperty(`returningIdent`) => prop
+ case prop @ NestedProperty(Ident(name))
+ if (name == returningIdent.name) =>
+ prop
+ case prop @ NestedProperty(ExternalIdent(name))
+ if (name == returningIdent.name) =>
+ prop
+ }
+
+ if (
+ matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp))
+ )
+ None
+ else
+ Some(assignment)
+ }
+ case assignment => Some(assignment)
+ }
+
+ object SomeIdent {
+ def unapply(ast: Ast): Option[Ast] =
+ ast match {
+ case id: Ident => Some(id)
+ case id: ExternalIdent => Some(id)
+ case _ => None
+ }
+ }
+
+ /**
+ * Is it the same property (but possibly of a different identity). E.g.
+ * `p.foo.bar` and `v.foo.bar`
+ */
+ private def isSameProperties(p1: Property, p2: Property): Boolean =
+ (p1.ast, p2.ast) match {
+ case (SomeIdent(_), SomeIdent(_)) =>
+ p1.name == p2.name
+ // If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the
+ // outer properties are the same before moving on to the inner ones.
+ case (pp1: Property, pp2: Property) if (p1.name == p2.name) =>
+ isSameProperties(pp1, pp2)
+ case _ =>
+ false
+ }
+}
diff --git a/src/main/scala/minisql/norm/OrderTerms.scala b/src/main/scala/minisql/norm/OrderTerms.scala
new file mode 100644
index 0000000..22422fa
--- /dev/null
+++ b/src/main/scala/minisql/norm/OrderTerms.scala
@@ -0,0 +1,29 @@
+package minisql.norm
+
+import minisql.ast._
+
+object OrderTerms {
+
+ def unapply(q: Query) =
+ q match {
+
+ case Take(Map(a: GroupBy, b, c), d) => None
+
+ // a.sortBy(b => c).filter(d => e) =>
+ // a.filter(d => e).sortBy(b => c)
+ case Filter(SortBy(a, b, c, d), e, f) =>
+ Some(SortBy(Filter(a, e, f), b, c, d))
+
+ // a.flatMap(b => c).take(n).map(d => e) =>
+ // a.flatMap(b => c).map(d => e).take(n)
+ case Map(Take(fm: FlatMap, n), ma, mb) =>
+ Some(Take(Map(fm, ma, mb), n))
+
+ // a.flatMap(b => c).drop(n).map(d => e) =>
+ // a.flatMap(b => c).map(d => e).drop(n)
+ case Map(Drop(fm: FlatMap, n), ma, mb) =>
+ Some(Drop(Map(fm, ma, mb), n))
+
+ case other => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/RenameProperties.scala b/src/main/scala/minisql/norm/RenameProperties.scala
new file mode 100644
index 0000000..53d135c
--- /dev/null
+++ b/src/main/scala/minisql/norm/RenameProperties.scala
@@ -0,0 +1,491 @@
+package minisql.norm
+
+import minisql.ast.Renameable.Fixed
+import minisql.ast.Visibility.Visible
+import minisql.ast._
+import minisql.util.Interpolator
+
+object RenameProperties extends StatelessTransformer {
+ val interp = new Interpolator(3)
+ import interp._
+ def traceDifferent(one: Any, two: Any) =
+ if (one != two)
+ trace"Replaced $one with $two".andLog()
+ else
+ trace"Replacements did not match".andLog()
+
+ override def apply(q: Query): Query =
+ applySchemaOnly(q)
+
+ override def apply(q: Action): Action =
+ applySchema(q) match {
+ case (q, schema) => q
+ }
+
+ override def apply(e: Operation): Operation =
+ e match {
+ case UnaryOperation(o, c: Query) =>
+ UnaryOperation(o, applySchemaOnly(apply(c)))
+ case _ => super.apply(e)
+ }
+
+ private def applySchema(q: Action): (Action, Schema) =
+ q match {
+ case Insert(q: Query, assignments) =>
+ applySchema(q, assignments, Insert.apply)
+ case Update(q: Query, assignments) =>
+ applySchema(q, assignments, Update.apply)
+ case Delete(q: Query) =>
+ applySchema(q) match {
+ case (q, schema) => (Delete(q), schema)
+ }
+ case Returning(action: Action, alias, body) =>
+ applySchema(action) match {
+ case (action, schema) =>
+ val replace =
+ trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
+ replacements(alias, schema)
+ val bodyr = BetaReduction(body, replace*)
+ traceDifferent(body, bodyr)
+ (Returning(action, alias, bodyr), schema)
+ }
+ case ReturningGenerated(action: Action, alias, body) =>
+ applySchema(action) match {
+ case (action, schema) =>
+ val replace =
+ trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
+ replacements(alias, schema)
+ val bodyr = BetaReduction(body, replace*)
+ traceDifferent(body, bodyr)
+ (ReturningGenerated(action, alias, bodyr), schema)
+ }
+ case OnConflict(a: Action, target, act) =>
+ applySchema(a) match {
+ case (action, schema) =>
+ val targetR = target match {
+ case OnConflict.Properties(props) =>
+ val propsR = props.map { prop =>
+ val replace =
+ trace"Finding Replacements for $props inside ${prop.ast} using schema $schema:" `andReturn`
+ replacements(
+ prop.ast,
+ schema
+ ) // A BetaReduction on a Property will always give back a Property
+ BetaReduction(prop, replace*).asInstanceOf[Property]
+ }
+ traceDifferent(props, propsR)
+ OnConflict.Properties(propsR)
+ case OnConflict.NoTarget => target
+ }
+ val actR = act match {
+ case OnConflict.Update(assignments) =>
+ OnConflict.Update(replaceAssignments(assignments, schema))
+ case _ => act
+ }
+ (OnConflict(action, targetR, actR), schema)
+ }
+ case q => (q, TupleSchema.empty)
+ }
+
+ private def replaceAssignments(
+ a: List[Assignment],
+ schema: Schema
+ ): List[Assignment] =
+ a.map {
+ case Assignment(alias, prop, value) =>
+ val replace =
+ trace"Finding Replacements for $prop inside $alias using schema $schema:" `andReturn`
+ replacements(alias, schema)
+ val propR = BetaReduction(prop, replace*)
+ traceDifferent(prop, propR)
+ val valueR = BetaReduction(value, replace*)
+ traceDifferent(value, valueR)
+ Assignment(alias, propR, valueR)
+ }
+
+ private def applySchema(
+ q: Query,
+ a: List[Assignment],
+ f: (Query, List[Assignment]) => Action
+ ): (Action, Schema) =
+ applySchema(q) match {
+ case (q, schema) =>
+ (f(q, replaceAssignments(a, schema)), schema)
+ }
+
+ private def applySchemaOnly(q: Query): Query =
+ applySchema(q) match {
+ case (q, _) => q
+ }
+
+ object TupleIndex {
+ def unapply(s: String): Option[Int] =
+ if (s.matches("_[0-9]*"))
+ Some(s.drop(1).toInt - 1)
+ else
+ None
+ }
+
+ sealed trait Schema {
+ def lookup(property: List[String]): Option[Schema] =
+ (property, this) match {
+ case (Nil, schema) =>
+ trace"Nil at $property returning " `andReturn`
+ Some(schema)
+ case (path, e @ EntitySchema(_)) =>
+ trace"Entity at $path returning " `andReturn`
+ Some(e.subSchemaOrEmpty(path))
+ case (head :: tail, CaseClassSchema(props)) if (props.contains(head)) =>
+ trace"Case class at $property returning " `andReturn`
+ props(head).lookup(tail)
+ case (TupleIndex(idx) :: tail, TupleSchema(values))
+ if values.contains(idx) =>
+ trace"Tuple at at $property returning " `andReturn`
+ values(idx).lookup(tail)
+ case _ =>
+ trace"Nothing found at $property returning " `andReturn`
+ None
+ }
+ }
+
+ // Represents a nested property path to an identity i.e. Property(Property(... Ident(), ...))
+ object PropertyMatroshka {
+
+ def traverse(initial: Property): Option[(Ident, List[String])] =
+ initial match {
+ // If it's a nested-property walk inside and append the name to the result (if something is returned)
+ case Property(inner: Property, name) =>
+ traverse(inner).map { case (id, list) => (id, list :+ name) }
+ // If it's a property with ident in the core, return that
+ case Property(id: Ident, name) =>
+ Some((id, List(name)))
+ // Otherwise an ident property is not inside so don't return anything
+ case _ =>
+ None
+ }
+
+ def unapply(ast: Ast): Option[(Ident, List[String])] =
+ ast match {
+ case p: Property => traverse(p)
+ case _ => None
+ }
+
+ }
+
+ def protractSchema(
+ body: Ast,
+ ident: Ident,
+ schema: Schema
+ ): Option[Schema] = {
+
+ def protractSchemaRecurse(body: Ast, schema: Schema): Option[Schema] =
+ body match {
+ // if any values yield a sub-schema which is not an entity, recurse into that
+ case cc @ CaseClass(values) =>
+ trace"Protracting CaseClass $cc into new schema:" `andReturn`
+ CaseClassSchema(
+ values.collect {
+ case (name, innerBody @ HierarchicalAstEntity()) =>
+ (name, protractSchemaRecurse(innerBody, schema))
+ // pass the schema into a recursive call an extract from it when we non tuple/caseclass element
+ case (name, innerBody @ PropertyMatroshka(`ident`, path)) =>
+ (name, protractSchemaRecurse(innerBody, schema))
+ // we have reached an ident i.e. recurse to pass the current schema into the case class
+ case (name, `ident`) =>
+ (name, protractSchemaRecurse(ident, schema))
+ }.collect {
+ case (name, Some(subSchema)) => (name, subSchema)
+ }
+ ).notEmpty
+ case tup @ Tuple(values) =>
+ trace"Protracting Tuple $tup into new schema:" `andReturn`
+ TupleSchema
+ .fromIndexes(
+ values.zipWithIndex.collect {
+ case (innerBody @ HierarchicalAstEntity(), index) =>
+ (index, protractSchemaRecurse(innerBody, schema))
+ // pass the schema into a recursive call an extract from it when we non tuple/caseclass element
+ case (innerBody @ PropertyMatroshka(`ident`, path), index) =>
+ (index, protractSchemaRecurse(innerBody, schema))
+ // we have reached an ident i.e. recurse to pass the current schema into the tuple
+ case (`ident`, index) =>
+ (index, protractSchemaRecurse(ident, schema))
+ }.collect {
+ case (index, Some(subSchema)) => (index, subSchema)
+ }
+ )
+ .notEmpty
+
+ case prop @ PropertyMatroshka(`ident`, path) =>
+ trace"Protraction completed schema path $prop at the schema $schema pointing to:" `andReturn`
+ schema match {
+ // case e: EntitySchema => Some(e)
+ case _ => schema.lookup(path)
+ }
+ case `ident` =>
+ trace"Protraction completed with the mapping identity $ident at the schema:" `andReturn`
+ Some(schema)
+ case other =>
+ trace"Protraction DID NOT find a sub schema, it completed with $other at the schema:" `andReturn`
+ Some(schema)
+ }
+
+ protractSchemaRecurse(body, schema)
+ }
+
+ case object EmptySchema extends Schema
+ case class EntitySchema(e: Entity) extends Schema {
+ def noAliases = e.properties.isEmpty
+
+ private def subSchema(path: List[String]) =
+ EntitySchema(
+ Entity(
+ s"sub-${e.name}",
+ e.properties.flatMap {
+ case PropertyAlias(aliasPath, alias) =>
+ if (aliasPath == path)
+ List(PropertyAlias(aliasPath, alias))
+ else if (aliasPath.startsWith(path))
+ List(PropertyAlias(aliasPath.diff(path), alias))
+ else
+ List()
+ }
+ )
+ )
+
+ def subSchemaOrEmpty(path: List[String]): Schema =
+ trace"Creating sub-schema for entity $e at path $path will be" andReturn {
+ val sub = subSchema(path)
+ if (sub.noAliases) EmptySchema else sub
+ }
+
+ }
+ case class TupleSchema(m: collection.Map[Int, Schema] /* Zero Indexed */ )
+ extends Schema {
+ def list = m.toList.sortBy(_._1)
+ def notEmpty =
+ if (this.m.nonEmpty) Some(this) else None
+ }
+ case class CaseClassSchema(m: collection.Map[String, Schema]) extends Schema {
+ def list = m.toList
+ def notEmpty =
+ if (this.m.nonEmpty) Some(this) else None
+ }
+ object CaseClassSchema {
+ def apply(property: String, value: Schema): CaseClassSchema =
+ CaseClassSchema(collection.Map(property -> value))
+ def apply(list: List[(String, Schema)]): CaseClassSchema =
+ CaseClassSchema(list.toMap)
+ }
+
+ object TupleSchema {
+ def fromIndexes(schemas: List[(Int, Schema)]): TupleSchema =
+ TupleSchema(schemas.toMap)
+
+ def apply(schemas: List[Schema]): TupleSchema =
+ TupleSchema(schemas.zipWithIndex.map(_.swap).toMap)
+
+ def apply(index: Int, schema: Schema): TupleSchema =
+ TupleSchema(collection.Map(index -> schema))
+
+ def empty: TupleSchema = TupleSchema(List.empty)
+ }
+
+ object HierarchicalAstEntity {
+ def unapply(ast: Ast): Boolean =
+ ast match {
+ case cc: CaseClass => true
+ case tup: Tuple => true
+ case _ => false
+ }
+ }
+
+ private def applySchema(q: Query): (Query, Schema) = {
+ q match {
+
+ // Don't understand why this is needed....
+ case Map(q: Query, x, p) =>
+ applySchema(q) match {
+ case (q, subSchema) =>
+ val replace =
+ trace"Looking for possible replacements for $p inside $x using schema $subSchema:" `andReturn`
+ replacements(x, subSchema)
+ val pr = BetaReduction(p, replace*)
+ traceDifferent(p, pr)
+ val prr = apply(pr)
+ traceDifferent(pr, prr)
+
+ val schema =
+ trace"Protracting Hierarchical Entity $prr into sub-schema: $subSchema" `andReturn` {
+ protractSchema(prr, x, subSchema)
+ }.getOrElse(EmptySchema)
+
+ (Map(q, x, prr), schema)
+ }
+
+ case e: Entity => (e, EntitySchema(e))
+ case Filter(q: Query, x, p) => applySchema(q, x, p, Filter.apply)
+ case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o))
+ case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy.apply)
+ case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _))
+ case Take(q: Query, n) => applySchema(q, Take(_, n))
+ case Drop(q: Query, n) => applySchema(q, Drop(_, n))
+ case Nested(q: Query) => applySchema(q, Nested.apply)
+ case Distinct(q: Query) => applySchema(q, Distinct.apply)
+ case DistinctOn(q: Query, iA, on) => applySchema(q, DistinctOn(_, iA, on))
+
+ case FlatMap(q: Query, x, p) =>
+ applySchema(q, x, p, FlatMap.apply) match {
+ case (FlatMap(q, x, p: Query), oldSchema) =>
+ val (pr, newSchema) = applySchema(p)
+ (FlatMap(q, x, pr), newSchema)
+ case (flatMap, oldSchema) =>
+ (flatMap, TupleSchema.empty)
+ }
+
+ case ConcatMap(q: Query, x, p) =>
+ applySchema(q, x, p, ConcatMap.apply) match {
+ case (ConcatMap(q, x, p: Query), oldSchema) =>
+ val (pr, newSchema) = applySchema(p)
+ (ConcatMap(q, x, pr), newSchema)
+ case (concatMap, oldSchema) =>
+ (concatMap, TupleSchema.empty)
+ }
+
+ case Join(typ, a: Query, b: Query, iA, iB, on) =>
+ (applySchema(a), applySchema(b)) match {
+ case ((a, schemaA), (b, schemaB)) =>
+ val combinedReplacements =
+ trace"Finding Replacements for $on inside ${(iA, iB)} using schemas ${(schemaA, schemaB)}:" andReturn {
+ val replaceA = replacements(iA, schemaA)
+ val replaceB = replacements(iB, schemaB)
+ replaceA ++ replaceB
+ }
+ val onr = BetaReduction(on, combinedReplacements*)
+ traceDifferent(on, onr)
+ (Join(typ, a, b, iA, iB, onr), TupleSchema(List(schemaA, schemaB)))
+ }
+
+ case FlatJoin(typ, a: Query, iA, on) =>
+ applySchema(a) match {
+ case (a, schemaA) =>
+ val replaceA =
+ trace"Finding Replacements for $on inside $iA using schema $schemaA:" `andReturn`
+ replacements(iA, schemaA)
+ val onr = BetaReduction(on, replaceA*)
+ traceDifferent(on, onr)
+ (FlatJoin(typ, a, iA, onr), schemaA)
+ }
+
+ case Map(q: Operation, x, p) if x == p =>
+ (Map(apply(q), x, p), TupleSchema.empty)
+
+ case Map(Infix(parts, params, pure, paren), x, p) =>
+ val transformed =
+ params.map {
+ case q: Query =>
+ val (qr, schema) = applySchema(q)
+ traceDifferent(q, qr)
+ (qr, Some(schema))
+ case q =>
+ (q, None)
+ }
+
+ val schema =
+ transformed.collect {
+ case (_, Some(schema)) => schema
+ } match {
+ case e :: Nil => e
+ case ls => TupleSchema(ls)
+ }
+ val replace =
+ trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
+ replacements(x, schema)
+ val pr = BetaReduction(p, replace*)
+ traceDifferent(p, pr)
+ val prr = apply(pr)
+ traceDifferent(pr, prr)
+
+ (Map(Infix(parts, transformed.map(_._1), pure, paren), x, prr), schema)
+
+ case q =>
+ (q, TupleSchema.empty)
+ }
+ }
+
+ private def applySchema(ast: Query, f: Ast => Query): (Query, Schema) =
+ applySchema(ast) match {
+ case (ast, schema) =>
+ (f(ast), schema)
+ }
+
+ private def applySchema[T](
+ q: Query,
+ x: Ident,
+ p: Ast,
+ f: (Ast, Ident, Ast) => T
+ ): (T, Schema) =
+ applySchema(q) match {
+ case (q, schema) =>
+ val replace =
+ trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
+ replacements(x, schema)
+ val pr = BetaReduction(p, replace*)
+ traceDifferent(p, pr)
+ val prr = apply(pr)
+ traceDifferent(pr, prr)
+ (f(q, x, prr), schema)
+ }
+
+ private def replacements(base: Ast, schema: Schema): Seq[(Ast, Ast)] =
+ schema match {
+ // The entity renameable property should already have been marked as Fixed
+ case EntitySchema(Entity(entity, properties)) =>
+ // trace"%4 Entity Schema: " andReturn
+ properties.flatMap {
+ // A property alias means that there was either a querySchema(tableName, _.propertyName -> PropertyAlias)
+ // or a schemaMeta (which ultimately gets turned into a querySchema) which is the same thing but implicit.
+ // In this case, we want to rename the properties based on the property aliases as well as mark
+ // them Fixed since they should not be renamed based on
+ // the naming strategy wherever they are tokenized (e.g. in SqlIdiom)
+ case PropertyAlias(path, alias) =>
+ def apply(base: Ast, path: List[String]): Ast =
+ path match {
+ case Nil => base
+ case head :: tail => apply(Property(base, head), tail)
+ }
+ List(
+ apply(base, path) -> Property.Opinionated(
+ base,
+ alias,
+ Fixed,
+ Visible
+ ) // Hidden properties cannot be renamed
+ )
+ }
+ case tup: TupleSchema =>
+ // trace"%4 Tuple Schema: " andReturn
+ tup.list.flatMap {
+ case (idx, value) =>
+ replacements(
+ // Should not matter whether property is fixed or variable here
+ // since beta reduction ignores that
+ Property(base, s"_${idx + 1}"),
+ value
+ )
+ }
+ case cc: CaseClassSchema =>
+ // trace"%4 CaseClass Schema: " andReturn
+ cc.list.flatMap {
+ case (property, value) =>
+ replacements(
+ // Should not matter whether property is fixed or variable here
+ // since beta reduction ignores that
+ Property(base, property),
+ value
+ )
+ }
+ // Do nothing if it is an empty schema
+ case EmptySchema => List()
+ }
+}
diff --git a/src/main/scala/minisql/util/Replacements.scala b/src/main/scala/minisql/norm/Replacements.scala
similarity index 98%
rename from src/main/scala/minisql/util/Replacements.scala
rename to src/main/scala/minisql/norm/Replacements.scala
index f0982e2..4b4a955 100644
--- a/src/main/scala/minisql/util/Replacements.scala
+++ b/src/main/scala/minisql/norm/Replacements.scala
@@ -1,4 +1,4 @@
-package minisql.util
+package minisql.norm
import minisql.ast.Ast
import scala.collection.immutable.Map
diff --git a/src/main/scala/minisql/norm/SimplifyNullChecks.scala b/src/main/scala/minisql/norm/SimplifyNullChecks.scala
new file mode 100644
index 0000000..a49b949
--- /dev/null
+++ b/src/main/scala/minisql/norm/SimplifyNullChecks.scala
@@ -0,0 +1,124 @@
+package minisql.norm
+
+import minisql.ast.*
+import minisql.norm.EqualityBehavior.AnsiEquality
+
+/**
+ * Due to the introduction of null checks in `map`, `flatMap`, and `exists`, in
+ * `FlattenOptionOperation` in order to resolve #1053, as well as to support
+ * non-ansi compliant string concatenation as outlined in #1295, large
+ * conditional composites became common. For example:
Now, let's add a
+ * case class
+ * Holder(value:Option[String])
+ *
+ * // The following statement query[Holder].map(h => h.value.map(_ + "foo")) //
+ * Will yield the following result SELECT CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END FROM Holder h
getOrElse
statement to the clause that requires an additional
+ * wrapped null check. We cannot rely on there being a map
call
+ * beforehand since we could be reading value
as a nullable field
+ * directly from the database).
+ * This of course is highly redundant and can be reduced to simply: // The following statement
+ * query[Holder].map(h => h.value.map(_ + "foo").getOrElse("bar")) // Yields the
+ * following result: SELECT CASE WHEN CASE WHEN h.value IS NOT NULL THEN h.value
+ * \|| 'foo' ELSE null END IS NOT NULL THEN CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END ELSE 'bar' END FROM Holder h
This reduction is
+ * done by the "Center Rule." There are some other simplification rules as well.
+ * Note how we are force to null-check both `h.value` as well as `(h.value ||
+ * 'foo')` because a user may use `Option[T].flatMap` and explicitly transform a
+ * particular value to `null`.
+ */
+class SimplifyNullChecks(equalityBehavior: EqualityBehavior)
+ extends StatelessTransformer {
+
+ override def apply(ast: Ast): Ast = {
+ import minisql.ast.Implicits.*
+ ast match {
+ // Center rule
+ case IfExist(
+ IfExistElseNull(condA, thenA),
+ IfExistElseNull(condB, thenB),
+ otherwise
+ ) if (condA == condB && thenA == thenB) =>
+ apply(
+ If(IsNotNullCheck(condA) +&&+ IsNotNullCheck(thenA), thenA, otherwise)
+ )
+
+ // Left hand rule
+ case IfExist(IfExistElseNull(check, affirm), value, otherwise) =>
+ apply(
+ If(
+ IsNotNullCheck(check) +&&+ IsNotNullCheck(affirm),
+ value,
+ otherwise
+ )
+ )
+
+ // Right hand rule
+ case IfExistElseNull(cond, IfExistElseNull(innerCond, innerThen)) =>
+ apply(
+ If(
+ IsNotNullCheck(cond) +&&+ IsNotNullCheck(innerCond),
+ innerThen,
+ NullValue
+ )
+ )
+
+ case OptionIsDefined(Optional(a)) +&&+ OptionIsDefined(
+ Optional(b)
+ ) +&&+ (exp @ (Optional(a1) `== or !=` Optional(b1)))
+ if (a == a1 && b == b1 && equalityBehavior == AnsiEquality) =>
+ apply(exp)
+
+ case OptionIsDefined(Optional(a)) +&&+ (exp @ (Optional(
+ a1
+ ) `== or !=` Optional(_)))
+ if (a == a1 && equalityBehavior == AnsiEquality) =>
+ apply(exp)
+ case OptionIsDefined(Optional(b)) +&&+ (exp @ (Optional(
+ _
+ ) `== or !=` Optional(b1)))
+ if (b == b1 && equalityBehavior == AnsiEquality) =>
+ apply(exp)
+
+ case (left +&&+ OptionIsEmpty(Optional(Constant(_)))) +||+ other =>
+ apply(other)
+ case (OptionIsEmpty(Optional(Constant(_))) +&&+ right) +||+ other =>
+ apply(other)
+ case other +||+ (left +&&+ OptionIsEmpty(Optional(Constant(_)))) =>
+ apply(other)
+ case other +||+ (OptionIsEmpty(Optional(Constant(_))) +&&+ right) =>
+ apply(other)
+
+ case (left +&&+ OptionIsDefined(Optional(Constant(_)))) => apply(left)
+ case (OptionIsDefined(Optional(Constant(_))) +&&+ right) => apply(right)
+ case (left +||+ OptionIsEmpty(Optional(Constant(_)))) => apply(left)
+ case (OptionIsEmpty(OptionSome(Optional(_))) +||+ right) => apply(right)
+
+ case other =>
+ super.apply(other)
+ }
+ }
+
+ object `== or !=` {
+ def unapply(ast: Ast): Option[(Ast, Ast)] = ast match {
+ case a +==+ b => Some((a, b))
+ case a +!=+ b => Some((a, b))
+ case _ => None
+ }
+ }
+
+ /**
+ * Simple extractor that looks inside of an optional values to see if the
+ * thing inside can be pulled out. If not, it just returns whatever element it
+ * can find.
+ */
+ object Optional {
+ def unapply(a: Ast): Option[Ast] = a match {
+ case OptionApply(value) => Some(value)
+ case OptionSome(value) => Some(value)
+ case value => Some(value)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/norm/SymbolicReduction.scala b/src/main/scala/minisql/norm/SymbolicReduction.scala
new file mode 100644
index 0000000..d7e8965
--- /dev/null
+++ b/src/main/scala/minisql/norm/SymbolicReduction.scala
@@ -0,0 +1,38 @@
+package minisql.norm
+
+import minisql.ast.Filter
+import minisql.ast.FlatMap
+import minisql.ast.Query
+import minisql.ast.Union
+import minisql.ast.UnionAll
+
+object SymbolicReduction {
+
+ def unapply(q: Query) =
+ q match {
+
+ // a.filter(b => c).flatMap(d => e.$) =>
+ // a.flatMap(d => e.filter(_ => c[b := d]).$)
+ case FlatMap(Filter(a, b, c), d, e: Query) =>
+ val cr = BetaReduction(c, b -> d)
+ val er = AttachToEntity(Filter(_, _, cr))(e)
+ Some(FlatMap(a, d, er))
+
+ // a.flatMap(b => c).flatMap(d => e) =>
+ // a.flatMap(b => c.flatMap(d => e))
+ case FlatMap(FlatMap(a, b, c), d, e) =>
+ Some(FlatMap(a, b, FlatMap(c, d, e)))
+
+ // a.union(b).flatMap(c => d)
+ // a.flatMap(c => d).union(b.flatMap(c => d))
+ case FlatMap(Union(a, b), c, d) =>
+ Some(Union(FlatMap(a, c, d), FlatMap(b, c, d)))
+
+ // a.unionAll(b).flatMap(c => d)
+ // a.flatMap(c => d).unionAll(b.flatMap(c => d))
+ case FlatMap(UnionAll(a, b), c, d) =>
+ Some(UnionAll(FlatMap(a, c, d), FlatMap(b, c, d)))
+
+ case other => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala
new file mode 100644
index 0000000..ed45c62
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala
@@ -0,0 +1,174 @@
+package minisql.norm.capture
+
+import minisql.ast.{
+ Entity,
+ Filter,
+ FlatJoin,
+ FlatMap,
+ GroupBy,
+ Ident,
+ Join,
+ Map,
+ Query,
+ SortBy,
+ StatefulTransformer,
+ _
+}
+import minisql.norm.{BetaReduction, Normalize}
+import scala.collection.immutable.Set
+
+private[minisql] case class AvoidAliasConflict(state: Set[Ident])
+ extends StatefulTransformer[Set[Ident]] {
+
+ object Unaliased {
+
+ private def isUnaliased(q: Ast): Boolean =
+ q match {
+ case Nested(q: Query) => isUnaliased(q)
+ case Take(q: Query, _) => isUnaliased(q)
+ case Drop(q: Query, _) => isUnaliased(q)
+ case Aggregation(_, q: Query) => isUnaliased(q)
+ case Distinct(q: Query) => isUnaliased(q)
+ case _: Entity | _: Infix => true
+ case _ => false
+ }
+
+ def unapply(q: Ast): Option[Ast] =
+ q match {
+ case q if (isUnaliased(q)) => Some(q)
+ case _ => None
+ }
+ }
+
+ override def apply(q: Query): (Query, StatefulTransformer[Set[Ident]]) =
+ q match {
+
+ case FlatMap(Unaliased(q), x, p) =>
+ apply(x, p)(FlatMap(q, _, _))
+
+ case ConcatMap(Unaliased(q), x, p) =>
+ apply(x, p)(ConcatMap(q, _, _))
+
+ case Map(Unaliased(q), x, p) =>
+ apply(x, p)(Map(q, _, _))
+
+ case Filter(Unaliased(q), x, p) =>
+ apply(x, p)(Filter(q, _, _))
+
+ case SortBy(Unaliased(q), x, p, o) =>
+ apply(x, p)(SortBy(q, _, _, o))
+
+ case GroupBy(Unaliased(q), x, p) =>
+ apply(x, p)(GroupBy(q, _, _))
+
+ case DistinctOn(Unaliased(q), x, p) =>
+ apply(x, p)(DistinctOn(q, _, _))
+
+ case Join(t, a, b, iA, iB, o) =>
+ val (ar, art) = apply(a)
+ val (br, brt) = art.apply(b)
+ val freshA = freshIdent(iA, brt.state)
+ val freshB = freshIdent(iB, brt.state + freshA)
+ val or = BetaReduction(o, iA -> freshA, iB -> freshB)
+ val (orr, orrt) = AvoidAliasConflict(brt.state + freshA + freshB)(or)
+ (Join(t, ar, br, freshA, freshB, orr), orrt)
+
+ case FlatJoin(t, a, iA, o) =>
+ val (ar, art) = apply(a)
+ val freshA = freshIdent(iA)
+ val or = BetaReduction(o, iA -> freshA)
+ val (orr, orrt) = AvoidAliasConflict(art.state + freshA)(or)
+ (FlatJoin(t, ar, freshA, orr), orrt)
+
+ case _: Entity | _: FlatMap | _: ConcatMap | _: Map | _: Filter |
+ _: SortBy | _: GroupBy | _: Aggregation | _: Take | _: Drop |
+ _: Union | _: UnionAll | _: Distinct | _: DistinctOn | _: Nested =>
+ super.apply(q)
+ }
+
+ private def apply(x: Ident, p: Ast)(
+ f: (Ident, Ast) => Query
+ ): (Query, StatefulTransformer[Set[Ident]]) = {
+ val fresh = freshIdent(x)
+ val pr = BetaReduction(p, x -> fresh)
+ val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
+ (f(fresh, prr), t)
+ }
+
+ private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = {
+ def loop(x: Ident, n: Int): Ident = {
+ val fresh = Ident(s"${x.name}$n")
+ if (!state.contains(fresh))
+ fresh
+ else
+ loop(x, n + 1)
+ }
+ if (!state.contains(x))
+ x
+ else
+ loop(x, 1)
+ }
+
+ /**
+ * Sometimes we need to change the variables in a function because they will
+ * might conflict with some variable further up in the macro. Right now, this
+ * only happens when you do something like this:
+ * SELECT CASE WHEN h.value IS NOT NULL AND (h.value || 'foo') IS NOT NULL THEN
+ * h.value || 'foo' ELSE 'bar' END FROM Holder h
val q = quote { (v:
+ * Foo) => query[Foo].insert(v) } run(q(lift(v)))
Since 'v' is used by
+ * actionMeta in order to map keys to values for insertion, using it as a
+ * function argument messes up the output SQL like so: INSERT INTO
+ * MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?)
+ * Therefore, we need to have a method to remove such conflicting variables
+ * from Function ASTs
+ */
+ private def applyFunction(f: Function): Function = {
+ val (newBody, _, newParams) =
+ f.params.foldLeft((f.body, state, List[Ident]())) {
+ case ((body, state, newParams), param) => {
+ val fresh = freshIdent(param)
+ val pr = BetaReduction(body, param -> fresh)
+ val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
+ (prr, t.state, newParams :+ fresh)
+ }
+ }
+ Function(newParams, newBody)
+ }
+
+ private def applyForeach(f: Foreach): Foreach = {
+ val fresh = freshIdent(f.alias)
+ val pr = BetaReduction(f.body, f.alias -> fresh)
+ val (prr, _) = AvoidAliasConflict(state + fresh)(pr)
+ Foreach(f.query, fresh, prr)
+ }
+}
+
+private[minisql] object AvoidAliasConflict {
+
+ def apply(q: Query): Query =
+ AvoidAliasConflict(Set[Ident]())(q) match {
+ case (q, _) => q
+ }
+
+ /**
+ * Make sure query parameters do not collide with paramters of a AST function.
+ * Do this by walkning through the function's subtree and transforming and
+ * queries encountered.
+ */
+ def sanitizeVariables(
+ f: Function,
+ dangerousVariables: Set[Ident]
+ ): Function = {
+ AvoidAliasConflict(dangerousVariables).applyFunction(f)
+ }
+
+ /** Same is `sanitizeVariables` but for Foreach * */
+ def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = {
+ AvoidAliasConflict(dangerousVariables).applyForeach(f)
+ }
+
+ def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = {
+ AvoidAliasConflict(dangerousVariables).apply(q) match {
+ // Propagate aliasing changes to the rest of the query
+ case (q, _) => Normalize(q)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/norm/capture/AvoidCapture.scala b/src/main/scala/minisql/norm/capture/AvoidCapture.scala
new file mode 100644
index 0000000..788b551
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/AvoidCapture.scala
@@ -0,0 +1,9 @@
+package minisql.norm.capture
+
+import minisql.ast.Query
+
+object AvoidCapture {
+
+ def apply(q: Query): Query =
+ Dealias(AvoidAliasConflict(q))
+}
diff --git a/src/main/scala/minisql/norm/capture/Dealias.scala b/src/main/scala/minisql/norm/capture/Dealias.scala
new file mode 100644
index 0000000..64a56ce
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/Dealias.scala
@@ -0,0 +1,72 @@
+package minisql.norm.capture
+
+import minisql.ast._
+import minisql.norm.BetaReduction
+
+case class Dealias(state: Option[Ident])
+ extends StatefulTransformer[Option[Ident]] {
+
+ override def apply(q: Query): (Query, StatefulTransformer[Option[Ident]]) =
+ q match {
+ case FlatMap(a, b, c) =>
+ dealias(a, b, c)(FlatMap.apply) match {
+ case (FlatMap(a, b, c), _) =>
+ val (cn, cnt) = apply(c)
+ (FlatMap(a, b, cn), cnt)
+ }
+ case ConcatMap(a, b, c) =>
+ dealias(a, b, c)(ConcatMap.apply) match {
+ case (ConcatMap(a, b, c), _) =>
+ val (cn, cnt) = apply(c)
+ (ConcatMap(a, b, cn), cnt)
+ }
+ case Map(a, b, c) =>
+ dealias(a, b, c)(Map.apply)
+ case Filter(a, b, c) =>
+ dealias(a, b, c)(Filter.apply)
+ case SortBy(a, b, c, d) =>
+ dealias(a, b, c)(SortBy(_, _, _, d))
+ case GroupBy(a, b, c) =>
+ dealias(a, b, c)(GroupBy.apply)
+ case DistinctOn(a, b, c) =>
+ dealias(a, b, c)(DistinctOn.apply)
+ case Take(a, b) =>
+ val (an, ant) = apply(a)
+ (Take(an, b), ant)
+ case Drop(a, b) =>
+ val (an, ant) = apply(a)
+ (Drop(an, b), ant)
+ case Union(a, b) =>
+ val (an, _) = apply(a)
+ val (bn, _) = apply(b)
+ (Union(an, bn), Dealias(None))
+ case UnionAll(a, b) =>
+ val (an, _) = apply(a)
+ val (bn, _) = apply(b)
+ (UnionAll(an, bn), Dealias(None))
+ case Join(t, a, b, iA, iB, o) =>
+ val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _))
+ val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _))
+ (Join(t, an, bn, iAn, iBn, onn), Dealias(None))
+ case FlatJoin(t, a, iA, o) =>
+ val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _))
+ (FlatJoin(t, an, iAn, on), Dealias(Some(iA)))
+ case _: Entity | _: Distinct | _: Aggregation | _: Nested =>
+ (q, Dealias(None))
+ }
+
+ private def dealias[T](a: Ast, b: Ident, c: Ast)(f: (Ast, Ident, Ast) => T) =
+ apply(a) match {
+ case (an, t @ Dealias(Some(alias))) =>
+ (f(an, alias, BetaReduction(c, b -> alias)), t)
+ case other =>
+ (f(a, b, c), Dealias(Some(b)))
+ }
+}
+
+object Dealias {
+ def apply(query: Query) =
+ new Dealias(None)(query) match {
+ case (q, _) => q
+ }
+}
diff --git a/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala
new file mode 100644
index 0000000..7c95f7c
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala
@@ -0,0 +1,98 @@
+package minisql.norm.capture
+
+import minisql.ast.*
+
+/**
+ * Walk through any Queries that a returning clause has and replace Ident of the
+ * returning variable with ExternalIdent so that in later steps involving filter
+ * simplification, it will not be mistakenly dealiased with a potential shadow.
+ * Take this query for instance:
query[TestEntity] + * .insert(lift(TestEntity("s", 0, 1L, None))) .returningGenerated( r => + * (query[Dummy].filter(r => r.i == r.i).filter(d => d.i == r.i).max) )+ * The returning clause has an alias `Ident("r")` as well as the first filter + * clause. These two filters will be combined into one at which point the + * meaning of `r.i` in the 2nd filter will be confused for the first filter's + * alias (i.e. the `r` in `filter(r => ...)`. Therefore, we need to change this + * vunerable `r.i` in the second filter clause to an `ExternalIdent` before any + * of the simplifications are done. + * + * Note that we only want to do this for Queries inside of a `Returning` clause + * body. Other places where this needs to be done (e.g. in a Tuple that + * `Returning` returns) are done in `ExpandReturning`. + */ +private[minisql] case class DemarcateExternalAliases(externalIdent: Ident) + extends StatelessTransformer { + + def applyNonOverride(idents: Ident*)(ast: Ast) = + if (idents.forall(_ != externalIdent)) apply(ast) + else ast + + override def apply(ast: Ast): Ast = ast match { + + case FlatMap(q, i, b) => + FlatMap(apply(q), i, applyNonOverride(i)(b)) + + case ConcatMap(q, i, b) => + ConcatMap(apply(q), i, applyNonOverride(i)(b)) + + case Map(q, i, b) => + Map(apply(q), i, applyNonOverride(i)(b)) + + case Filter(q, i, b) => + Filter(apply(q), i, applyNonOverride(i)(b)) + + case SortBy(q, i, p, o) => + SortBy(apply(q), i, applyNonOverride(i)(p), o) + + case GroupBy(q, i, b) => + GroupBy(apply(q), i, applyNonOverride(i)(b)) + + case DistinctOn(q, i, b) => + DistinctOn(apply(q), i, applyNonOverride(i)(b)) + + case Join(t, a, b, iA, iB, o) => + Join(t, a, b, iA, iB, applyNonOverride(iA, iB)(o)) + + case FlatJoin(t, a, iA, o) => + FlatJoin(t, a, iA, applyNonOverride(iA)(o)) + + case p @ Property.Opinionated( + id @ Ident(_), + value, + renameable, + visibility + ) => + if (id == externalIdent) + Property.Opinionated( + ExternalIdent(externalIdent.name), + value, + renameable, + visibility + ) + else + p + + case other => + super.apply(other) + } +} + +object DemarcateExternalAliases { + + private def demarcateQueriesInBody(id: Ident, body: Ast) = + Transform(body) { + // Apply to the AST defined apply method about, not to the superclass method that takes Query + case q: Query => + new DemarcateExternalAliases(id).apply(q.asInstanceOf[Ast]) + } + + def apply(ast: Ast): Ast = ast match { + case Returning(a, id, body) => + Returning(a, id, demarcateQueriesInBody(id, body)) + case ReturningGenerated(a, id, body) => + val d = demarcateQueriesInBody(id, body) + ReturningGenerated(a, id, demarcateQueriesInBody(id, body)) + case other => + other + } +} diff --git a/src/main/scala/minisql/parsing/BlockParsing.scala b/src/main/scala/minisql/parsing/BlockParsing.scala new file mode 100644 index 0000000..ae8722c --- /dev/null +++ b/src/main/scala/minisql/parsing/BlockParsing.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +type SParser[X] = + (q: Quotes) ?=> PartialFunction[q.reflect.Statement, Expr[X]] + +private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using + Quotes +): SParser[ast.Ast] = { + + import quotes.reflect.* + + @annotation.nowarn + lazy val valDefParser: SParser[ast.Val] = { + case ValDef(n, _, Some(b)) => + val body = astParser(b.asExpr) + '{ ast.Val(ast.Ident(${ Expr(n) }), $body) } + + } + valDefParser +} + +private[parsing] def blockParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + lazy val statementParser = statementParsing(astParser) + + termParser { + case Block(Nil, t) => astParser(t.asExpr) + case b @ Block(st, t) => + val asts = (st :+ t).map { + case e if e.isExpr => astParser(e.asExpr) + case `statementParser`(x) => x + case o => + report.errorAndAbort(s"Cannot parse statement: ${o.show}") + } + if (asts.size > 1) { + '{ ast.Block(${ Expr.ofList(asts) }) } + } else asts(0) + + } +} diff --git a/src/main/scala/minisql/parsing/BoxingParsing.scala b/src/main/scala/minisql/parsing/BoxingParsing.scala new file mode 100644 index 0000000..d7b7f31 --- /dev/null +++ b/src/main/scala/minisql/parsing/BoxingParsing.scala @@ -0,0 +1,31 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def boxingParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + case '{ BigDecimal.int2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.long2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.double2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.javaBigDecimal2bigDecimal($v) } => astParser(v) + case '{ Predef.byte2Byte($v) } => astParser(v) + case '{ Predef.short2Short($v) } => astParser(v) + case '{ Predef.char2Character($v) } => astParser(v) + case '{ Predef.int2Integer($v) } => astParser(v) + case '{ Predef.long2Long($v) } => astParser(v) + case '{ Predef.float2Float($v) } => astParser(v) + case '{ Predef.double2Double($v) } => astParser(v) + case '{ Predef.boolean2Boolean($v) } => astParser(v) + case '{ Predef.augmentString($v) } => astParser(v) + case '{ Predef.Byte2byte($v) } => astParser(v) + case '{ Predef.Short2short($v) } => astParser(v) + case '{ Predef.Character2char($v) } => astParser(v) + case '{ Predef.Integer2int($v) } => astParser(v) + case '{ Predef.Long2long($v) } => astParser(v) + case '{ Predef.Float2float($v) } => astParser(v) + case '{ Predef.Double2double($v) } => astParser(v) + case '{ Predef.Boolean2boolean($v) } => astParser(v) + +} diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala new file mode 100644 index 0000000..7b173db --- /dev/null +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -0,0 +1,13 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted.* + +private[parsing] def infixParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Infix] = { + + import quotes.reflect.* + ??? +} diff --git a/src/main/scala/minisql/parsing/LiftParsing.scala b/src/main/scala/minisql/parsing/LiftParsing.scala new file mode 100644 index 0000000..c01df89 --- /dev/null +++ b/src/main/scala/minisql/parsing/LiftParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import scala.quoted.* +import minisql.ParamEncoder +import minisql.ast +import minisql.dsl.* + +private[parsing] def liftParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Lift] = { + case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } => + import quotes.reflect.* + val name = x.asTerm.symbol.fullName + val liftId = x.asTerm.symbol.owner.fullName + "@" + name + '{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) } +} diff --git a/src/main/scala/minisql/parsing/OperationParsing.scala b/src/main/scala/minisql/parsing/OperationParsing.scala new file mode 100644 index 0000000..df93010 --- /dev/null +++ b/src/main/scala/minisql/parsing/OperationParsing.scala @@ -0,0 +1,113 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.{ + EqualityOperator, + StringOperator, + NumericOperator, + BooleanOperator +} +import minisql.dsl.* +import scala.quoted._ + +private[parsing] def operationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Operation] = { + import quotes.reflect.* + + def isNumeric(t: TypeRepr) = { + t <:< TypeRepr.of[Int] + || t <:< TypeRepr.of[Long] + || t <:< TypeRepr.of[Byte] + || t <:< TypeRepr.of[Float] + || t <:< TypeRepr.of[Double] + || t <:< TypeRepr.of[java.math.BigDecimal] + || t <:< TypeRepr.of[scala.math.BigDecimal] + } + + def parseBinary( + left: Expr[Any], + right: Expr[Any], + op: Expr[ast.BinaryOperator] + ) = { + val leftE = astParser(left) + val rightE = astParser(right) + '{ ast.BinaryOperation(${ leftE }, ${ op }, ${ rightE }) } + } + + def parseUnary(expr: Expr[Any], op: Expr[ast.UnaryOperator]) = { + val base = astParser(expr) + '{ ast.UnaryOperation($op, $base) } + + } + + val universalOpParser: Parser[ast.BinaryOperation] = termParser { + case Apply(Select(leftT, UniversalOp(op)), List(rightT)) => + parseBinary(leftT.asExpr, rightT.asExpr, op) + } + + val stringOpParser: Parser[ast.Operation] = { + case '{ ($x: String) + ($y: String) } => + parseBinary(x, y, '{ StringOperator.concat }) + case '{ ($x: String).startsWith($y) } => + parseBinary(x, y, '{ StringOperator.startsWith }) + case '{ ($x: String).split($y) } => + parseBinary(x, y, '{ StringOperator.split }) + case '{ ($x: String).toUpperCase } => + parseUnary(x, '{ StringOperator.toUpperCase }) + case '{ ($x: String).toLowerCase } => + parseUnary(x, '{ StringOperator.toLowerCase }) + case '{ ($x: String).toLong } => + parseUnary(x, '{ StringOperator.toLong }) + case '{ ($x: String).toInt } => + parseUnary(x, '{ StringOperator.toInt }) + } + + val numericOpParser = termParser { + case (Apply(Select(lt, NumericOp(op)), List(rt))) if isNumeric(lt.tpe) => + parseBinary(lt.asExpr, rt.asExpr, op) + case Select(leftTerm, "unary_-") if isNumeric(leftTerm.tpe) => + val leftExpr = astParser(leftTerm.asExpr) + '{ ast.UnaryOperation(NumericOperator.-, ${ leftExpr }) } + + } + + val booleanOpParser: Parser[ast.Operation] = { + case '{ ($x: Boolean) && $y } => + parseBinary(x, y, '{ BooleanOperator.&& }) + case '{ ($x: Boolean) || $y } => + parseBinary(x, y, '{ BooleanOperator.|| }) + case '{ !($x: Boolean) } => + parseUnary(x, '{ BooleanOperator.! }) + } + + universalOpParser + .orElse(stringOpParser) + .orElse(numericOpParser) + .orElse(booleanOpParser) +} + +private object UniversalOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "==" | "equals" => Some('{ EqualityOperator.== }) + case "!=" => Some('{ EqualityOperator.!= }) + case _ => None + } +} + +private object NumericOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "+" => Some('{ NumericOperator.+ }) + case "-" => Some('{ NumericOperator.- }) + case "*" => Some('{ NumericOperator.* }) + case "/" => Some('{ NumericOperator./ }) + case ">" => Some('{ NumericOperator.> }) + case ">=" => Some('{ NumericOperator.>= }) + case "<" => Some('{ NumericOperator.< }) + case "<=" => Some('{ NumericOperator.<= }) + case "%" => Some('{ NumericOperator.% }) + case _ => None + } +} diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala new file mode 100644 index 0000000..4bc7265 --- /dev/null +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.Ast +import scala.quoted.* + +private[minisql] inline def parseParamAt[F]( + inline f: F, + inline n: Int +): ast.Ident = ${ + parseParamAt('f, 'n) +} + +private[minisql] inline def parseBody[X]( + inline f: X +): ast.Ast = ${ + parseBody('f) +} + +private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using + Quotes +): Expr[ast.Ident] = { + + import quotes.reflect.* + + val pIdx = n.value.getOrElse( + report.errorAndAbort(s"Param index ${n.show} is not know") + ) + extractTerm(f.asTerm) match { + case Lambda(vals, _) => + vals(pIdx) match { + case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) } + } + } +} + +private[minisql] def parseBody[X]( + x: Expr[X] +)(using Quotes): Expr[Ast] = { + import quotes.reflect.* + x.asTerm match { + case Lambda(vals, body) => + Parsing.parseExpr(body.asExpr) + case o => + report.errorAndAbort(s"Can only parse function") + } +} diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala new file mode 100644 index 0000000..463112f --- /dev/null +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -0,0 +1,139 @@ +package minisql.parsing + +import minisql.ast +import minisql.context.{ReturningMultipleFieldSupported, _} +import minisql.norm.BetaReduction +import minisql.norm.capture.AvoidAliasConflict +import minisql.idiom.Idiom +import scala.annotation.tailrec +import minisql.ast.Implicits._ +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.{Hidden, Visible} +import minisql.util.Interleave +import scala.quoted.* + +type Parser[A] = PartialFunction[Expr[Any], Expr[A]] + +private def termParser[A](using q: Quotes)( + pf: PartialFunction[q.reflect.Term, Expr[A]] +): Parser[A] = { + import quotes.reflect._ + { + case e if pf.isDefinedAt(e.asTerm) => pf(e.asTerm) + } +} + +private def parser[A]( + f: PartialFunction[Expr[Any], Expr[A]] +)(using Quotes): Parser[A] = { + case e if f.isDefinedAt(e) => f(e) +} + +private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { + import quotes.reflect.* + def unwrapTerm(t: Term): Term = t match { + case Inlined(_, _, o) => unwrapTerm(o) + case Block(Nil, last) => last + case Typed(t, _) => + unwrapTerm(t) + case Select(t, "$asInstanceOf$") => + unwrapTerm(t) + case TypeApply(t, _) => + unwrapTerm(t) + case o => o + } + unwrapTerm(x) +} + +private[minisql] object Parsing { + + def parseExpr( + expr: Expr[?] + )(using q: Quotes): Expr[ast.Ast] = { + + import q.reflect._ + + def unwrapped( + f: Parser[ast.Ast] + ): Parser[ast.Ast] = { + case expr => + val t = expr.asTerm + f(extractTerm(t).asExpr) + } + + lazy val astParser: Parser[ast.Ast] = + unwrapped { + typedParser + .orElse(propertyParser) + .orElse(liftParser) + .orElse(identParser) + .orElse(valueParser) + .orElse(operationParser) + .orElse(constantParser) + .orElse(blockParser) + .orElse(boxingParser) + .orElse(ifParser) + .orElse(traversableOperationParser) + .orElse(patMatchParser) + .orElse(infixParser) + .orElse { + case o => + val str = scala.util.Try(o.show).getOrElse("") + report.errorAndAbort( + s"cannot parse ${str}", + o.asTerm.pos + ) + } + } + + lazy val typedParser: Parser[ast.Ast] = termParser { + case (Typed(t, _)) => + astParser(t.asExpr) + } + + lazy val blockParser: Parser[ast.Ast] = blockParsing(astParser) + + lazy val valueParser: Parser[ast.Value] = valueParsing(astParser) + + lazy val liftParser: Parser[ast.Lift] = liftParsing(astParser) + + lazy val constantParser: Parser[ast.Constant] = termParser { + case Literal(x) => + '{ ast.Constant(${ Literal(x).asExpr }) } + } + + lazy val identParser: Parser[ast.Ident] = termParser { + case x @ Ident(n) if x.symbol.isValDef => + '{ ast.Ident(${ Expr(n) }) } + } + + lazy val propertyParser: Parser[ast.Property] = propertyParsing(astParser) + + lazy val operationParser: Parser[ast.Operation] = operationParsing( + astParser + ) + + lazy val boxingParser: Parser[ast.Ast] = boxingParsing(astParser) + + lazy val ifParser: Parser[ast.If] = { + case '{ if ($a) $b else $c } => + '{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) } + + } + lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) + + lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser) + + lazy val traversableOperationParser: Parser[ast.IterableOperation] = + traversableOperationParsing(astParser) + + astParser(expr) + } + + private[minisql] inline def parse[A]( + inline a: A + ): ast.Ast = ${ + parseExpr('a) + } + +} diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala new file mode 100644 index 0000000..2db7652 --- /dev/null +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -0,0 +1,49 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def patMatchParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + termParser { + // Val defs that showd pattern variables will cause error + case e @ Match(t, List(CaseDef(IsTupleUnapply(binds), None, body))) => + val bm = binds.zipWithIndex.map { + case (Bind(n, ident), idx) => + n -> Select.unique(t, s"_${idx + 1}") + }.toMap + val tm = new TreeMap { + override def transformTerm(tree: Term)(owner: Symbol): Term = { + tree match { + case Ident(n) => bm(n) + case o => super.transformTerm(o)(owner) + } + } + } + val newBody = tm.transformTree(body)(e.symbol) + astParser(newBody.asExpr) + } + +} + +object IsTupleUnapply { + + def unapply(using + Quotes + )(t: quotes.reflect.Tree): Option[List[quotes.reflect.Tree]] = { + import quotes.reflect.* + def isTupleNUnapply(x: Term) = { + val fn = x.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.unapply") + } + t match { + case Unapply(m, _, binds) if isTupleNUnapply(m) => + Some(binds) + case _ => None + } + } +} diff --git a/src/main/scala/minisql/parsing/PropertyParsing.scala b/src/main/scala/minisql/parsing/PropertyParsing.scala new file mode 100644 index 0000000..292f281 --- /dev/null +++ b/src/main/scala/minisql/parsing/PropertyParsing.scala @@ -0,0 +1,30 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted._ + +private[parsing] def propertyParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Property] = { + import quotes.reflect.* + + def isAccessor(s: Select) = { + s.qualifier.tpe.typeSymbol.caseFields.exists(cf => cf.name == s.name) + } + + val parseApply: Parser[ast.Property] = termParser { + case m @ Select(base, n) if isAccessor(m) => + val obj = astParser(base.asExpr) + '{ ast.Property($obj, ${ Expr(n) }) } + } + + val parseOptionGet: Parser[ast.Property] = { + case '{ ($e: Option[t]).get } => + report.errorAndAbort( + "Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead." + ) + } + parseApply.orElse(parseOptionGet) + +} diff --git a/src/main/scala/minisql/parsing/TraversableOperationParsing.scala b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala new file mode 100644 index 0000000..8f50098 --- /dev/null +++ b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private def traversableOperationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.IterableOperation] = { + case '{ type k; type v; (${ m }: Map[`k`, `v`]).contains($key) } => + '{ ast.MapContains(${ astParser(m) }, ${ astParser(key) }) } + case '{ ($s: Set[e]).contains($i) } => + '{ ast.SetContains(${ astParser(s) }, ${ astParser(i) }) } + case '{ ($s: Seq[e]).contains($i) } => + '{ ast.ListContains(${ astParser(s) }, ${ astParser(i) }) } + +} diff --git a/src/main/scala/minisql/parsing/ValueParsing.scala b/src/main/scala/minisql/parsing/ValueParsing.scala new file mode 100644 index 0000000..6c2fb9e --- /dev/null +++ b/src/main/scala/minisql/parsing/ValueParsing.scala @@ -0,0 +1,71 @@ +package minisql +package parsing + +import scala.quoted._ + +private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using + Quotes +): Parser[ast.Value] = { + + import quotes.reflect.* + + val parseTupleApply: Parser[ast.Tuple] = { + case IsTupleApply(args) => + val t = args.map(astParser) + '{ ast.Tuple(${ Expr.ofList(t) }) } + } + + val parseAssocTuple: Parser[ast.Tuple] = { + case '{ ($x: tx) -> ($y: ty) } => + '{ ast.Tuple(List(${ astParser(x) }, ${ astParser(y) })) } + } + + parseTupleApply.orElse(parseAssocTuple) +} + +private[minisql] object IsTupleApply { + + def unapply(e: Expr[Any])(using Quotes): Option[Seq[Expr[Any]]] = { + + import quotes.reflect.* + + def isTupleNApply(t: Term) = { + val fn = t.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.apply") + } + + def isTupleXXLApply(t: Term) = { + t.symbol.fullName == "scala.runtime.TupleXXL$.apply" + } + + extractTerm(e.asTerm) match { + // TupleN(0-22).apply + case Apply(b, args) if isTupleNApply(b) => + Some(args.map(_.asExpr)) + case TypeApply(Select(t, "$asInstanceOf$"), tt) if isTupleXXLApply(t) => + t.asExpr match { + case '{ scala.runtime.TupleXXL.apply(${ Varargs(args) }*) } => + Some(args) + } + case o => + None + } + } +} + +private[parsing] object IsTuple2 { + + def unapply(using + Quotes + )(t: Expr[Any]): Option[(Expr[Any], Expr[Any])] = + t match { + case '{ scala.Tuple2.apply($x1, $x2) } => Some((x1, x2)) + case '{ ($x1: t1) -> ($x2: t2) } => Some((x1, x2)) + case _ => None + } + + def unapply(using + Quotes + )(t: quotes.reflect.Term): Option[(Expr[Any], Expr[Any])] = + unapply(t.asExpr) +} diff --git a/src/main/scala/minisql/util/Interpolator.scala b/src/main/scala/minisql/util/Interpolator.scala index 55e32cd..c63e984 100644 --- a/src/main/scala/minisql/util/Interpolator.scala +++ b/src/main/scala/minisql/util/Interpolator.scala @@ -11,20 +11,25 @@ import scala.util.matching.Regex class Interpolator( defaultIndent: Int = 0, qprint: AstPrinter = AstPrinter(), - out: PrintStream = System.out, + out: PrintStream = System.out ) { + + extension (sc: StringContext) { + def trace(elements: Any*) = new Traceable(sc, elements) + } + class Traceable(sc: StringContext, elementsSeq: Seq[Any]) { private val elementPrefix = "| " private sealed trait PrintElement private case class Str(str: String, first: Boolean) extends PrintElement - private case class Elem(value: String) extends PrintElement - private case object Separator extends PrintElement + private case class Elem(value: String) extends PrintElement + private case object Separator extends PrintElement private def generateStringForCommand(value: Any, indent: Int) = { val objectString = qprint(value) - val oneLine = objectString.fitsOnOneLine + val oneLine = objectString.fitsOnOneLine oneLine match { case true => s"${indent.prefix}> ${objectString}" case false => @@ -42,7 +47,7 @@ class Interpolator( private def readBuffers() = { def orZero(i: Int): Int = if (i < 0) 0 else i - val parts = sc.parts.iterator.toList + val parts = sc.parts.iterator.toList val elements = elementsSeq.toList.map(qprint(_)) val (firstStr, explicitIndent) = readFirst(parts.head) diff --git a/src/main/scala/minisql/util/Message.scala b/src/main/scala/minisql/util/Message.scala new file mode 100644 index 0000000..f748456 --- /dev/null +++ b/src/main/scala/minisql/util/Message.scala @@ -0,0 +1,75 @@ +package minisql.util + +import minisql.AstPrinter +import minisql.idiom.Idiom +import minisql.util.IndentUtil._ + +object Messages { + + private def variable(propName: String, envName: String, default: String) = + Option(System.getProperty(propName)) + .orElse(sys.env.get(envName)) + .getOrElse(default) + + private[util] val prettyPrint = + variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean + private[util] val debugEnabled = + variable("quill.macro.log", "quill_macro_log", "true").toBoolean + private[util] val traceEnabled = + variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean + private[util] val traceColors = + variable("quill.trace.color", "quill_trace_color,", "false").toBoolean + private[util] val traceOpinions = + variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean + private[util] val traceAstSimple = variable( + "quill.trace.ast.simple", + "quill_trace_ast_simple", + "false" + ).toBoolean + private[minisql] val cacheDynamicQueries = variable( + "quill.query.cacheDaynamic", + "query_query_cacheDaynamic", + "true" + ).toBoolean + private[util] val traces: List[TraceType] = + variable("quill.trace.types", "quill_trace_types", "standard") + .split(",") + .toList + .map(_.trim) + .flatMap(trace => + TraceType.values.filter(traceType => trace == traceType.value) + ) + + def tracesEnabled(tt: TraceType) = + traceEnabled && traces.contains(tt) + + sealed trait TraceType { def value: String } + object TraceType { + case object Normalizations extends TraceType { val value = "norm" } + case object Standard extends TraceType { val value = "standard" } + case object NestedQueryExpansion extends TraceType { val value = "nest" } + + def values: List[TraceType] = + List(Standard, Normalizations, NestedQueryExpansion) + } + + val qprint = AstPrinter() + + def fail(msg: String) = + throw new IllegalStateException(msg) + + def trace[T]( + label: String, + numIndent: Int = 0, + traceType: TraceType = TraceType.Standard + ) = + (v: T) => { + val indent = (0 to numIndent).map(_ => "").mkString(" ") + if (tracesEnabled(traceType)) + println(s"$indent$label\n${{ + if (traceColors) qprint.apply(v) + else qprint.apply(v) + }.split("\n").map(s"$indent " + _).mkString("\n")}") + v + } +}