Compare commits

...
Sign in to create a new pull request.

26 commits

Author SHA1 Message Date
jilen
071b27abcf Better test case 2025-07-02 15:46:40 +08:00
jilen
ed1952b915 增加returningGenerated 2025-07-02 15:39:42 +08:00
jilen
06850823d7 增加更多join 2025-07-02 13:54:53 +08:00
jilen
48cb1003bb 支持Infix 2025-07-02 12:07:54 +08:00
jilen
f5e43657b3 增加解析 case (x, y) => 函数定义 2025-07-02 10:33:47 +08:00
jilen
adc60400a7 unified extractTerm 2025-06-30 19:33:39 +08:00
jilen
c1f26a0704 Assert sql 2025-06-29 19:15:27 +08:00
jilen
23c0484609 More instance 2025-06-29 17:02:18 +08:00
jilen
a1201a67aa Add more test case. Expand query elements 2025-06-29 16:12:12 +08:00
jilen
2753f01001 fix naming 2025-06-29 10:25:38 +08:00
jilen
24f7f6aec0 Convert to using 2025-06-27 19:50:11 +08:00
jilen
3a9d15f015 Try add insert support 2025-06-22 21:21:05 +08:00
jilen
184ab0b884 Add insert placeholder 2025-06-22 20:45:26 +08:00
jilen
2b52ef3203 Add property alist 2025-06-22 14:27:15 +08:00
jilen
17e97495b7 Simplify Mirror Codec 2025-06-19 18:49:14 +08:00
jilen
1bc6baad68 add sql idiom 2025-06-18 16:59:06 +08:00
jilen
63a9a0cad3 Allow both decoder and encoder 2025-06-17 19:52:30 +08:00
jilen
cb0c6082d0 Add statement 2025-06-17 17:31:36 +08:00
jilen
47cf808e8f simplify typeclass 2024-12-29 20:19:07 +08:00
jilen
7f5092c396 add mirror context 2024-12-19 12:36:44 +08:00
jilen
87f1b70b27 try implement context 2024-12-19 11:42:23 +08:00
jilen
6be96aba2c save 2024-12-18 19:28:43 +08:00
jilen
59f969a232 test simple quoted ast 2024-12-18 16:09:08 +08:00
jilen
2e7e7df4a3 move package 2024-12-17 19:51:19 +08:00
jilen
a0ceea91a9 add one test case 2024-12-15 21:11:14 +08:00
jilen
8103d45178 try parsing function body 2024-12-15 20:51:38 +08:00
90 changed files with 6951 additions and 295 deletions

1
.gitignore vendored
View file

@ -3,3 +3,4 @@ target/
.metals/
.bloop/
project/metals.sbt
.aider*

View file

@ -5,7 +5,7 @@
大部分场景不用在 `macro` 对 Ast 进行复杂模式匹配来分析代码。
## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作
## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作
`FromExpr``scala3` 内置的 typeclass用来获取编译期值 。

View file

@ -1,8 +1,15 @@
name := "minisql"
scalaVersion := "3.6.2"
scalaVersion := "3.7.1"
libraryDependencies ++= Seq(
"org.scalameta" %% "munit" % "1.1.1" % Test
)
scalacOptions ++= Seq("-experimental", "-language:experimental.namedTuples")
scalacOptions ++= Seq(
"-deprecation",
"-feature",
"-source:3.7-migration",
"-rewrite"
)

View file

@ -1 +1 @@
sbt.version=1.10.5
sbt.version=1.11.2

View file

@ -0,0 +1,3 @@
package minisql
type QueryMeta

View file

@ -1,8 +1,22 @@
package minisql
trait ParamEncoder[E] {
import scala.util.Try
trait ParamEncoder[E] {
type Stmt
def setParam(s: Stmt, idx: Int, v: E): Unit
def setParam(s: Stmt, idx: Int, v: E): Stmt
}
trait ColumnDecoder[X] {
type DBRow
def decode(row: DBRow, idx: Int): Try[X]
}
object ColumnDecoder {
type Aux[R, X] = ColumnDecoder[X] {
type DBRow = R
}
}

View file

@ -1,127 +0,0 @@
package minisql.parsing
import minisql.ast
import minisql.ast.Ast
import scala.quoted.*
type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]]
private[minisql] inline def parseParamAt[F](
inline f: F,
inline n: Int
): ast.Ident = ${
parseParamAt('f, 'n)
}
private[minisql] inline def parseBody[X](
inline f: X
): ast.Ast = ${
parseBody('f)
}
private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
Quotes
): Expr[ast.Ident] = {
import quotes.reflect.*
val pIdx = n.value.getOrElse(
report.errorAndAbort(s"Param index ${n.show} is not know")
)
extractTerm(f.asTerm) match {
case Lambda(vals, _) =>
vals(pIdx) match {
case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
}
}
}
private[minisql] def parseBody[X](
x: Expr[X]
)(using Quotes): Expr[Ast] = {
import quotes.reflect.*
extractTerm(x.asTerm) match {
case Lambda(vals, body) =>
astParser(body.asExpr)
case o =>
report.errorAndAbort(s"Can only parse function")
}
}
private def isNumeric(x: Expr[?])(using Quotes): Boolean = {
import quotes.reflect.*
x.asTerm.tpe match {
case t if t <:< TypeRepr.of[Int] => true
case t if t <:< TypeRepr.of[Long] => true
case t if t <:< TypeRepr.of[Float] => true
case t if t <:< TypeRepr.of[Double] => true
case t if t <:< TypeRepr.of[BigDecimal] => true
case t if t <:< TypeRepr.of[BigInt] => true
case t if t <:< TypeRepr.of[java.lang.Integer] => true
case t if t <:< TypeRepr.of[java.lang.Long] => true
case t if t <:< TypeRepr.of[java.lang.Float] => true
case t if t <:< TypeRepr.of[java.lang.Double] => true
case t if t <:< TypeRepr.of[java.math.BigDecimal] => true
case _ => false
}
}
private def identParser(using Quotes): Parser[ast.Ident] = {
import quotes.reflect.*
{ (x: Expr[?]) =>
extractTerm(x.asTerm) match {
case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) })
case _ => None
}
}.unlift
}
private lazy val astParser: Quotes ?=> Parser[Ast] = {
identParser.orElse(propertyParser(astParser))
}
private object IsPropertySelect {
def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = {
import quotes.reflect.*
x.asTerm match {
case Select(x, n) => Some(x.asExpr, n)
case _ => None
}
}
}
def propertyParser(
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.Property] = {
case IsPropertySelect(expr, n) =>
'{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) }
}
def optionOperationParser(
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.OptionOperation] = {
case '{ ($x: Option[t]).isEmpty } =>
'{ ast.OptionIsEmpty(${ astParser(x) }) }
}
def binaryOperationParser(
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.BinaryOperation] = {
???
}
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
unwrapTerm(x)
}

View file

@ -0,0 +1,298 @@
package minisql
import minisql.*
import minisql.idiom.*
import minisql.parsing.*
import minisql.util.*
import minisql.ast.{
Ast,
Entity,
Map,
Property,
Ident,
Filter,
PropertyAlias,
JoinType,
Join,
given
}
import scala.quoted.*
import scala.deriving.*
import scala.compiletime.*
import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap}
opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted
opaque type Action[E] <: Quoted = Quoted
opaque type Insert[E] <: Action[Long] = Quoted
object Insert {
extension [E](inline insert: Insert[E]) {
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
transform(insert)(f)(ast.Returning.apply)
}
inline def returningGenerated[E1](
inline f: E => E1
): InsertReturning[E1] = {
transform(insert)(f)(ast.ReturningGenerated.apply)
}
}
}
opaque type InsertReturning[E] <: Action[E] = Quoted
sealed trait Joined[E1, E2]
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
object Joined {
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
new Joined[E1, E2] {}
extension [E1, E2](inline j: Joined[E1, E2]) {
inline def on(inline f: (E1, E2) => Boolean): JoinQuery[E1, E2] =
joinOn(j, f)
}
}
private inline def joinOn[E1, E2](
inline j: Joined[E1, E2],
inline f: (E1, E2) => Boolean
): JoinQuery[E1, E2] = j.toJoinQuery(f.param0, f.param1, f.body)
extension [E1, E2](inline j: Joined[E1, E2]) {
private inline def toJoinQuery(
inline aliasA: Ident,
inline aliasB: Ident,
inline on: Ast
): Ast = ${ joinQueryOf('j, 'aliasA, 'aliasB, 'on) }
}
private def joinQueryOf[E1, E2](
x: Expr[Joined[E1, E2]],
aliasA: Expr[Ident],
aliasB: Expr[Ident],
on: Expr[Ast]
)(using Quotes, Type[E1], Type[E2]): Expr[Join] = {
import quotes.reflect.*
extractTerm(x.asTerm).asExpr match {
case '{ Joined[E1, E2]($jt, $a, $b) } =>
'{
Join($jt, $a, $b, $aliasA, $aliasB, $on)
}
}
}
private inline def quotedLift[X](x: X)(using
e: ParamEncoder[X]
): ast.ScalarValueLift = ${
quotedLiftImpl[X]('x, 'e)
}
private def quotedLiftImpl[X: Type](
x: Expr[X],
e: Expr[ParamEncoder[X]]
)(using Quotes): Expr[ast.ScalarValueLift] = {
import quotes.reflect.*
val name = x.asTerm.show
val liftId = liftIdOfExpr(x)
'{
ast.ScalarValueLift(
${ Expr(name) },
${ Expr(liftId) },
Some(($x, $e))
)
}
}
object Query {
extension [E](inline e: Query[E]) {
private[minisql] inline def expanded: Query[E] = {
expandFields[E](e)
}
inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, Option[E1]] =
Joined[E, Option[E1]](JoinType.LeftJoin, e, e1)
inline def rightJoin[E1](inline e1: Query[E1]): Joined[Option[E], E1] =
Joined[Option[E], E1](JoinType.RightJoin, e, e1)
inline def join[E1](inline e1: Query[E1]): Joined[E, E1] =
Joined[E, E1](JoinType.InnerJoin, e, e1)
inline def fullJoin[E1](
inline e1: Query[E1]
): Joined[Option[E], Option[E1]] =
Joined[Option[E], Option[E1]](JoinType.FullJoin, e, e1)
inline def map[E1](inline f: E => E1): Query[E1] = {
transform(e)(f)(Map.apply)
}
inline def filter(inline f: E => Boolean): Query[E] = {
transform(e)(f)(Filter.apply)
}
inline def withFilter(inline f: E => Boolean): Query[E] = {
transform(e)(f)(Filter.apply)
}
}
}
opaque type EntityQuery[E] <: Query[E] = Query[E]
object EntityQuery {
extension [E](inline e: EntityQuery[E]) {
inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
transform(e)(f)(Map.apply)
}
inline def filter(inline f: E => Boolean): EntityQuery[E] = {
transform(e)(f)(Filter.apply)
}
inline def insert(v: E): Insert[E] = {
ast.Insert(e, transformCaseClassToAssignments[E](v))
}
}
}
private inline def transformCaseClassToAssignments[E](
v: E
): List[ast.Assignment] = ${
transformCaseClassToAssignmentsImpl[E]('v)
}
private def transformCaseClassToAssignmentsImpl[E: Type](
v: Expr[E]
)(using Quotes): Expr[List[ast.Assignment]] = {
import quotes.reflect.*
val fields = TypeRepr.of[E].typeSymbol.caseFields
val assignments = fields.map { field =>
val fieldName = field.name
val fieldType = field.tree match {
case v: ValDef => v.tpt.tpe
case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName")
}
fieldType.asType match {
case '[t] =>
'{
ast.Assignment(
ast.Ident("v"),
ast.Property(ast.Ident("v"), ${ Expr(fieldName) }),
quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using
summonInline[ParamEncoder[t]]
)
)
}
}
}
Expr.ofList(assignments)
}
private inline def transform[A, B](inline q1: Quoted)(
inline f: A => B
)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
fast(q1, f.param0, f.body)
}
inline def alias(inline from: String, inline to: String): PropertyAlias =
PropertyAlias(List(from), to)
inline def query[E](
inline table: String,
inline alias: PropertyAlias*
): EntityQuery[E] =
Entity(
table,
List(alias*)
)
extension [A, B](inline f1: A => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def body = parsing.parseBody(f1)
}
extension [A1, A2, B](inline f1: (A1, A2) => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def param1 = parsing.parseParamAt(f1, 1)
private inline def body = parsing.parseBody(f1)
}
def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException()
class NonQuotedException extends Exception("Cannot be used at runtime")
private[minisql] inline def compileTimeAst(inline q: Ast): Option[String] =
${ compileTimeAstImpl('q) }
private def compileTimeAstImpl(e: Expr[Ast])(using
Quotes
): Expr[Option[String]] = {
import quotes.reflect.*
e.value match {
case Some(v) => '{ Some(${ Expr(v.toString()) }) }
case None => '{ None }
}
}
private[minisql] inline def compile[I <: Idiom, N <: NamingStrategy](
inline q: Quoted,
inline idiom: I,
inline naming: N
): Statement = ${ compileImpl[I, N]('q, 'idiom, 'naming) }
private def compileImpl[I <: Idiom, N <: NamingStrategy](
q: Expr[Quoted],
idiom: Expr[I],
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
q.value match {
case Some(ast) =>
val idiom = LoadObject[I].getOrElse(
report.errorAndAbort(s"Idiom not known at compile")
)
val naming = LoadNaming
.static[N]
.getOrElse(report.errorAndAbort(s"NamingStrategy not known at compile"))
val stmt = idiom.translate(ast)(using naming)
report.info(s"Static Query: ${stmt._2}")
Expr(stmt._2)
case None =>
report.info("Dynamic Query")
'{
$idiom.translate($q)(using $n)._2
}
}
}
private inline def expandFields[E](inline base: Ast): Ast =
${ expandFieldsImpl[E]('base) }
private def expandFieldsImpl[E](baseExpr: Expr[Ast])(using
Quotes,
Type[E]
): Expr[Ast] = {
import quotes.reflect.*
val values = TypeRepr.of[E].typeSymbol.caseFields.map { f =>
'{ Property(ast.Ident("x"), ${ Expr(f.name) }) }
}
'{ Map(${ baseExpr }, ast.Ident("x"), ast.Tuple(${ Expr.ofList(values) })) }
}

View file

@ -0,0 +1,7 @@
package minisql
enum ReturnAction {
case ReturnNothing
case ReturnColumns(columns: List[String])
case ReturnRecord
}

View file

@ -0,0 +1,12 @@
package minisql
import minisql.ast.Ast
import scala.quoted.*
sealed trait InfixValue {
def as[T]: T
}
extension (sc: StringContext) {
def infix(args: Any*): InfixValue = throw NonQuotedException()
}

View file

@ -1,6 +1,7 @@
package minisql.ast
import minisql.NamingStrategy
import minisql.ParamEncoder
import scala.quoted.*
@ -58,9 +59,9 @@ object Entity {
object Opinionated {
inline def apply(
name: String,
properties: List[PropertyAlias],
renameableNew: Renameable
inline name: String,
inline properties: List[PropertyAlias],
inline renameableNew: Renameable
): Entity = Entity(name, properties, renameableNew)
def unapply(e: Entity) =
@ -84,13 +85,14 @@ case class SortBy(query: Ast, alias: Ident, criterias: Ast, ordering: Ordering)
sealed trait Ordering extends Ast
case class TupleOrdering(elems: List[Ordering]) extends Ordering
sealed trait PropertyOrdering extends Ordering
case object Asc extends PropertyOrdering
case object Desc extends PropertyOrdering
case object AscNullsFirst extends PropertyOrdering
case object DescNullsFirst extends PropertyOrdering
case object AscNullsLast extends PropertyOrdering
case object DescNullsLast extends PropertyOrdering
enum PropertyOrdering extends Ordering {
case Asc
case Desc
case AscNullsFirst
case DescNullsFirst
case AscNullsLast
case DescNullsLast
}
case class GroupBy(query: Ast, alias: Ident, body: Ast) extends Query
@ -153,11 +155,14 @@ case class Ident(name: String, visibility: Visibility) extends Ast {
* ExpandNestedQueries phase, needs to be marked invisible.
*/
object Ident {
def apply(name: String): Ident = Ident(name, Visibility.neutral)
def unapply(p: Ident) = Some((p.name))
inline def apply(inline name: String): Ident = Ident(name, Visibility.neutral)
def unapply(p: Ident) = Some((p.name))
object Opinionated {
def apply(name: String, visibilityNew: Visibility): Ident =
inline def apply(
inline name: String,
inline visibilityNew: Visibility
): Ident =
Ident(name, visibilityNew)
def unapply(p: Ident) =
Some((p.name, p.visibility))
@ -378,14 +383,21 @@ sealed trait ScalarLift extends Lift
case class ScalarValueLift(
name: String,
liftId: String
liftId: String,
value: Option[(Any, ParamEncoder[?])]
) extends ScalarLift
case class ScalarQueryLift(
name: String,
liftId: String,
value: Option[(Seq[Any], ParamEncoder[?])]
) extends ScalarLift
object ScalarLift {
given ToExpr[ScalarLift] with {
def apply(l: ScalarLift)(using Quotes) = l match {
case ScalarValueLift(n, id) =>
'{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) }
case ScalarValueLift(n, id, v) =>
'{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, None) }
}
}
}

View file

@ -0,0 +1,94 @@
package minisql.ast
object Implicits {
implicit class AstOps(body: Ast) {
private[minisql] def +||+(other: Ast) =
BinaryOperation(body, BooleanOperator.`||`, other)
private[minisql] def +&&+(other: Ast) =
BinaryOperation(body, BooleanOperator.`&&`, other)
private[minisql] def +==+(other: Ast) =
BinaryOperation(body, EqualityOperator.`==`, other)
private[minisql] def +!=+(other: Ast) =
BinaryOperation(body, EqualityOperator.`!=`, other)
}
}
object +||+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, BooleanOperator.`||`, two) => Some((one, two))
case _ => None
}
}
}
object +&&+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, BooleanOperator.`&&`, two) => Some((one, two))
case _ => None
}
}
}
val EqOp = EqualityOperator.`==`
val NeqOp = EqualityOperator.`!=`
object +==+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, EqOp, two) => Some((one, two))
case _ => None
}
}
}
object +!=+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, NeqOp, two) => Some((one, two))
case _ => None
}
}
}
object IsNotNullCheck {
def apply(ast: Ast) = BinaryOperation(ast, EqualityOperator.`!=`, NullValue)
def unapply(ast: Ast): Option[Ast] = {
ast match {
case BinaryOperation(cond, NeqOp, NullValue) => Some(cond)
case _ => None
}
}
}
object IsNullCheck {
def apply(ast: Ast) = BinaryOperation(ast, EqOp, NullValue)
def unapply(ast: Ast): Option[Ast] = {
ast match {
case BinaryOperation(cond, EqOp, NullValue) => Some(cond)
case _ => None
}
}
}
object IfExistElseNull {
def apply(exists: Ast, `then`: Ast) =
If(IsNotNullCheck(exists), `then`, NullValue)
def unapply(ast: Ast) = ast match {
case If(IsNotNullCheck(exists), t, NullValue) => Some((exists, t))
case _ => None
}
}
object IfExist {
def apply(exists: Ast, `then`: Ast, otherwise: Ast) =
If(IsNotNullCheck(exists), `then`, otherwise)
def unapply(ast: Ast) = ast match {
case If(IsNotNullCheck(exists), t, e) => Some((exists, t, e))
case _ => None
}
}

View file

@ -45,14 +45,14 @@ private given FromExpr[Infix] with {
private given FromExpr[ScalarValueLift] with {
def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] =
x match {
case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } =>
Some(ScalarValueLift(n, id))
case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } =>
Some(ScalarValueLift(n, id, None))
}
}
private given FromExpr[Ident] with {
def unapply(x: Expr[Ident])(using Quotes): Option[Ident] = x match {
case '{ Ident(${ Expr(n) }) } => Some(Ident(n))
case '{ Ident(${ Expr(n) }, ${ Expr(v) }) } => Some(Ident(n, v))
}
}
@ -69,19 +69,19 @@ private given FromExpr[Property] with {
)
} =>
Some(Property(a, n, r, v))
case o =>
println(s"Cannot extrat ${o.show}")
None
}
}
private given FromExpr[Ordering] with {
def unapply(x: Expr[Ordering])(using Quotes): Option[Ordering] = {
import PropertyOrdering.*
x match {
case '{ Asc } => Some(Asc)
case '{ Desc } => Some(Desc)
case '{ AscNullsFirst } => Some(AscNullsFirst)
case '{ AscNullsLast } => Some(AscNullsLast)
case '{ DescNullsFirst } => Some(DescNullsFirst)
case '{ DescNullsLast } => Some(DescNullsLast)
case '{ TupleOrdering($xs) } => xs.value.map(TupleOrdering(_))
}
}
@ -122,15 +122,48 @@ private given FromExpr[Query] with {
Some(FlatMap(b, id, body))
case '{ ConcatMap(${ Expr(b) }, ${ Expr(id) }, ${ Expr(body) }) } =>
Some(ConcatMap(b, id, body))
case '{
val x: Ast = ${ Expr(b) }
val y: Ident = ${ Expr(id) }
val z: Ast = ${ Expr(body) }
ConcatMap(x, y, z)
} =>
Some(ConcatMap(b, id, body))
case '{ Drop(${ Expr(b) }, ${ Expr(n) }) } =>
Some(Drop(b, n))
case '{ Take(${ Expr(b) }, ${ Expr[Ast](n) }) } =>
Some(Take(b, n))
case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } =>
Some(SortBy(b, p, s, o))
case o =>
println(s"Cannot extract ${o.show}")
None
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
Some(GroupBy(b, p, body))
case '{ Distinct(${ Expr(a) }) } =>
Some(Distinct(a))
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
Some(DistinctOn(q, a, body))
case '{ Aggregation(${ Expr(op) }, ${ Expr(a) }) } =>
Some(Aggregation(op, a))
case '{ Union(${ Expr(a) }, ${ Expr(b) }) } =>
Some(Union(a, b))
case '{ UnionAll(${ Expr(a) }, ${ Expr(b) }) } =>
Some(UnionAll(a, b))
case '{
Join(
${ Expr(t) },
${ Expr(a) },
${ Expr(b) },
${ Expr(ia) },
${ Expr(ib) },
${ Expr(on) }
)
} =>
Some(Join(t, a, b, ia, ib, on))
case '{
FlatJoin(${ Expr(t) }, ${ Expr(a) }, ${ Expr(ia) }, ${ Expr(on) })
} =>
Some(FlatJoin(t, a, ia, on))
case '{ Nested(${ Expr(a) }) } =>
Some(Nested(a))
}
}
@ -145,17 +178,22 @@ private given FromExpr[BinaryOperator] with {
case '{ NumericOperator.- } => Some(NumericOperator.-)
case '{ NumericOperator.* } => Some(NumericOperator.*)
case '{ NumericOperator./ } => Some(NumericOperator./)
case '{ NumericOperator.> } => Some(NumericOperator.>)
case '{ NumericOperator.>= } => Some(NumericOperator.>=)
case '{ NumericOperator.< } => Some(NumericOperator.<)
case '{ NumericOperator.<= } => Some(NumericOperator.<=)
case '{ NumericOperator.% } => Some(NumericOperator.%)
case '{ StringOperator.split } => Some(StringOperator.split)
case '{ StringOperator.startsWith } => Some(StringOperator.startsWith)
case '{ StringOperator.concat } => Some(StringOperator.concat)
case '{ BooleanOperator.&& } => Some(BooleanOperator.&&)
case '{ BooleanOperator.|| } => Some(BooleanOperator.||)
case '{ SetOperator.contains } => Some(SetOperator.contains)
}
}
}
private given FromExpr[UnaryOperator] with {
def unapply(x: Expr[UnaryOperator])(using Quotes): Option[UnaryOperator] = {
x match {
case '{ BooleanOperator.! } => Some(BooleanOperator.!)
@ -163,6 +201,33 @@ private given FromExpr[UnaryOperator] with {
case '{ StringOperator.toLowerCase } => Some(StringOperator.toLowerCase)
case '{ StringOperator.toLong } => Some(StringOperator.toLong)
case '{ StringOperator.toInt } => Some(StringOperator.toInt)
case '{ NumericOperator.- } => Some(NumericOperator.-)
case '{ SetOperator.nonEmpty } => Some(SetOperator.nonEmpty)
case '{ SetOperator.isEmpty } => Some(SetOperator.isEmpty)
}
}
}
private given FromExpr[AggregationOperator] with {
def unapply(
x: Expr[AggregationOperator]
)(using Quotes): Option[AggregationOperator] = {
x match {
case '{ AggregationOperator.min } => Some(AggregationOperator.min)
case '{ AggregationOperator.max } => Some(AggregationOperator.max)
case '{ AggregationOperator.avg } => Some(AggregationOperator.avg)
case '{ AggregationOperator.sum } => Some(AggregationOperator.sum)
case '{ AggregationOperator.size } => Some(AggregationOperator.size)
}
}
}
private given FromExpr[Operator] with {
def unapply(x: Expr[Operator])(using Quotes): Option[Operator] = {
x match {
case '{ $x: BinaryOperator } => x.value
case '{ $x: UnaryOperator } => x.value
case '{ $x: AggregationOperator } => x.value
}
}
}
@ -205,30 +270,50 @@ private given FromExpr[Action] with {
ass.sequence.map { ass1 =>
Update(a, ass1)
}
case '{ Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) } =>
case '{
Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) })
} =>
Some(Returning(act, id, body))
case '{
val x: Ast = ${ Expr(act) }
val y: Ident = ${ Expr(id) }
val z: Ast = ${ Expr(body) }
Returning(x, y, z)
} =>
Some(Returning(act, id, body))
case '{
ReturningGenerated(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) })
} =>
Some(ReturningGenerated(act, id, body))
case '{
val x: Ast = ${ Expr(act) }
val y: Ident = ${ Expr(id) }
val z: Ast = ${ Expr(body) }
ReturningGenerated(x, y, z)
} =>
Some(ReturningGenerated(act, id, body))
}
}
}
extension [A](xs: Seq[Expr[A]]) {
private def sequence(using FromExpr[A], Quotes): Option[List[A]] = {
val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) =>
for {
_r <- r
_x <- x.value
} yield _r += _x
if (xs.isEmpty) Some(Nil)
else {
val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) =>
for {
_r <- r
_x <- x.value
} yield _r += _x
}
acc.map(_.result())
}
acc.map(b => b.result())
}
}
private given FromExpr[Constant] with {
def unapply(x: Expr[Constant])(using Quotes): Option[Constant] = {
private given FromExpr[Value] with {
def unapply(x: Expr[Value])(using Quotes): Option[Value] = {
import quotes.reflect.{Constant => *, *}
x match {
case '{ Constant($ce) } =>
@ -236,8 +321,92 @@ private given FromExpr[Constant] with {
case Literal(v) =>
Some(Constant(v.value))
}
case '{ NullValue } =>
Some(NullValue)
case '{ $x: CaseClass } => x.value
}
}
}
private given FromExpr[OptionOperation] with {
def unapply(
x: Expr[OptionOperation]
)(using Quotes): Option[OptionOperation] = {
x match {
case '{ OptionFlatten(${ Expr(ast) }) } =>
Some(OptionFlatten(ast))
case '{ OptionGetOrElse(${ Expr(ast) }, ${ Expr(body) }) } =>
Some(OptionGetOrElse(ast, body))
case '{
OptionFlatMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
} =>
Some(OptionFlatMap(ast, alias, body))
case '{ OptionMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) } =>
Some(OptionMap(ast, alias, body))
case '{
OptionForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
} =>
Some(OptionForall(ast, alias, body))
case '{
OptionExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
} =>
Some(OptionExists(ast, alias, body))
case '{ OptionContains(${ Expr(ast) }, ${ Expr(body) }) } =>
Some(OptionContains(ast, body))
case '{ OptionIsEmpty(${ Expr(ast) }) } =>
Some(OptionIsEmpty(ast))
case '{ OptionNonEmpty(${ Expr(ast) }) } =>
Some(OptionNonEmpty(ast))
case '{ OptionIsDefined(${ Expr(ast) }) } =>
Some(OptionIsDefined(ast))
case '{
OptionTableFlatMap(
${ Expr(ast) },
${ Expr(alias) },
${ Expr(body) }
)
} =>
Some(OptionTableFlatMap(ast, alias, body))
case '{
OptionTableMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
} =>
Some(OptionTableMap(ast, alias, body))
case '{
OptionTableExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
} =>
Some(OptionTableExists(ast, alias, body))
case '{
OptionTableForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
} =>
Some(OptionTableForall(ast, alias, body))
case '{ OptionNone } => Some(OptionNone)
case '{ OptionSome(${ Expr(ast) }) } => Some(OptionSome(ast))
case '{ OptionApply(${ Expr(ast) }) } => Some(OptionApply(ast))
case '{ OptionOrNull(${ Expr(ast) }) } => Some(OptionOrNull(ast))
case '{ OptionGetOrNull(${ Expr(ast) }) } => Some(OptionGetOrNull(ast))
case _ => None
}
}
}
private given FromExpr[CaseClass] with {
def unapply(x: Expr[CaseClass])(using Quotes): Option[CaseClass] = {
import quotes.reflect.*
x match {
case '{ CaseClass(${ Expr(values) }) } =>
// Verify the values are properly structured as List[(String, Ast)]
try {
Some(CaseClass(values))
} catch {
case e: Exception =>
report.warning(
s"Failed to extract CaseClass values: ${e.getMessage}",
x.asTerm.pos
)
None
}
case _ => None
}
}
}
@ -248,53 +417,26 @@ private given FromExpr[If] with {
}
}
private def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
private given FromExpr[Block] with {
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
case '{ Block(${ Expr(statements) }) } =>
Some(Block(statements))
}
}
private given FromExpr[Val] with {
def unapply(x: Expr[Val])(using Quotes): Option[Val] = x match {
case '{ Val(${ Expr(n) }, ${ Expr(b) }) } =>
Some(Val(n, b))
}
val o = unwrapTerm(x)
o
}
extension (e: Expr[Any]) {
def toTerm(using Quotes) = {
private def toTerm(using Quotes) = {
import quotes.reflect.*
e.asTerm
}
}
private def fromBlock(using
Quotes
)(block: quotes.reflect.Block): Option[Ast] = {
println(s"Show block ${block.show}")
import quotes.reflect.*
val empty: Option[List[Ast]] = Some(Nil)
val stmts = block.statements.foldLeft(empty) { (r, stmt) =>
stmt match {
case ValDef(n, _, Some(body)) =>
r.flatMap { astList =>
body.asExprOf[Ast].value.map { v =>
astList :+ v
}
}
case o =>
None
}
}
stmts.flatMap { stmts =>
block.expr.asExprOf[Ast].value.map { last =>
minisql.ast.Block(stmts :+ last)
}
}
}
given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
@ -306,13 +448,17 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
case '{ $x: ScalarValueLift } => x.value
case '{ $x: Property } => x.value
case '{ $x: Ident } => x.value
case '{ $x: Val } => x.value
case '{ $x: Tuple } => x.value
case '{ $x: Constant } => x.value
case '{ $x: Value } => x.value
case '{ $x: Operation } => x.value
case '{ $x: Ordering } => x.value
case '{ $x: Action } => x.value
case '{ $x: If } => x.value
case '{ $x: Infix } => x.value
case '{ $x: CaseClass } => x.value
case '{ $x: OptionOperation } => x.value
case '{ $x: Block } => x.value
case o =>
import quotes.reflect.*
report.warning(s"Cannot get value from ${o.show}", o.asTerm.pos)

View file

@ -1,8 +1,22 @@
package minisql.ast
sealed trait JoinType
import scala.quoted.*
case object InnerJoin extends JoinType
case object LeftJoin extends JoinType
case object RightJoin extends JoinType
case object FullJoin extends JoinType
enum JoinType {
case InnerJoin
case LeftJoin
case RightJoin
case FullJoin
}
object JoinType {
given FromExpr[JoinType] with {
def unapply(x: Expr[JoinType])(using Quotes): Option[JoinType] = x match {
case '{ JoinType.InnerJoin } => Some(JoinType.InnerJoin)
case '{ JoinType.LeftJoin } => Some(JoinType.LeftJoin)
case '{ JoinType.RightJoin } => Some(JoinType.RightJoin)
case '{ JoinType.FullJoin } => Some(JoinType.FullJoin)
}
}
}

View file

@ -0,0 +1,137 @@
package minisql.context
import minisql.util.*
import minisql.idiom.{Idiom, Statement, ReifyStatement}
import minisql.{NamingStrategy, ParamEncoder}
import minisql.ColumnDecoder
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
import scala.deriving.*
import scala.compiletime.*
import scala.util.Try
import scala.annotation.targetName
trait RowExtract[A, Row] {
def extract(row: Row): Try[A]
}
object RowExtract {
private[context] def single[Row, E](
decoder: ColumnDecoder.Aux[Row, E]
): RowExtract[E, Row] = new RowExtract[E, Row] {
def extract(row: Row): Try[E] = {
decoder.decode(row, 0)
}
}
private def extractorImpl[A, Row](
decoders: IArray[Any],
m: Mirror.ProductOf[A]
): RowExtract[A, Row] = new RowExtract[A, Row] {
def extract(row: Row): Try[A] = {
val decodedFields = decoders.zipWithIndex.traverse {
case (d, i) =>
d.asInstanceOf[ColumnDecoder.Aux[Row, ?]].decode(row, i)
}
decodedFields.map { vs =>
m.fromProduct(Tuple.fromIArray(vs))
}
}
}
inline given [P <: Product, Row, Decoder[_]](using
m: Mirror.ProductOf[P]
): RowExtract[P, Row] = {
val decoders =
summonAll[
Tuple.Map[m.MirroredElemTypes, [X] =>> ColumnDecoder[
X
] { type DBRow = Row }]
]
extractorImpl(decoders.toIArray.asInstanceOf, m)
}
}
trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
val idiom: I
val naming: N
type DBStatement
type DBRow
type DBResultSet
type Encoder[X] = ParamEncoder[X] {
type Stmt = DBStatement
}
type Decoder[X] = ColumnDecoder.Aux[DBRow, X]
type DBIO[E] = (
sql: String,
params: List[(Any, Encoder[?])],
mapper: Iterable[DBRow] => Try[E]
)
extension (ast: Ast) {
private def liftMap = {
val lifts = CollectAst.byType[ScalarValueLift](ast)
lifts.map(l => l.liftId -> l.value.get).toMap
}
}
extension (stmt: Statement) {
def expand(liftMap: Map[String, (Any, ParamEncoder[?])]) =
ReifyStatement(
idiom.liftingPlaceholder,
idiom.emptySetContainsToken,
stmt,
liftMap
)
}
@targetName("ioAction")
inline def io[E](inline q: minisql.Action[E]): DBIO[E] = {
val extractor = summonFrom {
case e: RowExtract[E, DBRow] => e
case e: ColumnDecoder.Aux[DBRow, E] =>
RowExtract.single(e)
}
val lifts = q.liftMap
val stmt = minisql.compile[I, N](q, idiom, naming)
val (sql, params) = stmt.expand(lifts)
(
sql = sql,
params = params.map(_.value.get.asInstanceOf[(Any, Encoder[?])]),
mapper = (rows) =>
rows
.traverse(extractor.extract)
.flatMap(
_.headOption.toRight(new Exception(s"No value return")).toTry
)
)
}
@targetName("ioQuery")
inline def io[E](
inline q: minisql.Query[E]
): DBIO[IArray[E]] = {
val (stmt, extractor) = summonFrom {
case e: RowExtract[E, DBRow] =>
minisql.compile[I, N](q.expanded, idiom, naming) -> e
case e: ColumnDecoder.Aux[DBRow, E] =>
minisql.compile[I, N](q, idiom, naming) -> RowExtract.single(e)
}: @unchecked
val lifts = q.liftMap
val (sql, params) = stmt.expand(lifts)
(
sql = sql,
params = params.map(_.value.get.asInstanceOf),
mapper = (rows) => rows.traverse(extractor.extract)
)
}
}

