diff --git a/README.md b/README.md
index de83bb9..6f82dea 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
大部分场景不用在 `macro` 对 Ast 进行复杂模式匹配来分析代码。
-## 核心思路 使用 inline 和 `FromExpr` 代替大部分 parsing 工作
+## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作
`FromExpr` 是 `scala3` 内置的 typeclass,用来获取编译期值 。
diff --git a/build.sbt b/build.sbt
index 0a952af..2dc3b9c 100644
--- a/build.sbt
+++ b/build.sbt
@@ -1,8 +1,9 @@
name := "minisql"
-scalaVersion := "3.6.2"
+scalaVersion := "3.5.2"
libraryDependencies ++= Seq(
+ "org.scalameta" %% "munit" % "1.0.3" % Test
)
scalacOptions ++= Seq("-experimental", "-language:experimental.namedTuples")
diff --git a/src/main/scala/minisql/Meta.scala b/src/main/scala/minisql/Meta.scala
new file mode 100644
index 0000000..38d8fb9
--- /dev/null
+++ b/src/main/scala/minisql/Meta.scala
@@ -0,0 +1,3 @@
+package minisql
+
+type QueryMeta
diff --git a/src/main/scala/minisql/ParamEncoder.scala b/src/main/scala/minisql/ParamEncoder.scala
index a55c0a3..05ef348 100644
--- a/src/main/scala/minisql/ParamEncoder.scala
+++ b/src/main/scala/minisql/ParamEncoder.scala
@@ -1,8 +1,23 @@
package minisql
+import scala.util.Try
+
trait ParamEncoder[E] {
type Stmt
def setParam(s: Stmt, idx: Int, v: E): Unit
}
+
+trait ColumnDecoder[X] {
+
+ type DBRow
+
+ def decode(row: DBRow, idx: Int): Try[X]
+}
+
+object ColumnDecoder {
+ type Aux[R, X] = ColumnDecoder[X] {
+ type DBRow = R
+ }
+}
diff --git a/src/main/scala/minisql/Parser.scala b/src/main/scala/minisql/Parser.scala
deleted file mode 100644
index 3b21d90..0000000
--- a/src/main/scala/minisql/Parser.scala
+++ /dev/null
@@ -1,127 +0,0 @@
-package minisql.parsing
-
-import minisql.ast
-import minisql.ast.Ast
-import scala.quoted.*
-
-type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]]
-
-private[minisql] inline def parseParamAt[F](
- inline f: F,
- inline n: Int
-): ast.Ident = ${
- parseParamAt('f, 'n)
-}
-
-private[minisql] inline def parseBody[X](
- inline f: X
-): ast.Ast = ${
- parseBody('f)
-}
-
-private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
- Quotes
-): Expr[ast.Ident] = {
-
- import quotes.reflect.*
-
- val pIdx = n.value.getOrElse(
- report.errorAndAbort(s"Param index ${n.show} is not know")
- )
- extractTerm(f.asTerm) match {
- case Lambda(vals, _) =>
- vals(pIdx) match {
- case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
- }
- }
-}
-
-private[minisql] def parseBody[X](
- x: Expr[X]
-)(using Quotes): Expr[Ast] = {
- import quotes.reflect.*
- extractTerm(x.asTerm) match {
- case Lambda(vals, body) =>
- astParser(body.asExpr)
- case o =>
- report.errorAndAbort(s"Can only parse function")
- }
-}
-private def isNumeric(x: Expr[?])(using Quotes): Boolean = {
- import quotes.reflect.*
- x.asTerm.tpe match {
- case t if t <:< TypeRepr.of[Int] => true
- case t if t <:< TypeRepr.of[Long] => true
- case t if t <:< TypeRepr.of[Float] => true
- case t if t <:< TypeRepr.of[Double] => true
- case t if t <:< TypeRepr.of[BigDecimal] => true
- case t if t <:< TypeRepr.of[BigInt] => true
- case t if t <:< TypeRepr.of[java.lang.Integer] => true
- case t if t <:< TypeRepr.of[java.lang.Long] => true
- case t if t <:< TypeRepr.of[java.lang.Float] => true
- case t if t <:< TypeRepr.of[java.lang.Double] => true
- case t if t <:< TypeRepr.of[java.math.BigDecimal] => true
- case _ => false
- }
-}
-
-private def identParser(using Quotes): Parser[ast.Ident] = {
- import quotes.reflect.*
- { (x: Expr[?]) =>
- extractTerm(x.asTerm) match {
- case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) })
- case _ => None
- }
- }.unlift
-
-}
-
-private lazy val astParser: Quotes ?=> Parser[Ast] = {
- identParser.orElse(propertyParser(astParser))
-}
-
-private object IsPropertySelect {
- def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = {
- import quotes.reflect.*
- x.asTerm match {
- case Select(x, n) => Some(x.asExpr, n)
- case _ => None
- }
- }
-}
-
-def propertyParser(
- astParser: => Parser[Ast]
-)(using Quotes): Parser[ast.Property] = {
- case IsPropertySelect(expr, n) =>
- '{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) }
-}
-
-def optionOperationParser(
- astParser: => Parser[Ast]
-)(using Quotes): Parser[ast.OptionOperation] = {
- case '{ ($x: Option[t]).isEmpty } =>
- '{ ast.OptionIsEmpty(${ astParser(x) }) }
-}
-
-def binaryOperationParser(
- astParser: => Parser[Ast]
-)(using Quotes): Parser[ast.BinaryOperation] = {
- ???
-}
-
-private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
- import quotes.reflect.*
- def unwrapTerm(t: Term): Term = t match {
- case Inlined(_, _, o) => unwrapTerm(o)
- case Block(Nil, last) => last
- case Typed(t, _) =>
- unwrapTerm(t)
- case Select(t, "$asInstanceOf$") =>
- unwrapTerm(t)
- case TypeApply(t, _) =>
- unwrapTerm(t)
- case o => o
- }
- unwrapTerm(x)
-}
diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala
new file mode 100644
index 0000000..da7008e
--- /dev/null
+++ b/src/main/scala/minisql/Quoted.scala
@@ -0,0 +1,101 @@
+package minisql
+
+import minisql.*
+import minisql.idiom.*
+import minisql.parsing.*
+import minisql.util.*
+import minisql.ast.{Ast, Entity, Map, Property, Ident, Filter, given}
+import scala.quoted.*
+import scala.compiletime.*
+import scala.compiletime.ops.string.*
+import scala.collection.immutable.{Map => IMap}
+
+opaque type Quoted <: Ast = Ast
+
+opaque type Query[E] <: Quoted = Quoted
+
+opaque type EntityQuery[E] <: Query[E] = Query[E]
+
+object EntityQuery {
+ extension [E](inline e: EntityQuery[E]) {
+ inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
+ transform(e)(f)(Map.apply)
+ }
+
+ inline def filter(inline f: E => Boolean): EntityQuery[E] = {
+ transform(e)(f)(Filter.apply)
+ }
+ }
+}
+
+private inline def transform[A, B](inline q1: Quoted)(
+ inline f: A => B
+)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
+ fast(q1, f.param0, f.body)
+}
+
+inline def query[E](inline table: String): EntityQuery[E] =
+ Entity(table, Nil)
+
+extension [A, B](inline f1: A => B) {
+ private inline def param0 = parsing.parseParamAt(f1, 0)
+ private inline def body = parsing.parseBody(f1)
+}
+
+extension [A1, A2, B](inline f1: (A1, A2) => B) {
+ private inline def param0 = parsing.parseParamAt(f1, 0)
+ private inline def param1 = parsing.parseParamAt(f1, 1)
+ private inline def body = parsing.parseBody(f1)
+}
+
+def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException()
+
+class NonQuotedException extends Exception("Cannot be used at runtime")
+
+private[minisql] inline def compileTimeAst(inline q: Quoted): Option[String] =
+ ${
+ compileTimeAstImpl('q)
+ }
+
+private def compileTimeAstImpl(e: Expr[Quoted])(using
+ Quotes
+): Expr[Option[String]] = {
+ import quotes.reflect.*
+ e.value match {
+ case Some(v) => '{ Some(${ Expr(v.toString()) }) }
+ case None => '{ None }
+ }
+}
+
+private[minisql] inline def compile[I <: Idiom, N <: NamingStrategy](
+ inline q: Quoted,
+ inline idiom: I,
+ inline naming: N
+): Statement = ${ compileImpl[I, N]('q, 'idiom, 'naming) }
+
+private def compileImpl[I <: Idiom, N <: NamingStrategy](
+ q: Expr[Quoted],
+ idiom: Expr[I],
+ n: Expr[N]
+)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
+ import quotes.reflect.*
+ q.value match {
+ case Some(ast) =>
+ val idiom = LoadObject[I].getOrElse(
+ report.errorAndAbort(s"Idiom not known at compile")
+ )
+
+ val naming = LoadNaming
+ .static[N]
+ .getOrElse(report.errorAndAbort(s"NamingStrategy not known at compile"))
+
+ val stmt = idiom.translate(ast)(using naming)
+ Expr(stmt._2)
+ case None =>
+ report.info("Dynamic Query")
+ '{
+ $idiom.translate($q)(using $n)._2
+ }
+
+ }
+}
diff --git a/src/main/scala/minisql/ReturnAction.scala b/src/main/scala/minisql/ReturnAction.scala
new file mode 100644
index 0000000..63df7a0
--- /dev/null
+++ b/src/main/scala/minisql/ReturnAction.scala
@@ -0,0 +1,7 @@
+package minisql
+
+enum ReturnAction {
+ case ReturnNothing
+ case ReturnColumns(columns: List[String])
+ case ReturnRecord
+}
diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala
index 35feb70..52446e3 100644
--- a/src/main/scala/minisql/ast/Ast.scala
+++ b/src/main/scala/minisql/ast/Ast.scala
@@ -1,6 +1,7 @@
package minisql.ast
import minisql.NamingStrategy
+import minisql.ParamEncoder
import scala.quoted.*
@@ -378,14 +379,21 @@ sealed trait ScalarLift extends Lift
case class ScalarValueLift(
name: String,
- liftId: String
+ liftId: String,
+ value: Option[(Any, ParamEncoder[?])]
+) extends ScalarLift
+
+case class ScalarQueryLift(
+ name: String,
+ liftId: String,
+ value: Option[(Seq[Any], ParamEncoder[?])]
) extends ScalarLift
object ScalarLift {
given ToExpr[ScalarLift] with {
def apply(l: ScalarLift)(using Quotes) = l match {
- case ScalarValueLift(n, id) =>
- '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) }
+ case ScalarValueLift(n, id, v) =>
+ '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, None) }
}
}
}
diff --git a/src/main/scala/minisql/ast/AstOps.scala b/src/main/scala/minisql/ast/AstOps.scala
new file mode 100644
index 0000000..e0069ae
--- /dev/null
+++ b/src/main/scala/minisql/ast/AstOps.scala
@@ -0,0 +1,94 @@
+package minisql.ast
+object Implicits {
+ implicit class AstOps(body: Ast) {
+ private[minisql] def +||+(other: Ast) =
+ BinaryOperation(body, BooleanOperator.`||`, other)
+ private[minisql] def +&&+(other: Ast) =
+ BinaryOperation(body, BooleanOperator.`&&`, other)
+ private[minisql] def +==+(other: Ast) =
+ BinaryOperation(body, EqualityOperator.`==`, other)
+ private[minisql] def +!=+(other: Ast) =
+ BinaryOperation(body, EqualityOperator.`!=`, other)
+ }
+}
+
+object +||+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, BooleanOperator.`||`, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+object +&&+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, BooleanOperator.`&&`, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+val EqOp = EqualityOperator.`==`
+val NeqOp = EqualityOperator.`!=`
+
+object +==+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, EqOp, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+object +!=+ {
+ def unapply(a: Ast): Option[(Ast, Ast)] = {
+ a match {
+ case BinaryOperation(one, NeqOp, two) => Some((one, two))
+ case _ => None
+ }
+ }
+}
+
+object IsNotNullCheck {
+ def apply(ast: Ast) = BinaryOperation(ast, EqualityOperator.`!=`, NullValue)
+
+ def unapply(ast: Ast): Option[Ast] = {
+ ast match {
+ case BinaryOperation(cond, NeqOp, NullValue) => Some(cond)
+ case _ => None
+ }
+ }
+}
+
+object IsNullCheck {
+ def apply(ast: Ast) = BinaryOperation(ast, EqOp, NullValue)
+
+ def unapply(ast: Ast): Option[Ast] = {
+ ast match {
+ case BinaryOperation(cond, EqOp, NullValue) => Some(cond)
+ case _ => None
+ }
+ }
+}
+
+object IfExistElseNull {
+ def apply(exists: Ast, `then`: Ast) =
+ If(IsNotNullCheck(exists), `then`, NullValue)
+
+ def unapply(ast: Ast) = ast match {
+ case If(IsNotNullCheck(exists), t, NullValue) => Some((exists, t))
+ case _ => None
+ }
+}
+
+object IfExist {
+ def apply(exists: Ast, `then`: Ast, otherwise: Ast) =
+ If(IsNotNullCheck(exists), `then`, otherwise)
+
+ def unapply(ast: Ast) = ast match {
+ case If(IsNotNullCheck(exists), t, e) => Some((exists, t, e))
+ case _ => None
+ }
+}
diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala
index b66e4b9..9f70b0d 100644
--- a/src/main/scala/minisql/ast/FromExprs.scala
+++ b/src/main/scala/minisql/ast/FromExprs.scala
@@ -45,8 +45,9 @@ private given FromExpr[Infix] with {
private given FromExpr[ScalarValueLift] with {
def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] =
x match {
- case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } =>
- Some(ScalarValueLift(n, id))
+ case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } =>
+ // don't cared about value here, a little tricky
+ Some(ScalarValueLift(n, id, null))
}
}
@@ -122,6 +123,13 @@ private given FromExpr[Query] with {
Some(FlatMap(b, id, body))
case '{ ConcatMap(${ Expr(b) }, ${ Expr(id) }, ${ Expr(body) }) } =>
Some(ConcatMap(b, id, body))
+ case '{
+ val x: Ast = ${ Expr(b) }
+ val y: Ident = ${ Expr(id) }
+ val z: Ast = ${ Expr(body) }
+ ConcatMap(x, y, z)
+ } =>
+ Some(ConcatMap(b, id, body))
case '{ Drop(${ Expr(b) }, ${ Expr(n) }) } =>
Some(Drop(b, n))
case '{ Take(${ Expr(b) }, ${ Expr[Ast](n) }) } =>
@@ -129,7 +137,7 @@ private given FromExpr[Query] with {
case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } =>
Some(SortBy(b, p, s, o))
case o =>
- println(s"Cannot extract ${o.show}")
+ println(s"Cannot extract ${o}")
None
}
}
@@ -145,6 +153,7 @@ private given FromExpr[BinaryOperator] with {
case '{ NumericOperator.- } => Some(NumericOperator.-)
case '{ NumericOperator.* } => Some(NumericOperator.*)
case '{ NumericOperator./ } => Some(NumericOperator./)
+ case '{ NumericOperator.> } => Some(NumericOperator.>)
case '{ StringOperator.split } => Some(StringOperator.split)
case '{ StringOperator.startsWith } => Some(StringOperator.startsWith)
case '{ StringOperator.concat } => Some(StringOperator.concat)
diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala
new file mode 100644
index 0000000..af33d80
--- /dev/null
+++ b/src/main/scala/minisql/context/Context.scala
@@ -0,0 +1,90 @@
+package minisql.context
+
+import scala.deriving.*
+import scala.compiletime.*
+import scala.util.Try
+import minisql.util.*
+import minisql.idiom.{Idiom, Statement, ReifyStatement}
+import minisql.{NamingStrategy, ParamEncoder}
+import minisql.ColumnDecoder
+import minisql.ast.{Ast, ScalarValueLift, CollectAst}
+
+trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
+
+ val idiom: I
+ val naming: NamingStrategy
+
+ type DBStatement
+ type DBRow
+ type DBResultSet
+
+ trait RowExtract[A] {
+ def extract(row: DBRow): Try[A]
+ }
+
+ object RowExtract {
+
+ private class ExtractorImpl[A](
+ decoders: IArray[Any],
+ m: Mirror.ProductOf[A]
+ ) extends RowExtract[A] {
+ def extract(row: DBRow): Try[A] = {
+ val decodedFields = decoders.zipWithIndex.traverse {
+ case (d, i) =>
+ d.asInstanceOf[Decoder[?]].decode(row, i)
+ }
+ decodedFields.map { vs =>
+ m.fromProduct(Tuple.fromIArray(vs))
+ }
+ }
+ }
+
+ inline given [P <: Product](using m: Mirror.ProductOf[P]): RowExtract[P] = {
+ val decoders = summonAll[Tuple.Map[m.MirroredElemTypes, Decoder]]
+ ExtractorImpl(decoders.toIArray.asInstanceOf, m)
+ }
+ }
+
+ type Encoder[X] = ParamEncoder[X] {
+ type Stmt = DBStatement
+ }
+
+ type Decoder[X] = ColumnDecoder.Aux[DBRow, X]
+
+ type DBIO[E] = (
+ sql: String,
+ params: List[(Any, Encoder[?])],
+ mapper: Iterable[DBRow] => Try[E]
+ )
+
+ extension (ast: Ast) {
+ private def liftMap = {
+ val lifts = CollectAst.byType[ScalarValueLift](ast)
+ lifts.map(l => l.liftId -> l.value.get).toMap
+ }
+ }
+
+ extension (stmt: Statement) {
+ def expand(liftMap: Map[String, (Any, ParamEncoder[?])]) =
+ ReifyStatement(
+ idiom.liftingPlaceholder,
+ idiom.emptySetContainsToken,
+ stmt,
+ liftMap
+ )
+ }
+
+ inline def io[E](
+ inline q: minisql.Query[E]
+ )(using r: RowExtract[E]): DBIO[IArray[E]] = {
+ val lifts = q.liftMap
+ val stmt = minisql.compile(q, idiom, naming)
+ val (sql, params) = stmt.expand(lifts)
+ (
+ sql = sql,
+ params = params.map(_.value.get.asInstanceOf),
+ mapper = (rows) => rows.traverse(r.extract)
+ )
+ }
+
+}
diff --git a/src/main/scala/minisql/context/MirrorContext.scala b/src/main/scala/minisql/context/MirrorContext.scala
new file mode 100644
index 0000000..c1be053
--- /dev/null
+++ b/src/main/scala/minisql/context/MirrorContext.scala
@@ -0,0 +1,15 @@
+package minisql
+
+import minisql.context.mirror.*
+
+class MirrorContext[Idiom <: idiom.Idiom, Naming <: NamingStrategy](
+ val idiom: Idiom,
+ val naming: Naming
+) extends context.Context[Idiom, Naming] {
+
+ type DBRow = Row
+
+ type DBResultSet = Iterable[DBRow]
+
+ type DBStatement = IArray[Any]
+}
diff --git a/src/main/scala/minisql/context/ReturnFieldCapability.scala b/src/main/scala/minisql/context/ReturnFieldCapability.scala
new file mode 100644
index 0000000..cadbb78
--- /dev/null
+++ b/src/main/scala/minisql/context/ReturnFieldCapability.scala
@@ -0,0 +1,64 @@
+package minisql.context
+
+sealed trait ReturningCapability
+
+/**
+ * Data cannot be returned Insert/Update/etc... clauses in the target database.
+ */
+sealed trait ReturningNotSupported extends ReturningCapability
+
+/**
+ * Returning a single field from Insert/Update/etc... clauses is supported. This
+ * is the most common databases e.g. MySQL, Sqlite, and H2 (although as of
+ * h2database/h2database#1972 this may change. See #1496 regarding this.
+ * Typically this needs to be setup in the JDBC
+ * `connection.prepareStatement(sql, Array("returnColumn"))`.
+ */
+sealed trait ReturningSingleFieldSupported extends ReturningCapability
+
+/**
+ * Returning multiple columns from Insert/Update/etc... clauses is supported.
+ * This generally means that columns besides auto-incrementing ones can be
+ * returned. This is supported by Oracle. In JDBC, the following is done:
+ * `connection.prepareStatement(sql, Array("column1, column2, ..."))`.
+ */
+sealed trait ReturningMultipleFieldSupported extends ReturningCapability
+
+/**
+ * An actual `RETURNING` clause is supported in the SQL dialect of the specified
+ * database e.g. Postgres. this typically means that columns returned from
+ * Insert/Update/etc... clauses can have other database operations done on them
+ * such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. In
+ * JDBC, the following is done: `connection.prepareStatement(sql,
+ * Statement.RETURN_GENERATED_KEYS))`.
+ */
+sealed trait ReturningClauseSupported extends ReturningCapability
+
+object ReturningNotSupported extends ReturningNotSupported
+object ReturningSingleFieldSupported extends ReturningSingleFieldSupported
+object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported
+object ReturningClauseSupported extends ReturningClauseSupported
+
+trait Capabilities {
+ def idiomReturningCapability: ReturningCapability
+}
+
+trait CanReturnClause extends Capabilities {
+ override def idiomReturningCapability: ReturningClauseSupported =
+ ReturningClauseSupported
+}
+
+trait CanReturnField extends Capabilities {
+ override def idiomReturningCapability: ReturningSingleFieldSupported =
+ ReturningSingleFieldSupported
+}
+
+trait CanReturnMultiField extends Capabilities {
+ override def idiomReturningCapability: ReturningMultipleFieldSupported =
+ ReturningMultipleFieldSupported
+}
+
+trait CannotReturn extends Capabilities {
+ override def idiomReturningCapability: ReturningNotSupported =
+ ReturningNotSupported
+}
diff --git a/src/main/scala/minisql/context/mirror.scala b/src/main/scala/minisql/context/mirror.scala
new file mode 100644
index 0000000..cd3b725
--- /dev/null
+++ b/src/main/scala/minisql/context/mirror.scala
@@ -0,0 +1,35 @@
+package minisql.context.mirror
+
+import minisql.{MirrorContext, NamingStrategy}
+import minisql.idiom.Idiom
+import minisql.util.Messages.fail
+import scala.reflect.ClassTag
+
+/**
+* No extra class defined
+*/
+opaque type Row = IArray[Any] *: EmptyTuple
+
+extension (r: Row) {
+
+ def data: IArray[Any] = r._1
+
+ def add(value: Any): Row = (r.data :+ value) *: EmptyTuple
+
+ def apply[T](idx: Int)(using t: ClassTag[T]): T = {
+ r.data(idx) match {
+ case v: T => v
+ case other =>
+ fail(
+ s"Invalid column type. Expected '${t.runtimeClass}', but got '$other'"
+ )
+ }
+ }
+}
+
+trait MirrorCodecs[I <: Idiom, N <: NamingStrategy] {
+ this: MirrorContext[I, N] =>
+
+ given byteEncoder: Encoder[Byte]
+
+}
diff --git a/src/main/scala/minisql/dsl.scala b/src/main/scala/minisql/dsl.scala
deleted file mode 100644
index 43cc9c0..0000000
--- a/src/main/scala/minisql/dsl.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-package minisql.dsl
-
-import minisql.*
-import minisql.parsing.*
-import minisql.ast.{Ast, Entity, Map, Property, Ident, given}
-import scala.quoted.*
-import scala.compiletime.*
-import scala.compiletime.ops.string.*
-import scala.collection.immutable.{Map => IMap}
-
-opaque type Quoted <: Ast = Ast
-
-opaque type Query[E] <: Quoted = Quoted
-
-opaque type EntityQuery[E] <: Query[E] = Query[E]
-
-extension [E](inline e: EntityQuery[E]) {
- inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
- transform(e)(f)(Map.apply)
- }
-}
-
-private inline def transform[A, B](inline q1: Quoted)(
- inline f: A => B
-)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
- fast(q1, f.param0, f.body)
-}
-
-inline def query[E](inline table: String): EntityQuery[E] =
- Entity(table, Nil)
-
-inline def compile(inline x: Ast): Option[String] = ${
- compileImpl('{ x })
-}
-
-private def compileImpl(
- x: Expr[Ast]
-)(using Quotes): Expr[Option[String]] = {
- import quotes.reflect.*
- x.value match {
- case Some(xv) => '{ Some(${ Expr(xv.toString()) }) }
- case None => '{ None }
- }
-}
-
-extension [A, B](inline f1: A => B) {
- private inline def param0 = parsing.parseParamAt(f1, 0)
- private inline def body = parsing.parseBody(f1)
-}
-
-extension [A1, A2, B](inline f1: (A1, A2) => B) {
- private inline def param0 = parsing.parseParamAt(f1, 0)
- private inline def param1 = parsing.parseParamAt(f1, 1)
- private inline def body = parsing.parseBody(f1)
-}
-
-case class Foo(id: Int)
-
-inline def queryFooId = query[Foo]("foo").map(_.id)
diff --git a/src/main/scala/minisql/idiom/Idiom.scala b/src/main/scala/minisql/idiom/Idiom.scala
new file mode 100644
index 0000000..43b6110
--- /dev/null
+++ b/src/main/scala/minisql/idiom/Idiom.scala
@@ -0,0 +1,23 @@
+package minisql.idiom
+
+import minisql.NamingStrategy
+import minisql.ast._
+import minisql.context.Capabilities
+
+trait Idiom extends Capabilities {
+
+ def emptySetContainsToken(field: Token): Token = StringToken("FALSE")
+
+ def defaultAutoGeneratedToken(field: Token): Token = StringToken(
+ "DEFAULT VALUES"
+ )
+
+ def liftingPlaceholder(index: Int): String
+
+ def translate(ast: Ast)(using naming: NamingStrategy): (Ast, Statement)
+
+ def format(queryString: String): String = queryString
+
+ def prepareForProbing(string: String): String
+
+}
diff --git a/src/main/scala/minisql/idiom/LoadNaming.scala b/src/main/scala/minisql/idiom/LoadNaming.scala
new file mode 100644
index 0000000..2405080
--- /dev/null
+++ b/src/main/scala/minisql/idiom/LoadNaming.scala
@@ -0,0 +1,28 @@
+package minisql.idiom
+
+import scala.util.Try
+import scala.quoted._
+import minisql.NamingStrategy
+import minisql.util.CollectTry
+import minisql.util.LoadObject
+import minisql.CompositeNamingStrategy
+
+object LoadNaming {
+
+ def static[C](using Quotes, Type[C]): Try[NamingStrategy] = CollectTry {
+ strategies[C].map(LoadObject[NamingStrategy](_))
+ }.map(NamingStrategy(_))
+
+ private def strategies[C](using Quotes, Type[C]) = {
+ import quotes.reflect.*
+ val isComposite = TypeRepr.of[C] <:< TypeRepr.of[CompositeNamingStrategy]
+ val ct = TypeRepr.of[C]
+ if (isComposite) {
+ ct.typeArgs.filterNot { t =>
+ t =:= TypeRepr.of[NamingStrategy] && t =:= TypeRepr.of[Nothing]
+ }
+ } else {
+ List(ct)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala
new file mode 100644
index 0000000..88aab8c
--- /dev/null
+++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala
@@ -0,0 +1,355 @@
+package minisql
+
+import minisql.ast.Renameable.{ByStrategy, Fixed}
+import minisql.ast.Visibility.Hidden
+import minisql.ast._
+import minisql.context.CanReturnClause
+import minisql.idiom.{Idiom, SetContainsToken, Statement}
+import minisql.idiom.StatementInterpolator.*
+import minisql.norm.Normalize
+import minisql.util.Interleave
+
+object MirrorIdiom extends MirrorIdiom
+class MirrorIdiom extends MirrorIdiomBase with CanReturnClause
+
+object MirrorIdiomPrinting extends MirrorIdiom {
+ override def distinguishHidden: Boolean = true
+}
+
+trait MirrorIdiomBase extends Idiom {
+
+ def distinguishHidden: Boolean = false
+
+ override def prepareForProbing(string: String) = string
+
+ override def liftingPlaceholder(index: Int): String = "?"
+
+ override def translate(
+ ast: Ast
+ )(implicit naming: NamingStrategy): (Ast, Statement) = {
+ val normalizedAst = Normalize(ast)
+ (normalizedAst, stmt"${normalizedAst.token}")
+ }
+
+ implicit def astTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Ast] = Tokenizer[Ast] {
+ case ast: Query => ast.token
+ case ast: Function => ast.token
+ case ast: Value => ast.token
+ case ast: Operation => ast.token
+ case ast: Action => ast.token
+ case ast: Ident => ast.token
+ case ast: ExternalIdent => ast.token
+ case ast: Property => ast.token
+ case ast: Infix => ast.token
+ case ast: OptionOperation => ast.token
+ case ast: IterableOperation => ast.token
+ case ast: Dynamic => ast.token
+ case ast: If => ast.token
+ case ast: Block => ast.token
+ case ast: Val => ast.token
+ case ast: Ordering => ast.token
+ case ast: Lift => ast.token
+ case ast: Assignment => ast.token
+ case ast: OnConflict.Excluded => ast.token
+ case ast: OnConflict.Existing => ast.token
+ }
+
+ implicit def ifTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[If] = Tokenizer[If] {
+ case If(a, b, c) => stmt"if(${a.token}) ${b.token} else ${c.token}"
+ }
+
+ implicit val dynamicTokenizer: Tokenizer[Dynamic] = Tokenizer[Dynamic] {
+ case Dynamic(tree) => stmt"${tree.toString.token}"
+ }
+
+ implicit def blockTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Block] = Tokenizer[Block] {
+ case Block(statements) => stmt"{ ${statements.map(_.token).mkStmt("; ")} }"
+ }
+
+ implicit def valTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Val] = Tokenizer[Val] {
+ case Val(name, body) => stmt"val ${name.token} = ${body.token}"
+ }
+
+ implicit def queryTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Query] = Tokenizer[Query] {
+
+ case Entity.Opinionated(name, Nil, renameable) =>
+ stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token})"
+
+ case Entity.Opinionated(name, prop, renameable) =>
+ val properties =
+ prop.map(p => stmt"""_.${p.path.mkStmt(".")} -> "${p.alias.token}"""")
+ stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token}, ${properties.token})"
+
+ case Filter(source, alias, body) =>
+ stmt"${source.token}.filter(${alias.token} => ${body.token})"
+
+ case Map(source, alias, body) =>
+ stmt"${source.token}.map(${alias.token} => ${body.token})"
+
+ case FlatMap(source, alias, body) =>
+ stmt"${source.token}.flatMap(${alias.token} => ${body.token})"
+
+ case ConcatMap(source, alias, body) =>
+ stmt"${source.token}.concatMap(${alias.token} => ${body.token})"
+
+ case SortBy(source, alias, body, ordering) =>
+ stmt"${source.token}.sortBy(${alias.token} => ${body.token})(${ordering.token})"
+
+ case GroupBy(source, alias, body) =>
+ stmt"${source.token}.groupBy(${alias.token} => ${body.token})"
+
+ case Aggregation(op, ast) =>
+ stmt"${scopedTokenizer(ast)}.${op.token}"
+
+ case Take(source, n) =>
+ stmt"${source.token}.take(${n.token})"
+
+ case Drop(source, n) =>
+ stmt"${source.token}.drop(${n.token})"
+
+ case Union(a, b) =>
+ stmt"${a.token}.union(${b.token})"
+
+ case UnionAll(a, b) =>
+ stmt"${a.token}.unionAll(${b.token})"
+
+ case Join(t, a, b, iA, iB, on) =>
+ stmt"${a.token}.${t.token}(${b.token}).on((${iA.token}, ${iB.token}) => ${on.token})"
+
+ case FlatJoin(t, a, iA, on) =>
+ stmt"${a.token}.${t.token}((${iA.token}) => ${on.token})"
+
+ case Distinct(a) =>
+ stmt"${a.token}.distinct"
+
+ case DistinctOn(source, alias, body) =>
+ stmt"${source.token}.distinctOn(${alias.token} => ${body.token})"
+
+ case Nested(a) =>
+ stmt"${a.token}.nested"
+ }
+
+ implicit val orderingTokenizer: Tokenizer[Ordering] = Tokenizer[Ordering] {
+ case TupleOrdering(elems) => stmt"Ord(${elems.token})"
+ case Asc => stmt"Ord.asc"
+ case Desc => stmt"Ord.desc"
+ case AscNullsFirst => stmt"Ord.ascNullsFirst"
+ case DescNullsFirst => stmt"Ord.descNullsFirst"
+ case AscNullsLast => stmt"Ord.ascNullsLast"
+ case DescNullsLast => stmt"Ord.descNullsLast"
+ }
+
+ implicit def optionOperationTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] {
+ case OptionTableFlatMap(ast, alias, body) =>
+ stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
+ case OptionTableMap(ast, alias, body) =>
+ stmt"${ast.token}.map((${alias.token}) => ${body.token})"
+ case OptionTableExists(ast, alias, body) =>
+ stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
+ case OptionTableForall(ast, alias, body) =>
+ stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
+ case OptionFlatten(ast) => stmt"${ast.token}.flatten"
+ case OptionGetOrElse(ast, body) =>
+ stmt"${ast.token}.getOrElse(${body.token})"
+ case OptionFlatMap(ast, alias, body) =>
+ stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
+ case OptionMap(ast, alias, body) =>
+ stmt"${ast.token}.map((${alias.token}) => ${body.token})"
+ case OptionForall(ast, alias, body) =>
+ stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
+ case OptionExists(ast, alias, body) =>
+ stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
+ case OptionContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
+ case OptionIsEmpty(ast) => stmt"${ast.token}.isEmpty"
+ case OptionNonEmpty(ast) => stmt"${ast.token}.nonEmpty"
+ case OptionIsDefined(ast) => stmt"${ast.token}.isDefined"
+ case OptionSome(ast) => stmt"Some(${ast.token})"
+ case OptionApply(ast) => stmt"Option(${ast.token})"
+ case OptionOrNull(ast) => stmt"${ast.token}.orNull"
+ case OptionGetOrNull(ast) => stmt"${ast.token}.getOrNull"
+ case OptionNone => stmt"None"
+ }
+
+ implicit def traversableOperationTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[IterableOperation] = Tokenizer[IterableOperation] {
+ case MapContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
+ case SetContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
+ case ListContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
+ }
+
+ implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] {
+ case InnerJoin => stmt"join"
+ case LeftJoin => stmt"leftJoin"
+ case RightJoin => stmt"rightJoin"
+ case FullJoin => stmt"fullJoin"
+ }
+
+ implicit def functionTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Function] = Tokenizer[Function] {
+ case Function(params, body) => stmt"(${params.token}) => ${body.token}"
+ }
+
+ implicit def operationTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Operation] = Tokenizer[Operation] {
+ case UnaryOperation(op: PrefixUnaryOperator, ast) =>
+ stmt"${op.token}${scopedTokenizer(ast)}"
+ case UnaryOperation(op: PostfixUnaryOperator, ast) =>
+ stmt"${scopedTokenizer(ast)}.${op.token}"
+ case BinaryOperation(a, op @ SetOperator.`contains`, b) =>
+ SetContainsToken(scopedTokenizer(b), op.token, a.token)
+ case BinaryOperation(a, op, b) =>
+ stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
+ case FunctionApply(function, values) =>
+ stmt"${scopedTokenizer(function)}.apply(${values.token})"
+ }
+
+ implicit def operatorTokenizer[T <: Operator]: Tokenizer[T] = Tokenizer[T] {
+ case o => stmt"${o.toString.token}"
+ }
+
+ def tokenizeName(name: String, renameable: Renameable) =
+ renameable match {
+ case ByStrategy => name
+ case Fixed => s"`${name}`"
+ }
+
+ def bracketIfHidden(name: String, visibility: Visibility) =
+ (distinguishHidden, visibility) match {
+ case (true, Hidden) => s"[$name]"
+ case _ => name
+ }
+
+ implicit def propertyTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Property] = Tokenizer[Property] {
+ case Property.Opinionated(ExternalIdent(_), name, renameable, visibility) =>
+ stmt"${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
+ case Property.Opinionated(ref, name, renameable, visibility) =>
+ stmt"${scopedTokenizer(ref)}.${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
+ }
+
+ implicit val valueTokenizer: Tokenizer[Value] = Tokenizer[Value] {
+ case Constant(v: String) => stmt""""${v.token}""""
+ case Constant(()) => stmt"{}"
+ case Constant(v) => stmt"${v.toString.token}"
+ case NullValue => stmt"null"
+ case Tuple(values) => stmt"(${values.token})"
+ case CaseClass(values) =>
+ stmt"CaseClass(${values.map { case (k, v) => s"${k.token}: ${v.token}" }.mkString(", ").token})"
+ }
+
+ implicit val identTokenizer: Tokenizer[Ident] = Tokenizer[Ident] {
+ case Ident.Opinionated(name, visibility) =>
+ stmt"${bracketIfHidden(name, visibility).token}"
+ }
+
+ implicit val typeTokenizer: Tokenizer[ExternalIdent] =
+ Tokenizer[ExternalIdent] {
+ case e => stmt"${e.name.token}"
+ }
+
+ implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] =
+ Tokenizer[OnConflict.Excluded] {
+ case OnConflict.Excluded(ident) => stmt"${ident.token}"
+ }
+
+ implicit val existingTokenizer: Tokenizer[OnConflict.Existing] =
+ Tokenizer[OnConflict.Existing] {
+ case OnConflict.Existing(ident) => stmt"${ident.token}"
+ }
+
+ implicit def actionTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Action] = Tokenizer[Action] {
+ case Update(query, assignments) =>
+ stmt"${query.token}.update(${assignments.token})"
+ case Insert(query, assignments) =>
+ stmt"${query.token}.insert(${assignments.token})"
+ case Delete(query) => stmt"${query.token}.delete"
+ case Returning(query, alias, body) =>
+ stmt"${query.token}.returning((${alias.token}) => ${body.token})"
+ case ReturningGenerated(query, alias, body) =>
+ stmt"${query.token}.returningGenerated((${alias.token}) => ${body.token})"
+ case Foreach(query, alias, body) =>
+ stmt"${query.token}.foreach((${alias.token}) => ${body.token})"
+ case c: OnConflict => stmt"${c.token}"
+ }
+
+ implicit def conflictTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[OnConflict] = {
+
+ def targetProps(l: List[Property]) = l.map(p =>
+ Transform(p) {
+ case Ident(_) => Ident("_")
+ }
+ )
+
+ implicit val conflictTargetTokenizer: Tokenizer[OnConflict.Target] =
+ Tokenizer[OnConflict.Target] {
+ case OnConflict.NoTarget => stmt""
+ case OnConflict.Properties(props) =>
+ val listTokens = listTokenizer(astTokenizer).token(props)
+ stmt"(${listTokens})"
+ }
+
+ val updateAssignsTokenizer = Tokenizer[Assignment] {
+ case Assignment(i, p, v) =>
+ stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}"
+ }
+
+ Tokenizer[OnConflict] {
+ case OnConflict(i, t, OnConflict.Update(assign)) =>
+ stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})"
+ case OnConflict(i, t, OnConflict.Ignore) =>
+ stmt"${i.token}.onConflictIgnore${t.token}"
+ }
+ }
+
+ implicit def assignmentTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Assignment] = Tokenizer[Assignment] {
+ case Assignment(ident, property, value) =>
+ stmt"${ident.token} => ${property.token} -> ${value.token}"
+ }
+
+ implicit def infixTokenizer(implicit
+ liftTokenizer: Tokenizer[Lift]
+ ): Tokenizer[Infix] = Tokenizer[Infix] {
+ case Infix(parts, params, _, _) =>
+ def tokenParam(ast: Ast) =
+ ast match {
+ case ast: Ident => stmt"$$${ast.token}"
+ case other => stmt"$${${ast.token}}"
+ }
+
+ val pt = parts.map(_.token)
+ val pr = params.map(tokenParam)
+ val body = Statement(Interleave(pt, pr))
+ stmt"""infix"${body.token}""""
+ }
+
+ private def scopedTokenizer(
+ ast: Ast
+ )(implicit liftTokenizer: Tokenizer[Lift]) =
+ ast match {
+ case _: Function => stmt"(${ast.token})"
+ case _: BinaryOperation => stmt"(${ast.token})"
+ case other => ast.token
+ }
+}
diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala
new file mode 100644
index 0000000..7a4a07a
--- /dev/null
+++ b/src/main/scala/minisql/idiom/ReifyStatement.scala
@@ -0,0 +1,100 @@
+package minisql.idiom
+
+import minisql.ParamEncoder
+import minisql.ast.*
+import minisql.util.Interleave
+import minisql.idiom.StatementInterpolator.*
+import scala.annotation.tailrec
+import scala.collection.immutable.{Map => SMap}
+
+object ReifyStatement {
+
+ def apply(
+ liftingPlaceholder: Int => String,
+ emptySetContainsToken: Token => Token,
+ statement: Statement,
+ liftMap: SMap[String, (Any, ParamEncoder[?])]
+ ): (String, List[ScalarValueLift]) = {
+ val expanded = expandLiftings(statement, emptySetContainsToken, liftMap)
+ token2string(expanded, liftingPlaceholder)
+ }
+
+ private def token2string(
+ token: Token,
+ liftingPlaceholder: Int => String
+ ): (String, List[ScalarValueLift]) = {
+
+ val liftBuilder = List.newBuilder[ScalarValueLift]
+ val sqlBuilder = StringBuilder()
+ @tailrec
+ def loop(
+ workList: Seq[Token],
+ liftingSize: Int
+ ): Unit = workList match {
+ case Seq() => ()
+ case head +: tail =>
+ head match {
+ case StringToken(s2) =>
+ sqlBuilder ++= s2
+ loop(tail, liftingSize)
+ case SetContainsToken(a, op, b) =>
+ loop(
+ stmt"$a $op ($b)" +: tail,
+ liftingSize
+ )
+ case ScalarLiftToken(lift: ScalarValueLift) =>
+ sqlBuilder ++= liftingPlaceholder(liftingSize)
+ liftBuilder += lift
+ loop(tail, liftingSize + 1)
+ case ScalarLiftToken(o) =>
+ throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}")
+ case Statement(tokens) =>
+ loop(
+ tokens.foldRight(tail)(_ +: _),
+ liftingSize
+ )
+ }
+ }
+ loop(Vector(token), 0)
+ sqlBuilder.toString() -> liftBuilder.result()
+ }
+
+ private def expandLiftings(
+ statement: Statement,
+ emptySetContainsToken: Token => Token,
+ liftMap: SMap[String, (Any, ParamEncoder[?])]
+ ): (Token) = {
+ Statement {
+ val lb = List.newBuilder[Token]
+ statement.tokens.foldLeft(lb) {
+ case (
+ tokens,
+ SetContainsToken(a, op, ScalarLiftToken(lift: ScalarQueryLift))
+ ) =>
+ val (lv, le) = liftMap(lift.liftId)
+ lv.asInstanceOf[Iterable[Any]].toVector match {
+ case Vector() => tokens += emptySetContainsToken(a)
+ case values =>
+ val liftings = values.zipWithIndex.map {
+ case (v, i) =>
+ ScalarLiftToken(
+ ScalarValueLift(
+ s"${lift.name}[${i}]",
+ s"${lift.liftId}[${i}]",
+ Some(v -> le)
+ )
+ )
+ }
+ val separators = Vector.fill(liftings.size - 1)(StringToken(", "))
+ (tokens += stmt"$a $op (") ++= Interleave(
+ liftings,
+ separators
+ ) += StringToken(")")
+ }
+ case (tokens, token) =>
+ tokens += token
+ }
+ lb.result()
+ }
+ }
+}
diff --git a/src/main/scala/minisql/idiom/Statement.scala b/src/main/scala/minisql/idiom/Statement.scala
new file mode 100644
index 0000000..987ee2f
--- /dev/null
+++ b/src/main/scala/minisql/idiom/Statement.scala
@@ -0,0 +1,47 @@
+package minisql.idiom
+
+import scala.quoted._
+import minisql.ast._
+
+sealed trait Token
+
+case class StringToken(string: String) extends Token {
+ override def toString = string
+}
+
+case class ScalarLiftToken(lift: ScalarLift) extends Token {
+ override def toString = s"lift(${lift.name})"
+}
+
+case class Statement(tokens: List[Token]) extends Token {
+ override def toString = tokens.mkString
+}
+
+case class SetContainsToken(a: Token, op: Token, b: Token) extends Token {
+ override def toString = s"${a.toString} ${op.toString} (${b.toString})"
+}
+
+object Statement {
+ given ToExpr[Statement] with {
+ def apply(t: Statement)(using Quotes): Expr[Statement] = {
+ '{ Statement(${ Expr(t.tokens) }) }
+ }
+ }
+}
+
+object Token {
+
+ given ToExpr[Token] with {
+ def apply(t: Token)(using Quotes): Expr[Token] = {
+ t match {
+ case StringToken(s) =>
+ '{ StringToken(${ Expr(s) }) }
+ case ScalarLiftToken(l) =>
+ '{ ScalarLiftToken(${ Expr(l) }) }
+ case SetContainsToken(a, op, b) =>
+ '{ SetContainsToken(${ Expr(a) }, ${ Expr(op) }, ${ Expr(b) }) }
+ case s: Statement => Expr(s)
+ }
+ }
+ }
+}
diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala
new file mode 100644
index 0000000..3aa4d26
--- /dev/null
+++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala
@@ -0,0 +1,152 @@
+package minisql.idiom
+
+import minisql.ast._
+import minisql.util.Interleave
+import minisql.util.Messages._
+
+import scala.collection.mutable.ListBuffer
+
+object StatementInterpolator {
+
+ trait Tokenizer[T] {
+ extension (v: T) {
+ def token: Token
+ }
+ }
+
+ object Tokenizer {
+ def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] {
+ extension (v: T) {
+ def token: Token = f(v)
+ }
+ }
+ def withFallback[T](
+ fallback: Tokenizer[T] => Tokenizer[T]
+ )(pf: PartialFunction[T, Token]) =
+ new Tokenizer[T] {
+ extension (v: T) {
+ private def stable = fallback(this)
+ override def token = pf.applyOrElse(v, stable.token)
+ }
+ }
+ }
+
+ implicit class TokenImplicit[T](v: T)(implicit tokenizer: Tokenizer[T]) {
+ def token = tokenizer.token(v)
+ }
+
+ implicit def stringTokenizer: Tokenizer[String] =
+ Tokenizer[String] {
+ case string => StringToken(string)
+ }
+
+ implicit def liftTokenizer: Tokenizer[Lift] =
+ Tokenizer[Lift] {
+ case lift: ScalarLift => ScalarLiftToken(lift)
+ case lift =>
+ fail(
+ s"Can't tokenize a non-scalar lifting. ${lift.name}\n" +
+ s"\n" +
+ s"This might happen because:\n" +
+ s"* You are trying to insert or update an `Option[A]` field, but Scala infers the type\n" +
+ s" to `Some[A]` or `None.type`. For example:\n" +
+ s" run(query[Users].update(_.optionalField -> lift(Some(value))))" +
+ s" In that case, make sure the type is `Option`:\n" +
+ s" run(query[Users].update(_.optionalField -> lift(Some(value): Option[Int])))\n" +
+ s" or\n" +
+ s" run(query[Users].update(_.optionalField -> lift(Option(value))))\n" +
+ s"\n" +
+ s"* You are trying to insert or update whole Embedded case class. For example:\n" +
+ s" run(query[Users].update(_.embeddedCaseClass -> lift(someInstance)))\n" +
+ s" In that case, make sure you are updating individual columns, for example:\n" +
+ s" run(query[Users].update(\n" +
+ s" _.embeddedCaseClass.a -> lift(someInstance.a),\n" +
+ s" _.embeddedCaseClass.b -> lift(someInstance.b)\n" +
+ s" ))"
+ )
+ }
+
+ implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity)
+ implicit def statementTokenizer: Tokenizer[Statement] =
+ Tokenizer[Statement](identity)
+ implicit def stringTokenTokenizer: Tokenizer[StringToken] =
+ Tokenizer[StringToken](identity)
+ implicit def liftingTokenTokenizer: Tokenizer[ScalarLiftToken] =
+ Tokenizer[ScalarLiftToken](identity)
+
+ extension [T](list: List[T]) {
+ def mkStmt(sep: String = ", ")(implicit tokenize: Tokenizer[T]) = {
+ val l1 = list.map(_.token)
+ val l2 = List.fill(l1.size - 1)(StringToken(sep))
+ Statement(Interleave(l1, l2))
+ }
+ }
+
+ implicit def listTokenizer[T](implicit
+ tokenize: Tokenizer[T]
+ ): Tokenizer[List[T]] =
+ Tokenizer[List[T]] {
+ case list => list.mkStmt()
+ }
+
+ extension (sc: StringContext) {
+
+ def flatten(tokens: List[Token]): List[Token] = {
+
+ def unestStatements(tokens: List[Token]): List[Token] = {
+ tokens.flatMap {
+ case Statement(innerTokens) => unestStatements(innerTokens)
+ case token => token :: Nil
+ }
+ }
+
+ def mergeStringTokens(tokens: List[Token]): List[Token] = {
+ val (resultBuilder, leftTokens) =
+ tokens.foldLeft((new ListBuffer[Token], new ListBuffer[String])) {
+ case ((builder, acc), stringToken: StringToken) =>
+ val str = stringToken.string
+ if (str.nonEmpty)
+ acc += stringToken.string
+ (builder, acc)
+ case ((builder, prev), b) if prev.isEmpty =>
+ (builder += b.token, prev)
+ case ((builder, prev), b) /* if prev.nonEmpty */ =>
+ builder += StringToken(prev.result().mkString)
+ builder += b.token
+ (builder, new ListBuffer[String])
+ }
+ if (leftTokens.nonEmpty)
+ resultBuilder += StringToken(leftTokens.result().mkString)
+ resultBuilder.result()
+ }
+
+ (unestStatements)
+ .andThen(mergeStringTokens)
+ .apply(tokens)
+ }
+
+ def checkLengths(
+ args: scala.collection.Seq[Any],
+ parts: Seq[String]
+ ): Unit =
+ if (parts.length != args.length + 1)
+ throw new IllegalArgumentException(
+ "wrong number of arguments (" + args.length
+ + ") for interpolated string with " + parts.length + " parts"
+ )
+
+ def stmt(args: Token*): Statement = {
+ checkLengths(args, sc.parts)
+ val partsIterator = sc.parts.iterator
+ val argsIterator = args.iterator
+ val bldr = List.newBuilder[Token]
+ bldr += StringToken(partsIterator.next())
+ while (argsIterator.hasNext) {
+ bldr += argsIterator.next()
+ bldr += StringToken(partsIterator.next())
+ }
+ val tokens = flatten(bldr.result())
+ Statement(tokens)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/norm/AdHocReduction.scala b/src/main/scala/minisql/norm/AdHocReduction.scala
new file mode 100644
index 0000000..10ad9a8
--- /dev/null
+++ b/src/main/scala/minisql/norm/AdHocReduction.scala
@@ -0,0 +1,52 @@
+package minisql.norm
+
+import minisql.ast.BinaryOperation
+import minisql.ast.BooleanOperator
+import minisql.ast.Filter
+import minisql.ast.FlatMap
+import minisql.ast.Map
+import minisql.ast.Query
+import minisql.ast.Union
+import minisql.ast.UnionAll
+
+object AdHocReduction {
+
+ def unapply(q: Query) =
+ q match {
+
+ // ---------------------------
+ // *.filter
+
+ // a.filter(b => c).filter(d => e) =>
+ // a.filter(b => c && e[d := b])
+ case Filter(Filter(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> b)
+ Some(Filter(a, b, BinaryOperation(c, BooleanOperator.`&&`, er)))
+
+ // ---------------------------
+ // flatMap.*
+
+ // a.flatMap(b => c).map(d => e) =>
+ // a.flatMap(b => c.map(d => e))
+ case Map(FlatMap(a, b, c), d, e) =>
+ Some(FlatMap(a, b, Map(c, d, e)))
+
+ // a.flatMap(b => c).filter(d => e) =>
+ // a.flatMap(b => c.filter(d => e))
+ case Filter(FlatMap(a, b, c), d, e) =>
+ Some(FlatMap(a, b, Filter(c, d, e)))
+
+ // a.flatMap(b => c.union(d))
+ // a.flatMap(b => c).union(a.flatMap(b => d))
+ case FlatMap(a, b, Union(c, d)) =>
+ Some(Union(FlatMap(a, b, c), FlatMap(a, b, d)))
+
+ // a.flatMap(b => c.unionAll(d))
+ // a.flatMap(b => c).unionAll(a.flatMap(b => d))
+ case FlatMap(a, b, UnionAll(c, d)) =>
+ Some(UnionAll(FlatMap(a, b, c), FlatMap(a, b, d)))
+
+ case other => None
+ }
+
+}
diff --git a/src/main/scala/minisql/norm/ApplyMap.scala b/src/main/scala/minisql/norm/ApplyMap.scala
new file mode 100644
index 0000000..e5ddb0c
--- /dev/null
+++ b/src/main/scala/minisql/norm/ApplyMap.scala
@@ -0,0 +1,160 @@
+package minisql.norm
+
+import minisql.ast._
+
+object ApplyMap {
+
+ private def isomorphic(e: Ast, c: Ast, alias: Ident) =
+ BetaReduction(e, alias -> c) == c
+
+ object InfixedTailOperation {
+
+ def hasImpureInfix(ast: Ast) =
+ CollectAst(ast) {
+ case i @ Infix(_, _, false, _) => i
+ }.nonEmpty
+
+ def unapply(ast: Ast): Option[Ast] =
+ ast match {
+ case cc: CaseClass if hasImpureInfix(cc) => Some(cc)
+ case tup: Tuple if hasImpureInfix(tup) => Some(tup)
+ case p: Property if hasImpureInfix(p) => Some(p)
+ case b: BinaryOperation if hasImpureInfix(b) => Some(b)
+ case u: UnaryOperation if hasImpureInfix(u) => Some(u)
+ case i @ Infix(_, _, false, _) => Some(i)
+ case _ => None
+ }
+ }
+
+ object MapWithoutInfixes {
+ def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
+ ast match {
+ case Map(a, b, InfixedTailOperation(c)) => None
+ case Map(a, b, c) => Some((a, b, c))
+ case _ => None
+ }
+ }
+
+ object DetachableMap {
+ def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
+ ast match {
+ case Map(a: GroupBy, b, c) => None
+ case Map(a, b, InfixedTailOperation(c)) => None
+ case Map(a, b, c) => Some((a, b, c))
+ case _ => None
+ }
+ }
+
+ def unapply(q: Query): Option[Query] =
+ q match {
+
+ case Map(a: GroupBy, b, c) if (b == c) => None
+ case Map(a: DistinctOn, b, c) => None
+ case Map(a: Nested, b, c) if (b == c) => None
+ case Nested(DetachableMap(a: Join, b, c)) => None
+
+ // map(i => (i.i, i.l)).distinct.map(x => (x._1, x._2)) =>
+ // map(i => (i.i, i.l)).distinct
+ case Map(Distinct(DetachableMap(a, b, c)), d, e) if isomorphic(e, c, d) =>
+ Some(Distinct(Map(a, b, c)))
+
+ // a.map(b => c).map(d => e) =>
+ // a.map(b => e[d := c])
+ case before @ Map(MapWithoutInfixes(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Map(a, b, er))
+
+ // a.map(b => b) =>
+ // a
+ case Map(a: Query, b, c) if (b == c) =>
+ Some(a)
+
+ // a.map(b => c).flatMap(d => e) =>
+ // a.flatMap(b => e[d := c])
+ case FlatMap(DetachableMap(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ Some(FlatMap(a, b, er))
+
+ // a.map(b => c).filter(d => e) =>
+ // a.filter(b => e[d := c]).map(b => c)
+ case Filter(DetachableMap(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Map(Filter(a, b, er), b, c))
+
+ // a.map(b => c).sortBy(d => e) =>
+ // a.sortBy(b => e[d := c]).map(b => c)
+ case SortBy(DetachableMap(a, b, c), d, e, f) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Map(SortBy(a, b, er, f), b, c))
+
+ // a.map(b => c).sortBy(d => e).distinct =>
+ // a.sortBy(b => e[d := c]).map(b => c).distinct
+ case SortBy(Distinct(DetachableMap(a, b, c)), d, e, f) =>
+ val er = BetaReduction(e, d -> c)
+ Some(Distinct(Map(SortBy(a, b, er, f), b, c)))
+
+ // a.map(b => c).groupBy(d => e) =>
+ // a.groupBy(b => e[d := c]).map(x => (x._1, x._2.map(b => c)))
+ case GroupBy(DetachableMap(a, b, c), d, e) =>
+ val er = BetaReduction(e, d -> c)
+ val x = Ident("x")
+ val x1 = Property(
+ Ident("x"),
+ "_1"
+ ) // These introduced property should not be renamed
+ val x2 = Property(Ident("x"), "_2") // due to any naming convention.
+ val body = Tuple(List(x1, Map(x2, b, c)))
+ Some(Map(GroupBy(a, b, er), x, body))
+
+ // a.map(b => c).drop(d) =>
+ // a.drop(d).map(b => c)
+ case Drop(DetachableMap(a, b, c), d) =>
+ Some(Map(Drop(a, d), b, c))
+
+ // a.map(b => c).take(d) =>
+ // a.drop(d).map(b => c)
+ case Take(DetachableMap(a, b, c), d) =>
+ Some(Map(Take(a, d), b, c))
+
+ // a.map(b => c).nested =>
+ // a.nested.map(b => c)
+ case Nested(DetachableMap(a, b, c)) =>
+ Some(Map(Nested(a), b, c))
+
+ // a.map(b => c).*join(d.map(e => f)).on((iA, iB) => on)
+ // a.*join(d).on((b, e) => on[iA := c, iB := f]).map(t => (c[b := t._1], f[e := t._2]))
+ case Join(
+ tpe,
+ DetachableMap(a, b, c),
+ DetachableMap(d, e, f),
+ iA,
+ iB,
+ on
+ ) =>
+ val onr = BetaReduction(on, iA -> c, iB -> f)
+ val t = Ident("t")
+ val t1 = BetaReduction(c, b -> Property(t, "_1"))
+ val t2 = BetaReduction(f, e -> Property(t, "_2"))
+ Some(Map(Join(tpe, a, d, b, e, onr), t, Tuple(List(t1, t2))))
+
+ // a.*join(b.map(c => d)).on((iA, iB) => on)
+ // a.*join(b).on((iA, c) => on[iB := d]).map(t => (t._1, d[c := t._2]))
+ case Join(tpe, a, DetachableMap(b, c, d), iA, iB, on) =>
+ val onr = BetaReduction(on, iB -> d)
+ val t = Ident("t")
+ val t1 = Property(t, "_1")
+ val t2 = BetaReduction(d, c -> Property(t, "_2"))
+ Some(Map(Join(tpe, a, b, iA, c, onr), t, Tuple(List(t1, t2))))
+
+ // a.map(b => c).*join(d).on((iA, iB) => on)
+ // a.*join(d).on((b, iB) => on[iA := c]).map(t => (c[b := t._1], t._2))
+ case Join(tpe, DetachableMap(a, b, c), d, iA, iB, on) =>
+ val onr = BetaReduction(on, iA -> c)
+ val t = Ident("t")
+ val t1 = BetaReduction(c, b -> Property(t, "_1"))
+ val t2 = Property(t, "_2")
+ Some(Map(Join(tpe, a, d, b, iB, onr), t, Tuple(List(t1, t2))))
+
+ case other => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/AttachToEntity.scala b/src/main/scala/minisql/norm/AttachToEntity.scala
new file mode 100644
index 0000000..e5608e9
--- /dev/null
+++ b/src/main/scala/minisql/norm/AttachToEntity.scala
@@ -0,0 +1,48 @@
+package minisql.norm
+
+import minisql.util.Messages.fail
+import minisql.ast._
+
+object AttachToEntity {
+
+ private object IsEntity {
+ def unapply(q: Ast): Option[Ast] =
+ q match {
+ case q: Entity => Some(q)
+ case q: Infix => Some(q)
+ case _ => None
+ }
+ }
+
+ def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)(
+ q: Ast
+ ): Ast =
+ q match {
+
+ case Map(IsEntity(a), b, c) => Map(f(a, b), b, c)
+ case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c)
+ case ConcatMap(IsEntity(a), b, c) => ConcatMap(f(a, b), b, c)
+ case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c)
+ case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d)
+ case DistinctOn(IsEntity(a), b, c) => DistinctOn(f(a, b), b, c)
+
+ case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join |
+ _: FlatJoin =>
+ f(q, alias.getOrElse(Ident("x")))
+
+ case Map(a: Query, b, c) => Map(apply(f, Some(b))(a), b, c)
+ case FlatMap(a: Query, b, c) => FlatMap(apply(f, Some(b))(a), b, c)
+ case ConcatMap(a: Query, b, c) => ConcatMap(apply(f, Some(b))(a), b, c)
+ case Filter(a: Query, b, c) => Filter(apply(f, Some(b))(a), b, c)
+ case SortBy(a: Query, b, c, d) => SortBy(apply(f, Some(b))(a), b, c, d)
+ case Take(a: Query, b) => Take(apply(f, alias)(a), b)
+ case Drop(a: Query, b) => Drop(apply(f, alias)(a), b)
+ case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a))
+ case Distinct(a: Query) => Distinct(apply(f, alias)(a))
+ case DistinctOn(a: Query, b, c) => DistinctOn(apply(f, Some(b))(a), b, c)
+
+ case IsEntity(q) => f(q, alias.getOrElse(Ident("x")))
+
+ case other => fail(s"Can't find an 'Entity' in '$q'")
+ }
+}
diff --git a/src/main/scala/minisql/util/BetaReduction.scala b/src/main/scala/minisql/norm/BetaReduction.scala
similarity index 99%
rename from src/main/scala/minisql/util/BetaReduction.scala
rename to src/main/scala/minisql/norm/BetaReduction.scala
index 0940564..868b021 100644
--- a/src/main/scala/minisql/util/BetaReduction.scala
+++ b/src/main/scala/minisql/norm/BetaReduction.scala
@@ -1,6 +1,6 @@
-package minisql.util
+package minisql.norm
-import minisql.ast.*
+import minisql.ast._
import scala.collection.immutable.{Map => IMap}
case class BetaReduction(replacements: Replacements)
diff --git a/src/main/scala/minisql/norm/ConcatBehavior.scala b/src/main/scala/minisql/norm/ConcatBehavior.scala
new file mode 100644
index 0000000..3547b1d
--- /dev/null
+++ b/src/main/scala/minisql/norm/ConcatBehavior.scala
@@ -0,0 +1,7 @@
+package minisql.norm
+
+trait ConcatBehavior
+object ConcatBehavior {
+ case object AnsiConcat extends ConcatBehavior
+ case object NonAnsiConcat extends ConcatBehavior
+}
diff --git a/src/main/scala/minisql/norm/EqualityBehavior.scala b/src/main/scala/minisql/norm/EqualityBehavior.scala
new file mode 100644
index 0000000..ee1ff38
--- /dev/null
+++ b/src/main/scala/minisql/norm/EqualityBehavior.scala
@@ -0,0 +1,7 @@
+package minisql.norm
+
+trait EqualityBehavior
+object EqualityBehavior {
+ case object AnsiEquality extends EqualityBehavior
+ case object NonAnsiEquality extends EqualityBehavior
+}
diff --git a/src/main/scala/minisql/norm/ExpandReturning.scala b/src/main/scala/minisql/norm/ExpandReturning.scala
new file mode 100644
index 0000000..32a886f
--- /dev/null
+++ b/src/main/scala/minisql/norm/ExpandReturning.scala
@@ -0,0 +1,74 @@
+package minisql.norm
+
+import minisql.ReturnAction.ReturnColumns
+import minisql.{NamingStrategy, ReturnAction}
+import minisql.ast._
+import minisql.context.{
+ ReturningClauseSupported,
+ ReturningMultipleFieldSupported,
+ ReturningNotSupported,
+ ReturningSingleFieldSupported
+}
+import minisql.idiom.{Idiom, Statement}
+
+/**
+ * Take the `.returning` part in a query that contains it and return the array
+ * of columns representing of the returning seccovtion with any other operations
+ * etc... that they might contain.
+ */
+object ExpandReturning {
+
+ def applyMap(
+ returning: ReturningAction
+ )(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = {
+ val initialExpand = ExpandReturning.apply(returning)(idiom, naming)
+
+ idiom.idiomReturningCapability match {
+ case ReturningClauseSupported =>
+ ReturnAction.ReturnRecord
+ case ReturningMultipleFieldSupported =>
+ ReturnColumns(initialExpand.map {
+ case (ast, statement) => f(ast, statement)
+ })
+ case ReturningSingleFieldSupported =>
+ if (initialExpand.length == 1)
+ ReturnColumns(initialExpand.map {
+ case (ast, statement) => f(ast, statement)
+ })
+ else
+ throw new IllegalArgumentException(
+ s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified."
+ )
+ case ReturningNotSupported =>
+ throw new IllegalArgumentException(
+ s"RETURNING columns are not allowed in the ${idiom} dialect."
+ )
+ }
+ }
+
+ def apply(
+ returning: ReturningAction
+ )(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = {
+ val ReturningAction(_, alias, properties) = returning: @unchecked
+
+ // Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1))))
+ // => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
+ val dePropertized =
+ Transform(properties) {
+ case `alias` => ExternalIdent(alias.name)
+ }
+
+ val aliasName = alias.name
+
+ // Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
+ // => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))
+ val deTuplified = dePropertized match {
+ case Tuple(values) => values
+ case CaseClass(values) => values.map(_._2)
+ case other => List(other)
+ }
+
+ implicit val namingStrategy: NamingStrategy = naming
+ deTuplified.map(v => idiom.translate(v))
+ }
+}
diff --git a/src/main/scala/minisql/norm/FlattenOptionOperation.scala b/src/main/scala/minisql/norm/FlattenOptionOperation.scala
new file mode 100644
index 0000000..50ba857
--- /dev/null
+++ b/src/main/scala/minisql/norm/FlattenOptionOperation.scala
@@ -0,0 +1,108 @@
+package minisql.norm
+
+import minisql.ast.*
+import minisql.ast.Implicits.*
+import minisql.norm.ConcatBehavior.NonAnsiConcat
+
+class FlattenOptionOperation(concatBehavior: ConcatBehavior)
+ extends StatelessTransformer {
+
+ private def emptyOrNot(b: Boolean, ast: Ast) =
+ if (b) OptionIsEmpty(ast) else OptionNonEmpty(ast)
+
+ def uncheckedReduction(ast: Ast, alias: Ident, body: Ast) =
+ apply(BetaReduction(body, alias -> ast))
+
+ def uncheckedForall(ast: Ast, alias: Ident, body: Ast) = {
+ val reduced = BetaReduction(body, alias -> ast)
+ apply((IsNullCheck(ast) +||+ reduced): Ast)
+ }
+
+ def containsNonFallthroughElement(ast: Ast) =
+ CollectAst(ast) {
+ case If(_, _, _) => true
+ case Infix(_, _, _, _) => true
+ case BinaryOperation(_, StringOperator.`concat`, _)
+ if (concatBehavior == NonAnsiConcat) =>
+ true
+ }.nonEmpty
+
+ override def apply(ast: Ast): Ast =
+ ast match {
+
+ case OptionTableFlatMap(ast, alias, body) =>
+ uncheckedReduction(ast, alias, body)
+
+ case OptionTableMap(ast, alias, body) =>
+ uncheckedReduction(ast, alias, body)
+
+ case OptionTableExists(ast, alias, body) =>
+ uncheckedReduction(ast, alias, body)
+
+ case OptionTableForall(ast, alias, body) =>
+ uncheckedForall(ast, alias, body)
+
+ case OptionFlatten(ast) =>
+ apply(ast)
+
+ case OptionSome(ast) =>
+ apply(ast)
+
+ case OptionApply(ast) =>
+ apply(ast)
+
+ case OptionOrNull(ast) =>
+ apply(ast)
+
+ case OptionGetOrNull(ast) =>
+ apply(ast)
+
+ case OptionNone => NullValue
+
+ case OptionGetOrElse(OptionMap(ast, alias, body), Constant(b: Boolean)) =>
+ apply((BetaReduction(body, alias -> ast) +||+ emptyOrNot(b, ast)): Ast)
+
+ case OptionGetOrElse(ast, body) =>
+ apply(If(IsNotNullCheck(ast), ast, body))
+
+ case OptionFlatMap(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduced = BetaReduction(body, alias -> ast)
+ apply(IfExistElseNull(ast, reduced))
+ } else {
+ uncheckedReduction(ast, alias, body)
+ }
+
+ case OptionMap(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduced = BetaReduction(body, alias -> ast)
+ apply(IfExistElseNull(ast, reduced))
+ } else {
+ uncheckedReduction(ast, alias, body)
+ }
+
+ case OptionForall(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduction = BetaReduction(body, alias -> ast)
+ apply(
+ (IsNullCheck(ast) +||+ (IsNotNullCheck(ast) +&&+ reduction)): Ast
+ )
+ } else {
+ uncheckedForall(ast, alias, body)
+ }
+
+ case OptionExists(ast, alias, body) =>
+ if (containsNonFallthroughElement(body)) {
+ val reduction = BetaReduction(body, alias -> ast)
+ apply((IsNotNullCheck(ast) +&&+ reduction): Ast)
+ } else {
+ uncheckedReduction(ast, alias, body)
+ }
+
+ case OptionContains(ast, body) =>
+ apply((ast +==+ body): Ast)
+
+ case other =>
+ super.apply(other)
+ }
+}
diff --git a/src/main/scala/minisql/norm/NestImpureMappedInfix.scala b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala
new file mode 100644
index 0000000..789efba
--- /dev/null
+++ b/src/main/scala/minisql/norm/NestImpureMappedInfix.scala
@@ -0,0 +1,76 @@
+package minisql.norm
+
+import minisql.ast._
+
+/**
+ * A problem occurred in the original way infixes were done in that it was
+ * assumed that infix clauses represented pure functions. While this is true of
+ * many UDFs (e.g. `CONCAT`, `GETDATE`) it is certainly not true of many others
+ * e.g. `RAND()`, and most importantly `RANK()`. For this reason, the operations
+ * that are done in `ApplyMap` on standard AST `Map` clauses cannot be done
+ * therefore additional safety checks were introduced there in order to assure
+ * this does not happen. In addition to this however, it is necessary to add
+ * this normalization step which inserts `Nested` AST elements in every map that
+ * contains impure infix. See more information and examples in #1534.
+ */
+object NestImpureMappedInfix extends StatelessTransformer {
+
+ // Are there any impure infixes that exist inside the specified ASTs
+ def hasInfix(asts: Ast*): Boolean =
+ asts.exists(ast =>
+ CollectAst(ast) {
+ case i @ Infix(_, _, false, _) => i
+ }.nonEmpty
+ )
+
+ // Continue exploring into the Map to see if there are additional impure infix clauses inside.
+ private def applyInside(m: Map) =
+ Map(apply(m.query), m.alias, m.body)
+
+ override def apply(ast: Ast): Ast =
+ ast match {
+ // If there is already a nested clause inside the map, there is no reason to insert another one
+ case Nested(Map(inner, a, b)) =>
+ Nested(Map(apply(inner), a, b))
+
+ case m @ Map(_, x, cc @ CaseClass(values)) if hasInfix(cc) => // Nested(m)
+ Map(
+ Nested(applyInside(m)),
+ x,
+ CaseClass(values.map {
+ case (name, _) =>
+ (
+ name,
+ Property(x, name)
+ ) // mappings of nested-query case class properties should not be renamed
+ })
+ )
+
+ case m @ Map(_, x, tup @ Tuple(values)) if hasInfix(tup) =>
+ Map(
+ Nested(applyInside(m)),
+ x,
+ Tuple(values.zipWithIndex.map {
+ case (_, i) =>
+ Property(
+ x,
+ s"_${i + 1}"
+ ) // mappings of nested-query tuple properties should not be renamed
+ })
+ )
+
+ case m @ Map(_, x, i @ Infix(_, _, false, _)) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case m @ Map(_, x, Property(prop, _)) if hasInfix(prop) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case m @ Map(_, x, BinaryOperation(a, _, b)) if hasInfix(a, b) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case m @ Map(_, x, UnaryOperation(_, a)) if hasInfix(a) =>
+ Map(Nested(applyInside(m)), x, Property(x, "_1"))
+
+ case other => super.apply(other)
+ }
+}
diff --git a/src/main/scala/minisql/norm/Normalize.scala b/src/main/scala/minisql/norm/Normalize.scala
new file mode 100644
index 0000000..d3d8d67
--- /dev/null
+++ b/src/main/scala/minisql/norm/Normalize.scala
@@ -0,0 +1,51 @@
+package minisql.norm
+
+import minisql.ast.Ast
+import minisql.ast.Query
+import minisql.ast.StatelessTransformer
+import minisql.norm.capture.AvoidCapture
+import minisql.ast.Action
+import minisql.util.Messages.trace
+import minisql.util.Messages.TraceType.Normalizations
+
+import scala.annotation.tailrec
+
+object Normalize extends StatelessTransformer {
+
+ override def apply(q: Ast): Ast =
+ super.apply(BetaReduction(q))
+
+ override def apply(q: Action): Action =
+ NormalizeReturning(super.apply(q))
+
+ override def apply(q: Query): Query =
+ norm(AvoidCapture(q))
+
+ private def traceNorm[T](label: String) =
+ trace[T](s"${label} (Normalize)", 1, Normalizations)
+
+ @tailrec
+ private def norm(q: Query): Query =
+ q match {
+ case NormalizeNestedStructures(query) =>
+ traceNorm("NormalizeNestedStructures")(query)
+ norm(query)
+ case ApplyMap(query) =>
+ traceNorm("ApplyMap")(query)
+ norm(query)
+ case SymbolicReduction(query) =>
+ traceNorm("SymbolicReduction")(query)
+ norm(query)
+ case AdHocReduction(query) =>
+ traceNorm("AdHocReduction")(query)
+ norm(query)
+ case OrderTerms(query) =>
+ traceNorm("OrderTerms")(query)
+ norm(query)
+ case NormalizeAggregationIdent(query) =>
+ traceNorm("NormalizeAggregationIdent")(query)
+ norm(query)
+ case other =>
+ other
+ }
+}
diff --git a/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala
new file mode 100644
index 0000000..c451ee4
--- /dev/null
+++ b/src/main/scala/minisql/norm/NormalizeAggregationIdent.scala
@@ -0,0 +1,29 @@
+package minisql.norm
+
+import minisql.ast._
+
+object NormalizeAggregationIdent {
+
+ def unapply(q: Query) =
+ q match {
+
+ // a => a.b.map(x => x.c).agg =>
+ // a => a.b.map(a => a.c).agg
+ case Aggregation(
+ op,
+ Map(
+ p @ Property(i: Ident, _),
+ mi,
+ Property.Opinionated(_: Ident, n, renameable, visibility)
+ )
+ ) if i != mi =>
+ Some(
+ Aggregation(
+ op,
+ Map(p, i, Property.Opinionated(i, n, renameable, visibility))
+ )
+ ) // in example aove, if c in x.c is fixed c in a.c should also be
+
+ case _ => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/NormalizeNestedStructures.scala b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala
new file mode 100644
index 0000000..603c411
--- /dev/null
+++ b/src/main/scala/minisql/norm/NormalizeNestedStructures.scala
@@ -0,0 +1,47 @@
+package minisql.norm
+
+import minisql.ast._
+
+object NormalizeNestedStructures {
+
+ def unapply(q: Query): Option[Query] =
+ q match {
+ case e: Entity => None
+ case Map(a, b, c) => apply(a, c)(Map(_, b, _))
+ case FlatMap(a, b, c) => apply(a, c)(FlatMap(_, b, _))
+ case ConcatMap(a, b, c) => apply(a, c)(ConcatMap(_, b, _))
+ case Filter(a, b, c) => apply(a, c)(Filter(_, b, _))
+ case SortBy(a, b, c, d) => apply(a, c)(SortBy(_, b, _, d))
+ case GroupBy(a, b, c) => apply(a, c)(GroupBy(_, b, _))
+ case Aggregation(a, b) => apply(b)(Aggregation(a, _))
+ case Take(a, b) => apply(a, b)(Take.apply)
+ case Drop(a, b) => apply(a, b)(Drop.apply)
+ case Union(a, b) => apply(a, b)(Union.apply)
+ case UnionAll(a, b) => apply(a, b)(UnionAll.apply)
+ case Distinct(a) => apply(a)(Distinct.apply)
+ case DistinctOn(a, b, c) => apply(a, c)(DistinctOn(_, b, _))
+ case Nested(a) => apply(a)(Nested.apply)
+ case FlatJoin(t, a, iA, on) =>
+ (Normalize(a), Normalize(on)) match {
+ case (`a`, `on`) => None
+ case (a, on) => Some(FlatJoin(t, a, iA, on))
+ }
+ case Join(t, a, b, iA, iB, on) =>
+ (Normalize(a), Normalize(b), Normalize(on)) match {
+ case (`a`, `b`, `on`) => None
+ case (a, b, on) => Some(Join(t, a, b, iA, iB, on))
+ }
+ }
+
+ private def apply(a: Ast)(f: Ast => Query) =
+ (Normalize(a)) match {
+ case (`a`) => None
+ case (a) => Some(f(a))
+ }
+
+ private def apply(a: Ast, b: Ast)(f: (Ast, Ast) => Query) =
+ (Normalize(a), Normalize(b)) match {
+ case (`a`, `b`) => None
+ case (a, b) => Some(f(a, b))
+ }
+}
diff --git a/src/main/scala/minisql/norm/NormalizeReturning.scala b/src/main/scala/minisql/norm/NormalizeReturning.scala
new file mode 100644
index 0000000..43fc241
--- /dev/null
+++ b/src/main/scala/minisql/norm/NormalizeReturning.scala
@@ -0,0 +1,154 @@
+package minisql.norm
+
+import minisql.ast._
+import minisql.norm.capture.AvoidAliasConflict
+
+/**
+ * When actions are used with a `.returning` clause, remove the columns used in
+ * the returning clause from the action. E.g. for `insert(Person(id,
+ * name)).returning(_.id)` remove the `id` column from the original insert.
+ */
+object NormalizeReturning {
+
+ def apply(e: Action): Action = {
+ e match {
+ case ReturningGenerated(a: Action, alias, body) =>
+ // De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action.
+ // This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres.
+ // For example:
+ // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
+ // Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally
+ // be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case.
+ // In order to fix this, we need to change `t` into a different alias.
+ val newBody = dealiasBody(body, alias)
+ ReturningGenerated(apply(a, newBody, alias), alias, newBody)
+
+ // For a regular return clause, do not need to exclude assignments from insertion however, we still
+ // need to de-alias the Action body in case conflicts result. For example the following query:
+ // query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
+ // would incorrectly be interpreted as:
+ // INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t
+ // whereas it should be:
+ // INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1)
+ case Returning(a: Action, alias, body) =>
+ val newBody = dealiasBody(body, alias)
+ Returning(a, alias, newBody)
+
+ case _ => e
+ }
+ }
+
+ /**
+ * In some situations, a query can exist inside of a `returning` clause. In
+ * this case, we need to rename if the aliases used in that query override the
+ * alias used in the `returning` clause otherwise they will be treated as
+ * returning-clause aliases ExpandReturning (i.e. they will become
+ * ExternalAlias instances) and later be tokenized incorrectly.
+ */
+ private def dealiasBody(body: Ast, alias: Ident): Ast =
+ Transform(body) {
+ case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias))
+ }
+
+ private def apply(e: Action, body: Ast, returningIdent: Ident): Action =
+ e match {
+ case Insert(query, assignments) =>
+ Insert(query, filterReturnedColumn(assignments, body, returningIdent))
+ case Update(query, assignments) =>
+ Update(query, filterReturnedColumn(assignments, body, returningIdent))
+ case OnConflict(a: Action, target, act) =>
+ OnConflict(apply(a, body, returningIdent), target, act)
+ case _ => e
+ }
+
+ private def filterReturnedColumn(
+ assignments: List[Assignment],
+ column: Ast,
+ returningIdent: Ident
+ ): List[Assignment] =
+ assignments.flatMap(filterReturnedColumn(_, column, returningIdent))
+
+ /**
+ * In situations like Property(Property(ident, foo), bar) pull out the
+ * inner-most ident
+ */
+ object NestedProperty {
+ def unapply(ast: Property): Option[Ast] = {
+ ast match {
+ case p @ Property(subAst, _) => Some(innerMost(subAst))
+ }
+ }
+
+ private def innerMost(ast: Ast): Ast = ast match {
+ case Property(inner, _) => innerMost(inner)
+ case other => other
+ }
+ }
+
+ /**
+ * Remove the specified column from the assignment. For example, in a query
+ * like `insert(Person(id, name)).returning(r => r.id)` we need to remove the
+ * `id` column from the insertion. The value of the `column:Ast` in this case
+ * will be `Property(Ident(r), id)` and the values fo the assignment `p1`
+ * property will typically be `v.id` and `v.name` (the `v` variable is a
+ * default used for `insert` queries).
+ */
+ private def filterReturnedColumn(
+ assignment: Assignment,
+ body: Ast,
+ returningIdent: Ident
+ ): Option[Assignment] =
+ assignment match {
+ case Assignment(_, p1: Property, _) => {
+ // Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but
+ // if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2))
+ // or it can even be a query since queries are allowed to be in return sections e.g:
+ // query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max))
+ // In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment
+ // in order to know what to exclude.
+ val matchedProps =
+ CollectAst(body) {
+ // case prop @ NestedProperty(`returningIdent`) => prop
+ case prop @ NestedProperty(Ident(name))
+ if (name == returningIdent.name) =>
+ prop
+ case prop @ NestedProperty(ExternalIdent(name))
+ if (name == returningIdent.name) =>
+ prop
+ }
+
+ if (
+ matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp))
+ )
+ None
+ else
+ Some(assignment)
+ }
+ case assignment => Some(assignment)
+ }
+
+ object SomeIdent {
+ def unapply(ast: Ast): Option[Ast] =
+ ast match {
+ case id: Ident => Some(id)
+ case id: ExternalIdent => Some(id)
+ case _ => None
+ }
+ }
+
+ /**
+ * Is it the same property (but possibly of a different identity). E.g.
+ * `p.foo.bar` and `v.foo.bar`
+ */
+ private def isSameProperties(p1: Property, p2: Property): Boolean =
+ (p1.ast, p2.ast) match {
+ case (SomeIdent(_), SomeIdent(_)) =>
+ p1.name == p2.name
+ // If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the
+ // outer properties are the same before moving on to the inner ones.
+ case (pp1: Property, pp2: Property) if (p1.name == p2.name) =>
+ isSameProperties(pp1, pp2)
+ case _ =>
+ false
+ }
+}
diff --git a/src/main/scala/minisql/norm/OrderTerms.scala b/src/main/scala/minisql/norm/OrderTerms.scala
new file mode 100644
index 0000000..22422fa
--- /dev/null
+++ b/src/main/scala/minisql/norm/OrderTerms.scala
@@ -0,0 +1,29 @@
+package minisql.norm
+
+import minisql.ast._
+
+object OrderTerms {
+
+ def unapply(q: Query) =
+ q match {
+
+ case Take(Map(a: GroupBy, b, c), d) => None
+
+ // a.sortBy(b => c).filter(d => e) =>
+ // a.filter(d => e).sortBy(b => c)
+ case Filter(SortBy(a, b, c, d), e, f) =>
+ Some(SortBy(Filter(a, e, f), b, c, d))
+
+ // a.flatMap(b => c).take(n).map(d => e) =>
+ // a.flatMap(b => c).map(d => e).take(n)
+ case Map(Take(fm: FlatMap, n), ma, mb) =>
+ Some(Take(Map(fm, ma, mb), n))
+
+ // a.flatMap(b => c).drop(n).map(d => e) =>
+ // a.flatMap(b => c).map(d => e).drop(n)
+ case Map(Drop(fm: FlatMap, n), ma, mb) =>
+ Some(Drop(Map(fm, ma, mb), n))
+
+ case other => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/RenameProperties.scala b/src/main/scala/minisql/norm/RenameProperties.scala
new file mode 100644
index 0000000..53d135c
--- /dev/null
+++ b/src/main/scala/minisql/norm/RenameProperties.scala
@@ -0,0 +1,491 @@
+package minisql.norm
+
+import minisql.ast.Renameable.Fixed
+import minisql.ast.Visibility.Visible
+import minisql.ast._
+import minisql.util.Interpolator
+
+object RenameProperties extends StatelessTransformer {
+ val interp = new Interpolator(3)
+ import interp._
+ def traceDifferent(one: Any, two: Any) =
+ if (one != two)
+ trace"Replaced $one with $two".andLog()
+ else
+ trace"Replacements did not match".andLog()
+
+ override def apply(q: Query): Query =
+ applySchemaOnly(q)
+
+ override def apply(q: Action): Action =
+ applySchema(q) match {
+ case (q, schema) => q
+ }
+
+ override def apply(e: Operation): Operation =
+ e match {
+ case UnaryOperation(o, c: Query) =>
+ UnaryOperation(o, applySchemaOnly(apply(c)))
+ case _ => super.apply(e)
+ }
+
+ private def applySchema(q: Action): (Action, Schema) =
+ q match {
+ case Insert(q: Query, assignments) =>
+ applySchema(q, assignments, Insert.apply)
+ case Update(q: Query, assignments) =>
+ applySchema(q, assignments, Update.apply)
+ case Delete(q: Query) =>
+ applySchema(q) match {
+ case (q, schema) => (Delete(q), schema)
+ }
+ case Returning(action: Action, alias, body) =>
+ applySchema(action) match {
+ case (action, schema) =>
+ val replace =
+ trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
+ replacements(alias, schema)
+ val bodyr = BetaReduction(body, replace*)
+ traceDifferent(body, bodyr)
+ (Returning(action, alias, bodyr), schema)
+ }
+ case ReturningGenerated(action: Action, alias, body) =>
+ applySchema(action) match {
+ case (action, schema) =>
+ val replace =
+ trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
+ replacements(alias, schema)
+ val bodyr = BetaReduction(body, replace*)
+ traceDifferent(body, bodyr)
+ (ReturningGenerated(action, alias, bodyr), schema)
+ }
+ case OnConflict(a: Action, target, act) =>
+ applySchema(a) match {
+ case (action, schema) =>
+ val targetR = target match {
+ case OnConflict.Properties(props) =>
+ val propsR = props.map { prop =>
+ val replace =
+ trace"Finding Replacements for $props inside ${prop.ast} using schema $schema:" `andReturn`
+ replacements(
+ prop.ast,
+ schema
+ ) // A BetaReduction on a Property will always give back a Property
+ BetaReduction(prop, replace*).asInstanceOf[Property]
+ }
+ traceDifferent(props, propsR)
+ OnConflict.Properties(propsR)
+ case OnConflict.NoTarget => target
+ }
+ val actR = act match {
+ case OnConflict.Update(assignments) =>
+ OnConflict.Update(replaceAssignments(assignments, schema))
+ case _ => act
+ }
+ (OnConflict(action, targetR, actR), schema)
+ }
+ case q => (q, TupleSchema.empty)
+ }
+
+ private def replaceAssignments(
+ a: List[Assignment],
+ schema: Schema
+ ): List[Assignment] =
+ a.map {
+ case Assignment(alias, prop, value) =>
+ val replace =
+ trace"Finding Replacements for $prop inside $alias using schema $schema:" `andReturn`
+ replacements(alias, schema)
+ val propR = BetaReduction(prop, replace*)
+ traceDifferent(prop, propR)
+ val valueR = BetaReduction(value, replace*)
+ traceDifferent(value, valueR)
+ Assignment(alias, propR, valueR)
+ }
+
+ private def applySchema(
+ q: Query,
+ a: List[Assignment],
+ f: (Query, List[Assignment]) => Action
+ ): (Action, Schema) =
+ applySchema(q) match {
+ case (q, schema) =>
+ (f(q, replaceAssignments(a, schema)), schema)
+ }
+
+ private def applySchemaOnly(q: Query): Query =
+ applySchema(q) match {
+ case (q, _) => q
+ }
+
+ object TupleIndex {
+ def unapply(s: String): Option[Int] =
+ if (s.matches("_[0-9]*"))
+ Some(s.drop(1).toInt - 1)
+ else
+ None
+ }
+
+ sealed trait Schema {
+ def lookup(property: List[String]): Option[Schema] =
+ (property, this) match {
+ case (Nil, schema) =>
+ trace"Nil at $property returning " `andReturn`
+ Some(schema)
+ case (path, e @ EntitySchema(_)) =>
+ trace"Entity at $path returning " `andReturn`
+ Some(e.subSchemaOrEmpty(path))
+ case (head :: tail, CaseClassSchema(props)) if (props.contains(head)) =>
+ trace"Case class at $property returning " `andReturn`
+ props(head).lookup(tail)
+ case (TupleIndex(idx) :: tail, TupleSchema(values))
+ if values.contains(idx) =>
+ trace"Tuple at at $property returning " `andReturn`
+ values(idx).lookup(tail)
+ case _ =>
+ trace"Nothing found at $property returning " `andReturn`
+ None
+ }
+ }
+
+ // Represents a nested property path to an identity i.e. Property(Property(... Ident(), ...))
+ object PropertyMatroshka {
+
+ def traverse(initial: Property): Option[(Ident, List[String])] =
+ initial match {
+ // If it's a nested-property walk inside and append the name to the result (if something is returned)
+ case Property(inner: Property, name) =>
+ traverse(inner).map { case (id, list) => (id, list :+ name) }
+ // If it's a property with ident in the core, return that
+ case Property(id: Ident, name) =>
+ Some((id, List(name)))
+ // Otherwise an ident property is not inside so don't return anything
+ case _ =>
+ None
+ }
+
+ def unapply(ast: Ast): Option[(Ident, List[String])] =
+ ast match {
+ case p: Property => traverse(p)
+ case _ => None
+ }
+
+ }
+
+ def protractSchema(
+ body: Ast,
+ ident: Ident,
+ schema: Schema
+ ): Option[Schema] = {
+
+ def protractSchemaRecurse(body: Ast, schema: Schema): Option[Schema] =
+ body match {
+ // if any values yield a sub-schema which is not an entity, recurse into that
+ case cc @ CaseClass(values) =>
+ trace"Protracting CaseClass $cc into new schema:" `andReturn`
+ CaseClassSchema(
+ values.collect {
+ case (name, innerBody @ HierarchicalAstEntity()) =>
+ (name, protractSchemaRecurse(innerBody, schema))
+ // pass the schema into a recursive call an extract from it when we non tuple/caseclass element
+ case (name, innerBody @ PropertyMatroshka(`ident`, path)) =>
+ (name, protractSchemaRecurse(innerBody, schema))
+ // we have reached an ident i.e. recurse to pass the current schema into the case class
+ case (name, `ident`) =>
+ (name, protractSchemaRecurse(ident, schema))
+ }.collect {
+ case (name, Some(subSchema)) => (name, subSchema)
+ }
+ ).notEmpty
+ case tup @ Tuple(values) =>
+ trace"Protracting Tuple $tup into new schema:" `andReturn`
+ TupleSchema
+ .fromIndexes(
+ values.zipWithIndex.collect {
+ case (innerBody @ HierarchicalAstEntity(), index) =>
+ (index, protractSchemaRecurse(innerBody, schema))
+ // pass the schema into a recursive call an extract from it when we non tuple/caseclass element
+ case (innerBody @ PropertyMatroshka(`ident`, path), index) =>
+ (index, protractSchemaRecurse(innerBody, schema))
+ // we have reached an ident i.e. recurse to pass the current schema into the tuple
+ case (`ident`, index) =>
+ (index, protractSchemaRecurse(ident, schema))
+ }.collect {
+ case (index, Some(subSchema)) => (index, subSchema)
+ }
+ )
+ .notEmpty
+
+ case prop @ PropertyMatroshka(`ident`, path) =>
+ trace"Protraction completed schema path $prop at the schema $schema pointing to:" `andReturn`
+ schema match {
+ // case e: EntitySchema => Some(e)
+ case _ => schema.lookup(path)
+ }
+ case `ident` =>
+ trace"Protraction completed with the mapping identity $ident at the schema:" `andReturn`
+ Some(schema)
+ case other =>
+ trace"Protraction DID NOT find a sub schema, it completed with $other at the schema:" `andReturn`
+ Some(schema)
+ }
+
+ protractSchemaRecurse(body, schema)
+ }
+
+ case object EmptySchema extends Schema
+ case class EntitySchema(e: Entity) extends Schema {
+ def noAliases = e.properties.isEmpty
+
+ private def subSchema(path: List[String]) =
+ EntitySchema(
+ Entity(
+ s"sub-${e.name}",
+ e.properties.flatMap {
+ case PropertyAlias(aliasPath, alias) =>
+ if (aliasPath == path)
+ List(PropertyAlias(aliasPath, alias))
+ else if (aliasPath.startsWith(path))
+ List(PropertyAlias(aliasPath.diff(path), alias))
+ else
+ List()
+ }
+ )
+ )
+
+ def subSchemaOrEmpty(path: List[String]): Schema =
+ trace"Creating sub-schema for entity $e at path $path will be" andReturn {
+ val sub = subSchema(path)
+ if (sub.noAliases) EmptySchema else sub
+ }
+
+ }
+ case class TupleSchema(m: collection.Map[Int, Schema] /* Zero Indexed */ )
+ extends Schema {
+ def list = m.toList.sortBy(_._1)
+ def notEmpty =
+ if (this.m.nonEmpty) Some(this) else None
+ }
+ case class CaseClassSchema(m: collection.Map[String, Schema]) extends Schema {
+ def list = m.toList
+ def notEmpty =
+ if (this.m.nonEmpty) Some(this) else None
+ }
+ object CaseClassSchema {
+ def apply(property: String, value: Schema): CaseClassSchema =
+ CaseClassSchema(collection.Map(property -> value))
+ def apply(list: List[(String, Schema)]): CaseClassSchema =
+ CaseClassSchema(list.toMap)
+ }
+
+ object TupleSchema {
+ def fromIndexes(schemas: List[(Int, Schema)]): TupleSchema =
+ TupleSchema(schemas.toMap)
+
+ def apply(schemas: List[Schema]): TupleSchema =
+ TupleSchema(schemas.zipWithIndex.map(_.swap).toMap)
+
+ def apply(index: Int, schema: Schema): TupleSchema =
+ TupleSchema(collection.Map(index -> schema))
+
+ def empty: TupleSchema = TupleSchema(List.empty)
+ }
+
+ object HierarchicalAstEntity {
+ def unapply(ast: Ast): Boolean =
+ ast match {
+ case cc: CaseClass => true
+ case tup: Tuple => true
+ case _ => false
+ }
+ }
+
+ private def applySchema(q: Query): (Query, Schema) = {
+ q match {
+
+ // Don't understand why this is needed....
+ case Map(q: Query, x, p) =>
+ applySchema(q) match {
+ case (q, subSchema) =>
+ val replace =
+ trace"Looking for possible replacements for $p inside $x using schema $subSchema:" `andReturn`
+ replacements(x, subSchema)
+ val pr = BetaReduction(p, replace*)
+ traceDifferent(p, pr)
+ val prr = apply(pr)
+ traceDifferent(pr, prr)
+
+ val schema =
+ trace"Protracting Hierarchical Entity $prr into sub-schema: $subSchema" `andReturn` {
+ protractSchema(prr, x, subSchema)
+ }.getOrElse(EmptySchema)
+
+ (Map(q, x, prr), schema)
+ }
+
+ case e: Entity => (e, EntitySchema(e))
+ case Filter(q: Query, x, p) => applySchema(q, x, p, Filter.apply)
+ case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o))
+ case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy.apply)
+ case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _))
+ case Take(q: Query, n) => applySchema(q, Take(_, n))
+ case Drop(q: Query, n) => applySchema(q, Drop(_, n))
+ case Nested(q: Query) => applySchema(q, Nested.apply)
+ case Distinct(q: Query) => applySchema(q, Distinct.apply)
+ case DistinctOn(q: Query, iA, on) => applySchema(q, DistinctOn(_, iA, on))
+
+ case FlatMap(q: Query, x, p) =>
+ applySchema(q, x, p, FlatMap.apply) match {
+ case (FlatMap(q, x, p: Query), oldSchema) =>
+ val (pr, newSchema) = applySchema(p)
+ (FlatMap(q, x, pr), newSchema)
+ case (flatMap, oldSchema) =>
+ (flatMap, TupleSchema.empty)
+ }
+
+ case ConcatMap(q: Query, x, p) =>
+ applySchema(q, x, p, ConcatMap.apply) match {
+ case (ConcatMap(q, x, p: Query), oldSchema) =>
+ val (pr, newSchema) = applySchema(p)
+ (ConcatMap(q, x, pr), newSchema)
+ case (concatMap, oldSchema) =>
+ (concatMap, TupleSchema.empty)
+ }
+
+ case Join(typ, a: Query, b: Query, iA, iB, on) =>
+ (applySchema(a), applySchema(b)) match {
+ case ((a, schemaA), (b, schemaB)) =>
+ val combinedReplacements =
+ trace"Finding Replacements for $on inside ${(iA, iB)} using schemas ${(schemaA, schemaB)}:" andReturn {
+ val replaceA = replacements(iA, schemaA)
+ val replaceB = replacements(iB, schemaB)
+ replaceA ++ replaceB
+ }
+ val onr = BetaReduction(on, combinedReplacements*)
+ traceDifferent(on, onr)
+ (Join(typ, a, b, iA, iB, onr), TupleSchema(List(schemaA, schemaB)))
+ }
+
+ case FlatJoin(typ, a: Query, iA, on) =>
+ applySchema(a) match {
+ case (a, schemaA) =>
+ val replaceA =
+ trace"Finding Replacements for $on inside $iA using schema $schemaA:" `andReturn`
+ replacements(iA, schemaA)
+ val onr = BetaReduction(on, replaceA*)
+ traceDifferent(on, onr)
+ (FlatJoin(typ, a, iA, onr), schemaA)
+ }
+
+ case Map(q: Operation, x, p) if x == p =>
+ (Map(apply(q), x, p), TupleSchema.empty)
+
+ case Map(Infix(parts, params, pure, paren), x, p) =>
+ val transformed =
+ params.map {
+ case q: Query =>
+ val (qr, schema) = applySchema(q)
+ traceDifferent(q, qr)
+ (qr, Some(schema))
+ case q =>
+ (q, None)
+ }
+
+ val schema =
+ transformed.collect {
+ case (_, Some(schema)) => schema
+ } match {
+ case e :: Nil => e
+ case ls => TupleSchema(ls)
+ }
+ val replace =
+ trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
+ replacements(x, schema)
+ val pr = BetaReduction(p, replace*)
+ traceDifferent(p, pr)
+ val prr = apply(pr)
+ traceDifferent(pr, prr)
+
+ (Map(Infix(parts, transformed.map(_._1), pure, paren), x, prr), schema)
+
+ case q =>
+ (q, TupleSchema.empty)
+ }
+ }
+
+ private def applySchema(ast: Query, f: Ast => Query): (Query, Schema) =
+ applySchema(ast) match {
+ case (ast, schema) =>
+ (f(ast), schema)
+ }
+
+ private def applySchema[T](
+ q: Query,
+ x: Ident,
+ p: Ast,
+ f: (Ast, Ident, Ast) => T
+ ): (T, Schema) =
+ applySchema(q) match {
+ case (q, schema) =>
+ val replace =
+ trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
+ replacements(x, schema)
+ val pr = BetaReduction(p, replace*)
+ traceDifferent(p, pr)
+ val prr = apply(pr)
+ traceDifferent(pr, prr)
+ (f(q, x, prr), schema)
+ }
+
+ private def replacements(base: Ast, schema: Schema): Seq[(Ast, Ast)] =
+ schema match {
+ // The entity renameable property should already have been marked as Fixed
+ case EntitySchema(Entity(entity, properties)) =>
+ // trace"%4 Entity Schema: " andReturn
+ properties.flatMap {
+ // A property alias means that there was either a querySchema(tableName, _.propertyName -> PropertyAlias)
+ // or a schemaMeta (which ultimately gets turned into a querySchema) which is the same thing but implicit.
+ // In this case, we want to rename the properties based on the property aliases as well as mark
+ // them Fixed since they should not be renamed based on
+ // the naming strategy wherever they are tokenized (e.g. in SqlIdiom)
+ case PropertyAlias(path, alias) =>
+ def apply(base: Ast, path: List[String]): Ast =
+ path match {
+ case Nil => base
+ case head :: tail => apply(Property(base, head), tail)
+ }
+ List(
+ apply(base, path) -> Property.Opinionated(
+ base,
+ alias,
+ Fixed,
+ Visible
+ ) // Hidden properties cannot be renamed
+ )
+ }
+ case tup: TupleSchema =>
+ // trace"%4 Tuple Schema: " andReturn
+ tup.list.flatMap {
+ case (idx, value) =>
+ replacements(
+ // Should not matter whether property is fixed or variable here
+ // since beta reduction ignores that
+ Property(base, s"_${idx + 1}"),
+ value
+ )
+ }
+ case cc: CaseClassSchema =>
+ // trace"%4 CaseClass Schema: " andReturn
+ cc.list.flatMap {
+ case (property, value) =>
+ replacements(
+ // Should not matter whether property is fixed or variable here
+ // since beta reduction ignores that
+ Property(base, property),
+ value
+ )
+ }
+ // Do nothing if it is an empty schema
+ case EmptySchema => List()
+ }
+}
diff --git a/src/main/scala/minisql/util/Replacements.scala b/src/main/scala/minisql/norm/Replacements.scala
similarity index 98%
rename from src/main/scala/minisql/util/Replacements.scala
rename to src/main/scala/minisql/norm/Replacements.scala
index f0982e2..4b4a955 100644
--- a/src/main/scala/minisql/util/Replacements.scala
+++ b/src/main/scala/minisql/norm/Replacements.scala
@@ -1,4 +1,4 @@
-package minisql.util
+package minisql.norm
import minisql.ast.Ast
import scala.collection.immutable.Map
diff --git a/src/main/scala/minisql/norm/SimplifyNullChecks.scala b/src/main/scala/minisql/norm/SimplifyNullChecks.scala
new file mode 100644
index 0000000..a49b949
--- /dev/null
+++ b/src/main/scala/minisql/norm/SimplifyNullChecks.scala
@@ -0,0 +1,124 @@
+package minisql.norm
+
+import minisql.ast.*
+import minisql.norm.EqualityBehavior.AnsiEquality
+
+/**
+ * Due to the introduction of null checks in `map`, `flatMap`, and `exists`, in
+ * `FlattenOptionOperation` in order to resolve #1053, as well as to support
+ * non-ansi compliant string concatenation as outlined in #1295, large
+ * conditional composites became common. For example:
Now, let's add a
+ * case class
+ * Holder(value:Option[String])
+ *
+ * // The following statement query[Holder].map(h => h.value.map(_ + "foo")) //
+ * Will yield the following result SELECT CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END FROM Holder h
getOrElse
statement to the clause that requires an additional
+ * wrapped null check. We cannot rely on there being a map
call
+ * beforehand since we could be reading value
as a nullable field
+ * directly from the database).
+ * This of course is highly redundant and can be reduced to simply: // The following statement
+ * query[Holder].map(h => h.value.map(_ + "foo").getOrElse("bar")) // Yields the
+ * following result: SELECT CASE WHEN CASE WHEN h.value IS NOT NULL THEN h.value
+ * \|| 'foo' ELSE null END IS NOT NULL THEN CASE WHEN h.value IS NOT NULL THEN
+ * h.value || 'foo' ELSE null END ELSE 'bar' END FROM Holder h
This reduction is
+ * done by the "Center Rule." There are some other simplification rules as well.
+ * Note how we are force to null-check both `h.value` as well as `(h.value ||
+ * 'foo')` because a user may use `Option[T].flatMap` and explicitly transform a
+ * particular value to `null`.
+ */
+class SimplifyNullChecks(equalityBehavior: EqualityBehavior)
+ extends StatelessTransformer {
+
+ override def apply(ast: Ast): Ast = {
+ import minisql.ast.Implicits.*
+ ast match {
+ // Center rule
+ case IfExist(
+ IfExistElseNull(condA, thenA),
+ IfExistElseNull(condB, thenB),
+ otherwise
+ ) if (condA == condB && thenA == thenB) =>
+ apply(
+ If(IsNotNullCheck(condA) +&&+ IsNotNullCheck(thenA), thenA, otherwise)
+ )
+
+ // Left hand rule
+ case IfExist(IfExistElseNull(check, affirm), value, otherwise) =>
+ apply(
+ If(
+ IsNotNullCheck(check) +&&+ IsNotNullCheck(affirm),
+ value,
+ otherwise
+ )
+ )
+
+ // Right hand rule
+ case IfExistElseNull(cond, IfExistElseNull(innerCond, innerThen)) =>
+ apply(
+ If(
+ IsNotNullCheck(cond) +&&+ IsNotNullCheck(innerCond),
+ innerThen,
+ NullValue
+ )
+ )
+
+ case OptionIsDefined(Optional(a)) +&&+ OptionIsDefined(
+ Optional(b)
+ ) +&&+ (exp @ (Optional(a1) `== or !=` Optional(b1)))
+ if (a == a1 && b == b1 && equalityBehavior == AnsiEquality) =>
+ apply(exp)
+
+ case OptionIsDefined(Optional(a)) +&&+ (exp @ (Optional(
+ a1
+ ) `== or !=` Optional(_)))
+ if (a == a1 && equalityBehavior == AnsiEquality) =>
+ apply(exp)
+ case OptionIsDefined(Optional(b)) +&&+ (exp @ (Optional(
+ _
+ ) `== or !=` Optional(b1)))
+ if (b == b1 && equalityBehavior == AnsiEquality) =>
+ apply(exp)
+
+ case (left +&&+ OptionIsEmpty(Optional(Constant(_)))) +||+ other =>
+ apply(other)
+ case (OptionIsEmpty(Optional(Constant(_))) +&&+ right) +||+ other =>
+ apply(other)
+ case other +||+ (left +&&+ OptionIsEmpty(Optional(Constant(_)))) =>
+ apply(other)
+ case other +||+ (OptionIsEmpty(Optional(Constant(_))) +&&+ right) =>
+ apply(other)
+
+ case (left +&&+ OptionIsDefined(Optional(Constant(_)))) => apply(left)
+ case (OptionIsDefined(Optional(Constant(_))) +&&+ right) => apply(right)
+ case (left +||+ OptionIsEmpty(Optional(Constant(_)))) => apply(left)
+ case (OptionIsEmpty(OptionSome(Optional(_))) +||+ right) => apply(right)
+
+ case other =>
+ super.apply(other)
+ }
+ }
+
+ object `== or !=` {
+ def unapply(ast: Ast): Option[(Ast, Ast)] = ast match {
+ case a +==+ b => Some((a, b))
+ case a +!=+ b => Some((a, b))
+ case _ => None
+ }
+ }
+
+ /**
+ * Simple extractor that looks inside of an optional values to see if the
+ * thing inside can be pulled out. If not, it just returns whatever element it
+ * can find.
+ */
+ object Optional {
+ def unapply(a: Ast): Option[Ast] = a match {
+ case OptionApply(value) => Some(value)
+ case OptionSome(value) => Some(value)
+ case value => Some(value)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/norm/SymbolicReduction.scala b/src/main/scala/minisql/norm/SymbolicReduction.scala
new file mode 100644
index 0000000..d7e8965
--- /dev/null
+++ b/src/main/scala/minisql/norm/SymbolicReduction.scala
@@ -0,0 +1,38 @@
+package minisql.norm
+
+import minisql.ast.Filter
+import minisql.ast.FlatMap
+import minisql.ast.Query
+import minisql.ast.Union
+import minisql.ast.UnionAll
+
+object SymbolicReduction {
+
+ def unapply(q: Query) =
+ q match {
+
+ // a.filter(b => c).flatMap(d => e.$) =>
+ // a.flatMap(d => e.filter(_ => c[b := d]).$)
+ case FlatMap(Filter(a, b, c), d, e: Query) =>
+ val cr = BetaReduction(c, b -> d)
+ val er = AttachToEntity(Filter(_, _, cr))(e)
+ Some(FlatMap(a, d, er))
+
+ // a.flatMap(b => c).flatMap(d => e) =>
+ // a.flatMap(b => c.flatMap(d => e))
+ case FlatMap(FlatMap(a, b, c), d, e) =>
+ Some(FlatMap(a, b, FlatMap(c, d, e)))
+
+ // a.union(b).flatMap(c => d)
+ // a.flatMap(c => d).union(b.flatMap(c => d))
+ case FlatMap(Union(a, b), c, d) =>
+ Some(Union(FlatMap(a, c, d), FlatMap(b, c, d)))
+
+ // a.unionAll(b).flatMap(c => d)
+ // a.flatMap(c => d).unionAll(b.flatMap(c => d))
+ case FlatMap(UnionAll(a, b), c, d) =>
+ Some(UnionAll(FlatMap(a, c, d), FlatMap(b, c, d)))
+
+ case other => None
+ }
+}
diff --git a/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala
new file mode 100644
index 0000000..ed45c62
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala
@@ -0,0 +1,174 @@
+package minisql.norm.capture
+
+import minisql.ast.{
+ Entity,
+ Filter,
+ FlatJoin,
+ FlatMap,
+ GroupBy,
+ Ident,
+ Join,
+ Map,
+ Query,
+ SortBy,
+ StatefulTransformer,
+ _
+}
+import minisql.norm.{BetaReduction, Normalize}
+import scala.collection.immutable.Set
+
+private[minisql] case class AvoidAliasConflict(state: Set[Ident])
+ extends StatefulTransformer[Set[Ident]] {
+
+ object Unaliased {
+
+ private def isUnaliased(q: Ast): Boolean =
+ q match {
+ case Nested(q: Query) => isUnaliased(q)
+ case Take(q: Query, _) => isUnaliased(q)
+ case Drop(q: Query, _) => isUnaliased(q)
+ case Aggregation(_, q: Query) => isUnaliased(q)
+ case Distinct(q: Query) => isUnaliased(q)
+ case _: Entity | _: Infix => true
+ case _ => false
+ }
+
+ def unapply(q: Ast): Option[Ast] =
+ q match {
+ case q if (isUnaliased(q)) => Some(q)
+ case _ => None
+ }
+ }
+
+ override def apply(q: Query): (Query, StatefulTransformer[Set[Ident]]) =
+ q match {
+
+ case FlatMap(Unaliased(q), x, p) =>
+ apply(x, p)(FlatMap(q, _, _))
+
+ case ConcatMap(Unaliased(q), x, p) =>
+ apply(x, p)(ConcatMap(q, _, _))
+
+ case Map(Unaliased(q), x, p) =>
+ apply(x, p)(Map(q, _, _))
+
+ case Filter(Unaliased(q), x, p) =>
+ apply(x, p)(Filter(q, _, _))
+
+ case SortBy(Unaliased(q), x, p, o) =>
+ apply(x, p)(SortBy(q, _, _, o))
+
+ case GroupBy(Unaliased(q), x, p) =>
+ apply(x, p)(GroupBy(q, _, _))
+
+ case DistinctOn(Unaliased(q), x, p) =>
+ apply(x, p)(DistinctOn(q, _, _))
+
+ case Join(t, a, b, iA, iB, o) =>
+ val (ar, art) = apply(a)
+ val (br, brt) = art.apply(b)
+ val freshA = freshIdent(iA, brt.state)
+ val freshB = freshIdent(iB, brt.state + freshA)
+ val or = BetaReduction(o, iA -> freshA, iB -> freshB)
+ val (orr, orrt) = AvoidAliasConflict(brt.state + freshA + freshB)(or)
+ (Join(t, ar, br, freshA, freshB, orr), orrt)
+
+ case FlatJoin(t, a, iA, o) =>
+ val (ar, art) = apply(a)
+ val freshA = freshIdent(iA)
+ val or = BetaReduction(o, iA -> freshA)
+ val (orr, orrt) = AvoidAliasConflict(art.state + freshA)(or)
+ (FlatJoin(t, ar, freshA, orr), orrt)
+
+ case _: Entity | _: FlatMap | _: ConcatMap | _: Map | _: Filter |
+ _: SortBy | _: GroupBy | _: Aggregation | _: Take | _: Drop |
+ _: Union | _: UnionAll | _: Distinct | _: DistinctOn | _: Nested =>
+ super.apply(q)
+ }
+
+ private def apply(x: Ident, p: Ast)(
+ f: (Ident, Ast) => Query
+ ): (Query, StatefulTransformer[Set[Ident]]) = {
+ val fresh = freshIdent(x)
+ val pr = BetaReduction(p, x -> fresh)
+ val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
+ (f(fresh, prr), t)
+ }
+
+ private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = {
+ def loop(x: Ident, n: Int): Ident = {
+ val fresh = Ident(s"${x.name}$n")
+ if (!state.contains(fresh))
+ fresh
+ else
+ loop(x, n + 1)
+ }
+ if (!state.contains(x))
+ x
+ else
+ loop(x, 1)
+ }
+
+ /**
+ * Sometimes we need to change the variables in a function because they will
+ * might conflict with some variable further up in the macro. Right now, this
+ * only happens when you do something like this:
+ * SELECT CASE WHEN h.value IS NOT NULL AND (h.value || 'foo') IS NOT NULL THEN
+ * h.value || 'foo' ELSE 'bar' END FROM Holder h
val q = quote { (v:
+ * Foo) => query[Foo].insert(v) } run(q(lift(v)))
Since 'v' is used by
+ * actionMeta in order to map keys to values for insertion, using it as a
+ * function argument messes up the output SQL like so: INSERT INTO
+ * MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?)
+ * Therefore, we need to have a method to remove such conflicting variables
+ * from Function ASTs
+ */
+ private def applyFunction(f: Function): Function = {
+ val (newBody, _, newParams) =
+ f.params.foldLeft((f.body, state, List[Ident]())) {
+ case ((body, state, newParams), param) => {
+ val fresh = freshIdent(param)
+ val pr = BetaReduction(body, param -> fresh)
+ val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
+ (prr, t.state, newParams :+ fresh)
+ }
+ }
+ Function(newParams, newBody)
+ }
+
+ private def applyForeach(f: Foreach): Foreach = {
+ val fresh = freshIdent(f.alias)
+ val pr = BetaReduction(f.body, f.alias -> fresh)
+ val (prr, _) = AvoidAliasConflict(state + fresh)(pr)
+ Foreach(f.query, fresh, prr)
+ }
+}
+
+private[minisql] object AvoidAliasConflict {
+
+ def apply(q: Query): Query =
+ AvoidAliasConflict(Set[Ident]())(q) match {
+ case (q, _) => q
+ }
+
+ /**
+ * Make sure query parameters do not collide with paramters of a AST function.
+ * Do this by walkning through the function's subtree and transforming and
+ * queries encountered.
+ */
+ def sanitizeVariables(
+ f: Function,
+ dangerousVariables: Set[Ident]
+ ): Function = {
+ AvoidAliasConflict(dangerousVariables).applyFunction(f)
+ }
+
+ /** Same is `sanitizeVariables` but for Foreach * */
+ def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = {
+ AvoidAliasConflict(dangerousVariables).applyForeach(f)
+ }
+
+ def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = {
+ AvoidAliasConflict(dangerousVariables).apply(q) match {
+ // Propagate aliasing changes to the rest of the query
+ case (q, _) => Normalize(q)
+ }
+ }
+}
diff --git a/src/main/scala/minisql/norm/capture/AvoidCapture.scala b/src/main/scala/minisql/norm/capture/AvoidCapture.scala
new file mode 100644
index 0000000..788b551
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/AvoidCapture.scala
@@ -0,0 +1,9 @@
+package minisql.norm.capture
+
+import minisql.ast.Query
+
+object AvoidCapture {
+
+ def apply(q: Query): Query =
+ Dealias(AvoidAliasConflict(q))
+}
diff --git a/src/main/scala/minisql/norm/capture/Dealias.scala b/src/main/scala/minisql/norm/capture/Dealias.scala
new file mode 100644
index 0000000..64a56ce
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/Dealias.scala
@@ -0,0 +1,72 @@
+package minisql.norm.capture
+
+import minisql.ast._
+import minisql.norm.BetaReduction
+
+case class Dealias(state: Option[Ident])
+ extends StatefulTransformer[Option[Ident]] {
+
+ override def apply(q: Query): (Query, StatefulTransformer[Option[Ident]]) =
+ q match {
+ case FlatMap(a, b, c) =>
+ dealias(a, b, c)(FlatMap.apply) match {
+ case (FlatMap(a, b, c), _) =>
+ val (cn, cnt) = apply(c)
+ (FlatMap(a, b, cn), cnt)
+ }
+ case ConcatMap(a, b, c) =>
+ dealias(a, b, c)(ConcatMap.apply) match {
+ case (ConcatMap(a, b, c), _) =>
+ val (cn, cnt) = apply(c)
+ (ConcatMap(a, b, cn), cnt)
+ }
+ case Map(a, b, c) =>
+ dealias(a, b, c)(Map.apply)
+ case Filter(a, b, c) =>
+ dealias(a, b, c)(Filter.apply)
+ case SortBy(a, b, c, d) =>
+ dealias(a, b, c)(SortBy(_, _, _, d))
+ case GroupBy(a, b, c) =>
+ dealias(a, b, c)(GroupBy.apply)
+ case DistinctOn(a, b, c) =>
+ dealias(a, b, c)(DistinctOn.apply)
+ case Take(a, b) =>
+ val (an, ant) = apply(a)
+ (Take(an, b), ant)
+ case Drop(a, b) =>
+ val (an, ant) = apply(a)
+ (Drop(an, b), ant)
+ case Union(a, b) =>
+ val (an, _) = apply(a)
+ val (bn, _) = apply(b)
+ (Union(an, bn), Dealias(None))
+ case UnionAll(a, b) =>
+ val (an, _) = apply(a)
+ val (bn, _) = apply(b)
+ (UnionAll(an, bn), Dealias(None))
+ case Join(t, a, b, iA, iB, o) =>
+ val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _))
+ val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _))
+ (Join(t, an, bn, iAn, iBn, onn), Dealias(None))
+ case FlatJoin(t, a, iA, o) =>
+ val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _))
+ (FlatJoin(t, an, iAn, on), Dealias(Some(iA)))
+ case _: Entity | _: Distinct | _: Aggregation | _: Nested =>
+ (q, Dealias(None))
+ }
+
+ private def dealias[T](a: Ast, b: Ident, c: Ast)(f: (Ast, Ident, Ast) => T) =
+ apply(a) match {
+ case (an, t @ Dealias(Some(alias))) =>
+ (f(an, alias, BetaReduction(c, b -> alias)), t)
+ case other =>
+ (f(a, b, c), Dealias(Some(b)))
+ }
+}
+
+object Dealias {
+ def apply(query: Query) =
+ new Dealias(None)(query) match {
+ case (q, _) => q
+ }
+}
diff --git a/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala
new file mode 100644
index 0000000..7c95f7c
--- /dev/null
+++ b/src/main/scala/minisql/norm/capture/DemarcateExternalAliases.scala
@@ -0,0 +1,98 @@
+package minisql.norm.capture
+
+import minisql.ast.*
+
+/**
+ * Walk through any Queries that a returning clause has and replace Ident of the
+ * returning variable with ExternalIdent so that in later steps involving filter
+ * simplification, it will not be mistakenly dealiased with a potential shadow.
+ * Take this query for instance:
query[TestEntity] + * .insert(lift(TestEntity("s", 0, 1L, None))) .returningGenerated( r => + * (query[Dummy].filter(r => r.i == r.i).filter(d => d.i == r.i).max) )+ * The returning clause has an alias `Ident("r")` as well as the first filter + * clause. These two filters will be combined into one at which point the + * meaning of `r.i` in the 2nd filter will be confused for the first filter's + * alias (i.e. the `r` in `filter(r => ...)`. Therefore, we need to change this + * vunerable `r.i` in the second filter clause to an `ExternalIdent` before any + * of the simplifications are done. + * + * Note that we only want to do this for Queries inside of a `Returning` clause + * body. Other places where this needs to be done (e.g. in a Tuple that + * `Returning` returns) are done in `ExpandReturning`. + */ +private[minisql] case class DemarcateExternalAliases(externalIdent: Ident) + extends StatelessTransformer { + + def applyNonOverride(idents: Ident*)(ast: Ast) = + if (idents.forall(_ != externalIdent)) apply(ast) + else ast + + override def apply(ast: Ast): Ast = ast match { + + case FlatMap(q, i, b) => + FlatMap(apply(q), i, applyNonOverride(i)(b)) + + case ConcatMap(q, i, b) => + ConcatMap(apply(q), i, applyNonOverride(i)(b)) + + case Map(q, i, b) => + Map(apply(q), i, applyNonOverride(i)(b)) + + case Filter(q, i, b) => + Filter(apply(q), i, applyNonOverride(i)(b)) + + case SortBy(q, i, p, o) => + SortBy(apply(q), i, applyNonOverride(i)(p), o) + + case GroupBy(q, i, b) => + GroupBy(apply(q), i, applyNonOverride(i)(b)) + + case DistinctOn(q, i, b) => + DistinctOn(apply(q), i, applyNonOverride(i)(b)) + + case Join(t, a, b, iA, iB, o) => + Join(t, a, b, iA, iB, applyNonOverride(iA, iB)(o)) + + case FlatJoin(t, a, iA, o) => + FlatJoin(t, a, iA, applyNonOverride(iA)(o)) + + case p @ Property.Opinionated( + id @ Ident(_), + value, + renameable, + visibility + ) => + if (id == externalIdent) + Property.Opinionated( + ExternalIdent(externalIdent.name), + value, + renameable, + visibility + ) + else + p + + case other => + super.apply(other) + } +} + +object DemarcateExternalAliases { + + private def demarcateQueriesInBody(id: Ident, body: Ast) = + Transform(body) { + // Apply to the AST defined apply method about, not to the superclass method that takes Query + case q: Query => + new DemarcateExternalAliases(id).apply(q.asInstanceOf[Ast]) + } + + def apply(ast: Ast): Ast = ast match { + case Returning(a, id, body) => + Returning(a, id, demarcateQueriesInBody(id, body)) + case ReturningGenerated(a, id, body) => + val d = demarcateQueriesInBody(id, body) + ReturningGenerated(a, id, demarcateQueriesInBody(id, body)) + case other => + other + } +} diff --git a/src/main/scala/minisql/parsing/BlockParsing.scala b/src/main/scala/minisql/parsing/BlockParsing.scala new file mode 100644 index 0000000..ae8722c --- /dev/null +++ b/src/main/scala/minisql/parsing/BlockParsing.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +type SParser[X] = + (q: Quotes) ?=> PartialFunction[q.reflect.Statement, Expr[X]] + +private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using + Quotes +): SParser[ast.Ast] = { + + import quotes.reflect.* + + @annotation.nowarn + lazy val valDefParser: SParser[ast.Val] = { + case ValDef(n, _, Some(b)) => + val body = astParser(b.asExpr) + '{ ast.Val(ast.Ident(${ Expr(n) }), $body) } + + } + valDefParser +} + +private[parsing] def blockParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + lazy val statementParser = statementParsing(astParser) + + termParser { + case Block(Nil, t) => astParser(t.asExpr) + case b @ Block(st, t) => + val asts = (st :+ t).map { + case e if e.isExpr => astParser(e.asExpr) + case `statementParser`(x) => x + case o => + report.errorAndAbort(s"Cannot parse statement: ${o.show}") + } + if (asts.size > 1) { + '{ ast.Block(${ Expr.ofList(asts) }) } + } else asts(0) + + } +} diff --git a/src/main/scala/minisql/parsing/BoxingParsing.scala b/src/main/scala/minisql/parsing/BoxingParsing.scala new file mode 100644 index 0000000..d7b7f31 --- /dev/null +++ b/src/main/scala/minisql/parsing/BoxingParsing.scala @@ -0,0 +1,31 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def boxingParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + case '{ BigDecimal.int2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.long2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.double2bigDecimal($v) } => astParser(v) + case '{ BigDecimal.javaBigDecimal2bigDecimal($v) } => astParser(v) + case '{ Predef.byte2Byte($v) } => astParser(v) + case '{ Predef.short2Short($v) } => astParser(v) + case '{ Predef.char2Character($v) } => astParser(v) + case '{ Predef.int2Integer($v) } => astParser(v) + case '{ Predef.long2Long($v) } => astParser(v) + case '{ Predef.float2Float($v) } => astParser(v) + case '{ Predef.double2Double($v) } => astParser(v) + case '{ Predef.boolean2Boolean($v) } => astParser(v) + case '{ Predef.augmentString($v) } => astParser(v) + case '{ Predef.Byte2byte($v) } => astParser(v) + case '{ Predef.Short2short($v) } => astParser(v) + case '{ Predef.Character2char($v) } => astParser(v) + case '{ Predef.Integer2int($v) } => astParser(v) + case '{ Predef.Long2long($v) } => astParser(v) + case '{ Predef.Float2float($v) } => astParser(v) + case '{ Predef.Double2double($v) } => astParser(v) + case '{ Predef.Boolean2boolean($v) } => astParser(v) + +} diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala new file mode 100644 index 0000000..7b173db --- /dev/null +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -0,0 +1,13 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted.* + +private[parsing] def infixParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Infix] = { + + import quotes.reflect.* + ??? +} diff --git a/src/main/scala/minisql/parsing/LiftParsing.scala b/src/main/scala/minisql/parsing/LiftParsing.scala new file mode 100644 index 0000000..9a0f32b --- /dev/null +++ b/src/main/scala/minisql/parsing/LiftParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import scala.quoted.* +import minisql.ParamEncoder +import minisql.ast +import minisql.* + +private[parsing] def liftParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Lift] = { + case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } => + import quotes.reflect.* + val name = x.asTerm.symbol.fullName + val liftId = x.asTerm.symbol.owner.fullName + "@" + name + '{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) } +} diff --git a/src/main/scala/minisql/parsing/OperationParsing.scala b/src/main/scala/minisql/parsing/OperationParsing.scala new file mode 100644 index 0000000..4425dd5 --- /dev/null +++ b/src/main/scala/minisql/parsing/OperationParsing.scala @@ -0,0 +1,113 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.{ + EqualityOperator, + StringOperator, + NumericOperator, + BooleanOperator +} +import minisql.* +import scala.quoted._ + +private[parsing] def operationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Operation] = { + import quotes.reflect.* + + def isNumeric(t: TypeRepr) = { + t <:< TypeRepr.of[Int] + || t <:< TypeRepr.of[Long] + || t <:< TypeRepr.of[Byte] + || t <:< TypeRepr.of[Float] + || t <:< TypeRepr.of[Double] + || t <:< TypeRepr.of[java.math.BigDecimal] + || t <:< TypeRepr.of[scala.math.BigDecimal] + } + + def parseBinary( + left: Expr[Any], + right: Expr[Any], + op: Expr[ast.BinaryOperator] + ) = { + val leftE = astParser(left) + val rightE = astParser(right) + '{ ast.BinaryOperation(${ leftE }, ${ op }, ${ rightE }) } + } + + def parseUnary(expr: Expr[Any], op: Expr[ast.UnaryOperator]) = { + val base = astParser(expr) + '{ ast.UnaryOperation($op, $base) } + + } + + val universalOpParser: Parser[ast.BinaryOperation] = termParser { + case Apply(Select(leftT, UniversalOp(op)), List(rightT)) => + parseBinary(leftT.asExpr, rightT.asExpr, op) + } + + val stringOpParser: Parser[ast.Operation] = { + case '{ ($x: String) + ($y: String) } => + parseBinary(x, y, '{ StringOperator.concat }) + case '{ ($x: String).startsWith($y) } => + parseBinary(x, y, '{ StringOperator.startsWith }) + case '{ ($x: String).split($y) } => + parseBinary(x, y, '{ StringOperator.split }) + case '{ ($x: String).toUpperCase } => + parseUnary(x, '{ StringOperator.toUpperCase }) + case '{ ($x: String).toLowerCase } => + parseUnary(x, '{ StringOperator.toLowerCase }) + case '{ ($x: String).toLong } => + parseUnary(x, '{ StringOperator.toLong }) + case '{ ($x: String).toInt } => + parseUnary(x, '{ StringOperator.toInt }) + } + + val numericOpParser = termParser { + case (Apply(Select(lt, NumericOp(op)), List(rt))) if isNumeric(lt.tpe) => + parseBinary(lt.asExpr, rt.asExpr, op) + case Select(leftTerm, "unary_-") if isNumeric(leftTerm.tpe) => + val leftExpr = astParser(leftTerm.asExpr) + '{ ast.UnaryOperation(NumericOperator.-, ${ leftExpr }) } + + } + + val booleanOpParser: Parser[ast.Operation] = { + case '{ ($x: Boolean) && $y } => + parseBinary(x, y, '{ BooleanOperator.&& }) + case '{ ($x: Boolean) || $y } => + parseBinary(x, y, '{ BooleanOperator.|| }) + case '{ !($x: Boolean) } => + parseUnary(x, '{ BooleanOperator.! }) + } + + universalOpParser + .orElse(stringOpParser) + .orElse(numericOpParser) + .orElse(booleanOpParser) +} + +private object UniversalOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "==" | "equals" => Some('{ EqualityOperator.== }) + case "!=" => Some('{ EqualityOperator.!= }) + case _ => None + } +} + +private object NumericOp { + def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] = + op match { + case "+" => Some('{ NumericOperator.+ }) + case "-" => Some('{ NumericOperator.- }) + case "*" => Some('{ NumericOperator.* }) + case "/" => Some('{ NumericOperator./ }) + case ">" => Some('{ NumericOperator.> }) + case ">=" => Some('{ NumericOperator.>= }) + case "<" => Some('{ NumericOperator.< }) + case "<=" => Some('{ NumericOperator.<= }) + case "%" => Some('{ NumericOperator.% }) + case _ => None + } +} diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala new file mode 100644 index 0000000..91bfbc0 --- /dev/null +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -0,0 +1,47 @@ +package minisql.parsing + +import minisql.ast +import minisql.ast.Ast +import scala.quoted.* + +private[minisql] inline def parseParamAt[F]( + inline f: F, + inline n: Int +): ast.Ident = ${ + parseParamAt('f, 'n) +} + +private[minisql] inline def parseBody[X]( + inline f: X +): ast.Ast = ${ + parseBody('f) +} + +private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using + Quotes +): Expr[ast.Ident] = { + + import quotes.reflect.* + + val pIdx = n.value.getOrElse( + report.errorAndAbort(s"Param index ${n.show} is not know") + ) + extractTerm(f.asTerm) match { + case Lambda(vals, _) => + vals(pIdx) match { + case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) } + } + } +} + +private[minisql] def parseBody[X]( + x: Expr[X] +)(using Quotes): Expr[Ast] = { + import quotes.reflect.* + extractTerm(x.asTerm) match { + case Lambda(vals, body) => + Parsing.parseExpr(body.asExpr) + case o => + report.errorAndAbort(s"Can only parse function") + } +} diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala new file mode 100644 index 0000000..07da46a --- /dev/null +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -0,0 +1,139 @@ +package minisql.parsing + +import minisql.ast +import minisql.context.{ReturningMultipleFieldSupported, _} +import minisql.norm.BetaReduction +import minisql.norm.capture.AvoidAliasConflict +import minisql.idiom.Idiom +import scala.annotation.tailrec +import minisql.ast.Implicits._ +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.{Hidden, Visible} +import minisql.util.Interleave +import scala.quoted.* + +type Parser[A] = PartialFunction[Expr[Any], Expr[A]] + +private def termParser[A](using q: Quotes)( + pf: PartialFunction[q.reflect.Term, Expr[A]] +): Parser[A] = { + import quotes.reflect._ + { + case e if pf.isDefinedAt(e.asTerm) => pf(e.asTerm) + } +} + +private def parser[A]( + f: PartialFunction[Expr[Any], Expr[A]] +)(using Quotes): Parser[A] = { + case e if f.isDefinedAt(e) => f(e) +} + +private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = { + import quotes.reflect.* + def unwrapTerm(t: Term): Term = t match { + case Inlined(_, _, o) => unwrapTerm(o) + case Block(Nil, last) => last + case Typed(t, _) => + unwrapTerm(t) + case Select(t, "$asInstanceOf$") => + unwrapTerm(t) + case TypeApply(t, _) => + unwrapTerm(t) + case o => o + } + unwrapTerm(x) +} + +private[minisql] object Parsing { + + def parseExpr( + expr: Expr[?] + )(using q: Quotes): Expr[ast.Ast] = { + + import q.reflect._ + + def unwrapped( + f: Parser[ast.Ast] + ): Parser[ast.Ast] = { + case expr => + val t = expr.asTerm + f(extractTerm(t).asExpr) + } + + lazy val astParser: Parser[ast.Ast] = + unwrapped { + typedParser + .orElse(propertyParser) + .orElse(liftParser) + .orElse(identParser) + .orElse(valueParser) + .orElse(operationParser) + .orElse(constantParser) + .orElse(blockParser) + .orElse(boxingParser) + .orElse(ifParser) + .orElse(traversableOperationParser) + .orElse(patMatchParser) + // .orElse(infixParser) + .orElse { + case o => + val str = scala.util.Try(o.show).getOrElse("") + report.errorAndAbort( + s"cannot parse ${str}", + o.asTerm.pos + ) + } + } + + lazy val typedParser: Parser[ast.Ast] = termParser { + case (Typed(t, _)) => + astParser(t.asExpr) + } + + lazy val blockParser: Parser[ast.Ast] = blockParsing(astParser) + + lazy val valueParser: Parser[ast.Value] = valueParsing(astParser) + + lazy val liftParser: Parser[ast.Lift] = liftParsing(astParser) + + lazy val constantParser: Parser[ast.Constant] = termParser { + case Literal(x) => + '{ ast.Constant(${ Literal(x).asExpr }) } + } + + lazy val identParser: Parser[ast.Ident] = termParser { + case x @ Ident(n) if x.symbol.isValDef => + '{ ast.Ident(${ Expr(n) }) } + } + + lazy val propertyParser: Parser[ast.Property] = propertyParsing(astParser) + + lazy val operationParser: Parser[ast.Operation] = operationParsing( + astParser + ) + + lazy val boxingParser: Parser[ast.Ast] = boxingParsing(astParser) + + lazy val ifParser: Parser[ast.If] = { + case '{ if ($a) $b else $c } => + '{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) } + + } + lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser) + + // lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser) + + lazy val traversableOperationParser: Parser[ast.IterableOperation] = + traversableOperationParsing(astParser) + + astParser(expr) + } + + private[minisql] inline def parse[A]( + inline a: A + ): ast.Ast = ${ + parseExpr('a) + } + +} diff --git a/src/main/scala/minisql/parsing/PatMatchParsing.scala b/src/main/scala/minisql/parsing/PatMatchParsing.scala new file mode 100644 index 0000000..2db7652 --- /dev/null +++ b/src/main/scala/minisql/parsing/PatMatchParsing.scala @@ -0,0 +1,49 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private[parsing] def patMatchParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Ast] = { + + import quotes.reflect.* + + termParser { + // Val defs that showd pattern variables will cause error + case e @ Match(t, List(CaseDef(IsTupleUnapply(binds), None, body))) => + val bm = binds.zipWithIndex.map { + case (Bind(n, ident), idx) => + n -> Select.unique(t, s"_${idx + 1}") + }.toMap + val tm = new TreeMap { + override def transformTerm(tree: Term)(owner: Symbol): Term = { + tree match { + case Ident(n) => bm(n) + case o => super.transformTerm(o)(owner) + } + } + } + val newBody = tm.transformTree(body)(e.symbol) + astParser(newBody.asExpr) + } + +} + +object IsTupleUnapply { + + def unapply(using + Quotes + )(t: quotes.reflect.Tree): Option[List[quotes.reflect.Tree]] = { + import quotes.reflect.* + def isTupleNUnapply(x: Term) = { + val fn = x.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.unapply") + } + t match { + case Unapply(m, _, binds) if isTupleNUnapply(m) => + Some(binds) + case _ => None + } + } +} diff --git a/src/main/scala/minisql/parsing/PropertyParsing.scala b/src/main/scala/minisql/parsing/PropertyParsing.scala new file mode 100644 index 0000000..292f281 --- /dev/null +++ b/src/main/scala/minisql/parsing/PropertyParsing.scala @@ -0,0 +1,30 @@ +package minisql.parsing + +import minisql.ast +import minisql.dsl.* +import scala.quoted._ + +private[parsing] def propertyParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Property] = { + import quotes.reflect.* + + def isAccessor(s: Select) = { + s.qualifier.tpe.typeSymbol.caseFields.exists(cf => cf.name == s.name) + } + + val parseApply: Parser[ast.Property] = termParser { + case m @ Select(base, n) if isAccessor(m) => + val obj = astParser(base.asExpr) + '{ ast.Property($obj, ${ Expr(n) }) } + } + + val parseOptionGet: Parser[ast.Property] = { + case '{ ($e: Option[t]).get } => + report.errorAndAbort( + "Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead." + ) + } + parseApply.orElse(parseOptionGet) + +} diff --git a/src/main/scala/minisql/parsing/TraversableOperationParsing.scala b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala new file mode 100644 index 0000000..8f50098 --- /dev/null +++ b/src/main/scala/minisql/parsing/TraversableOperationParsing.scala @@ -0,0 +1,16 @@ +package minisql.parsing + +import minisql.ast +import scala.quoted.* + +private def traversableOperationParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.IterableOperation] = { + case '{ type k; type v; (${ m }: Map[`k`, `v`]).contains($key) } => + '{ ast.MapContains(${ astParser(m) }, ${ astParser(key) }) } + case '{ ($s: Set[e]).contains($i) } => + '{ ast.SetContains(${ astParser(s) }, ${ astParser(i) }) } + case '{ ($s: Seq[e]).contains($i) } => + '{ ast.ListContains(${ astParser(s) }, ${ astParser(i) }) } + +} diff --git a/src/main/scala/minisql/parsing/ValueParsing.scala b/src/main/scala/minisql/parsing/ValueParsing.scala new file mode 100644 index 0000000..6c2fb9e --- /dev/null +++ b/src/main/scala/minisql/parsing/ValueParsing.scala @@ -0,0 +1,71 @@ +package minisql +package parsing + +import scala.quoted._ + +private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using + Quotes +): Parser[ast.Value] = { + + import quotes.reflect.* + + val parseTupleApply: Parser[ast.Tuple] = { + case IsTupleApply(args) => + val t = args.map(astParser) + '{ ast.Tuple(${ Expr.ofList(t) }) } + } + + val parseAssocTuple: Parser[ast.Tuple] = { + case '{ ($x: tx) -> ($y: ty) } => + '{ ast.Tuple(List(${ astParser(x) }, ${ astParser(y) })) } + } + + parseTupleApply.orElse(parseAssocTuple) +} + +private[minisql] object IsTupleApply { + + def unapply(e: Expr[Any])(using Quotes): Option[Seq[Expr[Any]]] = { + + import quotes.reflect.* + + def isTupleNApply(t: Term) = { + val fn = t.symbol.fullName + fn.startsWith("scala.Tuple") && fn.endsWith("$.apply") + } + + def isTupleXXLApply(t: Term) = { + t.symbol.fullName == "scala.runtime.TupleXXL$.apply" + } + + extractTerm(e.asTerm) match { + // TupleN(0-22).apply + case Apply(b, args) if isTupleNApply(b) => + Some(args.map(_.asExpr)) + case TypeApply(Select(t, "$asInstanceOf$"), tt) if isTupleXXLApply(t) => + t.asExpr match { + case '{ scala.runtime.TupleXXL.apply(${ Varargs(args) }*) } => + Some(args) + } + case o => + None + } + } +} + +private[parsing] object IsTuple2 { + + def unapply(using + Quotes + )(t: Expr[Any]): Option[(Expr[Any], Expr[Any])] = + t match { + case '{ scala.Tuple2.apply($x1, $x2) } => Some((x1, x2)) + case '{ ($x1: t1) -> ($x2: t2) } => Some((x1, x2)) + case _ => None + } + + def unapply(using + Quotes + )(t: quotes.reflect.Term): Option[(Expr[Any], Expr[Any])] = + unapply(t.asExpr) +} diff --git a/src/main/scala/minisql/util/CollectTry.scala b/src/main/scala/minisql/util/CollectTry.scala index f0ee506..6b466be 100644 --- a/src/main/scala/minisql/util/CollectTry.scala +++ b/src/main/scala/minisql/util/CollectTry.scala @@ -1,6 +1,24 @@ package minisql.util -import scala.util.Try +import scala.util.* + +extension [A](xs: Iterable[A]) { + private[minisql] def traverse[B](f: A => Try[B]): Try[IArray[B]] = { + val out = IArray.newBuilder[Any] + var left: Option[Throwable] = None + xs.foreach { (v) => + if (!left.isDefined) { + f(v) match { + case Failure(e) => + left = Some(e) + case Success(r) => + out += r + } + } + } + left.toLeft(out.result().asInstanceOf).toTry + } +} object CollectTry { def apply[T](list: List[Try[T]]): Try[List[T]] = diff --git a/src/main/scala/minisql/util/Interpolator.scala b/src/main/scala/minisql/util/Interpolator.scala index 55e32cd..c63e984 100644 --- a/src/main/scala/minisql/util/Interpolator.scala +++ b/src/main/scala/minisql/util/Interpolator.scala @@ -11,20 +11,25 @@ import scala.util.matching.Regex class Interpolator( defaultIndent: Int = 0, qprint: AstPrinter = AstPrinter(), - out: PrintStream = System.out, + out: PrintStream = System.out ) { + + extension (sc: StringContext) { + def trace(elements: Any*) = new Traceable(sc, elements) + } + class Traceable(sc: StringContext, elementsSeq: Seq[Any]) { private val elementPrefix = "| " private sealed trait PrintElement private case class Str(str: String, first: Boolean) extends PrintElement - private case class Elem(value: String) extends PrintElement - private case object Separator extends PrintElement + private case class Elem(value: String) extends PrintElement + private case object Separator extends PrintElement private def generateStringForCommand(value: Any, indent: Int) = { val objectString = qprint(value) - val oneLine = objectString.fitsOnOneLine + val oneLine = objectString.fitsOnOneLine oneLine match { case true => s"${indent.prefix}> ${objectString}" case false => @@ -42,7 +47,7 @@ class Interpolator( private def readBuffers() = { def orZero(i: Int): Int = if (i < 0) 0 else i - val parts = sc.parts.iterator.toList + val parts = sc.parts.iterator.toList val elements = elementsSeq.toList.map(qprint(_)) val (firstStr, explicitIndent) = readFirst(parts.head) diff --git a/src/main/scala/minisql/util/LoadObject.scala b/src/main/scala/minisql/util/LoadObject.scala index 83bbec0..8c13c8f 100644 --- a/src/main/scala/minisql/util/LoadObject.scala +++ b/src/main/scala/minisql/util/LoadObject.scala @@ -5,10 +5,15 @@ import scala.util.Try object LoadObject { + def apply[T](using Quotes, Type[T]): Try[T] = { + import quotes.reflect.* + apply(TypeRepr.of[T]) + } + def apply[T](using Quotes)(ot: quotes.reflect.TypeRepr): Try[T] = Try { import quotes.reflect.* val moduleClsName = ot.typeSymbol.companionModule.moduleClass.fullName - val moduleCls = Class.forName(moduleClsName) + val moduleCls = Class.forName(moduleClsName) val field = moduleCls .getFields() .find { f => diff --git a/src/main/scala/minisql/util/Message.scala b/src/main/scala/minisql/util/Message.scala new file mode 100644 index 0000000..f748456 --- /dev/null +++ b/src/main/scala/minisql/util/Message.scala @@ -0,0 +1,75 @@ +package minisql.util + +import minisql.AstPrinter +import minisql.idiom.Idiom +import minisql.util.IndentUtil._ + +object Messages { + + private def variable(propName: String, envName: String, default: String) = + Option(System.getProperty(propName)) + .orElse(sys.env.get(envName)) + .getOrElse(default) + + private[util] val prettyPrint = + variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean + private[util] val debugEnabled = + variable("quill.macro.log", "quill_macro_log", "true").toBoolean + private[util] val traceEnabled = + variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean + private[util] val traceColors = + variable("quill.trace.color", "quill_trace_color,", "false").toBoolean + private[util] val traceOpinions = + variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean + private[util] val traceAstSimple = variable( + "quill.trace.ast.simple", + "quill_trace_ast_simple", + "false" + ).toBoolean + private[minisql] val cacheDynamicQueries = variable( + "quill.query.cacheDaynamic", + "query_query_cacheDaynamic", + "true" + ).toBoolean + private[util] val traces: List[TraceType] = + variable("quill.trace.types", "quill_trace_types", "standard") + .split(",") + .toList + .map(_.trim) + .flatMap(trace => + TraceType.values.filter(traceType => trace == traceType.value) + ) + + def tracesEnabled(tt: TraceType) = + traceEnabled && traces.contains(tt) + + sealed trait TraceType { def value: String } + object TraceType { + case object Normalizations extends TraceType { val value = "norm" } + case object Standard extends TraceType { val value = "standard" } + case object NestedQueryExpansion extends TraceType { val value = "nest" } + + def values: List[TraceType] = + List(Standard, Normalizations, NestedQueryExpansion) + } + + val qprint = AstPrinter() + + def fail(msg: String) = + throw new IllegalStateException(msg) + + def trace[T]( + label: String, + numIndent: Int = 0, + traceType: TraceType = TraceType.Standard + ) = + (v: T) => { + val indent = (0 to numIndent).map(_ => "").mkString(" ") + if (tracesEnabled(traceType)) + println(s"$indent$label\n${{ + if (traceColors) qprint.apply(v) + else qprint.apply(v) + }.split("\n").map(s"$indent " + _).mkString("\n")}") + v + } +} diff --git a/src/main/scala/minisql/util/Show.scala b/src/main/scala/minisql/util/Show.scala index b4acc97..3496caf 100644 --- a/src/main/scala/minisql/util/Show.scala +++ b/src/main/scala/minisql/util/Show.scala @@ -1,21 +1,20 @@ package minisql.util -object Show { - trait Show[T] { - def show(v: T): String +trait Show[T] { + extension (v: T) { + def show: String } +} - object Show { - def apply[T](f: T => String) = new Show[T] { - def show(v: T) = f(v) +object Show { + + def apply[T](f: T => String) = new Show[T] { + extension (v: T) { + def show: String = f(v) } } - implicit class Shower[T](v: T)(implicit shower: Show[T]) { - def show = shower.show(v) - } - - implicit def listShow[T](implicit shower: Show[T]): Show[List[T]] = + given listShow[T](using shower: Show[T]): Show[List[T]] = Show[List[T]] { case list => list.map(_.show).mkString(", ") } diff --git a/src/test/scala/minisql/parsing/ParsingSuite.scala b/src/test/scala/minisql/parsing/ParsingSuite.scala new file mode 100644 index 0000000..b90f9f8 --- /dev/null +++ b/src/test/scala/minisql/parsing/ParsingSuite.scala @@ -0,0 +1,38 @@ +package minisql.parsing + +import minisql.ast.* + +class ParsingSuite extends munit.FunSuite { + + test("Ident") { + val x = 1 + assertEquals(Parsing.parse(x), Ident("x")) + } + + test("NumericOperator.+") { + val a = 1 + val b = 2 + assertEquals( + Parsing.parse(a + b), + BinaryOperation(Ident("a"), NumericOperator.+, Ident("b")) + ) + } + + test("NumericOperator.-") { + val a = 1 + val b = 2 + assertEquals( + Parsing.parse(a - b), + BinaryOperation(Ident("a"), NumericOperator.-, Ident("b")) + ) + } + + test("NumericOperator.*") { + val a = 1 + val b = 2 + assertEquals( + Parsing.parse(a * b), + BinaryOperation(Ident("a"), NumericOperator.*, Ident("b")) + ) + } +} diff --git a/src/test/scala/minisql/parsing/QuerySuite.scala b/src/test/scala/minisql/parsing/QuerySuite.scala new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/test/scala/minisql/parsing/QuerySuite.scala @@ -0,0 +1 @@ + diff --git a/src/test/scala/minisql/parsing/QuotedSuite.scala b/src/test/scala/minisql/parsing/QuotedSuite.scala new file mode 100644 index 0000000..2bd02eb --- /dev/null +++ b/src/test/scala/minisql/parsing/QuotedSuite.scala @@ -0,0 +1,35 @@ +package minisql + +import minisql.ast.* + +class QuotedSuite extends munit.FunSuite { + private inline def testQuoted(label: String)( + inline x: Quoted, + expect: Ast + ) = test(label) { + assertEquals(compileTimeAst(x), Some(expect.toString())) + } + + case class Foo(id: Long) + + inline def Foos = query[Foo]("foo") + val entityFoo = Entity("foo", Nil) + val idx = Ident("x") + + testQuoted("EntityQuery")(Foos, entityFoo) + + testQuoted("Query/filter")( + Foos.filter(x => x.id > 0), + Filter( + entityFoo, + idx, + BinaryOperation(Property(idx, "id"), NumericOperator.>, Constant(0)) + ) + ) + + testQuoted("Query/map")( + Foos.map(x => x.id), + Map(entityFoo, idx, Property(idx, "id")) + ) + +}