View file

@ -0,0 +1,23 @@
package minisql
import minisql.context.mirror.*
import minisql.util.Messages.fail
import scala.reflect.ClassTag
class MirrorContext[Idiom <: idiom.Idiom, Naming <: NamingStrategy](
val idiom: Idiom,
val naming: Naming
) extends context.Context[Idiom, Naming]
with MirrorCodecs {
type DBRow = IArray[Any] *: EmptyTuple
type DBResultSet = Iterable[DBRow]
type DBStatement = Map[Int, Any]
extension (r: DBRow) {
def data: IArray[Any] = r._1
def add(value: Any): DBRow = (r.data :+ value) *: EmptyTuple
}
}

View file

@ -0,0 +1,64 @@
package minisql.context
sealed trait ReturningCapability
/**
* Data cannot be returned Insert/Update/etc... clauses in the target database.
*/
sealed trait ReturningNotSupported extends ReturningCapability
/**
* Returning a single field from Insert/Update/etc... clauses is supported. This
* is the most common databases e.g. MySQL, Sqlite, and H2 (although as of
* h2database/h2database#1972 this may change. See #1496 regarding this.
* Typically this needs to be setup in the JDBC
* `connection.prepareStatement(sql, Array("returnColumn"))`.
*/
sealed trait ReturningSingleFieldSupported extends ReturningCapability
/**
* Returning multiple columns from Insert/Update/etc... clauses is supported.
* This generally means that columns besides auto-incrementing ones can be
* returned. This is supported by Oracle. In JDBC, the following is done:
* `connection.prepareStatement(sql, Array("column1, column2, ..."))`.
*/
sealed trait ReturningMultipleFieldSupported extends ReturningCapability
/**
* An actual `RETURNING` clause is supported in the SQL dialect of the specified
* database e.g. Postgres. this typically means that columns returned from
* Insert/Update/etc... clauses can have other database operations done on them
* such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. In
* JDBC, the following is done: `connection.prepareStatement(sql,
* Statement.RETURN_GENERATED_KEYS))`.
*/
sealed trait ReturningClauseSupported extends ReturningCapability
object ReturningNotSupported extends ReturningNotSupported
object ReturningSingleFieldSupported extends ReturningSingleFieldSupported
object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported
object ReturningClauseSupported extends ReturningClauseSupported
trait Capabilities {
def idiomReturningCapability: ReturningCapability
}
trait CanReturnClause extends Capabilities {
override def idiomReturningCapability: ReturningClauseSupported =
ReturningClauseSupported
}
trait CanReturnField extends Capabilities {
override def idiomReturningCapability: ReturningSingleFieldSupported =
ReturningSingleFieldSupported
}
trait CanReturnMultiField extends Capabilities {
override def idiomReturningCapability: ReturningMultipleFieldSupported =
ReturningMultipleFieldSupported
}
trait CannotReturn extends Capabilities {
override def idiomReturningCapability: ReturningNotSupported =
ReturningNotSupported
}

View file

@ -0,0 +1,139 @@
package minisql.context.mirror
import minisql.MirrorContext
import java.time.LocalDate
import java.util.{Date, UUID}
import minisql.{ParamEncoder, ColumnDecoder}
import minisql.util.Messages.fail
import scala.util.{Failure, Success, Try}
import scala.util.Try
import scala.reflect.ClassTag
trait MirrorCodecs {
ctx: MirrorContext[?, ?] =>
final protected def mirrorEncoder[V]: Encoder[V] = new ParamEncoder[V] {
type Stmt = ctx.DBStatement
def setParam(s: Stmt, idx: Int, v: V): Stmt = {
s + (idx -> v)
}
}
final protected def mirrorColumnDecoder[X](
conv: Any => Option[X]
): Decoder[X] =
new ColumnDecoder[X] {
type DBRow = ctx.DBRow
def decode(row: DBRow, idx: Int): Try[X] = {
row.data
.lift(idx)
.flatMap { x =>
conv(x)
}
.toRight(new Exception(s"Cannot convert value at ${idx}"))
.toTry
}
}
given optionDecoder[T](using d: Decoder[T]): Decoder[Option[T]] = {
new ColumnDecoder[Option[T]] {
type DBRow = ctx.DBRow
override def decode(row: DBRow, idx: Int): Try[Option[T]] =
row.data.lift(idx) match {
case Some(null) => Success(None)
case Some(value) => d.decode(row, idx).map(Some(_))
case None => Success(None)
}
}
}
given optionEncoder[T](using e: Encoder[T]): Encoder[Option[T]] =
new ParamEncoder[Option[T]] {
type Stmt = ctx.DBStatement
override def setParam(
s: Stmt,
idx: Int,
v: Option[T]
): Stmt =
v match {
case Some(value) => e.setParam(s, idx, value)
case None =>
s + (idx -> null)
}
}
// Implement all required decoders using mirrorColumnDecoder from MirrorCodecs
given stringDecoder: Decoder[String] = mirrorColumnDecoder[String](x =>
x match { case s: String => Some(s); case _ => None }
)
given bigDecimalDecoder: Decoder[BigDecimal] =
mirrorColumnDecoder[BigDecimal](x =>
x match {
case bd: BigDecimal => Some(bd); case i: Int => Some(BigDecimal(i));
case l: Long => Some(BigDecimal(l));
case d: Double => Some(BigDecimal(d)); case _ => None
}
)
given booleanDecoder: Decoder[Boolean] = mirrorColumnDecoder[Boolean](x =>
x match { case b: Boolean => Some(b); case _ => None }
)
given byteDecoder: Decoder[Byte] = mirrorColumnDecoder[Byte](x =>
x match {
case b: Byte => Some(b); case i: Int => Some(i.toByte); case _ => None
}
)
given shortDecoder: Decoder[Short] = mirrorColumnDecoder[Short](x =>
x match {
case s: Short => Some(s); case i: Int => Some(i.toShort); case _ => None
}
)
given intDecoder: Decoder[Int] = mirrorColumnDecoder[Int](x =>
x match { case i: Int => Some(i); case _ => None }
)
given longDecoder: Decoder[Long] = mirrorColumnDecoder[Long](x =>
x match {
case l: Long => Some(l); case i: Int => Some(i.toLong); case _ => None
}
)
given floatDecoder: Decoder[Float] = mirrorColumnDecoder[Float](x =>
x match {
case f: Float => Some(f); case d: Double => Some(d.toFloat);
case _ => None
}
)
given doubleDecoder: Decoder[Double] = mirrorColumnDecoder[Double](x =>
x match {
case d: Double => Some(d); case f: Float => Some(f.toDouble);
case _ => None
}
)
given byteArrayDecoder: Decoder[Array[Byte]] =
mirrorColumnDecoder[Array[Byte]](x =>
x match { case ba: Array[Byte] => Some(ba); case _ => None }
)
given dateDecoder: Decoder[Date] = mirrorColumnDecoder[Date](x =>
x match { case d: Date => Some(d); case _ => None }
)
given localDateDecoder: Decoder[LocalDate] =
mirrorColumnDecoder[LocalDate](x =>
x match { case ld: LocalDate => Some(ld); case _ => None }
)
given uuidDecoder: Decoder[UUID] = mirrorColumnDecoder[UUID](x =>
x match { case uuid: UUID => Some(uuid); case _ => None }
)
// Implement all required encoders using mirrorEncoder from MirrorCodecs
given stringEncoder: Encoder[String] = mirrorEncoder[String]
given bigDecimalEncoder: Encoder[BigDecimal] = mirrorEncoder[BigDecimal]
given booleanEncoder: Encoder[Boolean] = mirrorEncoder[Boolean]
given byteEncoder: Encoder[Byte] = mirrorEncoder[Byte]
given shortEncoder: Encoder[Short] = mirrorEncoder[Short]
given intEncoder: Encoder[Int] = mirrorEncoder[Int]
given longEncoder: Encoder[Long] = mirrorEncoder[Long]
given floatEncoder: Encoder[Float] = mirrorEncoder[Float]
given doubleEncoder: Encoder[Double] = mirrorEncoder[Double]
given byteArrayEncoder: Encoder[Array[Byte]] = mirrorEncoder[Array[Byte]]
given dateEncoder: Encoder[Date] = mirrorEncoder[Date]
given localDateEncoder: Encoder[LocalDate] = mirrorEncoder[LocalDate]
given uuidEncoder: Encoder[UUID] = mirrorEncoder[UUID]
}

View file

@ -0,0 +1,17 @@
package minisql.context.sql.idiom
import minisql.util.Messages
trait ConcatSupport {
this: SqlIdiom =>
override def concatFunction = "UNNEST"
}
trait NoConcatSupport {
this: SqlIdiom =>
override def concatFunction = Messages.fail(
s"`concatMap` not supported by ${this.getClass.getSimpleName}"
)
}

View file

@ -0,0 +1,24 @@
package minisql.context.sql
import minisql.{NamingStrategy, MirrorContext}
import minisql.context.Context
import minisql.idiom.Idiom // Changed from minisql.idiom.* to avoid ambiguity with Statement
import minisql.context.mirror.MirrorCodecs
import minisql.context.ReturningClauseSupported
import minisql.context.ReturningCapability
class MirrorSqlIdiom extends idiom.SqlIdiom {
override def concatFunction: String = "CONCAT"
override def idiomReturningCapability: ReturningCapability =
ReturningClauseSupported
// Implementations previously provided by MirrorIdiomBase
override def prepareForProbing(string: String): String = string
override def liftingPlaceholder(index: Int): String = "?"
}
object MirrorSqlIdiom extends MirrorSqlIdiom
class MirrorSqlContext[N <: NamingStrategy](naming: N)
extends MirrorContext[MirrorSqlIdiom, N](MirrorSqlIdiom, naming)
with SqlContext[MirrorSqlIdiom, N]
with MirrorCodecs {}

View file

@ -0,0 +1,52 @@
package minisql.context.sql
import minisql.context.{
CanReturnClause,
CanReturnField,
CanReturnMultiField,
CannotReturn
}
import minisql.context.sql.idiom.SqlIdiom
import minisql.context.sql.idiom.QuestionMarkBindVariables
import minisql.context.sql.idiom.ConcatSupport
trait MirrorSqlDialect
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CanReturnField
trait MirrorSqlDialectWithReturnMulti
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CanReturnMultiField
trait MirrorSqlDialectWithReturnClause
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CanReturnClause
trait MirrorSqlDialectWithNoReturn
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CannotReturn
object MirrorSqlDialect extends MirrorSqlDialect {
override def prepareForProbing(string: String) = string
}
object MirrorSqlDialectWithReturnMulti extends MirrorSqlDialectWithReturnMulti {
override def prepareForProbing(string: String) = string
}
object MirrorSqlDialectWithReturnClause
extends MirrorSqlDialectWithReturnClause {
override def prepareForProbing(string: String) = string
}
object MirrorSqlDialectWithNoReturn extends MirrorSqlDialectWithNoReturn {
override def prepareForProbing(string: String) = string
}

View file

@ -0,0 +1,70 @@
package minisql.context.sql.idiom
import minisql.ast._
import minisql.idiom.StatementInterpolator._
import minisql.idiom.Token
import minisql.NamingStrategy
import minisql.util.Messages.fail
trait OnConflictSupport {
self: SqlIdiom =>
implicit def conflictTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[OnConflict] = {
val customEntityTokenizer = Tokenizer[Entity] {
case Entity.Opinionated(name, _, renameable) =>
stmt"INTO ${renameable.fixedOr(name.token)(strategy.table(name).token)} AS t"
}
val customAstTokenizer =
Tokenizer.withFallback[Ast](self.astTokenizer(using _, strategy)) {
case _: OnConflict.Excluded => stmt"EXCLUDED"
case OnConflict.Existing(a) => stmt"${a.token}"
case a: Action =>
self
.actionTokenizer(customEntityTokenizer)(
using actionAstTokenizer,
strategy
)
.token(a)
}
import OnConflict._
def doUpdateStmt(i: Token, t: Token, u: Update) = {
val assignments = u.assignments
.map(a =>
stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(using customAstTokenizer)}"
)
.mkStmt()
stmt"$i ON CONFLICT $t DO UPDATE SET $assignments"
}
def doNothingStmt(i: Ast, t: Token) =
stmt"${i.token} ON CONFLICT $t DO NOTHING"
implicit val conflictTargetPropsTokenizer: Tokenizer[Properties] =
Tokenizer[Properties] {
case OnConflict.Properties(props) =>
stmt"(${props.map(n => n.renameable.fixedOr(n.name)(strategy.column(n.name))).mkStmt(",")})"
}
def tokenizer(implicit astTokenizer: Tokenizer[Ast]) =
Tokenizer[OnConflict] {
case OnConflict(_, NoTarget, _: Update) =>
fail("'DO UPDATE' statement requires explicit conflict target")
case OnConflict(i, p: Properties, u: Update) =>
doUpdateStmt(i.token, p.token, u)
case OnConflict(i, NoTarget, Ignore) =>
stmt"${astTokenizer.token(i)} ON CONFLICT DO NOTHING"
case OnConflict(i, p: Properties, Ignore) => doNothingStmt(i, p.token)
}
tokenizer(using customAstTokenizer)
}
}

View file

@ -0,0 +1,6 @@
package minisql.context.sql.idiom
trait PositionalBindVariables { self: SqlIdiom =>
override def liftingPlaceholder(index: Int): String = s"$$${index + 1}"
}

View file

@ -0,0 +1,6 @@
package minisql.context.sql.idiom
trait QuestionMarkBindVariables { self: SqlIdiom =>
override def liftingPlaceholder(index: Int): String = s"?"
}

View file

@ -0,0 +1,44 @@
package minisql.context.sql
import java.time.LocalDate
import minisql.idiom.{Idiom => BaseIdiom}
import java.util.{Date, UUID}
import minisql.context.Context
import minisql.NamingStrategy
trait SqlContext[Idiom <: BaseIdiom, Naming <: NamingStrategy]
extends Context[Idiom, Naming] {
given optionDecoder[T](using d: Decoder[T]): Decoder[Option[T]]
given optionEncoder[T](using d: Encoder[T]): Encoder[Option[T]]
given stringDecoder: Decoder[String]
given bigDecimalDecoder: Decoder[BigDecimal]
given booleanDecoder: Decoder[Boolean]
given byteDecoder: Decoder[Byte]
given shortDecoder: Decoder[Short]
given intDecoder: Decoder[Int]
given longDecoder: Decoder[Long]
given floatDecoder: Decoder[Float]
given doubleDecoder: Decoder[Double]
given byteArrayDecoder: Decoder[Array[Byte]]
given dateDecoder: Decoder[Date]
given localDateDecoder: Decoder[LocalDate]
given uuidDecoder: Decoder[UUID]
given stringEncoder: Encoder[String]
given bigDecimalEncoder: Encoder[BigDecimal]
given booleanEncoder: Encoder[Boolean]
given byteEncoder: Encoder[Byte]
given shortEncoder: Encoder[Short]
given intEncoder: Encoder[Int]
given longEncoder: Encoder[Long]
given floatEncoder: Encoder[Float]
given doubleEncoder: Encoder[Double]
given byteArrayEncoder: Encoder[Array[Byte]]
given dateEncoder: Encoder[Date]
given localDateEncoder: Encoder[LocalDate]
given uuidEncoder: Encoder[UUID]
}

View file

@ -0,0 +1,706 @@
package minisql.context.sql.idiom
import minisql.ast._
import minisql.ast.BooleanOperator._
import minisql.ast.Lift
import minisql.context.sql._
import minisql.context.sql.norm._
import minisql.idiom._
import minisql.idiom.StatementInterpolator._
import minisql.NamingStrategy
import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.Hidden
import minisql.context.{ReturningCapability, ReturningClauseSupported}
import minisql.util.Interleave
import minisql.util.Messages.{fail, trace}
import minisql.idiom.Token
import minisql.norm.EqualityBehavior
import minisql.norm.ConcatBehavior
import minisql.norm.ConcatBehavior.AnsiConcat
import minisql.norm.EqualityBehavior.AnsiEquality
import minisql.norm.ExpandReturning
trait SqlIdiom extends Idiom {
override def prepareForProbing(string: String): String
protected def concatBehavior: ConcatBehavior = AnsiConcat
protected def equalityBehavior: EqualityBehavior = AnsiEquality
protected def actionAlias: Option[Ident] = None
override def format(queryString: String): String = queryString
def querifyAst(ast: Ast) = SqlQuery(ast)
private def doTranslate(ast: Ast, cached: Boolean)(using
naming: NamingStrategy
): (Ast, Statement) = {
val normalizedAst =
SqlNormalize(ast, concatBehavior, equalityBehavior)
given Tokenizer[Ast] = defaultTokenizer
val token =
normalizedAst match {
case q: Query =>
val sql = querifyAst(q)
trace("sql")(sql)
VerifySqlQuery(sql).map(fail)
val expanded = new ExpandNestedQueries(naming)(sql, List())
trace("expanded sql")(expanded)
val tokenized = expanded.token
trace("tokenized sql")(tokenized)
tokenized
case other =>
other.token
}
(normalizedAst, stmt"$token")
}
override def translate(
ast: Ast
)(implicit naming: NamingStrategy): (Ast, Statement) = {
doTranslate(ast, false)
}
def defaultTokenizer(using naming: NamingStrategy): Tokenizer[Ast] =
new Tokenizer[Ast] {
private val stableTokenizer = astTokenizer(using this, naming)
extension (v: Ast) {
def token = stableTokenizer.token(v)
}
}
def astTokenizer(using
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Ast] =
Tokenizer[Ast] {
case a: Query => SqlQuery(a).token
case a: Operation => a.token
case a: Infix => a.token
case a: Action => a.token
case a: Ident => a.token
case a: ExternalIdent => a.token
case a: Property => a.token
case a: Value => a.token
case a: If => a.token
case a: Lift => a.token
case a: Assignment => a.token
case a: OptionOperation => a.token
case a @ (
_: Function | _: FunctionApply | _: Dynamic | _: OptionOperation |
_: Block | _: Val | _: Ordering | _: IterableOperation |
_: OnConflict.Excluded | _: OnConflict.Existing
) =>
fail(s"Malformed or unsupported construct: $a.")
}
implicit def ifTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[If] = Tokenizer[If] {
case ast: If =>
def flatten(ast: Ast): (List[(Ast, Ast)], Ast) =
ast match {
case If(cond, a, b) =>
val (l, e) = flatten(b)
((cond, a) +: l, e)
case other =>
(List(), other)
}
val (l, e) = flatten(ast)
val conditions =
for ((cond, body) <- l) yield {
stmt"WHEN ${cond.token} THEN ${body.token}"
}
stmt"CASE ${conditions.mkStmt(" ")} ELSE ${e.token} END"
}
def concatFunction: String
protected def tokenizeGroupBy(
values: Ast
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Token =
values.token
protected class FlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
) {
import q._
def selectTokenizer =
select match {
case Nil => stmt"*"
case _ => select.token
}
def distinctTokenizer = (
distinct match {
case DistinctKind.Distinct => stmt"DISTINCT "
case DistinctKind.DistinctOn(props) =>
stmt"DISTINCT ON (${props.token}) "
case DistinctKind.None => stmt""
}
)
def withDistinct = stmt"$distinctTokenizer${selectTokenizer}"
def withFrom =
from match {
case Nil => withDistinct
case head :: tail =>
val t = tail.foldLeft(stmt"${head.token}") {
case (a, b: FlatJoinContext) =>
stmt"$a ${(b: FromContext).token}"
case (a, b) =>
stmt"$a, ${b.token}"
}
stmt"$withDistinct FROM $t"
}
def withWhere =
where match {
case None => withFrom
case Some(where) => stmt"$withFrom WHERE ${where.token}"
}
def withGroupBy =
groupBy match {
case None => withWhere
case Some(groupBy) =>
stmt"$withWhere GROUP BY ${tokenizeGroupBy(groupBy)}"
}
def withOrderBy =
orderBy match {
case Nil => withGroupBy
case orderBy => stmt"$withGroupBy ${tokenOrderBy(orderBy)}"
}
def withLimitOffset = limitOffsetToken(withOrderBy).token((limit, offset))
def apply = stmt"SELECT $withLimitOffset"
}
implicit def sqlQueryTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] {
case q: FlattenSqlQuery =>
new FlattenSqlQueryTokenizerHelper(q).apply
case SetOperationSqlQuery(a, op, b) =>
stmt"(${a.token}) ${op.token} (${b.token})"
case UnaryOperationSqlQuery(op, q) =>
stmt"SELECT ${op.token} (${q.token})"
}
protected def tokenizeColumn(
strategy: NamingStrategy,
column: String,
renameable: Renameable
) =
renameable match {
case Fixed => column
case _ => strategy.column(column)
}
protected def tokenizeTable(
strategy: NamingStrategy,
table: String,
renameable: Renameable
) =
renameable match {
case Fixed => table
case _ => strategy.table(table)
}
protected def tokenizeAlias(strategy: NamingStrategy, table: String) =
strategy.default(table)
implicit def selectValueTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[SelectValue] = {
def tokenizer(implicit astTokenizer: Tokenizer[Ast]) =
Tokenizer[SelectValue] {
case SelectValue(ast, Some(alias), false) => {
stmt"${ast.token} AS ${alias.token}"
}
case SelectValue(ast, Some(alias), true) =>
stmt"${concatFunction.token}(${ast.token}) AS ${alias.token}"
case selectValue =>
val value =
selectValue match {
case SelectValue(Ident("?"), _, _) => "?".token
case SelectValue(Ident(name), _, _) =>
stmt"${strategy.default(name).token}.*"
case SelectValue(ast, _, _) => ast.token
}
selectValue.concat match {
case true => stmt"${concatFunction.token}(${value.token})"
case false => value
}
}
val customAstTokenizer =
Tokenizer.withFallback[Ast](
SqlIdiom.this.astTokenizer(using _, strategy)
) {
case Aggregation(op, Ident(_) | Tuple(_)) => stmt"${op.token}(*)"
case Aggregation(op, Distinct(ast)) =>
stmt"${op.token}(DISTINCT ${ast.token})"
case ast @ Aggregation(op, _: Query) => scopedTokenizer(ast)
case Aggregation(op, ast) => stmt"${op.token}(${ast.token})"
}
tokenizer(using customAstTokenizer)
}
implicit def operationTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Operation] = Tokenizer[Operation] {
case UnaryOperation(op, ast) => stmt"${op.token} (${ast.token})"
case BinaryOperation(a, EqualityOperator.`==`, NullValue) =>
stmt"${scopedTokenizer(a)} IS NULL"
case BinaryOperation(NullValue, EqualityOperator.`==`, b) =>
stmt"${scopedTokenizer(b)} IS NULL"
case BinaryOperation(a, EqualityOperator.`!=`, NullValue) =>
stmt"${scopedTokenizer(a)} IS NOT NULL"
case BinaryOperation(NullValue, EqualityOperator.`!=`, b) =>
stmt"${scopedTokenizer(b)} IS NOT NULL"
case BinaryOperation(a, StringOperator.`startsWith`, b) =>
stmt"${scopedTokenizer(a)} LIKE (${(BinaryOperation(b, StringOperator.`concat`, Constant("%")): Ast).token})"
case BinaryOperation(a, op @ StringOperator.`split`, b) =>
stmt"${op.token}(${scopedTokenizer(a)}, ${scopedTokenizer(b)})"
case BinaryOperation(a, op @ SetOperator.`contains`, b) =>
SetContainsToken(scopedTokenizer(b), op.token, a.token)
case BinaryOperation(a, op @ `&&`, b) =>
(a, b) match {
case (BinaryOperation(_, `||`, _), BinaryOperation(_, `||`, _)) =>
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
case (BinaryOperation(_, `||`, _), _) =>
stmt"${scopedTokenizer(a)} ${op.token} ${b.token}"
case (_, BinaryOperation(_, `||`, _)) =>
stmt"${a.token} ${op.token} ${scopedTokenizer(b)}"
case _ => stmt"${a.token} ${op.token} ${b.token}"
}
case BinaryOperation(a, op @ `||`, b) =>
stmt"${a.token} ${op.token} ${b.token}"
case BinaryOperation(a, op, b) =>
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
case e: FunctionApply => fail(s"Can't translate the ast to sql: '$e'")
}
implicit def optionOperationTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] {
case OptionIsEmpty(ast) => stmt"${ast.token} IS NULL"
case OptionNonEmpty(ast) => stmt"${ast.token} IS NOT NULL"
case OptionIsDefined(ast) => stmt"${ast.token} IS NOT NULL"
case other => fail(s"Malformed or unsupported construct: $other.")
}
implicit val setOperationTokenizer: Tokenizer[SetOperation] =
Tokenizer[SetOperation] {
case UnionOperation => stmt"UNION"
case UnionAllOperation => stmt"UNION ALL"
}
protected def limitOffsetToken(
query: Statement
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) =
Tokenizer[(Option[Ast], Option[Ast])] {
case (None, None) => query
case (Some(limit), None) => stmt"$query LIMIT ${limit.token}"
case (Some(limit), Some(offset)) =>
stmt"$query LIMIT ${limit.token} OFFSET ${offset.token}"
case (None, Some(offset)) => stmt"$query OFFSET ${offset.token}"
}
protected def tokenOrderBy(
criterias: List[OrderByCriteria]
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) =
stmt"ORDER BY ${criterias.token}"
implicit def sourceTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[FromContext] = Tokenizer[FromContext] {
case TableContext(name, alias) =>
stmt"${name.token} ${tokenizeAlias(strategy, alias).token}"
case QueryContext(query, alias) =>
stmt"(${query.token}) AS ${tokenizeAlias(strategy, alias).token}"
case InfixContext(infix, alias) if infix.noParen =>
stmt"${(infix: Ast).token} AS ${strategy.default(alias).token}"
case InfixContext(infix, alias) =>
stmt"(${(infix: Ast).token}) AS ${strategy.default(alias).token}"
case JoinContext(t, a, b, on) =>
stmt"${a.token} ${t.token} ${b.token} ON ${on.token}"
case FlatJoinContext(t, a, on) => stmt"${t.token} ${a.token} ON ${on.token}"
}
implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] {
case JoinType.InnerJoin => stmt"INNER JOIN"
case JoinType.LeftJoin => stmt"LEFT JOIN"
case JoinType.RightJoin => stmt"RIGHT JOIN"
case JoinType.FullJoin => stmt"FULL JOIN"
}
implicit def orderByCriteriaTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] {
case OrderByCriteria(ast, PropertyOrdering.Asc) =>
stmt"${scopedTokenizer(ast)} ASC"
case OrderByCriteria(ast, PropertyOrdering.Desc) =>
stmt"${scopedTokenizer(ast)} DESC"
case OrderByCriteria(ast, PropertyOrdering.AscNullsFirst) =>
stmt"${scopedTokenizer(ast)} ASC NULLS FIRST"
case OrderByCriteria(ast, PropertyOrdering.DescNullsFirst) =>
stmt"${scopedTokenizer(ast)} DESC NULLS FIRST"
case OrderByCriteria(ast, PropertyOrdering.AscNullsLast) =>
stmt"${scopedTokenizer(ast)} ASC NULLS LAST"
case OrderByCriteria(ast, PropertyOrdering.DescNullsLast) =>
stmt"${scopedTokenizer(ast)} DESC NULLS LAST"
}
implicit val unaryOperatorTokenizer: Tokenizer[UnaryOperator] =
Tokenizer[UnaryOperator] {
case NumericOperator.`-` => stmt"-"
case BooleanOperator.`!` => stmt"NOT"
case StringOperator.`toUpperCase` => stmt"UPPER"
case StringOperator.`toLowerCase` => stmt"LOWER"
case StringOperator.`toLong` => stmt"" // cast is implicit
case StringOperator.`toInt` => stmt"" // cast is implicit
case SetOperator.`isEmpty` => stmt"NOT EXISTS"
case SetOperator.`nonEmpty` => stmt"EXISTS"
}
implicit val aggregationOperatorTokenizer: Tokenizer[AggregationOperator] =
Tokenizer[AggregationOperator] {
case AggregationOperator.`min` => stmt"MIN"
case AggregationOperator.`max` => stmt"MAX"
case AggregationOperator.`avg` => stmt"AVG"
case AggregationOperator.`sum` => stmt"SUM"
case AggregationOperator.`size` => stmt"COUNT"
}
implicit val binaryOperatorTokenizer: Tokenizer[BinaryOperator] =
Tokenizer[BinaryOperator] {
case EqualityOperator.`==` => stmt"="
case EqualityOperator.`!=` => stmt"<>"
case BooleanOperator.`&&` => stmt"AND"
case BooleanOperator.`||` => stmt"OR"
case StringOperator.`concat` => stmt"||"
case StringOperator.`startsWith` =>
fail("bug: this code should be unreachable")
case StringOperator.`split` => stmt"SPLIT"
case NumericOperator.`-` => stmt"-"
case NumericOperator.`+` => stmt"+"
case NumericOperator.`*` => stmt"*"
case NumericOperator.`>` => stmt">"
case NumericOperator.`>=` => stmt">="
case NumericOperator.`<` => stmt"<"
case NumericOperator.`<=` => stmt"<="
case NumericOperator.`/` => stmt"/"
case NumericOperator.`%` => stmt"%"
case SetOperator.`contains` => stmt"IN"
}
implicit def propertyTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Property] = {
def unnest(ast: Ast): (Ast, List[String]) =
ast match {
case Property.Opinionated(a, _, _, Hidden) =>
unnest(a) match {
case (a, nestedName) => (a, nestedName)
}
// Append the property name. This includes tuple indexes.
case Property(a, name) =>
unnest(a) match {
case (ast, nestedName) =>
(ast, nestedName :+ name)
}
case a => (a, Nil)
}
def tokenizePrefixedProperty(
name: String,
prefix: List[String],
strategy: NamingStrategy,
renameable: Renameable
) =
renameable.fixedOr(
(prefix.mkString + name).token
)(tokenizeColumn(strategy, prefix.mkString + name, renameable).token)
Tokenizer[Property] {
case Property.Opinionated(
ast,
name,
renameable,
_ /* Top level property cannot be invisible */
) =>
// When we have things like Embedded tables, properties inside of one another needs to be un-nested.
// E.g. in `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` the inner
// property needs to be unwrapped and the result of this should only be `realTable.realPropertyAlias`
// as opposed to `realTable.embeddedTableAlias.realPropertyAlias`.
unnest(ast) match {
// When using ExternalIdent such as .returning(eid => eid.idColumn) clauses drop the 'eid' since SQL
// returning clauses have no alias for the original table. I.e. INSERT [...] RETURNING idColumn there's no
// alias you can assign to the INSERT [...] clause that can be used as a prefix to 'idColumn'.
// In this case, `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)`
// should just be `realPropertyAlias` as opposed to `realTable.realPropertyAlias`.
// The exception to this is when a Query inside of a RETURNING clause is used. In that case, assume
// that there is an alias for the inserted table (i.e. `INSERT ... as theAlias values ... RETURNING`)
// and the instances of ExternalIdent use it.
case (ExternalIdent(_), prefix) =>
stmt"${actionAlias
.map(alias => stmt"${scopedTokenizer(alias)}.")
.getOrElse(stmt"")}${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
// In the rare case that the Ident is invisible, do not show it. See the Ident documentation for more info.
case (Ident.Opinionated(_, Hidden), prefix) =>
stmt"${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
// The normal case where `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)`
// becomes `realTable.realPropertyAlias`.
case (ast, prefix) =>
stmt"${scopedTokenizer(ast)}.${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
}
}
}
implicit def valueTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Value] = Tokenizer[Value] {
case Constant(v: String) => stmt"'${v.token}'"
case Constant(()) => stmt"1"
case Constant(v) => stmt"${v.toString.token}"
case NullValue => stmt"null"
case Tuple(values) => stmt"${values.token}"
case CaseClass(values) => stmt"${values.map(_._2).token}"
}
implicit def infixTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Infix] = Tokenizer[Infix] {
case Infix(parts, params, _, _) =>
val pt = parts.map(_.token)
val pr = params.map(_.token)
Statement(Interleave(pt, pr))
}
implicit def identTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Ident] =
Tokenizer[Ident](e => strategy.default(e.name).token)
implicit def externalIdentTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[ExternalIdent] =
Tokenizer[ExternalIdent](e => strategy.default(e.name).token)
implicit def assignmentTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Assignment] = Tokenizer[Assignment] {
case Assignment(alias, prop, value) =>
stmt"${prop.token} = ${scopedTokenizer(value)}"
}
implicit def defaultAstTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Action] = {
val insertEntityTokenizer = Tokenizer[Entity] {
case Entity.Opinionated(name, _, renameable) =>
stmt"INTO ${tokenizeTable(strategy, name, renameable).token}"
}
actionTokenizer(insertEntityTokenizer)(using actionAstTokenizer, strategy)
}
protected def actionAstTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
) =
Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(using _, strategy)) {
case q: Query => astTokenizer.token(q)
case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") =>
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NULL"
case Property(
Property.Opinionated(_, name, renameable, _),
"isDefined"
) =>
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL"
case Property(Property.Opinionated(_, name, renameable, _), "nonEmpty") =>
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL"
case Property.Opinionated(_, name, renameable, _) =>
renameable.fixedOr(name.token)(
tokenizeColumn(strategy, name, renameable).token
)
}
def returnListTokenizer(implicit
tokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[List[Ast]] = {
val customAstTokenizer =
Tokenizer.withFallback[Ast](
SqlIdiom.this.astTokenizer(using _, strategy)
) {
case sq: Query =>
stmt"(${tokenizer.token(sq)})"
}
Tokenizer[List[Ast]] {
case list =>
list.mkStmt(", ")(using customAstTokenizer)
}
}
protected def actionTokenizer(
insertEntityTokenizer: Tokenizer[Entity]
)(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Action] =
Tokenizer[Action] {
case Insert(entity: Entity, assignments) =>
val table = insertEntityTokenizer.token(entity)
val columns = assignments.map(_.property.token)
val values = assignments.map(_.value)
stmt"INSERT $table${actionAlias.map(alias => stmt" AS ${alias.token}").getOrElse(stmt"")} (${columns
.mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})"
case Update(table: Entity, assignments) =>
stmt"UPDATE ${table.token}${actionAlias
.map(alias => stmt" AS ${alias.token}")
.getOrElse(stmt"")} SET ${assignments.token}"
case Update(Filter(table: Entity, x, where), assignments) =>
stmt"UPDATE ${table.token}${actionAlias
.map(alias => stmt" AS ${alias.token}")
.getOrElse(stmt"")} SET ${assignments.token} WHERE ${where.token}"
case Delete(Filter(table: Entity, x, where)) =>
stmt"DELETE FROM ${table.token} WHERE ${where.token}"
case Delete(table: Entity) =>
stmt"DELETE FROM ${table.token}"
case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) =>
idiomReturningCapability match {
// If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as
// a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5)
case ReturningClauseSupported
if (CollectAst.byType[Entity](prop).nonEmpty) =>
SqlIdiom.withActionAlias(this, r)
case ReturningClauseSupported =>
stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer
.token(ExpandReturning(r)(this, strategy).map(_._1))}"
case other =>
stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}"
}
case r @ ReturningAction(action, alias, prop) =>
idiomReturningCapability match {
// If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as
// a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5)
case ReturningClauseSupported
if (CollectAst.byType[Entity](prop).nonEmpty) =>
SqlIdiom.withActionAlias(this, r)
case ReturningClauseSupported =>
stmt"${action.token} RETURNING ${returnListTokenizer.token(
ExpandReturning(r)(this, strategy).map(_._1)
)}"
case other =>
stmt"${action.token}"
}
case other =>
fail(s"Action ast can't be translated to sql: '$other'")
}
implicit def entityTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Entity] = Tokenizer[Entity] {
case Entity.Opinionated(name, _, renameable) =>
tokenizeTable(strategy, name, renameable).token
}
protected def scopedTokenizer(ast: Ast)(implicit tokenizer: Tokenizer[Ast]) =
ast match {
case _: Query => stmt"(${ast.token})"
case _: BinaryOperation => stmt"(${ast.token})"
case _: Tuple => stmt"(${ast.token})"
case _ => ast.token
}
}
object SqlIdiom {
private[minisql] def copyIdiom(
parent: SqlIdiom,
newActionAlias: Option[Ident]
): SqlIdiom =
new SqlIdiom {
override protected def actionAlias: Option[Ident] = newActionAlias
override def prepareForProbing(string: String): String =
parent.prepareForProbing(string)
override def concatFunction: String = parent.concatFunction
override def liftingPlaceholder(index: Int): String =
parent.liftingPlaceholder(index)
override def idiomReturningCapability: ReturningCapability =
parent.idiomReturningCapability
}
/**
* Construct a new instance of the specified idiom with `newActionAlias`
* variable specified so that actions (i.e. insert, and update) will be
* rendered with the specified alias. This is needed for RETURNING clauses
* that have queries inside. See #1509 for details.
*/
private[minisql] def withActionAlias(
parentIdiom: SqlIdiom,
query: ReturningAction
)(implicit strategy: NamingStrategy) = {
val idiom = copyIdiom(parentIdiom, Some(query.alias))
import idiom._
implicit val stableTokenizer: Tokenizer[Ast] = idiom.astTokenizer(using
new Tokenizer[Ast] { self =>
extension (v: Ast) {
def token = astTokenizer(using self, strategy).token(v)
}
},
strategy
)
query match {
case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) =>
stmt"INSERT INTO ${table.token} AS ${alias.name.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer
.token(ExpandReturning(r)(idiom, strategy).map(_._1))}"
case r @ ReturningAction(action, alias, prop) =>
stmt"${action.token} RETURNING ${returnListTokenizer.token(
ExpandReturning(r)(idiom, strategy).map(_._1)
)}"
case r =>
fail(s"Unsupported Returning construct: $r")
}
}
}

View file

@ -0,0 +1,326 @@
package minisql.context.sql
import minisql.ast._
import minisql.context.sql.norm.FlattenGroupByAggregation
import minisql.norm.BetaReduction
import minisql.util.Messages.fail
import minisql.{Literal, PseudoAst, NamingStrategy}
case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering)
sealed trait FromContext
case class TableContext(entity: Entity, alias: String) extends FromContext
case class QueryContext(query: SqlQuery, alias: String) extends FromContext
case class InfixContext(infix: Infix, alias: String) extends FromContext
case class JoinContext(t: JoinType, a: FromContext, b: FromContext, on: Ast)
extends FromContext
case class FlatJoinContext(t: JoinType, a: FromContext, on: Ast)
extends FromContext
sealed trait SqlQuery {
override def toString = {
import MirrorSqlDialect.*
import minisql.idiom.StatementInterpolator.*
given Tokenizer[SqlQuery] = sqlQueryTokenizer(using
defaultTokenizer(using Literal),
Literal
)
summon[Tokenizer[SqlQuery]].token(this).toString()
}
}
sealed trait SetOperation
case object UnionOperation extends SetOperation
case object UnionAllOperation extends SetOperation
sealed trait DistinctKind { def isDistinct: Boolean }
case object DistinctKind {
case object Distinct extends DistinctKind { val isDistinct: Boolean = true }
case class DistinctOn(props: List[Ast]) extends DistinctKind {
val isDistinct: Boolean = true
}
case object None extends DistinctKind { val isDistinct: Boolean = false }
}
case class SetOperationSqlQuery(
a: SqlQuery,
op: SetOperation,
b: SqlQuery
) extends SqlQuery
case class UnaryOperationSqlQuery(
op: UnaryOperator,
q: SqlQuery
) extends SqlQuery
case class SelectValue(
ast: Ast,
alias: Option[String] = None,
concat: Boolean = false
) extends PseudoAst {
override def toString: String =
s"${ast.toString}${alias.map("->" + _).getOrElse("")}"
}
case class FlattenSqlQuery(
from: List[FromContext] = List(),
where: Option[Ast] = None,
groupBy: Option[Ast] = None,
orderBy: List[OrderByCriteria] = Nil,
limit: Option[Ast] = None,
offset: Option[Ast] = None,
select: List[SelectValue],
distinct: DistinctKind = DistinctKind.None
) extends SqlQuery
object TakeDropFlatten {
def unapply(q: Query): Option[(Query, Option[Ast], Option[Ast])] = q match {
case Take(q: FlatMap, n) => Some((q, Some(n), None))
case Drop(q: FlatMap, n) => Some((q, None, Some(n)))
case _ => None
}
}
object SqlQuery {
def apply(query: Ast): SqlQuery =
query match {
case Union(a, b) =>
SetOperationSqlQuery(apply(a), UnionOperation, apply(b))
case UnionAll(a, b) =>
SetOperationSqlQuery(apply(a), UnionAllOperation, apply(b))
case UnaryOperation(op, q: Query) => UnaryOperationSqlQuery(op, apply(q))
case _: Operation | _: Value =>
FlattenSqlQuery(select = List(SelectValue(query)))
case Map(q, a, b) if a == b => apply(q)
case TakeDropFlatten(q, limit, offset) =>
flatten(q, "x").copy(limit = limit, offset = offset)
case q: Query => flatten(q, "x")
case infix: Infix => flatten(infix, "x")
case other =>
fail(
s"Query not properly normalized. Please open a bug report. Ast: '$other'"
)
}
private def flatten(query: Ast, alias: String): FlattenSqlQuery = {
val (sources, finalFlatMapBody) = flattenContexts(query)
flatten(sources, finalFlatMapBody, alias)
}
private def flattenContexts(query: Ast): (List[FromContext], Ast) =
query match {
case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Query) =>
val source = this.source(q, alias)
val (nestedContexts, finalFlatMapBody) = flattenContexts(p)
(source +: nestedContexts, finalFlatMapBody)
case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Infix) =>
fail(s"Infix can't be use as a `flatMap` body. $query")
case other =>
(List.empty, other)
}
object NestedNest {
def unapply(q: Ast): Option[Ast] =
q match {
case _: Nested => recurse(q)
case _ => None
}
private def recurse(q: Ast): Option[Ast] =
q match {
case Nested(qn) => recurse(qn)
case other => Some(other)
}
}
private def flatten(
sources: List[FromContext],
finalFlatMapBody: Ast,
alias: String
): FlattenSqlQuery = {
def select(alias: String) = SelectValue(Ident(alias), None) :: Nil
def base(q: Ast, alias: String) = {
def nest(ctx: FromContext) =
FlattenSqlQuery(from = sources :+ ctx, select = select(alias))
q match {
case Map(_: GroupBy, _, _) => nest(source(q, alias))
case NestedNest(q) => nest(QueryContext(apply(q), alias))
case q: ConcatMap => nest(QueryContext(apply(q), alias))
case Join(tpe, a, b, iA, iB, on) =>
val ctx = source(q, alias)
def aliases(ctx: FromContext): List[String] =
ctx match {
case TableContext(_, alias) => alias :: Nil
case QueryContext(_, alias) => alias :: Nil
case InfixContext(_, alias) => alias :: Nil
case JoinContext(_, a, b, _) => aliases(a) ::: aliases(b)
case FlatJoinContext(_, a, _) => aliases(a)
}
FlattenSqlQuery(
from = ctx :: Nil,
select = aliases(ctx).map(a => SelectValue(Ident(a), None))
)
case q @ (_: Map | _: Filter | _: Entity) => flatten(sources, q, alias)
case q if (sources == Nil) => flatten(sources, q, alias)
case other => nest(source(q, alias))
}
}
finalFlatMapBody match {
case ConcatMap(q, Ident(alias), p) =>
FlattenSqlQuery(
from = source(q, alias) :: Nil,
select = selectValues(p).map(_.copy(concat = true))
)
case Map(GroupBy(q, x @ Ident(alias), g), a, p) =>
val b = base(q, alias)
val select = BetaReduction(p, a -> Tuple(List(g, x)))
val flattenSelect = FlattenGroupByAggregation(x)(select)
b.copy(groupBy = Some(g), select = this.selectValues(flattenSelect))
case GroupBy(q, Ident(alias), p) =>
fail("A `groupBy` clause must be followed by `map`.")
case Map(q, Ident(alias), p) =>
val b = base(q, alias)
val agg = b.select.collect {
case s @ SelectValue(_: Aggregation, _, _) => s
}
if (!b.distinct.isDistinct && agg.isEmpty)
b.copy(select = selectValues(p))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
select = selectValues(p)
)
case Filter(q, Ident(alias), p) =>
val b = base(q, alias)
if (b.where.isEmpty)
b.copy(where = Some(p))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
where = Some(p),
select = select(alias)
)
case SortBy(q, Ident(alias), p, o) =>
val b = base(q, alias)
val criterias = orderByCriterias(p, o)
if (b.orderBy.isEmpty)
b.copy(orderBy = criterias)
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
orderBy = criterias,
select = select(alias)
)
case Aggregation(op, q: Query) =>
val b = flatten(q, alias)
b.select match {
case head :: Nil if !b.distinct.isDistinct =>
b.copy(select = List(head.copy(ast = Aggregation(op, head.ast))))
case other =>
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
select = List(SelectValue(Aggregation(op, Ident("*"))))
)
}
case Take(q, n) =>
val b = base(q, alias)
if (b.limit.isEmpty)
b.copy(limit = Some(n))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
limit = Some(n),
select = select(alias)
)
case Drop(q, n) =>
val b = base(q, alias)
if (b.offset.isEmpty && b.limit.isEmpty)
b.copy(offset = Some(n))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
offset = Some(n),
select = select(alias)
)
case Distinct(q: Query) =>
val b = base(q, alias)
b.copy(distinct = DistinctKind.Distinct)
case DistinctOn(q, Ident(alias), fields) =>
val distinctList =
fields match {
case Tuple(values) => values
case other => List(other)
}
q match {
// Ideally we don't need to make an extra sub-query for every single case of
// distinct-on but it only works when the parent AST is an entity. That's because DistinctOn
// selects from an alias of an outer clause. For example, query[Person].map(p => Name(p.firstName, p.lastName)).distinctOn(_.name)
// (Let's say Person(firstName, lastName, age), Name(first, last)) will turn into
// SELECT DISTINCT ON (p.name), p.firstName AS first, p.lastName AS last, p.age FROM Person
// This doesn't work beause `name` in `p.name` doesn't exist yet. Therefore we have to nest this in a subquery:
// SELECT DISTINCT ON (p.name) FROM (SELECT p.firstName AS first, p.lastName AS last, p.age FROM Person p) AS p
// The only exception to this is if we are directly selecting from an entity:
// query[Person].distinctOn(_.firstName) which should be fine: SELECT (x.firstName), x.firstName, x.lastName, a.age FROM Person x
// since all the fields inside the (...) of the DISTINCT ON must be contained in the entity.
case _: Entity =>
val b = base(q, alias)
b.copy(distinct = DistinctKind.DistinctOn(distinctList))
case _ =>
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
select = select(alias),
distinct = DistinctKind.DistinctOn(distinctList)
)
}
case other =>
FlattenSqlQuery(
from = sources :+ source(other, alias),
select = select(alias)
)
}
}
private def selectValues(ast: Ast) =
ast match {
case Tuple(values) => values.map(SelectValue(_))
case other => SelectValue(ast) :: Nil
}
private def source(ast: Ast, alias: String): FromContext =
ast match {
case entity: Entity => TableContext(entity, alias)
case infix: Infix => InfixContext(infix, alias)
case Join(t, a, b, ia, ib, on) =>
JoinContext(t, source(a, ia.name), source(b, ib.name), on)
case FlatJoin(t, a, ia, on) => FlatJoinContext(t, source(a, ia.name), on)
case Nested(q) => QueryContext(apply(q), alias)
case other => QueryContext(apply(other), alias)
}
private def orderByCriterias(ast: Ast, ordering: Ast): List[OrderByCriteria] =
(ast, ordering) match {
case (Tuple(properties), ord: PropertyOrdering) =>
properties.flatMap(orderByCriterias(_, ord))
case (Tuple(properties), TupleOrdering(ord)) =>
properties.zip(ord).flatMap { case (a, o) => orderByCriterias(a, o) }
case (a, o: PropertyOrdering) => List(OrderByCriteria(a, o))
case other => fail(s"Invalid order by criteria $ast")
}
}

View file

@ -0,0 +1,122 @@
package minisql.context.sql.idiom
import minisql.ast._
import minisql.context.sql._
import minisql.norm.FreeVariables
case class Error(free: List[Ident], ast: Ast)
case class InvalidSqlQuery(errors: List[Error]) {
override def toString =
s"The monad composition can't be expressed using applicative joins. " +
errors
.map(error =>
s"Faulty expression: '${error.ast}'. Free variables: '${error.free}'."
)
.mkString(", ")
}
object VerifySqlQuery {
def apply(query: SqlQuery): Option[String] =
verify(query).map(_.toString)
private def verify(query: SqlQuery): Option[InvalidSqlQuery] =
query match {
case q: FlattenSqlQuery => verify(q)
case SetOperationSqlQuery(a, op, b) => verify(a).orElse(verify(b))
case UnaryOperationSqlQuery(op, q) => verify(q)
}
private def verifyFlatJoins(q: FlattenSqlQuery) = {
def loop(l: List[FromContext], available: Set[String]): Set[String] =
l.foldLeft(available) {
case (av, TableContext(_, alias)) => Set(alias)
case (av, InfixContext(_, alias)) => Set(alias)
case (av, QueryContext(_, alias)) => Set(alias)
case (av, JoinContext(_, a, b, on)) =>
av ++ loop(a :: Nil, av) ++ loop(b :: Nil, av)
case (av, FlatJoinContext(_, a, on)) =>
val nav = av ++ loop(a :: Nil, av)
val free = FreeVariables(on).map(_.name)
val invalid = free -- nav
require(
invalid.isEmpty,
s"Found an `ON` table reference of a table that is not available: $invalid. " +
"The `ON` condition can only use tables defined through explicit joins."
)
nav
}
loop(q.from, Set())
}
private def verify(query: FlattenSqlQuery): Option[InvalidSqlQuery] = {
verifyFlatJoins(query)
val aliases =
query.from.flatMap(this.aliases).map(Ident(_)) :+ Ident("*") :+ Ident("?")
def verifyAst(ast: Ast) = {
val freeVariables =
(FreeVariables(ast) -- aliases).toList
val freeIdents =
(CollectAst(ast) {
case ast: Property => None
case Aggregation(_, _: Ident) => None
case ast: Ident => Some(ast)
}).flatten
(freeVariables ++ freeIdents) match {
case Nil => None
case free => Some(Error(free, ast))
}
}
// Recursively expand children until values are fully flattened. Identities in all these should
// be skipped during verification.
def expandSelect(sv: SelectValue): List[SelectValue] =
sv.ast match {
case Tuple(values) =>
values.map(v => SelectValue(v)).flatMap(expandSelect(_))
case CaseClass(values) =>
values.map(v => SelectValue(v._2)).flatMap(expandSelect(_))
case _ => List(sv)
}
val freeVariableErrors: List[Error] =
query.where.flatMap(verifyAst).toList ++
query.orderBy.map(_.ast).flatMap(verifyAst) ++
query.limit.flatMap(verifyAst) ++
query.select
.flatMap(
expandSelect(_)
) // Expand tuple select clauses so their top-level identities are skipped
.map(_.ast)
.filterNot(_.isInstanceOf[Ident])
.flatMap(verifyAst) ++
query.from.flatMap {
case j: JoinContext => verifyAst(j.on)
case j: FlatJoinContext => verifyAst(j.on)
case _ => Nil
}
val nestedErrors =
query.from.collect {
case QueryContext(query, alias) => verify(query).map(_.errors)
}.flatten.flatten
(freeVariableErrors ++ nestedErrors) match {
case Nil => None
case errors => Some(InvalidSqlQuery(errors))
}
}
private def aliases(s: FromContext): List[String] =
s match {
case s: TableContext => List(s.alias)
case s: QueryContext => List(s.alias)
case s: InfixContext => List(s.alias)
case s: JoinContext => aliases(s.a) ++ aliases(s.b)
case s: FlatJoinContext => aliases(s.a)
}
}

View file

@ -0,0 +1,47 @@
package minisql.context.sql.norm
import minisql.ast.Constant
import minisql.context.sql.{FlattenSqlQuery, SqlQuery, _}
/**
* In SQL Server, `Order By` clauses are only allowed in sub-queries if the
* sub-query has a `TOP` or `OFFSET` modifier. Otherwise an exception will be
* thrown. This transformation adds a 'dummy' `OFFSET 0` in this scenario (if an
* `Offset` clause does not exist already).
*/
object AddDropToNestedOrderBy {
def applyInner(q: SqlQuery): SqlQuery =
q match {
case q: FlattenSqlQuery =>
q.copy(
offset =
if (q.orderBy.nonEmpty) q.offset.orElse(Some(Constant(0)))
else q.offset,
from = q.from.map(applyInner(_))
)
case SetOperationSqlQuery(a, op, b) =>
SetOperationSqlQuery(applyInner(a), op, applyInner(b))
case UnaryOperationSqlQuery(op, a) =>
UnaryOperationSqlQuery(op, applyInner(a))
}
private def applyInner(f: FromContext): FromContext =
f match {
case QueryContext(a, alias) => QueryContext(applyInner(a), alias)
case JoinContext(t, a, b, on) =>
JoinContext(t, applyInner(a), applyInner(b), on)
case FlatJoinContext(t, a, on) => FlatJoinContext(t, applyInner(a), on)
case other => other
}
def apply(q: SqlQuery): SqlQuery =
q match {
case q: FlattenSqlQuery => q.copy(from = q.from.map(applyInner(_)))
case SetOperationSqlQuery(a, op, b) =>
SetOperationSqlQuery(applyInner(a), op, applyInner(b))
case UnaryOperationSqlQuery(op, a) =>
UnaryOperationSqlQuery(op, applyInner(a))
}
}

View file

@ -0,0 +1,68 @@
package minisql.context.sql.norm
import minisql.ast.Visibility.Hidden
import minisql.ast._
object ExpandDistinct {
@annotation.tailrec
def hasJoin(q: Ast): Boolean = {
q match {
case _: Join => true
case Map(q, _, _) => hasJoin(q)
case Filter(q, _, _) => hasJoin(q)
case _ => false
}
}
def apply(q: Ast): Ast =
q match {
case Distinct(q) =>
Distinct(apply(q))
case q =>
Transform(q) {
case Aggregation(op, Distinct(q)) =>
Aggregation(op, Distinct(apply(q)))
case Distinct(Map(q, x, cc @ Tuple(values))) =>
Map(
Distinct(Map(q, x, cc)),
x,
Tuple(values.zipWithIndex.map {
case (_, i) => Property(x, s"_${i + 1}")
})
)
// Situations like this:
// case class AdHocCaseClass(id: Int, name: String)
// val q = quote {
// query[SomeTable].map(st => AdHocCaseClass(st.id, st.name)).distinct
// }
// ... need some special treatment. Otherwise their values will not be correctly expanded.
case Distinct(Map(q, x, cc @ CaseClass(values))) =>
Map(
Distinct(Map(q, x, cc)),
x,
CaseClass(values.map {
case (name, _) => (name, Property(x, name))
})
)
// Need some special handling to address issues with distinct returning a single embedded entity i.e:
// query[Parent].map(p => p.emb).distinct.map(e => (e.name, e.id))
// cannot treat such a case normally or "confused" queries will result e.g:
// SELECT p.embname, p.embid FROM (SELECT DISTINCT emb.name /* Where the heck is 'emb' coming from? */ AS embname, emb.id AS embid FROM Parent p) AS p
case d @ Distinct(
Map(q, x, p @ Property.Opinionated(_, _, _, Hidden))
) =>
d
// Problems with distinct were first discovered in #1032. Basically, unless
// the distinct is "expanded" adding an outer map, Ident's representing a Table will end up in invalid places
// such as "ORDER BY tableIdent" etc...
case Distinct(Map(q, x, p)) =>
val newMap = Map(q, x, Tuple(List(p)))
val newIdent = Ident(x.name)
Map(Distinct(newMap), newIdent, Property(newIdent, "_1"))
}
}
}

View file

@ -0,0 +1,49 @@
package minisql.context.sql.norm
import minisql.ast._
import minisql.norm.BetaReduction
import minisql.norm.Normalize
object ExpandJoin {
def apply(q: Ast) = expand(q, None)
def expand(q: Ast, id: Option[Ident]) =
Transform(q) {
case q @ Join(_, _, _, Ident(a), Ident(b), _) =>
val (qr, tuple) = expandedTuple(q)
Map(qr, id.getOrElse(Ident(s"$a$b")), tuple)
}
private def expandedTuple(q: Join): (Join, Tuple) =
q match {
case Join(t, a: Join, b: Join, tA, tB, o) =>
val (ar, at) = expandedTuple(a)
val (br, bt) = expandedTuple(b)
val or = BetaReduction(o, tA -> at, tB -> bt)
(Join(t, ar, br, tA, tB, or), Tuple(List(at, bt)))
case Join(t, a: Join, b, tA, tB, o) =>
val (ar, at) = expandedTuple(a)
val or = BetaReduction(o, tA -> at)
(Join(t, ar, b, tA, tB, or), Tuple(List(at, tB)))
case Join(t, a, b: Join, tA, tB, o) =>
val (br, bt) = expandedTuple(b)
val or = BetaReduction(o, tB -> bt)
(Join(t, a, br, tA, tB, or), Tuple(List(tA, bt)))
case q @ Join(t, a, b, tA, tB, on) =>
(
Join(t, nestedExpand(a, tA), nestedExpand(b, tB), tA, tB, on),
Tuple(List(tA, tB))
)
}
private def nestedExpand(q: Ast, id: Ident) =
Normalize(expand(q, Some(id))) match {
case Map(q, _, _) => q
case q => q
}
}

View file

@ -0,0 +1,12 @@
package minisql.context.sql.norm
import minisql.ast._
object ExpandMappedInfix {
def apply(q: Ast): Ast = {
Transform(q) {
case Map(Infix("" :: parts, (q: Query) :: params, pure, noParen), x, p) =>
Infix("" :: parts, Map(q, x, p) :: params, pure, noParen)
}
}
}

View file

@ -0,0 +1,147 @@
package minisql.context.sql.norm
import minisql.NamingStrategy
import minisql.ast.Ast
import minisql.ast.Ident
import minisql.ast._
import minisql.ast.StatefulTransformer
import minisql.ast.Visibility.Visible
import minisql.context.sql._
import scala.collection.mutable.LinkedHashSet
import minisql.util.Interpolator
import minisql.util.Messages.TraceType.NestedQueryExpansion
import minisql.context.sql.norm.nested.ExpandSelect
import minisql.norm.BetaReduction
import scala.collection.mutable
class ExpandNestedQueries(strategy: NamingStrategy) {
val interp = new Interpolator(3)
import interp._
def apply(q: SqlQuery, references: List[Property]): SqlQuery =
apply(q, LinkedHashSet.empty ++ references)
// Using LinkedHashSet despite the fact that it is mutable because it has better characteristics then ListSet.
// Also this collection is strictly internal to ExpandNestedQueries and exposed anywhere else.
private def apply(
q: SqlQuery,
references: LinkedHashSet[Property]
): SqlQuery =
q match {
case q: FlattenSqlQuery =>
val expand = expandNested(
q.copy(select = ExpandSelect(q.select, references, strategy))
)
trace"Expanded Nested Query $q into $expand".andLog()
expand
case SetOperationSqlQuery(a, op, b) =>
SetOperationSqlQuery(apply(a, references), op, apply(b, references))
case UnaryOperationSqlQuery(op, q) =>
UnaryOperationSqlQuery(op, apply(q, references))
}
private def expandNested(q: FlattenSqlQuery): SqlQuery =
q match {
case FlattenSqlQuery(
from,
where,
groupBy,
orderBy,
limit,
offset,
select,
distinct
) =>
val asts = Nil ++ select.map(_.ast) ++ where ++ groupBy ++ orderBy.map(
_.ast
) ++ limit ++ offset
val expansions = q.from.map(expandContext(_, asts))
val from = expansions.map(_._1)
val references = expansions.flatMap(_._2)
val replacedRefs = references.map(ref => (ref, unhideAst(ref)))
// Need to unhide properties that were used during the query
def replaceProps(ast: Ast) =
BetaReduction(ast, replacedRefs*)
def replacePropsOption(ast: Option[Ast]) =
ast.map(replaceProps(_))
val distinctKind =
q.distinct match {
case DistinctKind.DistinctOn(props) =>
DistinctKind.DistinctOn(props.map(p => replaceProps(p)))
case other => other
}
q.copy(
select = select.map(sv => sv.copy(ast = replaceProps(sv.ast))),
from = from,
where = replacePropsOption(where),
groupBy = replacePropsOption(groupBy),
orderBy = orderBy.map(ob => ob.copy(ast = replaceProps(ob.ast))),
limit = replacePropsOption(limit),
offset = replacePropsOption(offset),
distinct = distinctKind
)
}
def unhideAst(ast: Ast): Ast =
Transform(ast) {
case Property.Opinionated(a, n, r, v) =>
Property.Opinionated(unhideAst(a), n, r, Visible)
}
private def unhideProperties(sv: SelectValue) =
sv.copy(ast = unhideAst(sv.ast))
private def expandContext(
s: FromContext,
asts: List[Ast]
): (FromContext, LinkedHashSet[Property]) =
s match {
case QueryContext(q, alias) =>
val refs = references(alias, asts)
(QueryContext(apply(q, refs), alias), refs)
case JoinContext(t, a, b, on) =>
val (left, leftRefs) = expandContext(a, asts :+ on)
val (right, rightRefs) = expandContext(b, asts :+ on)
(JoinContext(t, left, right, on), leftRefs ++ rightRefs)
case FlatJoinContext(t, a, on) =>
val (next, refs) = expandContext(a, asts :+ on)
(FlatJoinContext(t, next, on), refs)
case _: TableContext | _: InfixContext =>
(s, new mutable.LinkedHashSet[Property]())
}
private def references(alias: String, asts: List[Ast]) =
LinkedHashSet.empty ++ (References(State(Ident(alias), Nil))(asts)(
_.apply
)._2.state.references)
}
case class State(ident: Ident, references: List[Property])
case class References(val state: State) extends StatefulTransformer[State] {
import state._
override def apply(a: Ast) =
a match {
case `reference`(p) => (p, References(State(ident, references :+ p)))
case other => super.apply(a)
}
object reference {
def unapply(p: Property): Option[Property] =
p match {
case Property(`ident`, name) => Some(p)
case Property(reference(_), name) => Some(p)
case other => None
}
}
}

View file

@ -0,0 +1,58 @@
package minisql.context.sql.norm
import minisql.ast.Aggregation
import minisql.ast.Ast
import minisql.ast.Drop
import minisql.ast.Filter
import minisql.ast.FlatMap
import minisql.ast.Ident
import minisql.ast.Join
import minisql.ast.Map
import minisql.ast.Query
import minisql.ast.SortBy
import minisql.ast.StatelessTransformer
import minisql.ast.Take
import minisql.ast.Union
import minisql.ast.UnionAll
import minisql.norm.BetaReduction
import minisql.util.Messages.fail
import minisql.ast.ConcatMap
case class FlattenGroupByAggregation(agg: Ident) extends StatelessTransformer {
override def apply(ast: Ast) =
ast match {
case q: Query if (isGroupByAggregation(q)) =>
q match {
case Aggregation(op, Map(`agg`, ident, body)) =>
Aggregation(op, BetaReduction(body, ident -> agg))
case Map(`agg`, ident, body) =>
BetaReduction(body, ident -> agg)
case q @ Aggregation(op, `agg`) =>
q
case other =>
fail(s"Invalid group by aggregation: '$other'")
}
case other =>
super.apply(other)
}
private def isGroupByAggregation(ast: Ast): Boolean =
ast match {
case Aggregation(a, b) => isGroupByAggregation(b)
case Map(a, b, c) => isGroupByAggregation(a)
case FlatMap(a, b, c) => isGroupByAggregation(a)
case ConcatMap(a, b, c) => isGroupByAggregation(a)
case Filter(a, b, c) => isGroupByAggregation(a)
case SortBy(a, b, c, d) => isGroupByAggregation(a)
case Take(a, b) => isGroupByAggregation(a)
case Drop(a, b) => isGroupByAggregation(a)
case Union(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b)
case UnionAll(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b)
case Join(t, a, b, ta, tb, on) =>
isGroupByAggregation(a) || isGroupByAggregation(b)
case `agg` => true
case other => false
}
}

View file

@ -0,0 +1,53 @@
package minisql.context.sql.norm
import minisql.norm._
import minisql.ast.Ast
import minisql.norm.ConcatBehavior.AnsiConcat
import minisql.norm.EqualityBehavior.AnsiEquality
import minisql.norm.capture.DemarcateExternalAliases
import minisql.util.Messages.trace
object SqlNormalize {
def apply(
ast: Ast,
concatBehavior: ConcatBehavior = AnsiConcat,
equalityBehavior: EqualityBehavior = AnsiEquality
) =
new SqlNormalize(concatBehavior, equalityBehavior)(ast)
}
class SqlNormalize(
concatBehavior: ConcatBehavior,
equalityBehavior: EqualityBehavior
) {
private val normalize =
(identity[Ast])
.andThen(trace("original"))
.andThen(DemarcateExternalAliases.apply)
.andThen(trace("DemarcateReturningAliases"))
.andThen(new FlattenOptionOperation(concatBehavior).apply)
.andThen(trace("FlattenOptionOperation"))
.andThen(new SimplifyNullChecks(equalityBehavior).apply)
.andThen(trace("SimplifyNullChecks"))
.andThen(Normalize.apply)
.andThen(trace("Normalize"))
// Need to do RenameProperties before ExpandJoin which normalizes-out all the tuple indexes
// on which RenameProperties relies
.andThen(RenameProperties.apply)
.andThen(trace("RenameProperties"))
.andThen(ExpandDistinct.apply)
.andThen(trace("ExpandDistinct"))
.andThen(NestImpureMappedInfix.apply)
.andThen(trace("NestMappedInfix"))
.andThen(Normalize.apply)
.andThen(trace("Normalize"))
.andThen(ExpandJoin.apply)
.andThen(trace("ExpandJoin"))
.andThen(ExpandMappedInfix.apply)
.andThen(trace("ExpandMappedInfix"))
.andThen(Normalize.apply)
.andThen(trace("Normalize"))
def apply(ast: Ast) = normalize(ast)
}

View file

@ -0,0 +1,29 @@
package minisql.context.sql.norm.nested
import minisql.PseudoAst
import minisql.context.sql.SelectValue
object Elements {
/**
* In order to be able to reconstruct the original ordering of elements inside
* of a select clause, we need to keep track of their order, not only within
* the top-level select but also it's order within any possible
* tuples/case-classes that in which it is embedded. For example, in the
* query: <pre><code> query[Person].map(p => (p.id, (p.name, p.age))).nested
* // SELECT p.id, p.name, p.age FROM (SELECT x.id, x.name, x.age FROM person
* x) AS p </code></pre> Since the `p.name` and `p.age` elements are selected
* inside of a sub-tuple, their "order" is `List(2,1)` and `List(2,2)`
* respectively as opposed to `p.id` whose "order" is just `List(1)`.
*
* This class keeps track of the values needed in order to perform do this.
*/
case class OrderedSelect(order: List[Int], selectValue: SelectValue)
extends PseudoAst {
override def toString: String = s"[${order.mkString(",")}]${selectValue}"
}
object OrderedSelect {
def apply(order: Int, selectValue: SelectValue) =
new OrderedSelect(List(order), selectValue)
}
}

View file

@ -0,0 +1,262 @@
package minisql.context.sql.norm.nested
import minisql.NamingStrategy
import minisql.ast.Property
import minisql.context.sql.SelectValue
import minisql.util.Interpolator
import minisql.util.Messages.TraceType.NestedQueryExpansion
import scala.collection.mutable.LinkedHashSet
import minisql.context.sql.norm.nested.Elements._
import minisql.ast._
import minisql.norm.BetaReduction
/**
* Takes the `SelectValue` elements inside of a sub-query (if a super/sub-query
* constrct exists) and flattens them from a nested-hiearchical structure (i.e.
* tuples inside case classes inside tuples etc..) into into a single series of
* top-level select elements where needed. In cases where a user wants to select
* an element that contains an entire tuple (i.e. a sub-tuple of the outer
* select clause) we pull out the entire tuple that is being selected and leave
* it to the tokenizer to flatten later.
*
* The part about this operation that is tricky is if there are situations where
* there are infix clauses in a sub-query representing an element that has not
* been selected by the query-query but in order to ensure the SQL operation has
* the same meaning, we need to keep track for it. For example: <pre><code> val
* q = quote { query[Person].map(p => (infix"DISTINCT ON (${p.other})".as[Int],
* p.name, p.id)).map(t => (t._2, t._3)) } run(q) // SELECT p._2, p._3 FROM
* (SELECT DISTINCT ON (p.other), p.name AS _2, p.id AS _3 FROM Person p) AS p
* </code></pre> Since `DISTINCT ON` significantly changes the behavior of the
* outer query, we need to keep track of it inside of the inner query. In order
* to do this, we need to keep track of the location of the infix in the inner
* query so that we can reconstruct it. This is why the `OrderedSelect` and
* `DoubleOrderedSelect` objects are used. See the notes on these classes for
* more detail.
*
* See issue #1597 for more details and another example.
*/
private class ExpandSelect(
selectValues: List[SelectValue],
references: LinkedHashSet[Property],
strategy: NamingStrategy
) {
val interp = new Interpolator(3)
import interp._
object TupleIndex {
def unapply(s: String): Option[Int] =
if (s.matches("_[0-9]*"))
Some(s.drop(1).toInt - 1)
else
None
}
object MultiTupleIndex {
def unapply(s: String): Boolean =
if (s.matches("(_[0-9]+)+"))
true
else
false
}
val select =
selectValues.zipWithIndex.map {
case (value, index) => OrderedSelect(index, value)
}
def expandColumn(name: String, renameable: Renameable): String =
renameable.fixedOr(name)(strategy.column(name))
def apply: List[SelectValue] =
trace"Expanding Select values: $selectValues into references: $references" andReturn {
def expandReference(ref: Property): OrderedSelect =
trace"Expanding: $ref from $select" andReturn {
def expressIfTupleIndex(str: String) =
str match {
case MultiTupleIndex() => Some(str)
case _ => None
}
def concat(alias: Option[String], idx: Int) =
Some(s"${alias.getOrElse("")}_${idx + 1}")
val orderedSelect = ref match {
case pp @ Property(ast: Property, TupleIndex(idx)) =>
trace"Reference is a sub-property of a tuple index: $idx. Walking inside." `andReturn`
expandReference(ast) match {
case OrderedSelect(o, SelectValue(Tuple(elems), alias, c)) =>
trace"Expressing Element $idx of $elems " `andReturn`
OrderedSelect(
o :+ idx,
SelectValue(elems(idx), concat(alias, idx), c)
)
case OrderedSelect(o, SelectValue(ast, alias, c)) =>
trace"Appending $idx to $alias " `andReturn`
OrderedSelect(o, SelectValue(ast, concat(alias, idx), c))
}
case pp @ Property.Opinionated(
ast: Property,
name,
renameable,
visible
) =>
trace"Reference is a sub-property. Walking inside." `andReturn`
expandReference(ast) match {
case OrderedSelect(o, SelectValue(ast, nested, c)) =>
// Alias is the name of the column after the naming strategy
// The clauses in `SqlIdiom` that use `Tokenizer[SelectValue]` select the
// alias field when it's value is Some(T).
// Technically the aliases of a column should not be using naming strategies
// but this is an issue to fix at a later date.
// In the current implementation, aliases we add nested tuple names to queries e.g.
// SELECT foo from
// SELECT x, y FROM (SELECT foo, bar, red, orange FROM baz JOIN colors)
// Typically becomes SELECT foo _1foo, _1bar, _2red, _2orange when
// this kind of query is the result of an applicative join that looks like this:
// query[baz].join(query[colors]).nested
// this may need to change based on how distinct appends table names instead of just tuple indexes
// into the property path.
trace"...inside walk completed, continuing to return: " `andReturn`
OrderedSelect(
o,
SelectValue(
// Note: Pass invisible properties to be tokenized by the idiom, they should be excluded there
Property.Opinionated(ast, name, renameable, visible),
// Skip concatonation of invisible properties into the alias e.g. so it will be
Some(
s"${nested.getOrElse("")}${expandColumn(name, renameable)}"
)
)
)
}
case pp @ Property(_, TupleIndex(idx)) =>
trace"Reference is a tuple index: $idx from $select." `andReturn`
select(idx) match {
case OrderedSelect(o, SelectValue(ast, alias, c)) =>
OrderedSelect(o, SelectValue(ast, concat(alias, idx), c))
}
case pp @ Property.Opinionated(_, name, renameable, visible) =>
select match {
case List(
OrderedSelect(o, SelectValue(cc: CaseClass, alias, c))
) =>
// Currently case class element name is not being appended. Need to change that in order to ensure
// path name uniqueness in future.
val ((_, ast), index) =
cc.values.zipWithIndex.find(_._1._1 == name) match {
case Some(v) => v
case None =>
throw new IllegalArgumentException(
s"Cannot find element $name in $cc"
)
}
trace"Reference is a case class member: " `andReturn`
OrderedSelect(
o :+ index,
SelectValue(ast, Some(expandColumn(name, renameable)), c)
)
case List(OrderedSelect(o, SelectValue(i: Ident, _, c))) =>
trace"Reference is an identifier: " `andReturn`
OrderedSelect(
o,
SelectValue(
Property.Opinionated(i, name, renameable, visible),
Some(name),
c
)
)
case other =>
trace"Reference is unidentified: $other returning:" `andReturn`
OrderedSelect(
Integer.MAX_VALUE,
SelectValue(
Ident.Opinionated(name, visible),
Some(expandColumn(name, renameable)),
false
)
)
}
}
// For certain very large queries where entities are unwrapped and then re-wrapped into CaseClass/Tuple constructs,
// the actual row-types can contain Tuple/CaseClass values. For this reason. They need to be beta-reduced again.
val normalizedOrderedSelect = orderedSelect.copy(selectValue =
orderedSelect.selectValue.copy(ast =
BetaReduction(orderedSelect.selectValue.ast)
)
)
trace"Expanded $ref into $orderedSelect then Normalized to $normalizedOrderedSelect" `andReturn`
normalizedOrderedSelect
}
def deAliasWhenUneeded(os: OrderedSelect) =
os match {
case OrderedSelect(
_,
sv @ SelectValue(Property(Ident(_), propName), Some(alias), _)
) if (propName == alias) =>
trace"Detected select value with un-needed alias: $os removing it:" `andReturn`
os.copy(selectValue = sv.copy(alias = None))
case _ => os
}
references.toList match {
case Nil => select.map(_.selectValue)
case refs => {
// elements first need to be sorted by their order in the select clause. Since some may map to multiple
// properties when expanded, we want to maintain this order of properties as a secondary value.
val mappedRefs =
refs
// Expand all the references to properties that we have selected in the super query
.map(expandReference)
// Once all the recursive calls of expandReference are done, remove the alias if it is not needed.
// We cannot do this because during recursive calls, the aliases of outer clauses are used for inner ones.
.map(deAliasWhenUneeded(_))
trace"Mapped Refs: $mappedRefs".andLog()
// are there any selects that have infix values which we have not already selected? We need to include
// them because they could be doing essential things e.g. RANK ... ORDER BY
val remainingSelectsWithInfixes =
trace"Searching Selects with Infix:" `andReturn`
new FindUnexpressedInfixes(select)(mappedRefs)
implicit val ordering: scala.math.Ordering[List[Int]] =
new scala.math.Ordering[List[Int]] {
override def compare(x: List[Int], y: List[Int]): Int =
(x, y) match {
case (head1 :: tail1, head2 :: tail2) =>
val diff = head1 - head2
if (diff != 0) diff
else compare(tail1, tail2)
case (Nil, Nil) => 0 // List(1,2,3) == List(1,2,3)
case (head1, Nil) => -1 // List(1,2,3) < List(1,2)
case (Nil, head2) => 1 // List(1,2) > List(1,2,3)
}
}
val sortedRefs =
(mappedRefs ++ remainingSelectsWithInfixes).sortBy(ref =>
ref.order
) // (ref.order, ref.secondaryOrder)
sortedRefs.map(_.selectValue)
}
}
}
}
object ExpandSelect {
def apply(
selectValues: List[SelectValue],
references: LinkedHashSet[Property],
strategy: NamingStrategy
): List[SelectValue] =
new ExpandSelect(selectValues, references, strategy).apply
}

View file

@ -0,0 +1,83 @@
package minisql.context.sql.norm.nested
import minisql.context.sql.norm.nested.Elements._
import minisql.util.Interpolator
import minisql.util.Messages.TraceType.NestedQueryExpansion
import minisql.ast._
import minisql.context.sql.SelectValue
/**
* The challenge with appeneding infixes (that have not been used but are still
* needed) back into the query, is that they could be inside of
* tuples/case-classes that have already been selected, or inside of sibling
* elements which have been selected. Take for instance a query that looks like
* this: <pre><code> query[Person].map(p => (p.name, (p.id,
* infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2._1)) </code></pre> In
* this situation, `p.id` which is the sibling of the non-selected infix has
* been selected via `p._2._1` (whose select-order is List(1,0) to represent 1st
* element in 2nd tuple. We need to add it's sibling infix.
*
* Or take the following situation: <pre><code> query[Person].map(p => (p.name,
* (p.id, infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2))
* </code></pre> In this case, we have selected the entire 2nd element including
* the infix. We need to know that `P._2._2` does not need to be selected since
* `p._2` was.
*
* In order to do these things, we use the `order` property from `OrderedSelect`
* in order to see which sub-sub-...-element has been selected. If `p._2` (that
* has order `List(1)`) has been selected, we know that any infixes inside of it
* e.g. `p._2._1` (ordering `List(1,0)`) does not need to be.
*/
class FindUnexpressedInfixes(select: List[OrderedSelect]) {
val interp = new Interpolator(3)
import interp._
def apply(refs: List[OrderedSelect]) = {
def pathExists(path: List[Int]) =
refs.map(_.order).contains(path)
def containsInfix(ast: Ast) =
CollectAst.byType[Infix](ast).length > 0
// build paths to every infix and see these paths were not selected already
def findMissingInfixes(
ast: Ast,
parentOrder: List[Int]
): List[(Ast, List[Int])] = {
trace"Searching for infix: $ast in the sub-path $parentOrder".andLog()
if (pathExists(parentOrder))
trace"No infixes found" `andContinue`
List()
else
ast match {
case Tuple(values) =>
values.zipWithIndex
.filter(v => containsInfix(v._1))
.flatMap {
case (ast, index) =>
findMissingInfixes(ast, parentOrder :+ index)
}
case CaseClass(values) =>
values.zipWithIndex
.filter(v => containsInfix(v._1._2))
.flatMap {
case ((_, ast), index) =>
findMissingInfixes(ast, parentOrder :+ index)
}
case other if (containsInfix(other)) =>
trace"Found unexpressed infix inside $other in $parentOrder"
.andLog()
List((other, parentOrder))
case _ =>
List()
}
}
select.flatMap {
case OrderedSelect(o, sv) => findMissingInfixes(sv.ast, o)
}.map {
case (ast, order) => OrderedSelect(order, SelectValue(ast))
}
}
}

View file

@ -1,59 +0,0 @@
package minisql.dsl
import minisql.*
import minisql.parsing.*
import minisql.ast.{Ast, Entity, Map, Property, Ident, given}
import scala.quoted.*
import scala.compiletime.*
import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap}
opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted
opaque type EntityQuery[E] <: Query[E] = Query[E]
extension [E](inline e: EntityQuery[E]) {
inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
transform(e)(f)(Map.apply)
}
}
private inline def transform[A, B](inline q1: Quoted)(
inline f: A => B
)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
fast(q1, f.param0, f.body)
}
inline def query[E](inline table: String): EntityQuery[E] =
Entity(table, Nil)
inline def compile(inline x: Ast): Option[String] = ${
compileImpl('{ x })
}
private def compileImpl(
x: Expr[Ast]
)(using Quotes): Expr[Option[String]] = {
import quotes.reflect.*
x.value match {
case Some(xv) => '{ Some(${ Expr(xv.toString()) }) }
case None => '{ None }
}
}
extension [A, B](inline f1: A => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def body = parsing.parseBody(f1)
}
extension [A1, A2, B](inline f1: (A1, A2) => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def param1 = parsing.parseParamAt(f1, 1)
private inline def body = parsing.parseBody(f1)
}
case class Foo(id: Int)
inline def queryFooId = query[Foo]("foo").map(_.id)

View file

@ -0,0 +1,23 @@
package minisql.idiom
import minisql.NamingStrategy
import minisql.ast._
import minisql.context.Capabilities
trait Idiom extends Capabilities {
def emptySetContainsToken(field: Token): Token = StringToken("FALSE")
def defaultAutoGeneratedToken(field: Token): Token = StringToken(
"DEFAULT VALUES"
)
def liftingPlaceholder(index: Int): String
def translate(ast: Ast)(using naming: NamingStrategy): (Ast, Statement)
def format(queryString: String): String = queryString
def prepareForProbing(string: String): String
}

View file

@ -0,0 +1,28 @@
package minisql.idiom
import scala.util.Try
import scala.quoted._
import minisql.NamingStrategy
import minisql.util.CollectTry
import minisql.util.LoadObject
import minisql.CompositeNamingStrategy
object LoadNaming {
def static[C](using Quotes, Type[C]): Try[NamingStrategy] = CollectTry {
strategies[C].map(LoadObject[NamingStrategy](_))
}.map(NamingStrategy(_))
private def strategies[C](using Quotes, Type[C]) = {
import quotes.reflect.*
val isComposite = TypeRepr.of[C] <:< TypeRepr.of[CompositeNamingStrategy]
val ct = TypeRepr.of[C]
if (isComposite) {
ct.typeArgs.filterNot { t =>
t =:= TypeRepr.of[NamingStrategy] && t =:= TypeRepr.of[Nothing]
}
} else {
List(ct)
}
}
}

View file

@ -0,0 +1,356 @@
package minisql.idiom
import minisql.NamingStrategy
import minisql.ast.Renameable.{ByStrategy, Fixed}
import minisql.ast.Visibility.Hidden
import minisql.ast.*
import minisql.context.CanReturnClause
import minisql.idiom.{Idiom, SetContainsToken, Statement}
import minisql.idiom.StatementInterpolator.*
import minisql.norm.Normalize
import minisql.util.Interleave
object MirrorIdiom extends MirrorIdiom
class MirrorIdiom extends MirrorIdiomBase with CanReturnClause
object MirrorIdiomPrinting extends MirrorIdiom {
override def distinguishHidden: Boolean = true
}
trait MirrorIdiomBase extends Idiom {
def distinguishHidden: Boolean = false
override def prepareForProbing(string: String) = string
override def liftingPlaceholder(index: Int): String = "?"
override def translate(
ast: Ast
)(implicit naming: NamingStrategy): (Ast, Statement) = {
val normalizedAst = Normalize(ast)
(normalizedAst, stmt"${normalizedAst.token}")
}
implicit def astTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Ast] = Tokenizer[Ast] {
case ast: Query => ast.token
case ast: Function => ast.token
case ast: Value => ast.token
case ast: Operation => ast.token
case ast: Action => ast.token
case ast: Ident => ast.token
case ast: ExternalIdent => ast.token
case ast: Property => ast.token
case ast: Infix => ast.token
case ast: OptionOperation => ast.token
case ast: IterableOperation => ast.token
case ast: Dynamic => ast.token
case ast: If => ast.token
case ast: Block => ast.token
case ast: Val => ast.token
case ast: Ordering => ast.token
case ast: Lift => ast.token
case ast: Assignment => ast.token
case ast: OnConflict.Excluded => ast.token
case ast: OnConflict.Existing => ast.token
}
implicit def ifTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[If] = Tokenizer[If] {
case If(a, b, c) => stmt"if(${a.token}) ${b.token} else ${c.token}"
}
implicit val dynamicTokenizer: Tokenizer[Dynamic] = Tokenizer[Dynamic] {
case Dynamic(tree) => stmt"${tree.toString.token}"
}
implicit def blockTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Block] = Tokenizer[Block] {
case Block(statements) => stmt"{ ${statements.map(_.token).mkStmt("; ")} }"
}
implicit def valTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Val] = Tokenizer[Val] {
case Val(name, body) => stmt"val ${name.token} = ${body.token}"
}
implicit def queryTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Query] = Tokenizer[Query] {
case Entity.Opinionated(name, Nil, renameable) =>
stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token})"
case Entity.Opinionated(name, prop, renameable) =>
val properties =
prop.map(p => stmt"""_.${p.path.mkStmt(".")} -> "${p.alias.token}"""")
stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token}, ${properties.token})"
case Filter(source, alias, body) =>
stmt"${source.token}.filter(${alias.token} => ${body.token})"
case Map(source, alias, body) =>
stmt"${source.token}.map(${alias.token} => ${body.token})"
case FlatMap(source, alias, body) =>
stmt"${source.token}.flatMap(${alias.token} => ${body.token})"
case ConcatMap(source, alias, body) =>
stmt"${source.token}.concatMap(${alias.token} => ${body.token})"
case SortBy(source, alias, body, ordering) =>
stmt"${source.token}.sortBy(${alias.token} => ${body.token})(${ordering.token})"
case GroupBy(source, alias, body) =>
stmt"${source.token}.groupBy(${alias.token} => ${body.token})"
case Aggregation(op, ast) =>
stmt"${scopedTokenizer(ast)}.${op.token}"
case Take(source, n) =>
stmt"${source.token}.take(${n.token})"
case Drop(source, n) =>
stmt"${source.token}.drop(${n.token})"
case Union(a, b) =>
stmt"${a.token}.union(${b.token})"
case UnionAll(a, b) =>
stmt"${a.token}.unionAll(${b.token})"
case Join(t, a, b, iA, iB, on) =>
stmt"${a.token}.${t.token}(${b.token}).on((${iA.token}, ${iB.token}) => ${on.token})"
case FlatJoin(t, a, iA, on) =>
stmt"${a.token}.${t.token}((${iA.token}) => ${on.token})"
case Distinct(a) =>
stmt"${a.token}.distinct"
case DistinctOn(source, alias, body) =>
stmt"${source.token}.distinctOn(${alias.token} => ${body.token})"
case Nested(a) =>
stmt"${a.token}.nested"
}
implicit val orderingTokenizer: Tokenizer[Ordering] = Tokenizer[Ordering] {
case TupleOrdering(elems) => stmt"Ord(${elems.token})"
case PropertyOrdering.Asc => stmt"Ord.asc"
case PropertyOrdering.Desc => stmt"Ord.desc"
case PropertyOrdering.AscNullsFirst => stmt"Ord.ascNullsFirst"
case PropertyOrdering.DescNullsFirst => stmt"Ord.descNullsFirst"
case PropertyOrdering.AscNullsLast => stmt"Ord.ascNullsLast"
case PropertyOrdering.DescNullsLast => stmt"Ord.descNullsLast"
}
implicit def optionOperationTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] {
case OptionTableFlatMap(ast, alias, body) =>
stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
case OptionTableMap(ast, alias, body) =>
stmt"${ast.token}.map((${alias.token}) => ${body.token})"
case OptionTableExists(ast, alias, body) =>
stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
case OptionTableForall(ast, alias, body) =>
stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
case OptionFlatten(ast) => stmt"${ast.token}.flatten"
case OptionGetOrElse(ast, body) =>
stmt"${ast.token}.getOrElse(${body.token})"
case OptionFlatMap(ast, alias, body) =>
stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
case OptionMap(ast, alias, body) =>
stmt"${ast.token}.map((${alias.token}) => ${body.token})"
case OptionForall(ast, alias, body) =>
stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
case OptionExists(ast, alias, body) =>
stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
case OptionContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
case OptionIsEmpty(ast) => stmt"${ast.token}.isEmpty"
case OptionNonEmpty(ast) => stmt"${ast.token}.nonEmpty"
case OptionIsDefined(ast) => stmt"${ast.token}.isDefined"
case OptionSome(ast) => stmt"Some(${ast.token})"
case OptionApply(ast) => stmt"Option(${ast.token})"
case OptionOrNull(ast) => stmt"${ast.token}.orNull"
case OptionGetOrNull(ast) => stmt"${ast.token}.getOrNull"
case OptionNone => stmt"None"
}
implicit def traversableOperationTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[IterableOperation] = Tokenizer[IterableOperation] {
case MapContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
case SetContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
case ListContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
}
implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] {
case JoinType.InnerJoin => stmt"join"
case JoinType.LeftJoin => stmt"leftJoin"
case JoinType.RightJoin => stmt"rightJoin"
case JoinType.FullJoin => stmt"fullJoin"
}
implicit def functionTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Function] = Tokenizer[Function] {
case Function(params, body) => stmt"(${params.token}) => ${body.token}"
}
implicit def operationTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Operation] = Tokenizer[Operation] {
case UnaryOperation(op: PrefixUnaryOperator, ast) =>
stmt"${op.token}${scopedTokenizer(ast)}"
case UnaryOperation(op: PostfixUnaryOperator, ast) =>
stmt"${scopedTokenizer(ast)}.${op.token}"
case BinaryOperation(a, op @ SetOperator.`contains`, b) =>
SetContainsToken(scopedTokenizer(b), op.token, a.token)
case BinaryOperation(a, op, b) =>
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
case FunctionApply(function, values) =>
stmt"${scopedTokenizer(function)}.apply(${values.token})"
}
implicit def operatorTokenizer[T <: Operator]: Tokenizer[T] = Tokenizer[T] {
case o => stmt"${o.toString.token}"
}
def tokenizeName(name: String, renameable: Renameable) =
renameable match {
case ByStrategy => name
case Fixed => s"`${name}`"
}
def bracketIfHidden(name: String, visibility: Visibility) =
(distinguishHidden, visibility) match {
case (true, Hidden) => s"[$name]"
case _ => name
}
implicit def propertyTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Property] = Tokenizer[Property] {
case Property.Opinionated(ExternalIdent(_), name, renameable, visibility) =>
stmt"${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
case Property.Opinionated(ref, name, renameable, visibility) =>
stmt"${scopedTokenizer(ref)}.${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
}
implicit val valueTokenizer: Tokenizer[Value] = Tokenizer[Value] {
case Constant(v: String) => stmt""""${v.token}""""
case Constant(()) => stmt"{}"
case Constant(v) => stmt"${v.toString.token}"
case NullValue => stmt"null"
case Tuple(values) => stmt"(${values.token})"
case CaseClass(values) =>
stmt"CaseClass(${values.map { case (k, v) => s"${k.token}: ${v.token}" }.mkString(", ").token})"
}
implicit val identTokenizer: Tokenizer[Ident] = Tokenizer[Ident] {
case Ident.Opinionated(name, visibility) =>
stmt"${bracketIfHidden(name, visibility).token}"
}
implicit val typeTokenizer: Tokenizer[ExternalIdent] =
Tokenizer[ExternalIdent] {
case e => stmt"${e.name.token}"
}
implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] =
Tokenizer[OnConflict.Excluded] {
case OnConflict.Excluded(ident) => stmt"${ident.token}"
}
implicit val existingTokenizer: Tokenizer[OnConflict.Existing] =
Tokenizer[OnConflict.Existing] {
case OnConflict.Existing(ident) => stmt"${ident.token}"
}
implicit def actionTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Action] = Tokenizer[Action] {
case Update(query, assignments) =>
stmt"${query.token}.update(${assignments.token})"
case Insert(query, assignments) =>
stmt"${query.token}.insert(${assignments.token})"
case Delete(query) => stmt"${query.token}.delete"
case Returning(query, alias, body) =>
stmt"${query.token}.returning((${alias.token}) => ${body.token})"
case ReturningGenerated(query, alias, body) =>
stmt"${query.token}.returningGenerated((${alias.token}) => ${body.token})"
case Foreach(query, alias, body) =>
stmt"${query.token}.foreach((${alias.token}) => ${body.token})"
case c: OnConflict => stmt"${c.token}"
}
implicit def conflictTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[OnConflict] = {
def targetProps(l: List[Property]) = l.map(p =>
Transform(p) {
case Ident(_) => Ident("_")
}
)
implicit val conflictTargetTokenizer: Tokenizer[OnConflict.Target] =
Tokenizer[OnConflict.Target] {
case OnConflict.NoTarget => stmt""
case OnConflict.Properties(props) =>
val listTokens = props.token
stmt"(${listTokens})"
}
val updateAssignsTokenizer = Tokenizer[Assignment] {
case Assignment(i, p, v) =>
stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}"
}
Tokenizer[OnConflict] {
case OnConflict(i, t, OnConflict.Update(assign)) =>
stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})"
case OnConflict(i, t, OnConflict.Ignore) =>
stmt"${i.token}.onConflictIgnore${t.token}"
}
}
implicit def assignmentTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Assignment] = Tokenizer[Assignment] {
case Assignment(ident, property, value) =>
stmt"${ident.token} => ${property.token} -> ${value.token}"
}
implicit def infixTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Infix] = Tokenizer[Infix] {
case Infix(parts, params, _, _) =>
def tokenParam(ast: Ast) =
ast match {
case ast: Ident => stmt"$$${ast.token}"
case other => stmt"$${${ast.token}}"
}
val pt = parts.map(_.token)
val pr = params.map(tokenParam)
val body = Statement(Interleave(pt, pr))
stmt"""infix"${body.token}""""
}
private def scopedTokenizer(
ast: Ast
)(implicit liftTokenizer: Tokenizer[Lift]) =
ast match {
case _: Function => stmt"(${ast.token})"
case _: BinaryOperation => stmt"(${ast.token})"
case other => ast.token
}
}

View file

@ -0,0 +1,101 @@
package minisql.idiom
import minisql.ParamEncoder
import minisql.ast.*
import minisql.util.Interleave
import minisql.idiom.StatementInterpolator.*
import scala.annotation.tailrec
import scala.collection.immutable.{Map => SMap}
object ReifyStatement {
def apply(
liftingPlaceholder: Int => String,
emptySetContainsToken: Token => Token,
statement: Statement,
liftMap: SMap[String, (Any, ParamEncoder[?])]
): (String, List[ScalarValueLift]) = {
val expanded = expandLiftings(statement, emptySetContainsToken, liftMap)
token2string(expanded, liftMap, liftingPlaceholder)
}
private def token2string(
token: Token,
liftMap: SMap[String, (Any, ParamEncoder[?])],
liftingPlaceholder: Int => String
): (String, List[ScalarValueLift]) = {
val liftBuilder = List.newBuilder[ScalarValueLift]
val sqlBuilder = StringBuilder()
@tailrec
def loop(
workList: Seq[Token],
liftingSize: Int
): Unit = workList match {
case Seq() => ()
case head +: tail =>
head match {
case StringToken(s2) =>
sqlBuilder ++= s2
loop(tail, liftingSize)
case SetContainsToken(a, op, b) =>
loop(
stmt"$a $op ($b)" +: tail,
liftingSize
)
case ScalarLiftToken(lift: ScalarValueLift) =>
sqlBuilder ++= liftingPlaceholder(liftingSize)
liftBuilder += lift.copy(value = liftMap.get(lift.liftId))
loop(tail, liftingSize + 1)
case ScalarLiftToken(o) =>
throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}")
case Statement(tokens) =>
loop(
tokens.foldRight(tail)(_ +: _),
liftingSize
)
}
}
loop(Vector(token), 0)
sqlBuilder.toString() -> liftBuilder.result()
}
private def expandLiftings(
statement: Statement,
emptySetContainsToken: Token => Token,
liftMap: SMap[String, (Any, ParamEncoder[?])]
): (Token) = {
Statement {
val lb = List.newBuilder[Token]
statement.tokens.foldLeft(lb) {
case (
tokens,
SetContainsToken(a, op, ScalarLiftToken(lift: ScalarQueryLift))
) =>
val (lv, le) = liftMap(lift.liftId)
lv.asInstanceOf[Iterable[Any]].toVector match {
case Vector() => tokens += emptySetContainsToken(a)
case values =>
val liftings = values.zipWithIndex.map {
case (v, i) =>
ScalarLiftToken(
ScalarValueLift(
s"${lift.name}[${i}]",
s"${lift.liftId}[${i}]",
Some(v -> le)
)
)
}
val separators = Vector.fill(liftings.size - 1)(StringToken(", "))
(tokens += stmt"$a $op (") ++= Interleave(
liftings,
separators
) += StringToken(")")
}
case (tokens, token) =>
tokens += token
}
lb.result()
}
}
}

View file

@ -0,0 +1,47 @@
package minisql.idiom
import scala.quoted._
import minisql.ast._
sealed trait Token
case class StringToken(string: String) extends Token {
override def toString = string
}
case class ScalarLiftToken(lift: ScalarLift) extends Token {
override def toString = s"lift(${lift.name})"
}
case class Statement(tokens: List[Token]) extends Token {
override def toString = tokens.mkString
}
case class SetContainsToken(a: Token, op: Token, b: Token) extends Token {
override def toString = s"${a.toString} ${op.toString} (${b.toString})"
}
object Statement {
given ToExpr[Statement] with {
def apply(t: Statement)(using Quotes): Expr[Statement] = {
'{ Statement(${ Expr(t.tokens) }) }
}
}
}
object Token {
given ToExpr[Token] with {
def apply(t: Token)(using Quotes): Expr[Token] = {
t match {
case StringToken(s) =>
'{ StringToken(${ Expr(s) }) }
case ScalarLiftToken(l) =>
'{ ScalarLiftToken(${ Expr(l) }) }
case SetContainsToken(a, op, b) =>
'{ SetContainsToken(${ Expr(a) }, ${ Expr(op) }, ${ Expr(b) }) }
case s: Statement => Expr(s)
}
}
}
}

View file

@ -0,0 +1,131 @@
package minisql.idiom
import minisql.ast._
import minisql.util.Interleave
import minisql.util.Messages._
import scala.collection.mutable.ListBuffer
object StatementInterpolator {
extension [T](list: List[T]) {
private[minisql] def mkStmt(
sep: String = ", "
)(using tokenize: Tokenizer[T]) = {
val l1 = list.map(_.token)
val l2 = List.fill(l1.size - 1)(StringToken(sep))
Statement(Interleave(l1, l2))
}
}
trait Tokenizer[T] {
extension (v: T) {
def token: Token
}
}
object Tokenizer {
def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] {
extension (v: T) {
def token: Token = f(v)
}
}
def withFallback[T](
fallback: Tokenizer[T] => Tokenizer[T]
)(pf: PartialFunction[T, Token]) =
new Tokenizer[T] {
extension (v: T) {
private def stable = fallback(this)
override def token = pf.applyOrElse(v, stable.token)
}
}
}
extension [T](v: T)(using tokenizer: Tokenizer[T]) {
def token = tokenizer.token(v)
}
given stringTokenizer: Tokenizer[String] =
Tokenizer[String] {
case string => StringToken(string)
}
given liftTokenizer: Tokenizer[Lift] =
Tokenizer[Lift] {
case lift: ScalarLift => ScalarLiftToken(lift)
}
given tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity)
given statementTokenizer: Tokenizer[Statement] =
Tokenizer[Statement](identity)
given stringTokenTokenizer: Tokenizer[StringToken] =
Tokenizer[StringToken](identity)
given liftingTokenTokenizer: Tokenizer[ScalarLiftToken] =
Tokenizer[ScalarLiftToken](identity)
given listTokenizer[T](using
tokenize: Tokenizer[T]
): Tokenizer[List[T]] =
Tokenizer[List[T]] {
case list => list.mkStmt()
}
extension (sc: StringContext) {
def flatten(tokens: List[Token]): List[Token] = {
def unestStatements(tokens: List[Token]): List[Token] = {
tokens.flatMap {
case Statement(innerTokens) => unestStatements(innerTokens)
case token => token :: Nil
}
}
def mergeStringTokens(tokens: List[Token]): List[Token] = {
val (resultBuilder, leftTokens) =
tokens.foldLeft((new ListBuffer[Token], new ListBuffer[String])) {
case ((builder, acc), stringToken: StringToken) =>
val str = stringToken.string
if (str.nonEmpty)
acc += stringToken.string
(builder, acc)
case ((builder, prev), b) if prev.isEmpty =>
(builder += b.token, prev)
case ((builder, prev), b) /* if prev.nonEmpty */ =>
builder += StringToken(prev.result().mkString)
builder += b.token
(builder, new ListBuffer[String])
}
if (leftTokens.nonEmpty)
resultBuilder += StringToken(leftTokens.result().mkString)
resultBuilder.result()
}
(unestStatements)
.andThen(mergeStringTokens)
.apply(tokens)
}
def checkLengths(
args: scala.collection.Seq[Any],
parts: Seq[String]
): Unit =
if (parts.length != args.length + 1)
throw new IllegalArgumentException(
"wrong number of arguments (" + args.length
+ ") for interpolated string with " + parts.length + " parts"
)
def stmt(args: Token*): Statement = {
checkLengths(args, sc.parts)
val partsIterator = sc.parts.iterator
val argsIterator = args.iterator
val bldr = List.newBuilder[Token]
bldr += StringToken(partsIterator.next())
while (argsIterator.hasNext) {
bldr += argsIterator.next()
bldr += StringToken(partsIterator.next())
}
val tokens = flatten(bldr.result())
Statement(tokens)
}
}
}

View file

@ -0,0 +1,52 @@
package minisql.norm
import minisql.ast.BinaryOperation
import minisql.ast.BooleanOperator
import minisql.ast.Filter
import minisql.ast.FlatMap
import minisql.ast.Map
import minisql.ast.Query
import minisql.ast.Union
import minisql.ast.UnionAll
object AdHocReduction {
def unapply(q: Query) =
q match {
// ---------------------------
// *.filter
// a.filter(b => c).filter(d => e) =>
// a.filter(b => c && e[d := b])
case Filter(Filter(a, b, c), d, e) =>
val er = BetaReduction(e, d -> b)
Some(Filter(a, b, BinaryOperation(c, BooleanOperator.`&&`, er)))
// ---------------------------
// flatMap.*
// a.flatMap(b => c).map(d => e) =>
// a.flatMap(b => c.map(d => e))
case Map(FlatMap(a, b, c), d, e) =>
Some(FlatMap(a, b, Map(c, d, e)))
// a.flatMap(b => c).filter(d => e) =>
// a.flatMap(b => c.filter(d => e))
case Filter(FlatMap(a, b, c), d, e) =>
Some(FlatMap(a, b, Filter(c, d, e)))
// a.flatMap(b => c.union(d))
// a.flatMap(b => c).union(a.flatMap(b => d))
case FlatMap(a, b, Union(c, d)) =>
Some(Union(FlatMap(a, b, c), FlatMap(a, b, d)))
// a.flatMap(b => c.unionAll(d))
// a.flatMap(b => c).unionAll(a.flatMap(b => d))
case FlatMap(a, b, UnionAll(c, d)) =>
Some(UnionAll(FlatMap(a, b, c), FlatMap(a, b, d)))
case other => None
}
}

View file

@ -0,0 +1,160 @@
package minisql.norm
import minisql.ast._
object ApplyMap {
private def isomorphic(e: Ast, c: Ast, alias: Ident) =
BetaReduction(e, alias -> c) == c
object InfixedTailOperation {
def hasImpureInfix(ast: Ast) =
CollectAst(ast) {
case i @ Infix(_, _, false, _) => i
}.nonEmpty
def unapply(ast: Ast): Option[Ast] =
ast match {
case cc: CaseClass if hasImpureInfix(cc) => Some(cc)
case tup: Tuple if hasImpureInfix(tup) => Some(tup)
case p: Property if hasImpureInfix(p) => Some(p)
case b: BinaryOperation if hasImpureInfix(b) => Some(b)
case u: UnaryOperation if hasImpureInfix(u) => Some(u)
case i @ Infix(_, _, false, _) => Some(i)
case _ => None
}
}
object MapWithoutInfixes {
def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
ast match {
case Map(a, b, InfixedTailOperation(c)) => None
case Map(a, b, c) => Some((a, b, c))
case _ => None
}
}
object DetachableMap {
def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
ast match {
case Map(a: GroupBy, b, c) => None
case Map(a, b, InfixedTailOperation(c)) => None
case Map(a, b, c) => Some((a, b, c))
case _ => None
}
}
def unapply(q: Query): Option[Query] =
q match {
case Map(a: GroupBy, b, c) if (b == c) => None
case Map(a: DistinctOn, b, c) => None
case Map(a: Nested, b, c) if (b == c) => None
case Nested(DetachableMap(a: Join, b, c)) => None
// map(i => (i.i, i.l)).distinct.map(x => (x._1, x._2)) =>
// map(i => (i.i, i.l)).distinct
case Map(Distinct(DetachableMap(a, b, c)), d, e) if isomorphic(e, c, d) =>
Some(Distinct(Map(a, b, c)))
// a.map(b => c).map(d => e) =>
// a.map(b => e[d := c])
case before @ Map(MapWithoutInfixes(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
Some(Map(a, b, er))
// a.map(b => b) =>
// a
case Map(a: Query, b, c) if (b == c) =>
Some(a)
// a.map(b => c).flatMap(d => e) =>
// a.flatMap(b => e[d := c])
case FlatMap(DetachableMap(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
Some(FlatMap(a, b, er))
// a.map(b => c).filter(d => e) =>
// a.filter(b => e[d := c]).map(b => c)
case Filter(DetachableMap(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
Some(Map(Filter(a, b, er), b, c))
// a.map(b => c).sortBy(d => e) =>
// a.sortBy(b => e[d := c]).map(b => c)
case SortBy(DetachableMap(a, b, c), d, e, f) =>
val er = BetaReduction(e, d -> c)
Some(Map(SortBy(a, b, er, f), b, c))
// a.map(b => c).sortBy(d => e).distinct =>
// a.sortBy(b => e[d := c]).map(b => c).distinct
case SortBy(Distinct(DetachableMap(a, b, c)), d, e, f) =>
val er = BetaReduction(e, d -> c)
Some(Distinct(Map(SortBy(a, b, er, f), b, c)))
// a.map(b => c).groupBy(d => e) =>
// a.groupBy(b => e[d := c]).map(x => (x._1, x._2.map(b => c)))
case GroupBy(DetachableMap(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
val x = Ident("x")
val x1 = Property(
Ident("x"),
"_1"
) // These introduced property should not be renamed
val x2 = Property(Ident("x"), "_2") // due to any naming convention.
val body = Tuple(List(x1, Map(x2, b, c)))
Some(Map(GroupBy(a, b, er), x, body))
// a.map(b => c).drop(d) =>
// a.drop(d).map(b => c)
case Drop(DetachableMap(a, b, c), d) =>
Some(Map(Drop(a, d), b, c))
// a.map(b => c).take(d) =>
// a.drop(d).map(b => c)
case Take(DetachableMap(a, b, c), d) =>
Some(Map(Take(a, d), b, c))
// a.map(b => c).nested =>
// a.nested.map(b => c)
case Nested(DetachableMap(a, b, c)) =>
Some(Map(Nested(a), b, c))
// a.map(b => c).*join(d.map(e => f)).on((iA, iB) => on)
// a.*join(d).on((b, e) => on[iA := c, iB := f]).map(t => (c[b := t._1], f[e := t._2]))
case Join(
tpe,
DetachableMap(a, b, c),
DetachableMap(d, e, f),
iA,
iB,
on
) =>
val onr = BetaReduction(on, iA -> c, iB -> f)
val t = Ident("t")
val t1 = BetaReduction(c, b -> Property(t, "_1"))
val t2 = BetaReduction(f, e -> Property(t, "_2"))
Some(Map(Join(tpe, a, d, b, e, onr), t, Tuple(List(t1, t2))))
// a.*join(b.map(c => d)).on((iA, iB) => on)
// a.*join(b).on((iA, c) => on[iB := d]).map(t => (t._1, d[c := t._2]))
case Join(tpe, a, DetachableMap(b, c, d), iA, iB, on) =>
val onr = BetaReduction(on, iB -> d)
val t = Ident("t")
val t1 = Property(t, "_1")
val t2 = BetaReduction(d, c -> Property(t, "_2"))
Some(Map(Join(tpe, a, b, iA, c, onr), t, Tuple(List(t1, t2))))
// a.map(b => c).*join(d).on((iA, iB) => on)
// a.*join(d).on((b, iB) => on[iA := c]).map(t => (c[b := t._1], t._2))
case Join(tpe, DetachableMap(a, b, c), d, iA, iB, on) =>
val onr = BetaReduction(on, iA -> c)
val t = Ident("t")
val t1 = BetaReduction(c, b -> Property(t, "_1"))
val t2 = Property(t, "_2")
Some(Map(Join(tpe, a, d, b, iB, onr), t, Tuple(List(t1, t2))))
case other => None
}
}

View file

@ -0,0 +1,48 @@
package minisql.norm
import minisql.util.Messages.fail
import minisql.ast._
object AttachToEntity {
private object IsEntity {
def unapply(q: Ast): Option[Ast] =
q match {
case q: Entity => Some(q)
case q: Infix => Some(q)
case _ => None
}
}
def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)(
q: Ast
): Ast =
q match {
case Map(IsEntity(a), b, c) => Map(f(a, b), b, c)
case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c)
case ConcatMap(IsEntity(a), b, c) => ConcatMap(f(a, b), b, c)
case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c)
case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d)
case DistinctOn(IsEntity(a), b, c) => DistinctOn(f(a, b), b, c)
case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join |
_: FlatJoin =>
f(q, alias.getOrElse(Ident("x")))
case Map(a: Query, b, c) => Map(apply(f, Some(b))(a), b, c)
case FlatMap(a: Query, b, c) => FlatMap(apply(f, Some(b))(a), b, c)
case ConcatMap(a: Query, b, c) => ConcatMap(apply(f, Some(b))(a), b, c)
case Filter(a: Query, b, c) => Filter(apply(f, Some(b))(a), b, c)
case SortBy(a: Query, b, c, d) => SortBy(apply(f, Some(b))(a), b, c, d)
case Take(a: Query, b) => Take(apply(f, alias)(a), b)
case Drop(a: Query, b) => Drop(apply(f, alias)(a), b)
case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a))
case Distinct(a: Query) => Distinct(apply(f, alias)(a))
case DistinctOn(a: Query, b, c) => DistinctOn(apply(f, Some(b))(a), b, c)
case IsEntity(q) => f(q, alias.getOrElse(Ident("x")))
case other => fail(s"Can't find an 'Entity' in '$q'")
}
}

View file

@ -1,6 +1,6 @@
package minisql.util
package minisql.norm
import minisql.ast.*
import minisql.ast._
import scala.collection.immutable.{Map => IMap}
case class BetaReduction(replacements: Replacements)

View file

@ -0,0 +1,7 @@
package minisql.norm
trait ConcatBehavior
object ConcatBehavior {
case object AnsiConcat extends ConcatBehavior
case object NonAnsiConcat extends ConcatBehavior
}

View file

@ -0,0 +1,7 @@
package minisql.norm
trait EqualityBehavior
object EqualityBehavior {
case object AnsiEquality extends EqualityBehavior
case object NonAnsiEquality extends EqualityBehavior
}

View file

@ -0,0 +1,74 @@
package minisql.norm
import minisql.ReturnAction.ReturnColumns
import minisql.{NamingStrategy, ReturnAction}
import minisql.ast._
import minisql.context.{
ReturningClauseSupported,
ReturningMultipleFieldSupported,
ReturningNotSupported,
ReturningSingleFieldSupported
}
import minisql.idiom.{Idiom, Statement}
/**
* Take the `.returning` part in a query that contains it and return the array
* of columns representing of the returning seccovtion with any other operations
* etc... that they might contain.
*/
object ExpandReturning {
def applyMap(
returning: ReturningAction
)(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = {
val initialExpand = ExpandReturning.apply(returning)(idiom, naming)
idiom.idiomReturningCapability match {
case ReturningClauseSupported =>
ReturnAction.ReturnRecord
case ReturningMultipleFieldSupported =>
ReturnColumns(initialExpand.map {
case (ast, statement) => f(ast, statement)
})
case ReturningSingleFieldSupported =>
if (initialExpand.length == 1)
ReturnColumns(initialExpand.map {
case (ast, statement) => f(ast, statement)
})
else
throw new IllegalArgumentException(
s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified."
)
case ReturningNotSupported =>
throw new IllegalArgumentException(
s"RETURNING columns are not allowed in the ${idiom} dialect."
)
}
}
def apply(
returning: ReturningAction
)(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = {
val ReturningAction(_, alias, properties) = returning: @unchecked
// Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1))))
// => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
val dePropertized =
Transform(properties) {
case `alias` => ExternalIdent(alias.name)
}
val aliasName = alias.name
// Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
// => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))
val deTuplified = dePropertized match {
case Tuple(values) => values
case CaseClass(values) => values.map(_._2)
case other => List(other)
}
implicit val namingStrategy: NamingStrategy = naming
deTuplified.map(v => idiom.translate(v))
}
}

View file

@ -0,0 +1,108 @@
package minisql.norm
import minisql.ast.*
import minisql.ast.Implicits.*
import minisql.norm.ConcatBehavior.NonAnsiConcat
class FlattenOptionOperation(concatBehavior: ConcatBehavior)
extends StatelessTransformer {
private def emptyOrNot(b: Boolean, ast: Ast) =
if (b) OptionIsEmpty(ast) else OptionNonEmpty(ast)
def uncheckedReduction(ast: Ast, alias: Ident, body: Ast) =
apply(BetaReduction(body, alias -> ast))
def uncheckedForall(ast: Ast, alias: Ident, body: Ast) = {
val reduced = BetaReduction(body, alias -> ast)
apply((IsNullCheck(ast) +||+ reduced): Ast)
}
def containsNonFallthroughElement(ast: Ast) =
CollectAst(ast) {
case If(_, _, _) => true
case Infix(_, _, _, _) => true
case BinaryOperation(_, StringOperator.`concat`, _)
if (concatBehavior == NonAnsiConcat) =>
true
}.nonEmpty
override def apply(ast: Ast): Ast =
ast match {
case OptionTableFlatMap(ast, alias, body) =>
uncheckedReduction(ast, alias, body)
case OptionTableMap(ast, alias, body) =>
uncheckedReduction(ast, alias, body)
case OptionTableExists(ast, alias, body) =>
uncheckedReduction(ast, alias, body)
case OptionTableForall(ast, alias, body) =>
uncheckedForall(ast, alias, body)
case OptionFlatten(ast) =>
apply(ast)
case OptionSome(ast) =>
apply(ast)
case OptionApply(ast) =>
apply(ast)
case OptionOrNull(ast) =>
apply(ast)
case OptionGetOrNull(ast) =>
apply(ast)
case OptionNone => NullValue
case OptionGetOrElse(OptionMap(ast, alias, body), Constant(b: Boolean)) =>
apply((BetaReduction(body, alias -> ast) +||+ emptyOrNot(b, ast)): Ast)
case OptionGetOrElse(ast, body) =>
apply(If(IsNotNullCheck(ast), ast, body))
case OptionFlatMap(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduced = BetaReduction(body, alias -> ast)
apply(IfExistElseNull(ast, reduced))
} else {
uncheckedReduction(ast, alias, body)
}
case OptionMap(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduced = BetaReduction(body, alias -> ast)
apply(IfExistElseNull(ast, reduced))
} else {
uncheckedReduction(ast, alias, body)
}
case OptionForall(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduction = BetaReduction(body, alias -> ast)
apply(
(IsNullCheck(ast) +||+ (IsNotNullCheck(ast) +&&+ reduction)): Ast
)
} else {
uncheckedForall(ast, alias, body)
}
case OptionExists(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduction = BetaReduction(body, alias -> ast)
apply((IsNotNullCheck(ast) +&&+ reduction): Ast)
} else {
uncheckedReduction(ast, alias, body)
}
case OptionContains(ast, body) =>
apply((ast +==+ body): Ast)
case other =>
super.apply(other)
}
}

View file

@ -0,0 +1,120 @@
package minisql.norm
import minisql.ast.*
import collection.immutable.Set
case class State(seen: Set[Ident], free: Set[Ident])
case class FreeVariables(state: State) extends StatefulTransformer[State] {
override def apply(ast: Ast): (Ast, StatefulTransformer[State]) =
ast match {
case ident: Ident if (!state.seen.contains(ident)) =>
(ident, FreeVariables(State(state.seen, state.free + ident)))
case f @ Function(params, body) =>
val (_, t) =
FreeVariables(State(state.seen ++ params, state.free))(body)
(f, FreeVariables(State(state.seen, state.free ++ t.state.free)))
case q @ Foreach(a, b, c) =>
(q, free(a, b, c))
case other =>
super.apply(other)
}
override def apply(
o: OptionOperation
): (OptionOperation, StatefulTransformer[State]) =
o match {
case q @ OptionTableFlatMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionTableMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionTableExists(a, b, c) =>
(q, free(a, b, c))
case q @ OptionTableForall(a, b, c) =>
(q, free(a, b, c))
case q @ OptionFlatMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionForall(a, b, c) =>
(q, free(a, b, c))
case q @ OptionExists(a, b, c) =>
(q, free(a, b, c))
case other =>
super.apply(other)
}
override def apply(e: Assignment): (Assignment, StatefulTransformer[State]) =
e match {
case Assignment(a, b, c) =>
val t = FreeVariables(State(state.seen + a, state.free))
val (bt, btt) = t(b)
val (ct, ctt) = t(c)
(
Assignment(a, bt, ct),
FreeVariables(
State(state.seen, state.free ++ btt.state.free ++ ctt.state.free)
)
)
}
override def apply(action: Action): (Action, StatefulTransformer[State]) =
action match {
case q @ Returning(a, b, c) =>
(q, free(a, b, c))
case q @ ReturningGenerated(a, b, c) =>
(q, free(a, b, c))
case other =>
super.apply(other)
}
override def apply(
e: OnConflict.Target
): (OnConflict.Target, StatefulTransformer[State]) = (e, this)
override def apply(query: Query): (Query, StatefulTransformer[State]) =
query match {
case q @ Filter(a, b, c) => (q, free(a, b, c))
case q @ Map(a, b, c) => (q, free(a, b, c))
case q @ DistinctOn(a, b, c) => (q, free(a, b, c))
case q @ FlatMap(a, b, c) => (q, free(a, b, c))
case q @ ConcatMap(a, b, c) => (q, free(a, b, c))
case q @ SortBy(a, b, c, d) => (q, free(a, b, c))
case q @ GroupBy(a, b, c) => (q, free(a, b, c))
case q @ FlatJoin(t, a, b, c) => (q, free(a, b, c))
case q @ Join(t, a, b, iA, iB, on) =>
val (_, freeA) = apply(a)
val (_, freeB) = apply(b)
val (_, freeOn) =
FreeVariables(State(state.seen + iA + iB, Set.empty))(on)
(
q,
FreeVariables(
State(
state.seen,
state.free ++ freeA.state.free ++ freeB.state.free ++ freeOn.state.free
)
)
)
case _: Entity | _: Take | _: Drop | _: Union | _: UnionAll |
_: Aggregation | _: Distinct | _: Nested =>
super.apply(query)
}
private def free(a: Ast, ident: Ident, c: Ast) = {
val (_, ta) = apply(a)
val (_, tc) = FreeVariables(State(state.seen + ident, state.free))(c)
FreeVariables(
State(state.seen, state.free ++ ta.state.free ++ tc.state.free)
)
}
}
object FreeVariables {
def apply(ast: Ast): Set[Ident] =
new FreeVariables(State(Set.empty, Set.empty))(ast) match {
case (_, transformer) =>
transformer.state.free
}
}

View file

@ -0,0 +1,76 @@
package minisql.norm
import minisql.ast._
/**
* A problem occurred in the original way infixes were done in that it was
* assumed that infix clauses represented pure functions. While this is true of
* many UDFs (e.g. `CONCAT`, `GETDATE`) it is certainly not true of many others
* e.g. `RAND()`, and most importantly `RANK()`. For this reason, the operations
* that are done in `ApplyMap` on standard AST `Map` clauses cannot be done
* therefore additional safety checks were introduced there in order to assure
* this does not happen. In addition to this however, it is necessary to add
* this normalization step which inserts `Nested` AST elements in every map that
* contains impure infix. See more information and examples in #1534.
*/
object NestImpureMappedInfix extends StatelessTransformer {
// Are there any impure infixes that exist inside the specified ASTs
def hasInfix(asts: Ast*): Boolean =
asts.exists(ast =>
CollectAst(ast) {
case i @ Infix(_, _, false, _) => i
}.nonEmpty
)
// Continue exploring into the Map to see if there are additional impure infix clauses inside.
private def applyInside(m: Map) =
Map(apply(m.query), m.alias, m.body)
override def apply(ast: Ast): Ast =
ast match {
// If there is already a nested clause inside the map, there is no reason to insert another one
case Nested(Map(inner, a, b)) =>
Nested(Map(apply(inner), a, b))
case m @ Map(_, x, cc @ CaseClass(values)) if hasInfix(cc) => // Nested(m)
Map(
Nested(applyInside(m)),
x,
CaseClass(values.map {
case (name, _) =>
(
name,
Property(x, name)
) // mappings of nested-query case class properties should not be renamed
})
)
case m @ Map(_, x, tup @ Tuple(values)) if hasInfix(tup) =>
Map(
Nested(applyInside(m)),
x,
Tuple(values.zipWithIndex.map {
case (_, i) =>
Property(
x,
s"_${i + 1}"
) // mappings of nested-query tuple properties should not be renamed
})
)
case m @ Map(_, x, i @ Infix(_, _, false, _)) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case m @ Map(_, x, Property(prop, _)) if hasInfix(prop) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case m @ Map(_, x, BinaryOperation(a, _, b)) if hasInfix(a, b) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case m @ Map(_, x, UnaryOperation(_, a)) if hasInfix(a) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case other => super.apply(other)
}
}

View file

@ -0,0 +1,51 @@
package minisql.norm
import minisql.ast.Ast
import minisql.ast.Query
import minisql.ast.StatelessTransformer
import minisql.norm.capture.AvoidCapture
import minisql.ast.Action
import minisql.util.Messages.trace
import minisql.util.Messages.TraceType.Normalizations
import scala.annotation.tailrec
object Normalize extends StatelessTransformer {
override def apply(q: Ast): Ast =
super.apply(BetaReduction(q))
override def apply(q: Action): Action =
NormalizeReturning(super.apply(q))
override def apply(q: Query): Query =
norm(AvoidCapture(q))
private def traceNorm[T](label: String) =
trace[T](s"${label} (Normalize)", 1, Normalizations)
@tailrec
private def norm(q: Query): Query =
q match {
case NormalizeNestedStructures(query) =>
traceNorm("NormalizeNestedStructures")(query)
norm(query)
case ApplyMap(query) =>
traceNorm("ApplyMap")(query)
norm(query)
case SymbolicReduction(query) =>
traceNorm("SymbolicReduction")(query)
norm(query)
case AdHocReduction(query) =>
traceNorm("AdHocReduction")(query)
norm(query)
case OrderTerms(query) =>
traceNorm("OrderTerms")(query)
norm(query)
case NormalizeAggregationIdent(query) =>
traceNorm("NormalizeAggregationIdent")(query)
norm(query)
case other =>
other
}
}

View file

@ -0,0 +1,29 @@
package minisql.norm
import minisql.ast._
object NormalizeAggregationIdent {
def unapply(q: Query) =
q match {
// a => a.b.map(x => x.c).agg =>
// a => a.b.map(a => a.c).agg
case Aggregation(
op,
Map(
p @ Property(i: Ident, _),
mi,
Property.Opinionated(_: Ident, n, renameable, visibility)
)
) if i != mi =>
Some(
Aggregation(
op,
Map(p, i, Property.Opinionated(i, n, renameable, visibility))
)
) // in example aove, if c in x.c is fixed c in a.c should also be
case _ => None
}
}

View file

@ -0,0 +1,47 @@
package minisql.norm
import minisql.ast._
object NormalizeNestedStructures {
def unapply(q: Query): Option[Query] =
q match {
case e: Entity => None
case Map(a, b, c) => apply(a, c)(Map(_, b, _))
case FlatMap(a, b, c) => apply(a, c)(FlatMap(_, b, _))
case ConcatMap(a, b, c) => apply(a, c)(ConcatMap(_, b, _))
case Filter(a, b, c) => apply(a, c)(Filter(_, b, _))
case SortBy(a, b, c, d) => apply(a, c)(SortBy(_, b, _, d))
case GroupBy(a, b, c) => apply(a, c)(GroupBy(_, b, _))
case Aggregation(a, b) => apply(b)(Aggregation(a, _))
case Take(a, b) => apply(a, b)(Take.apply)
case Drop(a, b) => apply(a, b)(Drop.apply)
case Union(a, b) => apply(a, b)(Union.apply)
case UnionAll(a, b) => apply(a, b)(UnionAll.apply)
case Distinct(a) => apply(a)(Distinct.apply)
case DistinctOn(a, b, c) => apply(a, c)(DistinctOn(_, b, _))
case Nested(a) => apply(a)(Nested.apply)
case FlatJoin(t, a, iA, on) =>
(Normalize(a), Normalize(on)) match {
case (`a`, `on`) => None
case (a, on) => Some(FlatJoin(t, a, iA, on))
}
case Join(t, a, b, iA, iB, on) =>
(Normalize(a), Normalize(b), Normalize(on)) match {
case (`a`, `b`, `on`) => None
case (a, b, on) => Some(Join(t, a, b, iA, iB, on))
}
}
private def apply(a: Ast)(f: Ast => Query) =
(Normalize(a)) match {
case (`a`) => None
case (a) => Some(f(a))
}
private def apply(a: Ast, b: Ast)(f: (Ast, Ast) => Query) =
(Normalize(a), Normalize(b)) match {
case (`a`, `b`) => None
case (a, b) => Some(f(a, b))
}
}

View file

@ -0,0 +1,154 @@
package minisql.norm
import minisql.ast._
import minisql.norm.capture.AvoidAliasConflict
/**
* When actions are used with a `.returning` clause, remove the columns used in
* the returning clause from the action. E.g. for `insert(Person(id,
* name)).returning(_.id)` remove the `id` column from the original insert.
*/
object NormalizeReturning {
def apply(e: Action): Action = {
e match {
case ReturningGenerated(a: Action, alias, body) =>
// De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action.
// This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres.
// For example:
// query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
// Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally
// be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case.
// In order to fix this, we need to change `t` into a different alias.
val newBody = dealiasBody(body, alias)
ReturningGenerated(apply(a, newBody, alias), alias, newBody)
// For a regular return clause, do not need to exclude assignments from insertion however, we still
// need to de-alias the Action body in case conflicts result. For example the following query:
// query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
// would incorrectly be interpreted as:
// INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t
// whereas it should be:
// INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1)
case Returning(a: Action, alias, body) =>
val newBody = dealiasBody(body, alias)
Returning(a, alias, newBody)
case _ => e
}
}
/**
* In some situations, a query can exist inside of a `returning` clause. In
* this case, we need to rename if the aliases used in that query override the
* alias used in the `returning` clause otherwise they will be treated as
* returning-clause aliases ExpandReturning (i.e. they will become
* ExternalAlias instances) and later be tokenized incorrectly.
*/
private def dealiasBody(body: Ast, alias: Ident): Ast =
Transform(body) {
case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias))
}
private def apply(e: Action, body: Ast, returningIdent: Ident): Action =
e match {
case Insert(query, assignments) =>
Insert(query, filterReturnedColumn(assignments, body, returningIdent))
case Update(query, assignments) =>
Update(query, filterReturnedColumn(assignments, body, returningIdent))
case OnConflict(a: Action, target, act) =>
OnConflict(apply(a, body, returningIdent), target, act)
case _ => e
}
private def filterReturnedColumn(
assignments: List[Assignment],
column: Ast,
returningIdent: Ident
): List[Assignment] =
assignments.flatMap(filterReturnedColumn(_, column, returningIdent))
/**
* In situations like Property(Property(ident, foo), bar) pull out the
* inner-most ident
*/
object NestedProperty {
def unapply(ast: Property): Option[Ast] = {
ast match {
case p @ Property(subAst, _) => Some(innerMost(subAst))
}
}
private def innerMost(ast: Ast): Ast = ast match {
case Property(inner, _) => innerMost(inner)
case other => other
}
}
/**
* Remove the specified column from the assignment. For example, in a query
* like `insert(Person(id, name)).returning(r => r.id)` we need to remove the
* `id` column from the insertion. The value of the `column:Ast` in this case
* will be `Property(Ident(r), id)` and the values fo the assignment `p1`
* property will typically be `v.id` and `v.name` (the `v` variable is a
* default used for `insert` queries).
*/
private def filterReturnedColumn(
assignment: Assignment,
body: Ast,
returningIdent: Ident
): Option[Assignment] =
assignment match {
case Assignment(_, p1: Property, _) => {
// Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but
// if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2))
// or it can even be a query since queries are allowed to be in return sections e.g:
// query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max))
// In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment
// in order to know what to exclude.
val matchedProps =
CollectAst(body) {
// case prop @ NestedProperty(`returningIdent`) => prop
case prop @ NestedProperty(Ident(name))
if (name == returningIdent.name) =>
prop
case prop @ NestedProperty(ExternalIdent(name))
if (name == returningIdent.name) =>
prop
}
if (
matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp))
)
None
else
Some(assignment)
}
case assignment => Some(assignment)
}
object SomeIdent {
def unapply(ast: Ast): Option[Ast] =
ast match {
case id: Ident => Some(id)
case id: ExternalIdent => Some(id)
case _ => None
}
}
/**
* Is it the same property (but possibly of a different identity). E.g.
* `p.foo.bar` and `v.foo.bar`
*/
private def isSameProperties(p1: Property, p2: Property): Boolean =
(p1.ast, p2.ast) match {
case (SomeIdent(_), SomeIdent(_)) =>
p1.name == p2.name
// If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the
// outer properties are the same before moving on to the inner ones.
case (pp1: Property, pp2: Property) if (p1.name == p2.name) =>
isSameProperties(pp1, pp2)
case _ =>
false
}
}

View file

@ -0,0 +1,29 @@
package minisql.norm
import minisql.ast._
object OrderTerms {
def unapply(q: Query) =
q match {
case Take(Map(a: GroupBy, b, c), d) => None
// a.sortBy(b => c).filter(d => e) =>
// a.filter(d => e).sortBy(b => c)
case Filter(SortBy(a, b, c, d), e, f) =>
Some(SortBy(Filter(a, e, f), b, c, d))
// a.flatMap(b => c).take(n).map(d => e) =>
// a.flatMap(b => c).map(d => e).take(n)
case Map(Take(fm: FlatMap, n), ma, mb) =>
Some(Take(Map(fm, ma, mb), n))
// a.flatMap(b => c).drop(n).map(d => e) =>
// a.flatMap(b => c).map(d => e).drop(n)
case Map(Drop(fm: FlatMap, n), ma, mb) =>
Some(Drop(Map(fm, ma, mb), n))
case other => None
}
}

View file

@ -0,0 +1,491 @@
package minisql.norm
import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.Visible
import minisql.ast._
import minisql.util.Interpolator
object RenameProperties extends StatelessTransformer {
val interp = new Interpolator(3)
import interp._
def traceDifferent(one: Any, two: Any) =
if (one != two)
trace"Replaced $one with $two".andLog()
else
trace"Replacements did not match".andLog()
override def apply(q: Query): Query =
applySchemaOnly(q)
override def apply(q: Action): Action =
applySchema(q) match {
case (q, schema) => q
}
override def apply(e: Operation): Operation =
e match {
case UnaryOperation(o, c: Query) =>
UnaryOperation(o, applySchemaOnly(apply(c)))
case _ => super.apply(e)
}
private def applySchema(q: Action): (Action, Schema) =
q match {
case Insert(q: Query, assignments) =>
applySchema(q, assignments, Insert.apply)
case Update(q: Query, assignments) =>
applySchema(q, assignments, Update.apply)
case Delete(q: Query) =>
applySchema(q) match {
case (q, schema) => (Delete(q), schema)
}
case Returning(action: Action, alias, body) =>
applySchema(action) match {
case (action, schema) =>
val replace =
trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
replacements(alias, schema)
val bodyr = BetaReduction(body, replace*)
traceDifferent(body, bodyr)
(Returning(action, alias, bodyr), schema)
}
case ReturningGenerated(action: Action, alias, body) =>
applySchema(action) match {
case (action, schema) =>
val replace =
trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
replacements(alias, schema)
val bodyr = BetaReduction(body, replace*)
traceDifferent(body, bodyr)
(ReturningGenerated(action, alias, bodyr), schema)
}
case OnConflict(a: Action, target, act) =>
applySchema(a) match {
case (action, schema) =>
val targetR = target match {
case OnConflict.Properties(props) =>
val propsR = props.map { prop =>
val replace =
trace"Finding Replacements for $props inside ${prop.ast} using schema $schema:" `andReturn`
replacements(
prop.ast,
schema
) // A BetaReduction on a Property will always give back a Property
BetaReduction(prop, replace*).asInstanceOf[Property]
}
traceDifferent(props, propsR)
OnConflict.Properties(propsR)
case OnConflict.NoTarget => target
}
val actR = act match {
case OnConflict.Update(assignments) =>
OnConflict.Update(replaceAssignments(assignments, schema))
case _ => act
}
(OnConflict(action, targetR, actR), schema)
}
case q => (q, TupleSchema.empty)
}
private def replaceAssignments(
a: List[Assignment],
schema: Schema
): List[Assignment] =
a.map {
case Assignment(alias, prop, value) =>
val replace =
trace"Finding Replacements for $prop inside $alias using schema $schema:" `andReturn`
replacements(alias, schema)
val propR = BetaReduction(prop, replace*)
traceDifferent(prop, propR)
val valueR = BetaReduction(value, replace*)
traceDifferent(value, valueR)
Assignment(alias, propR, valueR)
}
private def applySchema(
q: Query,
a: List[Assignment],
f: (Query, List[Assignment]) => Action
): (Action, Schema) =
applySchema(q) match {
case (q, schema) =>
(f(q, replaceAssignments(a, schema)), schema)
}
private def applySchemaOnly(q: Query): Query =
applySchema(q) match {
case (q, _) => q
}
object TupleIndex {
def unapply(s: String): Option[Int] =
if (s.matches("_[0-9]*"))
Some(s.drop(1).toInt - 1)
else
None
}
sealed trait Schema {
def lookup(property: List[String]): Option[Schema] =
(property, this) match {
case (Nil, schema) =>
trace"Nil at $property returning " `andReturn`
Some(schema)
case (path, e @ EntitySchema(_)) =>
trace"Entity at $path returning " `andReturn`
Some(e.subSchemaOrEmpty(path))
case (head :: tail, CaseClassSchema(props)) if (props.contains(head)) =>
trace"Case class at $property returning " `andReturn`
props(head).lookup(tail)
case (TupleIndex(idx) :: tail, TupleSchema(values))
if values.contains(idx) =>
trace"Tuple at at $property returning " `andReturn`
values(idx).lookup(tail)
case _ =>
trace"Nothing found at $property returning " `andReturn`
None
}
}
// Represents a nested property path to an identity i.e. Property(Property(... Ident(), ...))
object PropertyMatroshka {
def traverse(initial: Property): Option[(Ident, List[String])] =
initial match {
// If it's a nested-property walk inside and append the name to the result (if something is returned)
case Property(inner: Property, name) =>
traverse(inner).map { case (id, list) => (id, list :+ name) }
// If it's a property with ident in the core, return that
case Property(id: Ident, name) =>
Some((id, List(name)))
// Otherwise an ident property is not inside so don't return anything
case _ =>
None
}
def unapply(ast: Ast): Option[(Ident, List[String])] =
ast match {
case p: Property => traverse(p)
case _ => None
}
}
def protractSchema(
body: Ast,
ident: Ident,
schema: Schema
): Option[Schema] = {
def protractSchemaRecurse(body: Ast, schema: Schema): Option[Schema] =
body match {
// if any values yield a sub-schema which is not an entity, recurse into that
case cc @ CaseClass(values) =>
trace"Protracting CaseClass $cc into new schema:" `andReturn`
CaseClassSchema(
values.collect {
case (name, innerBody @ HierarchicalAstEntity()) =>
(name, protractSchemaRecurse(innerBody, schema))
// pass the schema into a recursive call an extract from it when we non tuple/caseclass element
case (name, innerBody @ PropertyMatroshka(`ident`, path)) =>
(name, protractSchemaRecurse(innerBody, schema))
// we have reached an ident i.e. recurse to pass the current schema into the case class
case (name, `ident`) =>
(name, protractSchemaRecurse(ident, schema))
}.collect {
case (name, Some(subSchema)) => (name, subSchema)
}
).notEmpty
case tup @ Tuple(values) =>
trace"Protracting Tuple $tup into new schema:" `andReturn`
TupleSchema
.fromIndexes(
values.zipWithIndex.collect {
case (innerBody @ HierarchicalAstEntity(), index) =>
(index, protractSchemaRecurse(innerBody, schema))
// pass the schema into a recursive call an extract from it when we non tuple/caseclass element
case (innerBody @ PropertyMatroshka(`ident`, path), index) =>
(index, protractSchemaRecurse(innerBody, schema))
// we have reached an ident i.e. recurse to pass the current schema into the tuple
case (`ident`, index) =>
(index, protractSchemaRecurse(ident, schema))
}.collect {
case (index, Some(subSchema)) => (index, subSchema)
}
)
.notEmpty
case prop @ PropertyMatroshka(`ident`, path) =>
trace"Protraction completed schema path $prop at the schema $schema pointing to:" `andReturn`
schema match {
// case e: EntitySchema => Some(e)
case _ => schema.lookup(path)
}
case `ident` =>
trace"Protraction completed with the mapping identity $ident at the schema:" `andReturn`
Some(schema)
case other =>
trace"Protraction DID NOT find a sub schema, it completed with $other at the schema:" `andReturn`
Some(schema)
}
protractSchemaRecurse(body, schema)
}
case object EmptySchema extends Schema
case class EntitySchema(e: Entity) extends Schema {
def noAliases = e.properties.isEmpty
private def subSchema(path: List[String]) =
EntitySchema(
Entity(
s"sub-${e.name}",
e.properties.flatMap {
case PropertyAlias(aliasPath, alias) =>
if (aliasPath == path)
List(PropertyAlias(aliasPath, alias))
else if (aliasPath.startsWith(path))
List(PropertyAlias(aliasPath.diff(path), alias))
else
List()
}
)
)
def subSchemaOrEmpty(path: List[String]): Schema =
trace"Creating sub-schema for entity $e at path $path will be" andReturn {
val sub = subSchema(path)
if (sub.noAliases) EmptySchema else sub
}
}
case class TupleSchema(m: collection.Map[Int, Schema] /* Zero Indexed */ )
extends Schema {
def list = m.toList.sortBy(_._1)
def notEmpty =
if (this.m.nonEmpty) Some(this) else None
}
case class CaseClassSchema(m: collection.Map[String, Schema]) extends Schema {
def list = m.toList
def notEmpty =
if (this.m.nonEmpty) Some(this) else None
}
object CaseClassSchema {
def apply(property: String, value: Schema): CaseClassSchema =
CaseClassSchema(collection.Map(property -> value))
def apply(list: List[(String, Schema)]): CaseClassSchema =
CaseClassSchema(list.toMap)
}
object TupleSchema {
def fromIndexes(schemas: List[(Int, Schema)]): TupleSchema =
TupleSchema(schemas.toMap)
def apply(schemas: List[Schema]): TupleSchema =
TupleSchema(schemas.zipWithIndex.map(_.swap).toMap)
def apply(index: Int, schema: Schema): TupleSchema =
TupleSchema(collection.Map(index -> schema))
def empty: TupleSchema = TupleSchema(List.empty)
}
object HierarchicalAstEntity {
def unapply(ast: Ast): Boolean =
ast match {
case cc: CaseClass => true
case tup: Tuple => true
case _ => false
}
}
private def applySchema(q: Query): (Query, Schema) = {
q match {
// Don't understand why this is needed....
case Map(q: Query, x, p) =>
applySchema(q) match {
case (q, subSchema) =>
val replace =
trace"Looking for possible replacements for $p inside $x using schema $subSchema:" `andReturn`
replacements(x, subSchema)
val pr = BetaReduction(p, replace*)
traceDifferent(p, pr)
val prr = apply(pr)
traceDifferent(pr, prr)
val schema =
trace"Protracting Hierarchical Entity $prr into sub-schema: $subSchema" `andReturn` {
protractSchema(prr, x, subSchema)
}.getOrElse(EmptySchema)
(Map(q, x, prr), schema)
}
case e: Entity => (e, EntitySchema(e))
case Filter(q: Query, x, p) => applySchema(q, x, p, Filter.apply)
case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o))
case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy.apply)
case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _))
case Take(q: Query, n) => applySchema(q, Take(_, n))
case Drop(q: Query, n) => applySchema(q, Drop(_, n))
case Nested(q: Query) => applySchema(q, Nested.apply)
case Distinct(q: Query) => applySchema(q, Distinct.apply)
case DistinctOn(q: Query, iA, on) => applySchema(q, DistinctOn(_, iA, on))
case FlatMap(q: Query, x, p) =>
applySchema(q, x, p, FlatMap.apply) match {
case (FlatMap(q, x, p: Query), oldSchema) =>
val (pr, newSchema) = applySchema(p)
(FlatMap(q, x, pr), newSchema)
case (flatMap, oldSchema) =>
(flatMap, TupleSchema.empty)
}
case ConcatMap(q: Query, x, p) =>
applySchema(q, x, p, ConcatMap.apply) match {
case (ConcatMap(q, x, p: Query), oldSchema) =>
val (pr, newSchema) = applySchema(p)
(ConcatMap(q, x, pr), newSchema)
case (concatMap, oldSchema) =>
(concatMap, TupleSchema.empty)
}
case Join(typ, a: Query, b: Query, iA, iB, on) =>
(applySchema(a), applySchema(b)) match {
case ((a, schemaA), (b, schemaB)) =>
val combinedReplacements =
trace"Finding Replacements for $on inside ${(iA, iB)} using schemas ${(schemaA, schemaB)}:" andReturn {
val replaceA = replacements(iA, schemaA)
val replaceB = replacements(iB, schemaB)
replaceA ++ replaceB
}
val onr = BetaReduction(on, combinedReplacements*)
traceDifferent(on, onr)
(Join(typ, a, b, iA, iB, onr), TupleSchema(List(schemaA, schemaB)))
}
case FlatJoin(typ, a: Query, iA, on) =>
applySchema(a) match {
case (a, schemaA) =>
val replaceA =
trace"Finding Replacements for $on inside $iA using schema $schemaA:" `andReturn`
replacements(iA, schemaA)
val onr = BetaReduction(on, replaceA*)
traceDifferent(on, onr)
(FlatJoin(typ, a, iA, onr), schemaA)
}
case Map(q: Operation, x, p) if x == p =>
(Map(apply(q), x, p), TupleSchema.empty)
case Map(Infix(parts, params, pure, paren), x, p) =>
val transformed =
params.map {
case q: Query =>
val (qr, schema) = applySchema(q)
traceDifferent(q, qr)
(qr, Some(schema))
case q =>
(q, None)
}
val schema =
transformed.collect {
case (_, Some(schema)) => schema
} match {
case e :: Nil => e
case ls => TupleSchema(ls)
}
val replace =
trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
replacements(x, schema)
val pr = BetaReduction(p, replace*)
traceDifferent(p, pr)
val prr = apply(pr)
traceDifferent(pr, prr)
(Map(Infix(parts, transformed.map(_._1), pure, paren), x, prr), schema)
case q =>
(q, TupleSchema.empty)
}
}
private def applySchema(ast: Query, f: Ast => Query): (Query, Schema) =
applySchema(ast) match {
case (ast, schema) =>
(f(ast), schema)
}
private def applySchema[T](
q: Query,
x: Ident,
p: Ast,
f: (Ast, Ident, Ast) => T
): (T, Schema) =
applySchema(q) match {
case (q, schema) =>
val replace =
trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
replacements(x, schema)
val pr = BetaReduction(p, replace*)
traceDifferent(p, pr)
val prr = apply(pr)
traceDifferent(pr, prr)
(f(q, x, prr), schema)
}
private def replacements(base: Ast, schema: Schema): Seq[(Ast, Ast)] =
schema match {
// The entity renameable property should already have been marked as Fixed
case EntitySchema(Entity(entity, properties)) =>
// trace"%4 Entity Schema: " andReturn
properties.flatMap {
// A property alias means that there was either a querySchema(tableName, _.propertyName -> PropertyAlias)
// or a schemaMeta (which ultimately gets turned into a querySchema) which is the same thing but implicit.
// In this case, we want to rename the properties based on the property aliases as well as mark
// them Fixed since they should not be renamed based on
// the naming strategy wherever they are tokenized (e.g. in SqlIdiom)
case PropertyAlias(path, alias) =>
def apply(base: Ast, path: List[String]): Ast =
path match {
case Nil => base
case head :: tail => apply(Property(base, head), tail)
}
List(
apply(base, path) -> Property.Opinionated(
base,
alias,
Fixed,
Visible
) // Hidden properties cannot be renamed
)
}
case tup: TupleSchema =>
// trace"%4 Tuple Schema: " andReturn
tup.list.flatMap {
case (idx, value) =>
replacements(
// Should not matter whether property is fixed or variable here
// since beta reduction ignores that
Property(base, s"_${idx + 1}"),
value
)
}
case cc: CaseClassSchema =>
// trace"%4 CaseClass Schema: " andReturn
cc.list.flatMap {
case (property, value) =>
replacements(
// Should not matter whether property is fixed or variable here
// since beta reduction ignores that
Property(base, property),
value
)
}
// Do nothing if it is an empty schema
case EmptySchema => List()
}
}

View file

@ -1,4 +1,4 @@
package minisql.util
package minisql.norm
import minisql.ast.Ast
import scala.collection.immutable.Map

View file

@ -0,0 +1,124 @@
package minisql.norm
import minisql.ast.*
import minisql.norm.EqualityBehavior.AnsiEquality
/**
* Due to the introduction of null checks in `map`, `flatMap`, and `exists`, in
* `FlattenOptionOperation` in order to resolve #1053, as well as to support
* non-ansi compliant string concatenation as outlined in #1295, large
* conditional composites became common. For example: <code><pre> case class
* Holder(value:Option[String])
*
* // The following statement query[Holder].map(h => h.value.map(_ + "foo")) //
* Will yield the following result SELECT CASE WHEN h.value IS NOT NULL THEN
* h.value || 'foo' ELSE null END FROM Holder h </pre></code> Now, let's add a
* <code>getOrElse</code> statement to the clause that requires an additional
* wrapped null check. We cannot rely on there being a <code>map</code> call
* beforehand since we could be reading <code>value</code> as a nullable field
* directly from the database). <code><pre> // The following statement
* query[Holder].map(h => h.value.map(_ + "foo").getOrElse("bar")) // Yields the
* following result: SELECT CASE WHEN CASE WHEN h.value IS NOT NULL THEN h.value
* \|| 'foo' ELSE null END IS NOT NULL THEN CASE WHEN h.value IS NOT NULL THEN
* h.value || 'foo' ELSE null END ELSE 'bar' END FROM Holder h </pre></code>
* This of course is highly redundant and can be reduced to simply: <code><pre>
* SELECT CASE WHEN h.value IS NOT NULL AND (h.value || 'foo') IS NOT NULL THEN
* h.value || 'foo' ELSE 'bar' END FROM Holder h </pre></code> This reduction is
* done by the "Center Rule." There are some other simplification rules as well.
* Note how we are force to null-check both `h.value` as well as `(h.value ||
* 'foo')` because a user may use `Option[T].flatMap` and explicitly transform a
* particular value to `null`.
*/
class SimplifyNullChecks(equalityBehavior: EqualityBehavior)
extends StatelessTransformer {
override def apply(ast: Ast): Ast = {
import minisql.ast.Implicits.*
ast match {
// Center rule
case IfExist(
IfExistElseNull(condA, thenA),
IfExistElseNull(condB, thenB),
otherwise
) if (condA == condB && thenA == thenB) =>
apply(
If(IsNotNullCheck(condA) +&&+ IsNotNullCheck(thenA), thenA, otherwise)
)
// Left hand rule
case IfExist(IfExistElseNull(check, affirm), value, otherwise) =>
apply(
If(
IsNotNullCheck(check) +&&+ IsNotNullCheck(affirm),
value,
otherwise
)
)
// Right hand rule
case IfExistElseNull(cond, IfExistElseNull(innerCond, innerThen)) =>
apply(
If(
IsNotNullCheck(cond) +&&+ IsNotNullCheck(innerCond),
innerThen,
NullValue
)
)
case OptionIsDefined(Optional(a)) +&&+ OptionIsDefined(
Optional(b)
) +&&+ (exp @ (Optional(a1) `== or !=` Optional(b1)))
if (a == a1 && b == b1 && equalityBehavior == AnsiEquality) =>
apply(exp)
case OptionIsDefined(Optional(a)) +&&+ (exp @ (Optional(
a1
) `== or !=` Optional(_)))
if (a == a1 && equalityBehavior == AnsiEquality) =>
apply(exp)
case OptionIsDefined(Optional(b)) +&&+ (exp @ (Optional(
_
) `== or !=` Optional(b1)))
if (b == b1 && equalityBehavior == AnsiEquality) =>
apply(exp)
case (left +&&+ OptionIsEmpty(Optional(Constant(_)))) +||+ other =>
apply(other)
case (OptionIsEmpty(Optional(Constant(_))) +&&+ right) +||+ other =>
apply(other)
case other +||+ (left +&&+ OptionIsEmpty(Optional(Constant(_)))) =>
apply(other)
case other +||+ (OptionIsEmpty(Optional(Constant(_))) +&&+ right) =>
apply(other)
case (left +&&+ OptionIsDefined(Optional(Constant(_)))) => apply(left)
case (OptionIsDefined(Optional(Constant(_))) +&&+ right) => apply(right)
case (left +||+ OptionIsEmpty(Optional(Constant(_)))) => apply(left)
case (OptionIsEmpty(OptionSome(Optional(_))) +||+ right) => apply(right)
case other =>
super.apply(other)
}
}
object `== or !=` {
def unapply(ast: Ast): Option[(Ast, Ast)] = ast match {
case a +==+ b => Some((a, b))
case a +!=+ b => Some((a, b))
case _ => None
}
}
/**
* Simple extractor that looks inside of an optional values to see if the
* thing inside can be pulled out. If not, it just returns whatever element it
* can find.
*/
object Optional {
def unapply(a: Ast): Option[Ast] = a match {
case OptionApply(value) => Some(value)
case OptionSome(value) => Some(value)
case value => Some(value)
}
}
}

View file

@ -0,0 +1,38 @@
package minisql.norm
import minisql.ast.Filter
import minisql.ast.FlatMap
import minisql.ast.Query
import minisql.ast.Union
import minisql.ast.UnionAll
object SymbolicReduction {
def unapply(q: Query) =
q match {
// a.filter(b => c).flatMap(d => e.$) =>
// a.flatMap(d => e.filter(_ => c[b := d]).$)
case FlatMap(Filter(a, b, c), d, e: Query) =>
val cr = BetaReduction(c, b -> d)
val er = AttachToEntity(Filter(_, _, cr))(e)
Some(FlatMap(a, d, er))
// a.flatMap(b => c).flatMap(d => e) =>
// a.flatMap(b => c.flatMap(d => e))
case FlatMap(FlatMap(a, b, c), d, e) =>
Some(FlatMap(a, b, FlatMap(c, d, e)))
// a.union(b).flatMap(c => d)
// a.flatMap(c => d).union(b.flatMap(c => d))
case FlatMap(Union(a, b), c, d) =>
Some(Union(FlatMap(a, c, d), FlatMap(b, c, d)))
// a.unionAll(b).flatMap(c => d)
// a.flatMap(c => d).unionAll(b.flatMap(c => d))
case FlatMap(UnionAll(a, b), c, d) =>
Some(UnionAll(FlatMap(a, c, d), FlatMap(b, c, d)))
case other => None
}
}

View file

@ -0,0 +1,174 @@
package minisql.norm.capture
import minisql.ast.{
Entity,
Filter,
FlatJoin,
FlatMap,
GroupBy,
Ident,
Join,
Map,
Query,
SortBy,
StatefulTransformer,
_
}
import minisql.norm.{BetaReduction, Normalize}
import scala.collection.immutable.Set
private[minisql] case class AvoidAliasConflict(state: Set[Ident])
extends StatefulTransformer[Set[Ident]] {
object Unaliased {
private def isUnaliased(q: Ast): Boolean =
q match {
case Nested(q: Query) => isUnaliased(q)
case Take(q: Query, _) => isUnaliased(q)
case Drop(q: Query, _) => isUnaliased(q)
case Aggregation(_, q: Query) => isUnaliased(q)
case Distinct(q: Query) => isUnaliased(q)
case _: Entity | _: Infix => true
case _ => false
}
def unapply(q: Ast): Option[Ast] =
q match {
case q if (isUnaliased(q)) => Some(q)
case _ => None
}
}
override def apply(q: Query): (Query, StatefulTransformer[Set[Ident]]) =
q match {
case FlatMap(Unaliased(q), x, p) =>
apply(x, p)(FlatMap(q, _, _))
case ConcatMap(Unaliased(q), x, p) =>
apply(x, p)(ConcatMap(q, _, _))
case Map(Unaliased(q), x, p) =>
apply(x, p)(Map(q, _, _))
case Filter(Unaliased(q), x, p) =>
apply(x, p)(Filter(q, _, _))
case SortBy(Unaliased(q), x, p, o) =>
apply(x, p)(SortBy(q, _, _, o))
case GroupBy(Unaliased(q), x, p) =>
apply(x, p)(GroupBy(q, _, _))
case DistinctOn(Unaliased(q), x, p) =>
apply(x, p)(DistinctOn(q, _, _))
case Join(t, a, b, iA, iB, o) =>
val (ar, art) = apply(a)
val (br, brt) = art.apply(b)
val freshA = freshIdent(iA, brt.state)
val freshB = freshIdent(iB, brt.state + freshA)
val or = BetaReduction(o, iA -> freshA, iB -> freshB)
val (orr, orrt) = AvoidAliasConflict(brt.state + freshA + freshB)(or)
(Join(t, ar, br, freshA, freshB, orr), orrt)
case FlatJoin(t, a, iA, o) =>
val (ar, art) = apply(a)
val freshA = freshIdent(iA)
val or = BetaReduction(o, iA -> freshA)
val (orr, orrt) = AvoidAliasConflict(art.state + freshA)(or)
(FlatJoin(t, ar, freshA, orr), orrt)
case _: Entity | _: FlatMap | _: ConcatMap | _: Map | _: Filter |
_: SortBy | _: GroupBy | _: Aggregation | _: Take | _: Drop |
_: Union | _: UnionAll | _: Distinct | _: DistinctOn | _: Nested =>
super.apply(q)
}
private def apply(x: Ident, p: Ast)(
f: (Ident, Ast) => Query
): (Query, StatefulTransformer[Set[Ident]]) = {
val fresh = freshIdent(x)
val pr = BetaReduction(p, x -> fresh)
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
(f(fresh, prr), t)
}
private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = {
def loop(x: Ident, n: Int): Ident = {
val fresh = Ident(s"${x.name}$n")
if (!state.contains(fresh))
fresh
else
loop(x, n + 1)
}
if (!state.contains(x))
x
else
loop(x, 1)
}
/**
* Sometimes we need to change the variables in a function because they will
* might conflict with some variable further up in the macro. Right now, this
* only happens when you do something like this: <code> val q = quote { (v:
* Foo) => query[Foo].insert(v) } run(q(lift(v))) </code> Since 'v' is used by
* actionMeta in order to map keys to values for insertion, using it as a
* function argument messes up the output SQL like so: <code> INSERT INTO
* MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?) </code>
* Therefore, we need to have a method to remove such conflicting variables
* from Function ASTs
*/
private def applyFunction(f: Function): Function = {
val (newBody, _, newParams) =
f.params.foldLeft((f.body, state, List[Ident]())) {
case ((body, state, newParams), param) => {
val fresh = freshIdent(param)
val pr = BetaReduction(body, param -> fresh)
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
(prr, t.state, newParams :+ fresh)
}
}
Function(newParams, newBody)
}
private def applyForeach(f: Foreach): Foreach = {
val fresh = freshIdent(f.alias)
val pr = BetaReduction(f.body, f.alias -> fresh)
val (prr, _) = AvoidAliasConflict(state + fresh)(pr)
Foreach(f.query, fresh, prr)
}
}
private[minisql] object AvoidAliasConflict {
def apply(q: Query): Query =
AvoidAliasConflict(Set[Ident]())(q) match {
case (q, _) => q
}
/**
* Make sure query parameters do not collide with paramters of a AST function.
* Do this by walkning through the function's subtree and transforming and
* queries encountered.
*/
def sanitizeVariables(
f: Function,
dangerousVariables: Set[Ident]
): Function = {
AvoidAliasConflict(dangerousVariables).applyFunction(f)
}
/** Same is `sanitizeVariables` but for Foreach * */
def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = {
AvoidAliasConflict(dangerousVariables).applyForeach(f)
}
def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = {
AvoidAliasConflict(dangerousVariables).apply(q) match {
// Propagate aliasing changes to the rest of the query
case (q, _) => Normalize(q)
}
}
}

View file

@ -0,0 +1,9 @@
package minisql.norm.capture
import minisql.ast.Query
object AvoidCapture {
def apply(q: Query): Query =
Dealias(AvoidAliasConflict(q))
}

View file

@ -0,0 +1,72 @@
package minisql.norm.capture
import minisql.ast._
import minisql.norm.BetaReduction
case class Dealias(state: Option[Ident])
extends StatefulTransformer[Option[Ident]] {
override def apply(q: Query): (Query, StatefulTransformer[Option[Ident]]) =
q match {
case FlatMap(a, b, c) =>
dealias(a, b, c)(FlatMap.apply) match {
case (FlatMap(a, b, c), _) =>
val (cn, cnt) = apply(c)
(FlatMap(a, b, cn), cnt)
}
case ConcatMap(a, b, c) =>
dealias(a, b, c)(ConcatMap.apply) match {
case (ConcatMap(a, b, c), _) =>
val (cn, cnt) = apply(c)
(ConcatMap(a, b, cn), cnt)
}
case Map(a, b, c) =>
dealias(a, b, c)(Map.apply)
case Filter(a, b, c) =>
dealias(a, b, c)(Filter.apply)
case SortBy(a, b, c, d) =>
dealias(a, b, c)(SortBy(_, _, _, d))
case GroupBy(a, b, c) =>
dealias(a, b, c)(GroupBy.apply)
case DistinctOn(a, b, c) =>
dealias(a, b, c)(DistinctOn.apply)
case Take(a, b) =>
val (an, ant) = apply(a)
(Take(an, b), ant)
case Drop(a, b) =>
val (an, ant) = apply(a)
(Drop(an, b), ant)
case Union(a, b) =>
val (an, _) = apply(a)
val (bn, _) = apply(b)
(Union(an, bn), Dealias(None))
case UnionAll(a, b) =>
val (an, _) = apply(a)
val (bn, _) = apply(b)
(UnionAll(an, bn), Dealias(None))
case Join(t, a, b, iA, iB, o) =>
val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _))
val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _))
(Join(t, an, bn, iAn, iBn, onn), Dealias(None))
case FlatJoin(t, a, iA, o) =>
val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _))
(FlatJoin(t, an, iAn, on), Dealias(Some(iA)))
case _: Entity | _: Distinct | _: Aggregation | _: Nested =>
(q, Dealias(None))
}
private def dealias[T](a: Ast, b: Ident, c: Ast)(f: (Ast, Ident, Ast) => T) =
apply(a) match {
case (an, t @ Dealias(Some(alias))) =>
(f(an, alias, BetaReduction(c, b -> alias)), t)
case other =>
(f(a, b, c), Dealias(Some(b)))
}
}
object Dealias {
def apply(query: Query) =
new Dealias(None)(query) match {
case (q, _) => q
}
}

View file

@ -0,0 +1,98 @@
package minisql.norm.capture
import minisql.ast.*
/**
* Walk through any Queries that a returning clause has and replace Ident of the
* returning variable with ExternalIdent so that in later steps involving filter
* simplification, it will not be mistakenly dealiased with a potential shadow.
* Take this query for instance: <pre> query[TestEntity]
* .insert(lift(TestEntity("s", 0, 1L, None))) .returningGenerated( r =>
* (query[Dummy].filter(r => r.i == r.i).filter(d => d.i == r.i).max) ) </pre>
* The returning clause has an alias `Ident("r")` as well as the first filter
* clause. These two filters will be combined into one at which point the
* meaning of `r.i` in the 2nd filter will be confused for the first filter's
* alias (i.e. the `r` in `filter(r => ...)`. Therefore, we need to change this
* vunerable `r.i` in the second filter clause to an `ExternalIdent` before any
* of the simplifications are done.
*
* Note that we only want to do this for Queries inside of a `Returning` clause
* body. Other places where this needs to be done (e.g. in a Tuple that
* `Returning` returns) are done in `ExpandReturning`.
*/
private[minisql] case class DemarcateExternalAliases(externalIdent: Ident)
extends StatelessTransformer {
def applyNonOverride(idents: Ident*)(ast: Ast) =
if (idents.forall(_ != externalIdent)) apply(ast)
else ast
override def apply(ast: Ast): Ast = ast match {
case FlatMap(q, i, b) =>
FlatMap(apply(q), i, applyNonOverride(i)(b))
case ConcatMap(q, i, b) =>
ConcatMap(apply(q), i, applyNonOverride(i)(b))
case Map(q, i, b) =>
Map(apply(q), i, applyNonOverride(i)(b))
case Filter(q, i, b) =>
Filter(apply(q), i, applyNonOverride(i)(b))
case SortBy(q, i, p, o) =>
SortBy(apply(q), i, applyNonOverride(i)(p), o)
case GroupBy(q, i, b) =>
GroupBy(apply(q), i, applyNonOverride(i)(b))
case DistinctOn(q, i, b) =>
DistinctOn(apply(q), i, applyNonOverride(i)(b))
case Join(t, a, b, iA, iB, o) =>
Join(t, a, b, iA, iB, applyNonOverride(iA, iB)(o))
case FlatJoin(t, a, iA, o) =>
FlatJoin(t, a, iA, applyNonOverride(iA)(o))
case p @ Property.Opinionated(
id @ Ident(_),
value,
renameable,
visibility
) =>
if (id == externalIdent)
Property.Opinionated(
ExternalIdent(externalIdent.name),
value,
renameable,
visibility
)
else
p
case other =>
super.apply(other)
}
}
object DemarcateExternalAliases {
private def demarcateQueriesInBody(id: Ident, body: Ast) =
Transform(body) {
// Apply to the AST defined apply method about, not to the superclass method that takes Query
case q: Query =>
new DemarcateExternalAliases(id).apply(q.asInstanceOf[Ast])
}
def apply(ast: Ast): Ast = ast match {
case Returning(a, id, body) =>
Returning(a, id, demarcateQueriesInBody(id, body))
case ReturningGenerated(a, id, body) =>
val d = demarcateQueriesInBody(id, body)
ReturningGenerated(a, id, demarcateQueriesInBody(id, body))
case other =>
other
}
}

View file

@ -0,0 +1,66 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
type SParser[X] =
(q: Quotes) ?=> PartialFunction[q.reflect.Statement, Expr[X]]
private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using
Quotes
): SParser[ast.Ast] = {
import quotes.reflect.*
@annotation.nowarn
lazy val valDefParser: SParser[ast.Val] = {
case ValDef(n, _, Some(b)) =>
val body = astParser(b.asExpr)
'{ ast.Val(ast.Ident(${ Expr(n) }), $body) }
}
valDefParser
}
private[parsing] def parseBlockList(
astParser: => Parser[ast.Ast],
e: Expr[Any]
)(using Quotes): List[Expr[ast.Ast]] = {
import quotes.reflect.*
lazy val statementParser = statementParsing(astParser)
e.asTerm match {
case Block(st, t) =>
(st :+ t).map {
case e if e.isExpr => astParser(e.asExpr)
case `statementParser`(x) => x
case o =>
report.errorAndAbort(s"Cannot parse statement: ${o.show}")
}
}
}
private[parsing] def blockParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Ast] = {
import quotes.reflect.*
lazy val statementParser = statementParsing(astParser)
termParser {
case Block(Nil, t) => astParser(t.asExpr)
case b @ Block(st, t) =>
val asts = (st :+ t).map {
case e if e.isExpr => astParser(e.asExpr)
case `statementParser`(x) => x
case o =>
report.errorAndAbort(s"Cannot parse statement: ${o.show}")
}
if (asts.size > 1) {
'{ ast.Block(${ Expr.ofList(asts) }) }
} else asts(0)
}
}

View file

@ -0,0 +1,31 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private[parsing] def boxingParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Ast] = {
case '{ BigDecimal.int2bigDecimal($v) } => astParser(v)
case '{ BigDecimal.long2bigDecimal($v) } => astParser(v)
case '{ BigDecimal.double2bigDecimal($v) } => astParser(v)
case '{ BigDecimal.javaBigDecimal2bigDecimal($v) } => astParser(v)
case '{ Predef.byte2Byte($v) } => astParser(v)
case '{ Predef.short2Short($v) } => astParser(v)
case '{ Predef.char2Character($v) } => astParser(v)
case '{ Predef.int2Integer($v) } => astParser(v)
case '{ Predef.long2Long($v) } => astParser(v)
case '{ Predef.float2Float($v) } => astParser(v)
case '{ Predef.double2Double($v) } => astParser(v)
case '{ Predef.boolean2Boolean($v) } => astParser(v)
case '{ Predef.augmentString($v) } => astParser(v)
case '{ Predef.Byte2byte($v) } => astParser(v)
case '{ Predef.Short2short($v) } => astParser(v)
case '{ Predef.Character2char($v) } => astParser(v)
case '{ Predef.Integer2int($v) } => astParser(v)
case '{ Predef.Long2long($v) } => astParser(v)
case '{ Predef.Float2float($v) } => astParser(v)
case '{ Predef.Double2double($v) } => astParser(v)
case '{ Predef.Boolean2boolean($v) } => astParser(v)
}

View file

@ -0,0 +1,29 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private[parsing] def infixParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Infix] = {
import quotes.reflect.*
{
case '{ ($x: minisql.InfixValue).as[t] } => infixParsing(astParser)(x)
case '{
minisql.infix(StringContext(${ Varargs(partsExprs) }*))(${
Varargs(argsExprs)
}*)
} =>
val parts = partsExprs.map { p =>
p.value.getOrElse(
report.errorAndAbort(
s"Expected a string literal in StringContext parts, but got: ${p.show}"
)
)
}.toList
val params = argsExprs.map(arg => astParser(arg)).toList
'{ ast.Infix(${ Expr(parts) }, ${ Expr.ofList(params) }, true, false) }
}
}

View file

@ -0,0 +1,17 @@
package minisql.parsing
import scala.quoted.*
import minisql.ParamEncoder
import minisql.ast
import minisql.*
import minisql.util.*
private[parsing] def liftParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Lift] = {
case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } =>
import quotes.reflect.*
val name = x.show
val liftId = liftIdOfExpr(x)
'{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) }
}

View file

@ -0,0 +1,113 @@
package minisql.parsing
import minisql.ast
import minisql.ast.{
EqualityOperator,
StringOperator,
NumericOperator,
BooleanOperator
}
import minisql.*
import scala.quoted._
private[parsing] def operationParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Operation] = {
import quotes.reflect.*
def isNumeric(t: TypeRepr) = {
t <:< TypeRepr.of[Int]
|| t <:< TypeRepr.of[Long]
|| t <:< TypeRepr.of[Byte]
|| t <:< TypeRepr.of[Float]
|| t <:< TypeRepr.of[Double]
|| t <:< TypeRepr.of[java.math.BigDecimal]
|| t <:< TypeRepr.of[scala.math.BigDecimal]
}
def parseBinary(
left: Expr[Any],
right: Expr[Any],
op: Expr[ast.BinaryOperator]
) = {
val leftE = astParser(left)
val rightE = astParser(right)
'{ ast.BinaryOperation(${ leftE }, ${ op }, ${ rightE }) }
}
def parseUnary(expr: Expr[Any], op: Expr[ast.UnaryOperator]) = {
val base = astParser(expr)
'{ ast.UnaryOperation($op, $base) }
}
val universalOpParser: Parser[ast.BinaryOperation] = termParser {
case Apply(Select(leftT, UniversalOp(op)), List(rightT)) =>
parseBinary(leftT.asExpr, rightT.asExpr, op)
}
val stringOpParser: Parser[ast.Operation] = {
case '{ ($x: String) + ($y: String) } =>
parseBinary(x, y, '{ StringOperator.concat })
case '{ ($x: String).startsWith($y) } =>
parseBinary(x, y, '{ StringOperator.startsWith })
case '{ ($x: String).split($y) } =>
parseBinary(x, y, '{ StringOperator.split })
case '{ ($x: String).toUpperCase } =>
parseUnary(x, '{ StringOperator.toUpperCase })
case '{ ($x: String).toLowerCase } =>
parseUnary(x, '{ StringOperator.toLowerCase })
case '{ ($x: String).toLong } =>
parseUnary(x, '{ StringOperator.toLong })
case '{ ($x: String).toInt } =>
parseUnary(x, '{ StringOperator.toInt })
}
val numericOpParser = termParser {
case (Apply(Select(lt, NumericOp(op)), List(rt))) if isNumeric(lt.tpe) =>
parseBinary(lt.asExpr, rt.asExpr, op)
case Select(leftTerm, "unary_-") if isNumeric(leftTerm.tpe) =>
val leftExpr = astParser(leftTerm.asExpr)
'{ ast.UnaryOperation(NumericOperator.-, ${ leftExpr }) }
}
val booleanOpParser: Parser[ast.Operation] = {
case '{ ($x: Boolean) && $y } =>
parseBinary(x, y, '{ BooleanOperator.&& })
case '{ ($x: Boolean) || $y } =>
parseBinary(x, y, '{ BooleanOperator.|| })
case '{ !($x: Boolean) } =>
parseUnary(x, '{ BooleanOperator.! })
}
universalOpParser
.orElse(stringOpParser)
.orElse(numericOpParser)
.orElse(booleanOpParser)
}
private object UniversalOp {
def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] =
op match {
case "==" | "equals" => Some('{ EqualityOperator.== })
case "!=" => Some('{ EqualityOperator.!= })
case _ => None
}
}
private object NumericOp {
def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] =
op match {
case "+" => Some('{ NumericOperator.+ })
case "-" => Some('{ NumericOperator.- })
case "*" => Some('{ NumericOperator.* })
case "/" => Some('{ NumericOperator./ })
case ">" => Some('{ NumericOperator.> })
case ">=" => Some('{ NumericOperator.>= })
case "<" => Some('{ NumericOperator.< })
case "<=" => Some('{ NumericOperator.<= })
case "%" => Some('{ NumericOperator.% })
case _ => None
}
}

View file

@ -0,0 +1,48 @@
package minisql.parsing
import minisql.ast
import minisql.ast.Ast
import scala.quoted.*
import minisql.util.*
private[minisql] inline def parseParamAt[F](
inline f: F,
inline n: Int
): ast.Ident = ${
parseParamAt('f, 'n)
}
private[minisql] inline def parseBody[X](
inline f: X
): ast.Ast = ${
parseBody('f)
}
private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
Quotes
): Expr[ast.Ident] = {
import quotes.reflect.*
val pIdx = n.value.getOrElse(
report.errorAndAbort(s"Param index ${n.show} is not know")
)
extractTerm(f.asTerm) match {
case Lambda(vals, _) =>
vals(pIdx) match {
case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
}
}
}
private[minisql] def parseBody[X](
x: Expr[X]
)(using Quotes): Expr[Ast] = {
import quotes.reflect.*
extractTerm(x.asTerm) match {
case Lambda(vals, body) =>
Parsing.parseExpr(body.asExpr)
case o =>
report.errorAndAbort(s"Can only parse function")
}
}

View file

@ -0,0 +1,127 @@
package minisql.parsing
import minisql.ast
import minisql.context.{ReturningMultipleFieldSupported, _}
import minisql.norm.BetaReduction
import minisql.norm.capture.AvoidAliasConflict
import minisql.idiom.Idiom
import scala.annotation.tailrec
import minisql.ast.Implicits._
import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.{Hidden, Visible}
import minisql.util.{Interleave, extractTerm}
import scala.quoted.*
type Parser[A] = PartialFunction[Expr[Any], Expr[A]]
private def termParser[A](using q: Quotes)(
pf: PartialFunction[q.reflect.Term, Expr[A]]
): Parser[A] = {
import quotes.reflect._
{
case e if pf.isDefinedAt(e.asTerm) => pf(e.asTerm)
}
}
private def parser[A](
f: PartialFunction[Expr[Any], Expr[A]]
)(using Quotes): Parser[A] = {
case e if f.isDefinedAt(e) => f(e)
}
private[minisql] object Parsing {
def parseExpr(
expr: Expr[?]
)(using q: Quotes): Expr[ast.Ast] = {
import q.reflect._
def unwrapped(
f: Parser[ast.Ast]
): Parser[ast.Ast] = {
case expr =>
val t = extractTerm(expr.asTerm)
if (t.isExpr)
f(t.asExpr)
else f(expr)
}
lazy val astParser: Parser[ast.Ast] =
unwrapped {
typedParser
.orElse(propertyParser)
.orElse(liftParser)
.orElse(infixParser)
.orElse(identParser)
.orElse(valueParser)
.orElse(operationParser)
.orElse(constantParser)
.orElse(blockParser)
.orElse(boxingParser)
.orElse(ifParser)
.orElse(traversableOperationParser)
.orElse(patMatchParser)
.orElse {
case o =>
val str = scala.util.Try(o.show).getOrElse("")
report.errorAndAbort(
s"cannot parse ${str}",
o.asTerm.pos
)
}
}
lazy val typedParser: Parser[ast.Ast] = termParser {
case (Typed(t, _)) =>
astParser(t.asExpr)
}
lazy val blockParser: Parser[ast.Ast] = blockParsing(astParser)
lazy val valueParser: Parser[ast.Value] = valueParsing(astParser)
lazy val liftParser: Parser[ast.Lift] = liftParsing(astParser)
lazy val constantParser: Parser[ast.Constant] = termParser {
case Literal(x) =>
'{ ast.Constant(${ Literal(x).asExpr }) }
}
lazy val identParser: Parser[ast.Ident] = termParser {
case x @ Ident(n) =>
'{ ast.Ident(${ Expr(n) }) }
}
lazy val propertyParser: Parser[ast.Property] = propertyParsing(astParser)
lazy val operationParser: Parser[ast.Operation] = operationParsing(
astParser
)
lazy val boxingParser: Parser[ast.Ast] = boxingParsing(astParser)
lazy val ifParser: Parser[ast.If] = {
case '{ if ($a) $b else $c } =>
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
}
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
lazy val traversableOperationParser: Parser[ast.IterableOperation] =
traversableOperationParsing(astParser)
lazy val infixParser: Parser[ast.Infix] = infixParsing(
astParser
)
astParser(expr)
}
private[minisql] inline def parse[A](
inline a: A
): ast.Ast = ${
parseExpr('a)
}
}

View file

@ -0,0 +1,50 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private[parsing] def patMatchParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Ast] = {
import quotes.reflect.*
termParser {
// Val defs that showd pattern variables will cause error
case e @ Match(
Ident(t),
List(CaseDef(IsTupleUnapply(binds), None, body))
) =>
val bindStmts = binds.map {
case Bind(bn, _) =>
'{
ast.Val(
ast.Ident(${ Expr(bn) }),
ast.Property(ast.Ident(${ Expr(t) }), "_1")
)
}
}
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)
'{ ast.Block(${ Expr.ofList(allStmts.toList) }) }
}
}
object IsTupleUnapply {
def unapply(using
Quotes
)(t: quotes.reflect.Tree): Option[List[quotes.reflect.Tree]] = {
import quotes.reflect.*
def isTupleNUnapply(x: Term) = {
val fn = x.symbol.fullName
fn.startsWith("scala.Tuple") && fn.endsWith("$.unapply")
}
t match {
case Unapply(m, _, binds) if isTupleNUnapply(m) =>
Some(binds)
case _ => None
}
}
}

View file

@ -0,0 +1,29 @@
package minisql.parsing
import minisql.ast
import scala.quoted._
private[parsing] def propertyParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Property] = {
import quotes.reflect.*
def isAccessor(s: Select) = {
s.qualifier.tpe.typeSymbol.caseFields.exists(cf => cf.name == s.name)
}
val parseApply: Parser[ast.Property] = termParser {
case m @ Select(base, n) if isAccessor(m) =>
val obj = astParser(base.asExpr)
'{ ast.Property($obj, ${ Expr(n) }) }
}
val parseOptionGet: Parser[ast.Property] = {
case '{ ($e: Option[t]).get } =>
report.errorAndAbort(
"Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead."
)
}
parseApply.orElse(parseOptionGet)
}

View file

@ -0,0 +1,16 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private def traversableOperationParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.IterableOperation] = {
case '{ type k; type v; (${ m }: Map[`k`, `v`]).contains($key) } =>
'{ ast.MapContains(${ astParser(m) }, ${ astParser(key) }) }
case '{ ($s: Set[e]).contains($i) } =>
'{ ast.SetContains(${ astParser(s) }, ${ astParser(i) }) }
case '{ ($s: Seq[e]).contains($i) } =>
'{ ast.ListContains(${ astParser(s) }, ${ astParser(i) }) }
}

View file

@ -0,0 +1,72 @@
package minisql
package parsing
import scala.quoted._
import minisql.util.*
private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using
Quotes
): Parser[ast.Value] = {
import quotes.reflect.*
val parseTupleApply: Parser[ast.Tuple] = {
case IsTupleApply(args) =>
val t = args.map(astParser)
'{ ast.Tuple(${ Expr.ofList(t) }) }
}
val parseAssocTuple: Parser[ast.Tuple] = {
case '{ ($x: tx) -> ($y: ty) } =>
'{ ast.Tuple(List(${ astParser(x) }, ${ astParser(y) })) }
}
parseTupleApply.orElse(parseAssocTuple)
}
private[minisql] object IsTupleApply {
def unapply(e: Expr[Any])(using Quotes): Option[Seq[Expr[Any]]] = {
import quotes.reflect.*
def isTupleNApply(t: Term) = {
val fn = t.symbol.fullName
fn.startsWith("scala.Tuple") && fn.endsWith("$.apply")
}
def isTupleXXLApply(t: Term) = {
t.symbol.fullName == "scala.runtime.TupleXXL$.apply"
}
extractTerm(e.asTerm) match {
// TupleN(0-22).apply
case Apply(b, args) if isTupleNApply(b) =>
Some(args.map(_.asExpr))
case TypeApply(Select(t, "$asInstanceOf$"), tt) if isTupleXXLApply(t) =>
t.asExpr match {
case '{ scala.runtime.TupleXXL.apply(${ Varargs(args) }*) } =>
Some(args)
}
case o =>
None
}
}
}
private[parsing] object IsTuple2 {
def unapply(using
Quotes
)(t: Expr[Any]): Option[(Expr[Any], Expr[Any])] =
t match {
case '{ scala.Tuple2.apply($x1, $x2) } => Some((x1, x2))
case '{ ($x1: t1) -> ($x2: t2) } => Some((x1, x2))
case _ => None
}
def unapply(using
Quotes
)(t: quotes.reflect.Term): Option[(Expr[Any], Expr[Any])] =
unapply(t.asExpr)
}

View file

@ -1,8 +1,26 @@
package minisql.util
import scala.util.Try
import scala.util.*
object CollectTry {
extension [A](xs: Iterable[A]) {
private[minisql] def traverse[B](f: A => Try[B]): Try[IArray[B]] = {
val out = IArray.newBuilder[Any]
var left: Option[Throwable] = None
xs.foreach { (v) =>
if (!left.isDefined) {
f(v) match {
case Failure(e) =>
left = Some(e)
case Success(r) =>
out += r
}
}
}
left.toLeft(out.result().asInstanceOf).toTry
}
}
private[minisql] object CollectTry {
def apply[T](list: List[Try[T]]): Try[List[T]] =
list.foldLeft(Try(List.empty[T])) {
case (list, t) =>

View file

@ -11,20 +11,25 @@ import scala.util.matching.Regex
class Interpolator(
defaultIndent: Int = 0,
qprint: AstPrinter = AstPrinter(),
out: PrintStream = System.out,
out: PrintStream = System.out
) {
extension (sc: StringContext) {
def trace(elements: Any*) = new Traceable(sc, elements)
}
class Traceable(sc: StringContext, elementsSeq: Seq[Any]) {
private val elementPrefix = "| "
private sealed trait PrintElement
private case class Str(str: String, first: Boolean) extends PrintElement
private case class Elem(value: String) extends PrintElement
private case object Separator extends PrintElement
private case class Elem(value: String) extends PrintElement
private case object Separator extends PrintElement
private def generateStringForCommand(value: Any, indent: Int) = {
val objectString = qprint(value)
val oneLine = objectString.fitsOnOneLine
val oneLine = objectString.fitsOnOneLine
oneLine match {
case true => s"${indent.prefix}> ${objectString}"
case false =>
@ -42,7 +47,7 @@ class Interpolator(
private def readBuffers() = {
def orZero(i: Int): Int = if (i < 0) 0 else i
val parts = sc.parts.iterator.toList
val parts = sc.parts.iterator.toList
val elements = elementsSeq.toList.map(qprint(_))
val (firstStr, explicitIndent) = readFirst(parts.head)

View file

@ -5,10 +5,15 @@ import scala.util.Try
object LoadObject {
def apply[T](using Quotes, Type[T]): Try[T] = {
import quotes.reflect.*
apply(TypeRepr.of[T])
}
def apply[T](using Quotes)(ot: quotes.reflect.TypeRepr): Try[T] = Try {
import quotes.reflect.*
val moduleClsName = ot.typeSymbol.companionModule.moduleClass.fullName
val moduleCls = Class.forName(moduleClsName)
val moduleCls = Class.forName(moduleClsName)
val field = moduleCls
.getFields()
.find { f =>

View file

@ -0,0 +1,75 @@
package minisql.util
import minisql.AstPrinter
import minisql.idiom.Idiom
import minisql.util.IndentUtil._
object Messages {
private def variable(propName: String, envName: String, default: String) =
Option(System.getProperty(propName))
.orElse(sys.env.get(envName))
.getOrElse(default)
private[util] val prettyPrint =
variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean
private[util] val debugEnabled =
variable("quill.macro.log", "quill_macro_log", "true").toBoolean
private[util] val traceEnabled =
variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean
private[util] val traceColors =
variable("quill.trace.color", "quill_trace_color,", "false").toBoolean
private[util] val traceOpinions =
variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean
private[util] val traceAstSimple = variable(
"quill.trace.ast.simple",
"quill_trace_ast_simple",
"false"
).toBoolean
private[minisql] val cacheDynamicQueries = variable(
"quill.query.cacheDaynamic",
"query_query_cacheDaynamic",
"true"
).toBoolean
private[util] val traces: List[TraceType] =
variable("quill.trace.types", "quill_trace_types", "standard")
.split(",")
.toList
.map(_.trim)
.flatMap(trace =>
TraceType.values.filter(traceType => trace == traceType.value)
)
def tracesEnabled(tt: TraceType) =
traceEnabled && traces.contains(tt)
sealed trait TraceType { def value: String }
object TraceType {
case object Normalizations extends TraceType { val value = "norm" }
case object Standard extends TraceType { val value = "standard" }
case object NestedQueryExpansion extends TraceType { val value = "nest" }
def values: List[TraceType] =
List(Standard, Normalizations, NestedQueryExpansion)
}
val qprint = AstPrinter()
def fail(msg: String) =
throw new IllegalStateException(msg)
def trace[T](
label: String,
numIndent: Int = 0,
traceType: TraceType = TraceType.Standard
) =
(v: T) => {
val indent = (0 to numIndent).map(_ => "").mkString(" ")
if (tracesEnabled(traceType))
println(s"$indent$label\n${{
if (traceColors) qprint.apply(v)
else qprint.apply(v)
}.split("\n").map(s"$indent " + _).mkString("\n")}")
v
}
}

View file

@ -0,0 +1,41 @@
package minisql.util
import scala.quoted.*
private[minisql] def splicePkgPath(using Quotes) = {
import quotes.reflect.*
def recurse(sym: Symbol): String =
sym match {
case s if s.isPackageDef => s.fullName
case s if s.isNoSymbol => ""
case _ =>
recurse(sym.maybeOwner)
}
recurse(Symbol.spliceOwner)
}
private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = {
import quotes.reflect.*
val name = x.asTerm.show
val packageName = splicePkgPath
val pos = x.asTerm.pos
val fileName = pos.sourceFile.name
s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}"
}
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
val o = unwrapTerm(x)
o
}

View file

@ -1,21 +1,20 @@
package minisql.util
object Show {
trait Show[T] {
def show(v: T): String
trait Show[T] {
extension (v: T) {
def show: String
}
}
object Show {
def apply[T](f: T => String) = new Show[T] {
def show(v: T) = f(v)
object Show {
def apply[T](f: T => String) = new Show[T] {
extension (v: T) {
def show: String = f(v)
}
}
implicit class Shower[T](v: T)(implicit shower: Show[T]) {
def show = shower.show(v)
}
implicit def listShow[T](implicit shower: Show[T]): Show[List[T]] =
given listShow[T](using shower: Show[T]): Show[List[T]] =
Show[List[T]] {
case list => list.map(_.show).mkString(", ")
}

View file

@ -0,0 +1,188 @@
package minisql.ast
import munit.FunSuite
import minisql.ast.*
import scala.quoted.*
class FromExprsSuite extends FunSuite {
// Helper to test both compile-time and runtime extraction
inline def testFor[A <: Ast](label: String)(inline ast: A) = {
test(label) {
// Test compile-time extraction
val compileTimeResult = minisql.compileTimeAst(ast)
assert(compileTimeResult.contains(ast.toString))
}
}
testFor("Ident") {
Ident("test")
}
testFor("Ident with visibility") {
Ident.Opinionated("test", Visibility.Hidden)
}
testFor("Property") {
Property(Ident("a"), "b")
}
testFor("Property with opinions") {
Property.Opinionated(Ident("a"), "b", Renameable.Fixed, Visibility.Visible)
}
testFor("BinaryOperation") {
BinaryOperation(Ident("a"), EqualityOperator.==, Ident("b"))
}
testFor("UnaryOperation") {
UnaryOperation(BooleanOperator.!, Ident("flag"))
}
testFor("ScalarValueLift") {
ScalarValueLift("name", "id", None)
}
testFor("Ordering") {
PropertyOrdering.Asc
}
testFor("TupleOrdering") {
TupleOrdering(List(PropertyOrdering.Asc, PropertyOrdering.Desc))
}
testFor("Entity") {
Entity("people", Nil)
}
testFor("Entity with properties") {
Entity("people", List(PropertyAlias(List("name"), "full_name")))
}
testFor("Action - Insert") {
Insert(
Ident("table"),
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
)
}
testFor("Action - Update") {
Update(
Ident("table"),
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
)
}
testFor("Action - Returning") {
Returning(
Insert(
Ident("table"),
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
),
Ident("x"),
Property(Ident("x"), "id")
)
}
testFor("Action - ReturningGenerated") {
ReturningGenerated(
Insert(
Ident("table"),
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
),
Ident("x"),
Property(Ident("x"), "generatedId")
)
}
testFor("Action - Val outside") {
val p1 = Update(
Ident("table"),
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
)
val p2 = Ident("x")
val p3 = Property(Ident("x"), "id")
Returning(p1, p2, p3)
}
testFor("Action - ReturningGenerated with Update") {
val p1 = Update(
Ident("table"),
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
)
val p2 = Ident("x")
val p3 = Property(Ident("x"), "id")
ReturningGenerated(p1, p2, p3)
}
testFor("If expression") {
If(Ident("cond"), Ident("then"), Ident("else"))
}
testFor("Infix") {
Infix(
List("func(", ")"),
List(Ident("param")),
pure = true,
noParen = false
)
}
testFor("Infix with different parameters") {
Infix(
List("?", " + ", "?"),
List(Constant(1), Constant(2)),
pure = true,
noParen = true
)
}
testFor("OptionOperation - OptionMap") {
OptionMap(Ident("opt"), Ident("x"), Ident("x"))
}
testFor("OptionOperation - OptionFlatMap") {
OptionFlatMap(Ident("opt"), Ident("x"), Ident("x"))
}
testFor("OptionOperation - OptionGetOrElse") {
OptionGetOrElse(Ident("opt"), Ident("default"))
}
testFor("Join") {
Join(
JoinType.InnerJoin,
Ident("a"),
Ident("b"),
Ident("a1"),
Ident("b1"),
BinaryOperation(Ident("a1.id"), EqualityOperator.==, Ident("b1.id"))
)
}
testFor("Distinct") {
Distinct(Ident("query"))
}
testFor("GroupBy") {
GroupBy(Ident("query"), Ident("alias"), Ident("body"))
}
testFor("Aggregation") {
Aggregation(AggregationOperator.avg, Ident("field"))
}
testFor("CaseClass") {
CaseClass(List(("name", Ident("value"))))
}
testFor("Block") { // Also tested Val
Block(
List(
Val(Ident("x"), Constant(1)),
Val(Ident("y"), Constant(2)),
BinaryOperation(Ident("x"), NumericOperator.+, Ident("y"))
)
)
}
}

View file

@ -0,0 +1,63 @@
package minisql.context.sql
import minisql.*
import minisql.ast.*
import minisql.idiom.*
import minisql.NamingStrategy
import minisql.MirrorContext
import minisql.context.mirror.{*, given}
class MirrorSqlContextSuite extends munit.FunSuite {
case class Foo(id: Long, name: String)
inline def Foos = query[Foo]("foo")
import testContext.given
test("SimpleQuery") {
val o = testContext.io(
query[Foo](
"foo",
alias("id", "id1")
).filter(x => x.id > 0)
)
assertEquals(o.sql, "SELECT x.id1, x.name FROM foo x WHERE x.id1 > 0")
}
test("Insert") {
val v: Foo = Foo(0L, "foo")
val o = testContext.io(Foos.insert(v))
assertEquals(
o.sql,
"INSERT INTO foo (id,name) VALUES (?, ?)"
)
}
test("InsertReturningGenerated") {
val v: Foo = Foo(0L, "foo")
val o = testContext.io(Foos.insert(v).returningGenerated(_.id))
assertEquals(
o.sql,
"INSERT INTO foo (name) VALUES (?) RETURNING id"
)
}
test("LeftJoin") {
val o = testContext
.io(Foos.join(Foos).on((f1, f2) => f1.id == f2.id).map {
case (f1, f2) => (f1.id, f2.id)
})
println(o)
}
test("Infix string interpolation") {
val o = testContext.io(
Foos.map(f => infix"CONCAT(${f.name}, ' ', ${f.id})".as[String])
)
assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f")
}
}

View file

@ -0,0 +1,5 @@
package minisql.context.sql
import minisql.*
val testContext = new MirrorSqlContext(Literal)

View file

@ -0,0 +1,38 @@
package minisql.parsing
import minisql.ast.*
class ParsingSuite extends munit.FunSuite {
test("Ident") {
val x = 1
assertEquals(Parsing.parse(x), Ident("x"))
}
test("NumericOperator.+") {
val a = 1
val b = 2
assertEquals(
Parsing.parse(a + b),
BinaryOperation(Ident("a"), NumericOperator.+, Ident("b"))
)
}
test("NumericOperator.-") {
val a = 1
val b = 2
assertEquals(
Parsing.parse(a - b),
BinaryOperation(Ident("a"), NumericOperator.-, Ident("b"))
)
}
test("NumericOperator.*") {
val a = 1
val b = 2
assertEquals(
Parsing.parse(a * b),
BinaryOperation(Ident("a"), NumericOperator.*, Ident("b"))
)
}
}

View file

@ -0,0 +1 @@