Compare commits
26 commits
main
...
add_parsin
Author | SHA1 | Date | |
---|---|---|---|
![]() |
071b27abcf | ||
![]() |
ed1952b915 | ||
![]() |
06850823d7 | ||
![]() |
48cb1003bb | ||
![]() |
f5e43657b3 | ||
![]() |
adc60400a7 | ||
![]() |
c1f26a0704 | ||
![]() |
23c0484609 | ||
![]() |
a1201a67aa | ||
![]() |
2753f01001 | ||
![]() |
24f7f6aec0 | ||
![]() |
3a9d15f015 | ||
![]() |
184ab0b884 | ||
![]() |
2b52ef3203 | ||
![]() |
17e97495b7 | ||
![]() |
1bc6baad68 | ||
![]() |
63a9a0cad3 | ||
![]() |
cb0c6082d0 | ||
![]() |
47cf808e8f | ||
![]() |
7f5092c396 | ||
![]() |
87f1b70b27 | ||
![]() |
6be96aba2c | ||
![]() |
59f969a232 | ||
![]() |
2e7e7df4a3 | ||
![]() |
a0ceea91a9 | ||
![]() |
8103d45178 |
90 changed files with 6951 additions and 295 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -2,4 +2,5 @@ target/
|
|||
.bsp/
|
||||
.metals/
|
||||
.bloop/
|
||||
project/metals.sbt
|
||||
project/metals.sbt
|
||||
.aider*
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
大部分场景不用在 `macro` 对 Ast 进行复杂模式匹配来分析代码。
|
||||
|
||||
|
||||
## 核心思路 使用 inline 和 `FromExpr` 代替大部分 parsing 工作
|
||||
## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作
|
||||
|
||||
`FromExpr` 是 `scala3` 内置的 typeclass,用来获取编译期值 。
|
||||
|
||||
|
|
11
build.sbt
11
build.sbt
|
@ -1,8 +1,15 @@
|
|||
name := "minisql"
|
||||
|
||||
scalaVersion := "3.6.2"
|
||||
scalaVersion := "3.7.1"
|
||||
|
||||
libraryDependencies ++= Seq(
|
||||
"org.scalameta" %% "munit" % "1.1.1" % Test
|
||||
)
|
||||
|
||||
scalacOptions ++= Seq("-experimental", "-language:experimental.namedTuples")
|
||||
|
||||
scalacOptions ++= Seq(
|
||||
"-deprecation",
|
||||
"-feature",
|
||||
"-source:3.7-migration",
|
||||
"-rewrite"
|
||||
)
|
||||
|
|
|
@ -1 +1 @@
|
|||
sbt.version=1.10.5
|
||||
sbt.version=1.11.2
|
||||
|
|
3
src/main/scala/minisql/Meta.scala
Normal file
3
src/main/scala/minisql/Meta.scala
Normal file
|
@ -0,0 +1,3 @@
|
|||
package minisql
|
||||
|
||||
type QueryMeta
|
|
@ -1,8 +1,22 @@
|
|||
package minisql
|
||||
|
||||
trait ParamEncoder[E] {
|
||||
import scala.util.Try
|
||||
|
||||
trait ParamEncoder[E] {
|
||||
type Stmt
|
||||
|
||||
def setParam(s: Stmt, idx: Int, v: E): Unit
|
||||
def setParam(s: Stmt, idx: Int, v: E): Stmt
|
||||
}
|
||||
|
||||
trait ColumnDecoder[X] {
|
||||
|
||||
type DBRow
|
||||
|
||||
def decode(row: DBRow, idx: Int): Try[X]
|
||||
}
|
||||
|
||||
object ColumnDecoder {
|
||||
type Aux[R, X] = ColumnDecoder[X] {
|
||||
type DBRow = R
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import minisql.ast.Ast
|
||||
import scala.quoted.*
|
||||
|
||||
type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]]
|
||||
|
||||
private[minisql] inline def parseParamAt[F](
|
||||
inline f: F,
|
||||
inline n: Int
|
||||
): ast.Ident = ${
|
||||
parseParamAt('f, 'n)
|
||||
}
|
||||
|
||||
private[minisql] inline def parseBody[X](
|
||||
inline f: X
|
||||
): ast.Ast = ${
|
||||
parseBody('f)
|
||||
}
|
||||
|
||||
private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
|
||||
Quotes
|
||||
): Expr[ast.Ident] = {
|
||||
|
||||
import quotes.reflect.*
|
||||
|
||||
val pIdx = n.value.getOrElse(
|
||||
report.errorAndAbort(s"Param index ${n.show} is not know")
|
||||
)
|
||||
extractTerm(f.asTerm) match {
|
||||
case Lambda(vals, _) =>
|
||||
vals(pIdx) match {
|
||||
case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[minisql] def parseBody[X](
|
||||
x: Expr[X]
|
||||
)(using Quotes): Expr[Ast] = {
|
||||
import quotes.reflect.*
|
||||
extractTerm(x.asTerm) match {
|
||||
case Lambda(vals, body) =>
|
||||
astParser(body.asExpr)
|
||||
case o =>
|
||||
report.errorAndAbort(s"Can only parse function")
|
||||
}
|
||||
}
|
||||
private def isNumeric(x: Expr[?])(using Quotes): Boolean = {
|
||||
import quotes.reflect.*
|
||||
x.asTerm.tpe match {
|
||||
case t if t <:< TypeRepr.of[Int] => true
|
||||
case t if t <:< TypeRepr.of[Long] => true
|
||||
case t if t <:< TypeRepr.of[Float] => true
|
||||
case t if t <:< TypeRepr.of[Double] => true
|
||||
case t if t <:< TypeRepr.of[BigDecimal] => true
|
||||
case t if t <:< TypeRepr.of[BigInt] => true
|
||||
case t if t <:< TypeRepr.of[java.lang.Integer] => true
|
||||
case t if t <:< TypeRepr.of[java.lang.Long] => true
|
||||
case t if t <:< TypeRepr.of[java.lang.Float] => true
|
||||
case t if t <:< TypeRepr.of[java.lang.Double] => true
|
||||
case t if t <:< TypeRepr.of[java.math.BigDecimal] => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
private def identParser(using Quotes): Parser[ast.Ident] = {
|
||||
import quotes.reflect.*
|
||||
{ (x: Expr[?]) =>
|
||||
extractTerm(x.asTerm) match {
|
||||
case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) })
|
||||
case _ => None
|
||||
}
|
||||
}.unlift
|
||||
|
||||
}
|
||||
|
||||
private lazy val astParser: Quotes ?=> Parser[Ast] = {
|
||||
identParser.orElse(propertyParser(astParser))
|
||||
}
|
||||
|
||||
private object IsPropertySelect {
|
||||
def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = {
|
||||
import quotes.reflect.*
|
||||
x.asTerm match {
|
||||
case Select(x, n) => Some(x.asExpr, n)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def propertyParser(
|
||||
astParser: => Parser[Ast]
|
||||
)(using Quotes): Parser[ast.Property] = {
|
||||
case IsPropertySelect(expr, n) =>
|
||||
'{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) }
|
||||
}
|
||||
|
||||
def optionOperationParser(
|
||||
astParser: => Parser[Ast]
|
||||
)(using Quotes): Parser[ast.OptionOperation] = {
|
||||
case '{ ($x: Option[t]).isEmpty } =>
|
||||
'{ ast.OptionIsEmpty(${ astParser(x) }) }
|
||||
}
|
||||
|
||||
def binaryOperationParser(
|
||||
astParser: => Parser[Ast]
|
||||
)(using Quotes): Parser[ast.BinaryOperation] = {
|
||||
???
|
||||
}
|
||||
|
||||
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
|
||||
import quotes.reflect.*
|
||||
def unwrapTerm(t: Term): Term = t match {
|
||||
case Inlined(_, _, o) => unwrapTerm(o)
|
||||
case Block(Nil, last) => last
|
||||
case Typed(t, _) =>
|
||||
unwrapTerm(t)
|
||||
case Select(t, "$asInstanceOf$") =>
|
||||
unwrapTerm(t)
|
||||
case TypeApply(t, _) =>
|
||||
unwrapTerm(t)
|
||||
case o => o
|
||||
}
|
||||
unwrapTerm(x)
|
||||
}
|
298
src/main/scala/minisql/Quoted.scala
Normal file
298
src/main/scala/minisql/Quoted.scala
Normal file
|
@ -0,0 +1,298 @@
|
|||
package minisql
|
||||
|
||||
import minisql.*
|
||||
import minisql.idiom.*
|
||||
import minisql.parsing.*
|
||||
import minisql.util.*
|
||||
import minisql.ast.{
|
||||
Ast,
|
||||
Entity,
|
||||
Map,
|
||||
Property,
|
||||
Ident,
|
||||
Filter,
|
||||
PropertyAlias,
|
||||
JoinType,
|
||||
Join,
|
||||
given
|
||||
}
|
||||
import scala.quoted.*
|
||||
import scala.deriving.*
|
||||
import scala.compiletime.*
|
||||
import scala.compiletime.ops.string.*
|
||||
import scala.collection.immutable.{Map => IMap}
|
||||
|
||||
opaque type Quoted <: Ast = Ast
|
||||
|
||||
opaque type Query[E] <: Quoted = Quoted
|
||||
|
||||
opaque type Action[E] <: Quoted = Quoted
|
||||
|
||||
opaque type Insert[E] <: Action[Long] = Quoted
|
||||
|
||||
object Insert {
|
||||
extension [E](inline insert: Insert[E]) {
|
||||
inline def returning[E1](inline f: E => E1): InsertReturning[E1] = {
|
||||
transform(insert)(f)(ast.Returning.apply)
|
||||
}
|
||||
|
||||
inline def returningGenerated[E1](
|
||||
inline f: E => E1
|
||||
): InsertReturning[E1] = {
|
||||
transform(insert)(f)(ast.ReturningGenerated.apply)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
opaque type InsertReturning[E] <: Action[E] = Quoted
|
||||
|
||||
sealed trait Joined[E1, E2]
|
||||
|
||||
opaque type JoinQuery[E1, E2] <: Query[(E1, E2)] = Quoted
|
||||
|
||||
object Joined {
|
||||
|
||||
def apply[E1, E2](joinType: JoinType, ta: Ast, tb: Ast): Joined[E1, E2] =
|
||||
new Joined[E1, E2] {}
|
||||
|
||||
extension [E1, E2](inline j: Joined[E1, E2]) {
|
||||
inline def on(inline f: (E1, E2) => Boolean): JoinQuery[E1, E2] =
|
||||
joinOn(j, f)
|
||||
}
|
||||
}
|
||||
|
||||
private inline def joinOn[E1, E2](
|
||||
inline j: Joined[E1, E2],
|
||||
inline f: (E1, E2) => Boolean
|
||||
): JoinQuery[E1, E2] = j.toJoinQuery(f.param0, f.param1, f.body)
|
||||
|
||||
extension [E1, E2](inline j: Joined[E1, E2]) {
|
||||
private inline def toJoinQuery(
|
||||
inline aliasA: Ident,
|
||||
inline aliasB: Ident,
|
||||
inline on: Ast
|
||||
): Ast = ${ joinQueryOf('j, 'aliasA, 'aliasB, 'on) }
|
||||
}
|
||||
|
||||
private def joinQueryOf[E1, E2](
|
||||
x: Expr[Joined[E1, E2]],
|
||||
aliasA: Expr[Ident],
|
||||
aliasB: Expr[Ident],
|
||||
on: Expr[Ast]
|
||||
)(using Quotes, Type[E1], Type[E2]): Expr[Join] = {
|
||||
import quotes.reflect.*
|
||||
extractTerm(x.asTerm).asExpr match {
|
||||
case '{ Joined[E1, E2]($jt, $a, $b) } =>
|
||||
'{
|
||||
Join($jt, $a, $b, $aliasA, $aliasB, $on)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private inline def quotedLift[X](x: X)(using
|
||||
e: ParamEncoder[X]
|
||||
): ast.ScalarValueLift = ${
|
||||
quotedLiftImpl[X]('x, 'e)
|
||||
}
|
||||
|
||||
private def quotedLiftImpl[X: Type](
|
||||
x: Expr[X],
|
||||
e: Expr[ParamEncoder[X]]
|
||||
)(using Quotes): Expr[ast.ScalarValueLift] = {
|
||||
import quotes.reflect.*
|
||||
val name = x.asTerm.show
|
||||
val liftId = liftIdOfExpr(x)
|
||||
'{
|
||||
ast.ScalarValueLift(
|
||||
${ Expr(name) },
|
||||
${ Expr(liftId) },
|
||||
Some(($x, $e))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
object Query {
|
||||
|
||||
extension [E](inline e: Query[E]) {
|
||||
|
||||
private[minisql] inline def expanded: Query[E] = {
|
||||
expandFields[E](e)
|
||||
}
|
||||
|
||||
inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, Option[E1]] =
|
||||
Joined[E, Option[E1]](JoinType.LeftJoin, e, e1)
|
||||
|
||||
inline def rightJoin[E1](inline e1: Query[E1]): Joined[Option[E], E1] =
|
||||
Joined[Option[E], E1](JoinType.RightJoin, e, e1)
|
||||
|
||||
inline def join[E1](inline e1: Query[E1]): Joined[E, E1] =
|
||||
Joined[E, E1](JoinType.InnerJoin, e, e1)
|
||||
|
||||
inline def fullJoin[E1](
|
||||
inline e1: Query[E1]
|
||||
): Joined[Option[E], Option[E1]] =
|
||||
Joined[Option[E], Option[E1]](JoinType.FullJoin, e, e1)
|
||||
|
||||
inline def map[E1](inline f: E => E1): Query[E1] = {
|
||||
transform(e)(f)(Map.apply)
|
||||
}
|
||||
|
||||
inline def filter(inline f: E => Boolean): Query[E] = {
|
||||
transform(e)(f)(Filter.apply)
|
||||
}
|
||||
|
||||
inline def withFilter(inline f: E => Boolean): Query[E] = {
|
||||
transform(e)(f)(Filter.apply)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
opaque type EntityQuery[E] <: Query[E] = Query[E]
|
||||
|
||||
object EntityQuery {
|
||||
|
||||
extension [E](inline e: EntityQuery[E]) {
|
||||
|
||||
inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
|
||||
transform(e)(f)(Map.apply)
|
||||
}
|
||||
|
||||
inline def filter(inline f: E => Boolean): EntityQuery[E] = {
|
||||
transform(e)(f)(Filter.apply)
|
||||
}
|
||||
|
||||
inline def insert(v: E): Insert[E] = {
|
||||
ast.Insert(e, transformCaseClassToAssignments[E](v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private inline def transformCaseClassToAssignments[E](
|
||||
v: E
|
||||
): List[ast.Assignment] = ${
|
||||
transformCaseClassToAssignmentsImpl[E]('v)
|
||||
}
|
||||
|
||||
private def transformCaseClassToAssignmentsImpl[E: Type](
|
||||
v: Expr[E]
|
||||
)(using Quotes): Expr[List[ast.Assignment]] = {
|
||||
import quotes.reflect.*
|
||||
|
||||
val fields = TypeRepr.of[E].typeSymbol.caseFields
|
||||
val assignments = fields.map { field =>
|
||||
val fieldName = field.name
|
||||
val fieldType = field.tree match {
|
||||
case v: ValDef => v.tpt.tpe
|
||||
case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName")
|
||||
}
|
||||
fieldType.asType match {
|
||||
case '[t] =>
|
||||
'{
|
||||
ast.Assignment(
|
||||
ast.Ident("v"),
|
||||
ast.Property(ast.Ident("v"), ${ Expr(fieldName) }),
|
||||
quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using
|
||||
summonInline[ParamEncoder[t]]
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expr.ofList(assignments)
|
||||
}
|
||||
|
||||
private inline def transform[A, B](inline q1: Quoted)(
|
||||
inline f: A => B
|
||||
)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
|
||||
fast(q1, f.param0, f.body)
|
||||
}
|
||||
|
||||
inline def alias(inline from: String, inline to: String): PropertyAlias =
|
||||
PropertyAlias(List(from), to)
|
||||
|
||||
inline def query[E](
|
||||
inline table: String,
|
||||
inline alias: PropertyAlias*
|
||||
): EntityQuery[E] =
|
||||
Entity(
|
||||
table,
|
||||
List(alias*)
|
||||
)
|
||||
|
||||
extension [A, B](inline f1: A => B) {
|
||||
private inline def param0 = parsing.parseParamAt(f1, 0)
|
||||
private inline def body = parsing.parseBody(f1)
|
||||
}
|
||||
|
||||
extension [A1, A2, B](inline f1: (A1, A2) => B) {
|
||||
private inline def param0 = parsing.parseParamAt(f1, 0)
|
||||
private inline def param1 = parsing.parseParamAt(f1, 1)
|
||||
private inline def body = parsing.parseBody(f1)
|
||||
}
|
||||
|
||||
def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException()
|
||||
|
||||
class NonQuotedException extends Exception("Cannot be used at runtime")
|
||||
|
||||
private[minisql] inline def compileTimeAst(inline q: Ast): Option[String] =
|
||||
${ compileTimeAstImpl('q) }
|
||||
|
||||
private def compileTimeAstImpl(e: Expr[Ast])(using
|
||||
Quotes
|
||||
): Expr[Option[String]] = {
|
||||
import quotes.reflect.*
|
||||
e.value match {
|
||||
case Some(v) => '{ Some(${ Expr(v.toString()) }) }
|
||||
case None => '{ None }
|
||||
}
|
||||
}
|
||||
|
||||
private[minisql] inline def compile[I <: Idiom, N <: NamingStrategy](
|
||||
inline q: Quoted,
|
||||
inline idiom: I,
|
||||
inline naming: N
|
||||
): Statement = ${ compileImpl[I, N]('q, 'idiom, 'naming) }
|
||||
|
||||
private def compileImpl[I <: Idiom, N <: NamingStrategy](
|
||||
q: Expr[Quoted],
|
||||
idiom: Expr[I],
|
||||
n: Expr[N]
|
||||
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
|
||||
import quotes.reflect.*
|
||||
q.value match {
|
||||
case Some(ast) =>
|
||||
val idiom = LoadObject[I].getOrElse(
|
||||
report.errorAndAbort(s"Idiom not known at compile")
|
||||
)
|
||||
|
||||
val naming = LoadNaming
|
||||
.static[N]
|
||||
.getOrElse(report.errorAndAbort(s"NamingStrategy not known at compile"))
|
||||
|
||||
val stmt = idiom.translate(ast)(using naming)
|
||||
report.info(s"Static Query: ${stmt._2}")
|
||||
Expr(stmt._2)
|
||||
case None =>
|
||||
report.info("Dynamic Query")
|
||||
'{
|
||||
$idiom.translate($q)(using $n)._2
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private inline def expandFields[E](inline base: Ast): Ast =
|
||||
${ expandFieldsImpl[E]('base) }
|
||||
|
||||
private def expandFieldsImpl[E](baseExpr: Expr[Ast])(using
|
||||
Quotes,
|
||||
Type[E]
|
||||
): Expr[Ast] = {
|
||||
import quotes.reflect.*
|
||||
val values = TypeRepr.of[E].typeSymbol.caseFields.map { f =>
|
||||
'{ Property(ast.Ident("x"), ${ Expr(f.name) }) }
|
||||
}
|
||||
'{ Map(${ baseExpr }, ast.Ident("x"), ast.Tuple(${ Expr.ofList(values) })) }
|
||||
}
|
7
src/main/scala/minisql/ReturnAction.scala
Normal file
7
src/main/scala/minisql/ReturnAction.scala
Normal file
|
@ -0,0 +1,7 @@
|
|||
package minisql
|
||||
|
||||
enum ReturnAction {
|
||||
case ReturnNothing
|
||||
case ReturnColumns(columns: List[String])
|
||||
case ReturnRecord
|
||||
}
|
12
src/main/scala/minisql/SqlInfix.scala
Normal file
12
src/main/scala/minisql/SqlInfix.scala
Normal file
|
@ -0,0 +1,12 @@
|
|||
package minisql
|
||||
|
||||
import minisql.ast.Ast
|
||||
import scala.quoted.*
|
||||
|
||||
sealed trait InfixValue {
|
||||
def as[T]: T
|
||||
}
|
||||
|
||||
extension (sc: StringContext) {
|
||||
def infix(args: Any*): InfixValue = throw NonQuotedException()
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package minisql.ast
|
||||
|
||||
import minisql.NamingStrategy
|
||||
import minisql.ParamEncoder
|
||||
|
||||
import scala.quoted.*
|
||||
|
||||
|
@ -58,9 +59,9 @@ object Entity {
|
|||
|
||||
object Opinionated {
|
||||
inline def apply(
|
||||
name: String,
|
||||
properties: List[PropertyAlias],
|
||||
renameableNew: Renameable
|
||||
inline name: String,
|
||||
inline properties: List[PropertyAlias],
|
||||
inline renameableNew: Renameable
|
||||
): Entity = Entity(name, properties, renameableNew)
|
||||
|
||||
def unapply(e: Entity) =
|
||||
|
@ -84,13 +85,14 @@ case class SortBy(query: Ast, alias: Ident, criterias: Ast, ordering: Ordering)
|
|||
sealed trait Ordering extends Ast
|
||||
case class TupleOrdering(elems: List[Ordering]) extends Ordering
|
||||
|
||||
sealed trait PropertyOrdering extends Ordering
|
||||
case object Asc extends PropertyOrdering
|
||||
case object Desc extends PropertyOrdering
|
||||
case object AscNullsFirst extends PropertyOrdering
|
||||
case object DescNullsFirst extends PropertyOrdering
|
||||
case object AscNullsLast extends PropertyOrdering
|
||||
case object DescNullsLast extends PropertyOrdering
|
||||
enum PropertyOrdering extends Ordering {
|
||||
case Asc
|
||||
case Desc
|
||||
case AscNullsFirst
|
||||
case DescNullsFirst
|
||||
case AscNullsLast
|
||||
case DescNullsLast
|
||||
}
|
||||
|
||||
case class GroupBy(query: Ast, alias: Ident, body: Ast) extends Query
|
||||
|
||||
|
@ -153,11 +155,14 @@ case class Ident(name: String, visibility: Visibility) extends Ast {
|
|||
* ExpandNestedQueries phase, needs to be marked invisible.
|
||||
*/
|
||||
object Ident {
|
||||
def apply(name: String): Ident = Ident(name, Visibility.neutral)
|
||||
def unapply(p: Ident) = Some((p.name))
|
||||
inline def apply(inline name: String): Ident = Ident(name, Visibility.neutral)
|
||||
def unapply(p: Ident) = Some((p.name))
|
||||
|
||||
object Opinionated {
|
||||
def apply(name: String, visibilityNew: Visibility): Ident =
|
||||
inline def apply(
|
||||
inline name: String,
|
||||
inline visibilityNew: Visibility
|
||||
): Ident =
|
||||
Ident(name, visibilityNew)
|
||||
def unapply(p: Ident) =
|
||||
Some((p.name, p.visibility))
|
||||
|
@ -378,14 +383,21 @@ sealed trait ScalarLift extends Lift
|
|||
|
||||
case class ScalarValueLift(
|
||||
name: String,
|
||||
liftId: String
|
||||
liftId: String,
|
||||
value: Option[(Any, ParamEncoder[?])]
|
||||
) extends ScalarLift
|
||||
|
||||
case class ScalarQueryLift(
|
||||
name: String,
|
||||
liftId: String,
|
||||
value: Option[(Seq[Any], ParamEncoder[?])]
|
||||
) extends ScalarLift
|
||||
|
||||
object ScalarLift {
|
||||
given ToExpr[ScalarLift] with {
|
||||
def apply(l: ScalarLift)(using Quotes) = l match {
|
||||
case ScalarValueLift(n, id) =>
|
||||
'{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) }
|
||||
case ScalarValueLift(n, id, v) =>
|
||||
'{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, None) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
94
src/main/scala/minisql/ast/AstOps.scala
Normal file
94
src/main/scala/minisql/ast/AstOps.scala
Normal file
|
@ -0,0 +1,94 @@
|
|||
package minisql.ast
|
||||
object Implicits {
|
||||
implicit class AstOps(body: Ast) {
|
||||
private[minisql] def +||+(other: Ast) =
|
||||
BinaryOperation(body, BooleanOperator.`||`, other)
|
||||
private[minisql] def +&&+(other: Ast) =
|
||||
BinaryOperation(body, BooleanOperator.`&&`, other)
|
||||
private[minisql] def +==+(other: Ast) =
|
||||
BinaryOperation(body, EqualityOperator.`==`, other)
|
||||
private[minisql] def +!=+(other: Ast) =
|
||||
BinaryOperation(body, EqualityOperator.`!=`, other)
|
||||
}
|
||||
}
|
||||
|
||||
object +||+ {
|
||||
def unapply(a: Ast): Option[(Ast, Ast)] = {
|
||||
a match {
|
||||
case BinaryOperation(one, BooleanOperator.`||`, two) => Some((one, two))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object +&&+ {
|
||||
def unapply(a: Ast): Option[(Ast, Ast)] = {
|
||||
a match {
|
||||
case BinaryOperation(one, BooleanOperator.`&&`, two) => Some((one, two))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val EqOp = EqualityOperator.`==`
|
||||
val NeqOp = EqualityOperator.`!=`
|
||||
|
||||
object +==+ {
|
||||
def unapply(a: Ast): Option[(Ast, Ast)] = {
|
||||
a match {
|
||||
case BinaryOperation(one, EqOp, two) => Some((one, two))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object +!=+ {
|
||||
def unapply(a: Ast): Option[(Ast, Ast)] = {
|
||||
a match {
|
||||
case BinaryOperation(one, NeqOp, two) => Some((one, two))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object IsNotNullCheck {
|
||||
def apply(ast: Ast) = BinaryOperation(ast, EqualityOperator.`!=`, NullValue)
|
||||
|
||||
def unapply(ast: Ast): Option[Ast] = {
|
||||
ast match {
|
||||
case BinaryOperation(cond, NeqOp, NullValue) => Some(cond)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object IsNullCheck {
|
||||
def apply(ast: Ast) = BinaryOperation(ast, EqOp, NullValue)
|
||||
|
||||
def unapply(ast: Ast): Option[Ast] = {
|
||||
ast match {
|
||||
case BinaryOperation(cond, EqOp, NullValue) => Some(cond)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object IfExistElseNull {
|
||||
def apply(exists: Ast, `then`: Ast) =
|
||||
If(IsNotNullCheck(exists), `then`, NullValue)
|
||||
|
||||
def unapply(ast: Ast) = ast match {
|
||||
case If(IsNotNullCheck(exists), t, NullValue) => Some((exists, t))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
object IfExist {
|
||||
def apply(exists: Ast, `then`: Ast, otherwise: Ast) =
|
||||
If(IsNotNullCheck(exists), `then`, otherwise)
|
||||
|
||||
def unapply(ast: Ast) = ast match {
|
||||
case If(IsNotNullCheck(exists), t, e) => Some((exists, t, e))
|
||||
case _ => None
|
||||
}
|
||||
}
|
|
@ -45,14 +45,14 @@ private given FromExpr[Infix] with {
|
|||
private given FromExpr[ScalarValueLift] with {
|
||||
def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] =
|
||||
x match {
|
||||
case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } =>
|
||||
Some(ScalarValueLift(n, id))
|
||||
case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } =>
|
||||
Some(ScalarValueLift(n, id, None))
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[Ident] with {
|
||||
def unapply(x: Expr[Ident])(using Quotes): Option[Ident] = x match {
|
||||
case '{ Ident(${ Expr(n) }) } => Some(Ident(n))
|
||||
case '{ Ident(${ Expr(n) }, ${ Expr(v) }) } => Some(Ident(n, v))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,19 +69,19 @@ private given FromExpr[Property] with {
|
|||
)
|
||||
} =>
|
||||
Some(Property(a, n, r, v))
|
||||
case o =>
|
||||
println(s"Cannot extrat ${o.show}")
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[Ordering] with {
|
||||
def unapply(x: Expr[Ordering])(using Quotes): Option[Ordering] = {
|
||||
import PropertyOrdering.*
|
||||
x match {
|
||||
case '{ Asc } => Some(Asc)
|
||||
case '{ Desc } => Some(Desc)
|
||||
case '{ AscNullsFirst } => Some(AscNullsFirst)
|
||||
case '{ AscNullsLast } => Some(AscNullsLast)
|
||||
case '{ DescNullsFirst } => Some(DescNullsFirst)
|
||||
case '{ DescNullsLast } => Some(DescNullsLast)
|
||||
case '{ TupleOrdering($xs) } => xs.value.map(TupleOrdering(_))
|
||||
}
|
||||
}
|
||||
|
@ -122,15 +122,48 @@ private given FromExpr[Query] with {
|
|||
Some(FlatMap(b, id, body))
|
||||
case '{ ConcatMap(${ Expr(b) }, ${ Expr(id) }, ${ Expr(body) }) } =>
|
||||
Some(ConcatMap(b, id, body))
|
||||
case '{
|
||||
val x: Ast = ${ Expr(b) }
|
||||
val y: Ident = ${ Expr(id) }
|
||||
val z: Ast = ${ Expr(body) }
|
||||
ConcatMap(x, y, z)
|
||||
} =>
|
||||
Some(ConcatMap(b, id, body))
|
||||
case '{ Drop(${ Expr(b) }, ${ Expr(n) }) } =>
|
||||
Some(Drop(b, n))
|
||||
case '{ Take(${ Expr(b) }, ${ Expr[Ast](n) }) } =>
|
||||
Some(Take(b, n))
|
||||
case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } =>
|
||||
Some(SortBy(b, p, s, o))
|
||||
case o =>
|
||||
println(s"Cannot extract ${o.show}")
|
||||
None
|
||||
case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } =>
|
||||
Some(GroupBy(b, p, body))
|
||||
case '{ Distinct(${ Expr(a) }) } =>
|
||||
Some(Distinct(a))
|
||||
case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } =>
|
||||
Some(DistinctOn(q, a, body))
|
||||
case '{ Aggregation(${ Expr(op) }, ${ Expr(a) }) } =>
|
||||
Some(Aggregation(op, a))
|
||||
case '{ Union(${ Expr(a) }, ${ Expr(b) }) } =>
|
||||
Some(Union(a, b))
|
||||
case '{ UnionAll(${ Expr(a) }, ${ Expr(b) }) } =>
|
||||
Some(UnionAll(a, b))
|
||||
case '{
|
||||
Join(
|
||||
${ Expr(t) },
|
||||
${ Expr(a) },
|
||||
${ Expr(b) },
|
||||
${ Expr(ia) },
|
||||
${ Expr(ib) },
|
||||
${ Expr(on) }
|
||||
)
|
||||
} =>
|
||||
Some(Join(t, a, b, ia, ib, on))
|
||||
case '{
|
||||
FlatJoin(${ Expr(t) }, ${ Expr(a) }, ${ Expr(ia) }, ${ Expr(on) })
|
||||
} =>
|
||||
Some(FlatJoin(t, a, ia, on))
|
||||
case '{ Nested(${ Expr(a) }) } =>
|
||||
Some(Nested(a))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,17 +178,22 @@ private given FromExpr[BinaryOperator] with {
|
|||
case '{ NumericOperator.- } => Some(NumericOperator.-)
|
||||
case '{ NumericOperator.* } => Some(NumericOperator.*)
|
||||
case '{ NumericOperator./ } => Some(NumericOperator./)
|
||||
case '{ NumericOperator.> } => Some(NumericOperator.>)
|
||||
case '{ NumericOperator.>= } => Some(NumericOperator.>=)
|
||||
case '{ NumericOperator.< } => Some(NumericOperator.<)
|
||||
case '{ NumericOperator.<= } => Some(NumericOperator.<=)
|
||||
case '{ NumericOperator.% } => Some(NumericOperator.%)
|
||||
case '{ StringOperator.split } => Some(StringOperator.split)
|
||||
case '{ StringOperator.startsWith } => Some(StringOperator.startsWith)
|
||||
case '{ StringOperator.concat } => Some(StringOperator.concat)
|
||||
case '{ BooleanOperator.&& } => Some(BooleanOperator.&&)
|
||||
case '{ BooleanOperator.|| } => Some(BooleanOperator.||)
|
||||
case '{ SetOperator.contains } => Some(SetOperator.contains)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[UnaryOperator] with {
|
||||
|
||||
def unapply(x: Expr[UnaryOperator])(using Quotes): Option[UnaryOperator] = {
|
||||
x match {
|
||||
case '{ BooleanOperator.! } => Some(BooleanOperator.!)
|
||||
|
@ -163,6 +201,33 @@ private given FromExpr[UnaryOperator] with {
|
|||
case '{ StringOperator.toLowerCase } => Some(StringOperator.toLowerCase)
|
||||
case '{ StringOperator.toLong } => Some(StringOperator.toLong)
|
||||
case '{ StringOperator.toInt } => Some(StringOperator.toInt)
|
||||
case '{ NumericOperator.- } => Some(NumericOperator.-)
|
||||
case '{ SetOperator.nonEmpty } => Some(SetOperator.nonEmpty)
|
||||
case '{ SetOperator.isEmpty } => Some(SetOperator.isEmpty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[AggregationOperator] with {
|
||||
def unapply(
|
||||
x: Expr[AggregationOperator]
|
||||
)(using Quotes): Option[AggregationOperator] = {
|
||||
x match {
|
||||
case '{ AggregationOperator.min } => Some(AggregationOperator.min)
|
||||
case '{ AggregationOperator.max } => Some(AggregationOperator.max)
|
||||
case '{ AggregationOperator.avg } => Some(AggregationOperator.avg)
|
||||
case '{ AggregationOperator.sum } => Some(AggregationOperator.sum)
|
||||
case '{ AggregationOperator.size } => Some(AggregationOperator.size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[Operator] with {
|
||||
def unapply(x: Expr[Operator])(using Quotes): Option[Operator] = {
|
||||
x match {
|
||||
case '{ $x: BinaryOperator } => x.value
|
||||
case '{ $x: UnaryOperator } => x.value
|
||||
case '{ $x: AggregationOperator } => x.value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -205,30 +270,50 @@ private given FromExpr[Action] with {
|
|||
ass.sequence.map { ass1 =>
|
||||
Update(a, ass1)
|
||||
}
|
||||
case '{ Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) } =>
|
||||
case '{
|
||||
Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(Returning(act, id, body))
|
||||
case '{
|
||||
val x: Ast = ${ Expr(act) }
|
||||
val y: Ident = ${ Expr(id) }
|
||||
val z: Ast = ${ Expr(body) }
|
||||
Returning(x, y, z)
|
||||
} =>
|
||||
Some(Returning(act, id, body))
|
||||
case '{
|
||||
ReturningGenerated(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(ReturningGenerated(act, id, body))
|
||||
case '{
|
||||
val x: Ast = ${ Expr(act) }
|
||||
val y: Ident = ${ Expr(id) }
|
||||
val z: Ast = ${ Expr(body) }
|
||||
ReturningGenerated(x, y, z)
|
||||
} =>
|
||||
Some(ReturningGenerated(act, id, body))
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension [A](xs: Seq[Expr[A]]) {
|
||||
private def sequence(using FromExpr[A], Quotes): Option[List[A]] = {
|
||||
val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) =>
|
||||
for {
|
||||
_r <- r
|
||||
_x <- x.value
|
||||
} yield _r += _x
|
||||
if (xs.isEmpty) Some(Nil)
|
||||
else {
|
||||
val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) =>
|
||||
for {
|
||||
_r <- r
|
||||
_x <- x.value
|
||||
} yield _r += _x
|
||||
}
|
||||
acc.map(_.result())
|
||||
}
|
||||
acc.map(b => b.result())
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[Constant] with {
|
||||
def unapply(x: Expr[Constant])(using Quotes): Option[Constant] = {
|
||||
private given FromExpr[Value] with {
|
||||
def unapply(x: Expr[Value])(using Quotes): Option[Value] = {
|
||||
import quotes.reflect.{Constant => *, *}
|
||||
x match {
|
||||
case '{ Constant($ce) } =>
|
||||
|
@ -236,8 +321,92 @@ private given FromExpr[Constant] with {
|
|||
case Literal(v) =>
|
||||
Some(Constant(v.value))
|
||||
}
|
||||
case '{ NullValue } =>
|
||||
Some(NullValue)
|
||||
case '{ $x: CaseClass } => x.value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[OptionOperation] with {
|
||||
def unapply(
|
||||
x: Expr[OptionOperation]
|
||||
)(using Quotes): Option[OptionOperation] = {
|
||||
x match {
|
||||
case '{ OptionFlatten(${ Expr(ast) }) } =>
|
||||
Some(OptionFlatten(ast))
|
||||
case '{ OptionGetOrElse(${ Expr(ast) }, ${ Expr(body) }) } =>
|
||||
Some(OptionGetOrElse(ast, body))
|
||||
case '{
|
||||
OptionFlatMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(OptionFlatMap(ast, alias, body))
|
||||
case '{ OptionMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) } =>
|
||||
Some(OptionMap(ast, alias, body))
|
||||
case '{
|
||||
OptionForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(OptionForall(ast, alias, body))
|
||||
case '{
|
||||
OptionExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(OptionExists(ast, alias, body))
|
||||
case '{ OptionContains(${ Expr(ast) }, ${ Expr(body) }) } =>
|
||||
Some(OptionContains(ast, body))
|
||||
case '{ OptionIsEmpty(${ Expr(ast) }) } =>
|
||||
Some(OptionIsEmpty(ast))
|
||||
case '{ OptionNonEmpty(${ Expr(ast) }) } =>
|
||||
Some(OptionNonEmpty(ast))
|
||||
case '{ OptionIsDefined(${ Expr(ast) }) } =>
|
||||
Some(OptionIsDefined(ast))
|
||||
case '{
|
||||
OptionTableFlatMap(
|
||||
${ Expr(ast) },
|
||||
${ Expr(alias) },
|
||||
${ Expr(body) }
|
||||
)
|
||||
} =>
|
||||
Some(OptionTableFlatMap(ast, alias, body))
|
||||
case '{
|
||||
OptionTableMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(OptionTableMap(ast, alias, body))
|
||||
case '{
|
||||
OptionTableExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(OptionTableExists(ast, alias, body))
|
||||
case '{
|
||||
OptionTableForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) })
|
||||
} =>
|
||||
Some(OptionTableForall(ast, alias, body))
|
||||
case '{ OptionNone } => Some(OptionNone)
|
||||
case '{ OptionSome(${ Expr(ast) }) } => Some(OptionSome(ast))
|
||||
case '{ OptionApply(${ Expr(ast) }) } => Some(OptionApply(ast))
|
||||
case '{ OptionOrNull(${ Expr(ast) }) } => Some(OptionOrNull(ast))
|
||||
case '{ OptionGetOrNull(${ Expr(ast) }) } => Some(OptionGetOrNull(ast))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[CaseClass] with {
|
||||
def unapply(x: Expr[CaseClass])(using Quotes): Option[CaseClass] = {
|
||||
import quotes.reflect.*
|
||||
x match {
|
||||
case '{ CaseClass(${ Expr(values) }) } =>
|
||||
// Verify the values are properly structured as List[(String, Ast)]
|
||||
try {
|
||||
Some(CaseClass(values))
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
report.warning(
|
||||
s"Failed to extract CaseClass values: ${e.getMessage}",
|
||||
x.asTerm.pos
|
||||
)
|
||||
None
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -248,53 +417,26 @@ private given FromExpr[If] with {
|
|||
}
|
||||
}
|
||||
|
||||
private def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
|
||||
import quotes.reflect.*
|
||||
def unwrapTerm(t: Term): Term = t match {
|
||||
case Inlined(_, _, o) => unwrapTerm(o)
|
||||
case Block(Nil, last) => last
|
||||
case Typed(t, _) =>
|
||||
unwrapTerm(t)
|
||||
case Select(t, "$asInstanceOf$") =>
|
||||
unwrapTerm(t)
|
||||
case TypeApply(t, _) =>
|
||||
unwrapTerm(t)
|
||||
case o => o
|
||||
private given FromExpr[Block] with {
|
||||
def unapply(x: Expr[Block])(using Quotes): Option[Block] = x match {
|
||||
case '{ Block(${ Expr(statements) }) } =>
|
||||
Some(Block(statements))
|
||||
}
|
||||
}
|
||||
|
||||
private given FromExpr[Val] with {
|
||||
def unapply(x: Expr[Val])(using Quotes): Option[Val] = x match {
|
||||
case '{ Val(${ Expr(n) }, ${ Expr(b) }) } =>
|
||||
Some(Val(n, b))
|
||||
}
|
||||
val o = unwrapTerm(x)
|
||||
o
|
||||
}
|
||||
|
||||
extension (e: Expr[Any]) {
|
||||
def toTerm(using Quotes) = {
|
||||
private def toTerm(using Quotes) = {
|
||||
import quotes.reflect.*
|
||||
e.asTerm
|
||||
}
|
||||
}
|
||||
|
||||
private def fromBlock(using
|
||||
Quotes
|
||||
)(block: quotes.reflect.Block): Option[Ast] = {
|
||||
println(s"Show block ${block.show}")
|
||||
import quotes.reflect.*
|
||||
val empty: Option[List[Ast]] = Some(Nil)
|
||||
val stmts = block.statements.foldLeft(empty) { (r, stmt) =>
|
||||
stmt match {
|
||||
case ValDef(n, _, Some(body)) =>
|
||||
r.flatMap { astList =>
|
||||
body.asExprOf[Ast].value.map { v =>
|
||||
astList :+ v
|
||||
}
|
||||
}
|
||||
case o =>
|
||||
None
|
||||
}
|
||||
}
|
||||
stmts.flatMap { stmts =>
|
||||
block.expr.asExprOf[Ast].value.map { last =>
|
||||
minisql.ast.Block(stmts :+ last)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
|
||||
|
@ -306,13 +448,17 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] {
|
|||
case '{ $x: ScalarValueLift } => x.value
|
||||
case '{ $x: Property } => x.value
|
||||
case '{ $x: Ident } => x.value
|
||||
case '{ $x: Val } => x.value
|
||||
case '{ $x: Tuple } => x.value
|
||||
case '{ $x: Constant } => x.value
|
||||
case '{ $x: Value } => x.value
|
||||
case '{ $x: Operation } => x.value
|
||||
case '{ $x: Ordering } => x.value
|
||||
case '{ $x: Action } => x.value
|
||||
case '{ $x: If } => x.value
|
||||
case '{ $x: Infix } => x.value
|
||||
case '{ $x: CaseClass } => x.value
|
||||
case '{ $x: OptionOperation } => x.value
|
||||
case '{ $x: Block } => x.value
|
||||
case o =>
|
||||
import quotes.reflect.*
|
||||
report.warning(s"Cannot get value from ${o.show}", o.asTerm.pos)
|
||||
|
|
|
@ -1,8 +1,22 @@
|
|||
package minisql.ast
|
||||
|
||||
sealed trait JoinType
|
||||
import scala.quoted.*
|
||||
|
||||
case object InnerJoin extends JoinType
|
||||
case object LeftJoin extends JoinType
|
||||
case object RightJoin extends JoinType
|
||||
case object FullJoin extends JoinType
|
||||
enum JoinType {
|
||||
case InnerJoin
|
||||
case LeftJoin
|
||||
case RightJoin
|
||||
case FullJoin
|
||||
}
|
||||
|
||||
object JoinType {
|
||||
given FromExpr[JoinType] with {
|
||||
|
||||
def unapply(x: Expr[JoinType])(using Quotes): Option[JoinType] = x match {
|
||||
case '{ JoinType.InnerJoin } => Some(JoinType.InnerJoin)
|
||||
case '{ JoinType.LeftJoin } => Some(JoinType.LeftJoin)
|
||||
case '{ JoinType.RightJoin } => Some(JoinType.RightJoin)
|
||||
case '{ JoinType.FullJoin } => Some(JoinType.FullJoin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
137
src/main/scala/minisql/context/Context.scala
Normal file
137
src/main/scala/minisql/context/Context.scala
Normal file
|
@ -0,0 +1,137 @@
|
|||
package minisql.context
|
||||
|
||||
import minisql.util.*
|
||||
import minisql.idiom.{Idiom, Statement, ReifyStatement}
|
||||
import minisql.{NamingStrategy, ParamEncoder}
|
||||
import minisql.ColumnDecoder
|
||||
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
|
||||
import scala.deriving.*
|
||||
import scala.compiletime.*
|
||||
import scala.util.Try
|
||||
import scala.annotation.targetName
|
||||
|
||||
trait RowExtract[A, Row] {
|
||||
def extract(row: Row): Try[A]
|
||||
}
|
||||
|
||||
object RowExtract {
|
||||
|
||||
private[context] def single[Row, E](
|
||||
decoder: ColumnDecoder.Aux[Row, E]
|
||||
): RowExtract[E, Row] = new RowExtract[E, Row] {
|
||||
def extract(row: Row): Try[E] = {
|
||||
decoder.decode(row, 0)
|
||||
}
|
||||
}
|
||||
|
||||
private def extractorImpl[A, Row](
|
||||
decoders: IArray[Any],
|
||||
m: Mirror.ProductOf[A]
|
||||
): RowExtract[A, Row] = new RowExtract[A, Row] {
|
||||
def extract(row: Row): Try[A] = {
|
||||
val decodedFields = decoders.zipWithIndex.traverse {
|
||||
case (d, i) =>
|
||||
d.asInstanceOf[ColumnDecoder.Aux[Row, ?]].decode(row, i)
|
||||
}
|
||||
decodedFields.map { vs =>
|
||||
m.fromProduct(Tuple.fromIArray(vs))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline given [P <: Product, Row, Decoder[_]](using
|
||||
m: Mirror.ProductOf[P]
|
||||
): RowExtract[P, Row] = {
|
||||
val decoders =
|
||||
summonAll[
|
||||
Tuple.Map[m.MirroredElemTypes, [X] =>> ColumnDecoder[
|
||||
X
|
||||
] { type DBRow = Row }]
|
||||
]
|
||||
extractorImpl(decoders.toIArray.asInstanceOf, m)
|
||||
}
|
||||
}
|
||||
|
||||
trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
|
||||
|
||||
val idiom: I
|
||||
val naming: N
|
||||
|
||||
type DBStatement
|
||||
type DBRow
|
||||
type DBResultSet
|
||||
|
||||
type Encoder[X] = ParamEncoder[X] {
|
||||
type Stmt = DBStatement
|
||||
}
|
||||
|
||||
type Decoder[X] = ColumnDecoder.Aux[DBRow, X]
|
||||
|
||||
type DBIO[E] = (
|
||||
sql: String,
|
||||
params: List[(Any, Encoder[?])],
|
||||
mapper: Iterable[DBRow] => Try[E]
|
||||
)
|
||||
|
||||
extension (ast: Ast) {
|
||||
private def liftMap = {
|
||||
val lifts = CollectAst.byType[ScalarValueLift](ast)
|
||||
lifts.map(l => l.liftId -> l.value.get).toMap
|
||||
}
|
||||
}
|
||||
|
||||
extension (stmt: Statement) {
|
||||
def expand(liftMap: Map[String, (Any, ParamEncoder[?])]) =
|
||||
ReifyStatement(
|
||||
idiom.liftingPlaceholder,
|
||||
idiom.emptySetContainsToken,
|
||||
stmt,
|
||||
liftMap
|
||||
)
|
||||
}
|
||||
|
||||
@targetName("ioAction")
|
||||
inline def io[E](inline q: minisql.Action[E]): DBIO[E] = {
|
||||
val extractor = summonFrom {
|
||||
case e: RowExtract[E, DBRow] => e
|
||||
case e: ColumnDecoder.Aux[DBRow, E] =>
|
||||
RowExtract.single(e)
|
||||
}
|
||||
|
||||
val lifts = q.liftMap
|
||||
val stmt = minisql.compile[I, N](q, idiom, naming)
|
||||
val (sql, params) = stmt.expand(lifts)
|
||||
(
|
||||
sql = sql,
|
||||
params = params.map(_.value.get.asInstanceOf[(Any, Encoder[?])]),
|
||||
mapper = (rows) =>
|
||||
rows
|
||||
.traverse(extractor.extract)
|
||||
.flatMap(
|
||||
_.headOption.toRight(new Exception(s"No value return")).toTry
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@targetName("ioQuery")
|
||||
inline def io[E](
|
||||
inline q: minisql.Query[E]
|
||||
): DBIO[IArray[E]] = {
|
||||
|
||||
val (stmt, extractor) = summonFrom {
|
||||
case e: RowExtract[E, DBRow] =>
|
||||
minisql.compile[I, N](q.expanded, idiom, naming) -> e
|
||||
case e: ColumnDecoder.Aux[DBRow, E] =>
|
||||
minisql.compile[I, N](q, idiom, naming) -> RowExtract.single(e)
|
||||
}: @unchecked
|
||||
|
||||
val lifts = q.liftMap
|
||||
val (sql, params) = stmt.expand(lifts)
|
||||
(
|
||||
sql = sql,
|
||||
params = params.map(_.value.get.asInstanceOf),
|
||||
mapper = (rows) => rows.traverse(extractor.extract)
|
||||
)
|
||||
}
|
||||
|
||||
}
|
23
src/main/scala/minisql/context/MirrorContext.scala
Normal file
23
src/main/scala/minisql/context/MirrorContext.scala
Normal file
|
@ -0,0 +1,23 @@
|
|||
package minisql
|
||||
|
||||
import minisql.context.mirror.*
|
||||
import minisql.util.Messages.fail
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
class MirrorContext[Idiom <: idiom.Idiom, Naming <: NamingStrategy](
|
||||
val idiom: Idiom,
|
||||
val naming: Naming
|
||||
) extends context.Context[Idiom, Naming]
|
||||
with MirrorCodecs {
|
||||
|
||||
type DBRow = IArray[Any] *: EmptyTuple
|
||||
type DBResultSet = Iterable[DBRow]
|
||||
type DBStatement = Map[Int, Any]
|
||||
|
||||
extension (r: DBRow) {
|
||||
|
||||
def data: IArray[Any] = r._1
|
||||
def add(value: Any): DBRow = (r.data :+ value) *: EmptyTuple
|
||||
}
|
||||
|
||||
}
|
64
src/main/scala/minisql/context/ReturnFieldCapability.scala
Normal file
64
src/main/scala/minisql/context/ReturnFieldCapability.scala
Normal file
|
@ -0,0 +1,64 @@
|
|||
package minisql.context
|
||||
|
||||
sealed trait ReturningCapability
|
||||
|
||||
/**
|
||||
* Data cannot be returned Insert/Update/etc... clauses in the target database.
|
||||
*/
|
||||
sealed trait ReturningNotSupported extends ReturningCapability
|
||||
|
||||
/**
|
||||
* Returning a single field from Insert/Update/etc... clauses is supported. This
|
||||
* is the most common databases e.g. MySQL, Sqlite, and H2 (although as of
|
||||
* h2database/h2database#1972 this may change. See #1496 regarding this.
|
||||
* Typically this needs to be setup in the JDBC
|
||||
* `connection.prepareStatement(sql, Array("returnColumn"))`.
|
||||
*/
|
||||
sealed trait ReturningSingleFieldSupported extends ReturningCapability
|
||||
|
||||
/**
|
||||
* Returning multiple columns from Insert/Update/etc... clauses is supported.
|
||||
* This generally means that columns besides auto-incrementing ones can be
|
||||
* returned. This is supported by Oracle. In JDBC, the following is done:
|
||||
* `connection.prepareStatement(sql, Array("column1, column2, ..."))`.
|
||||
*/
|
||||
sealed trait ReturningMultipleFieldSupported extends ReturningCapability
|
||||
|
||||
/**
|
||||
* An actual `RETURNING` clause is supported in the SQL dialect of the specified
|
||||
* database e.g. Postgres. this typically means that columns returned from
|
||||
* Insert/Update/etc... clauses can have other database operations done on them
|
||||
* such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. In
|
||||
* JDBC, the following is done: `connection.prepareStatement(sql,
|
||||
* Statement.RETURN_GENERATED_KEYS))`.
|
||||
*/
|
||||
sealed trait ReturningClauseSupported extends ReturningCapability
|
||||
|
||||
object ReturningNotSupported extends ReturningNotSupported
|
||||
object ReturningSingleFieldSupported extends ReturningSingleFieldSupported
|
||||
object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported
|
||||
object ReturningClauseSupported extends ReturningClauseSupported
|
||||
|
||||
trait Capabilities {
|
||||
def idiomReturningCapability: ReturningCapability
|
||||
}
|
||||
|
||||
trait CanReturnClause extends Capabilities {
|
||||
override def idiomReturningCapability: ReturningClauseSupported =
|
||||
ReturningClauseSupported
|
||||
}
|
||||
|
||||
trait CanReturnField extends Capabilities {
|
||||
override def idiomReturningCapability: ReturningSingleFieldSupported =
|
||||
ReturningSingleFieldSupported
|
||||
}
|
||||
|
||||
trait CanReturnMultiField extends Capabilities {
|
||||
override def idiomReturningCapability: ReturningMultipleFieldSupported =
|
||||
ReturningMultipleFieldSupported
|
||||
}
|
||||
|
||||
trait CannotReturn extends Capabilities {
|
||||
override def idiomReturningCapability: ReturningNotSupported =
|
||||
ReturningNotSupported
|
||||
}
|
139
src/main/scala/minisql/context/mirror.scala
Normal file
139
src/main/scala/minisql/context/mirror.scala
Normal file
|
@ -0,0 +1,139 @@
|
|||
package minisql.context.mirror
|
||||
|
||||
import minisql.MirrorContext
|
||||
import java.time.LocalDate
|
||||
import java.util.{Date, UUID}
|
||||
import minisql.{ParamEncoder, ColumnDecoder}
|
||||
import minisql.util.Messages.fail
|
||||
import scala.util.{Failure, Success, Try}
|
||||
import scala.util.Try
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
trait MirrorCodecs {
|
||||
ctx: MirrorContext[?, ?] =>
|
||||
|
||||
final protected def mirrorEncoder[V]: Encoder[V] = new ParamEncoder[V] {
|
||||
type Stmt = ctx.DBStatement
|
||||
def setParam(s: Stmt, idx: Int, v: V): Stmt = {
|
||||
s + (idx -> v)
|
||||
}
|
||||
}
|
||||
|
||||
final protected def mirrorColumnDecoder[X](
|
||||
conv: Any => Option[X]
|
||||
): Decoder[X] =
|
||||
new ColumnDecoder[X] {
|
||||
type DBRow = ctx.DBRow
|
||||
def decode(row: DBRow, idx: Int): Try[X] = {
|
||||
row.data
|
||||
.lift(idx)
|
||||
.flatMap { x =>
|
||||
conv(x)
|
||||
}
|
||||
.toRight(new Exception(s"Cannot convert value at ${idx}"))
|
||||
.toTry
|
||||
}
|
||||
}
|
||||
|
||||
given optionDecoder[T](using d: Decoder[T]): Decoder[Option[T]] = {
|
||||
new ColumnDecoder[Option[T]] {
|
||||
type DBRow = ctx.DBRow
|
||||
override def decode(row: DBRow, idx: Int): Try[Option[T]] =
|
||||
row.data.lift(idx) match {
|
||||
case Some(null) => Success(None)
|
||||
case Some(value) => d.decode(row, idx).map(Some(_))
|
||||
case None => Success(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
given optionEncoder[T](using e: Encoder[T]): Encoder[Option[T]] =
|
||||
new ParamEncoder[Option[T]] {
|
||||
type Stmt = ctx.DBStatement
|
||||
override def setParam(
|
||||
s: Stmt,
|
||||
idx: Int,
|
||||
v: Option[T]
|
||||
): Stmt =
|
||||
v match {
|
||||
case Some(value) => e.setParam(s, idx, value)
|
||||
case None =>
|
||||
s + (idx -> null)
|
||||
}
|
||||
}
|
||||
|
||||
// Implement all required decoders using mirrorColumnDecoder from MirrorCodecs
|
||||
given stringDecoder: Decoder[String] = mirrorColumnDecoder[String](x =>
|
||||
x match { case s: String => Some(s); case _ => None }
|
||||
)
|
||||
given bigDecimalDecoder: Decoder[BigDecimal] =
|
||||
mirrorColumnDecoder[BigDecimal](x =>
|
||||
x match {
|
||||
case bd: BigDecimal => Some(bd); case i: Int => Some(BigDecimal(i));
|
||||
case l: Long => Some(BigDecimal(l));
|
||||
case d: Double => Some(BigDecimal(d)); case _ => None
|
||||
}
|
||||
)
|
||||
given booleanDecoder: Decoder[Boolean] = mirrorColumnDecoder[Boolean](x =>
|
||||
x match { case b: Boolean => Some(b); case _ => None }
|
||||
)
|
||||
given byteDecoder: Decoder[Byte] = mirrorColumnDecoder[Byte](x =>
|
||||
x match {
|
||||
case b: Byte => Some(b); case i: Int => Some(i.toByte); case _ => None
|
||||
}
|
||||
)
|
||||
given shortDecoder: Decoder[Short] = mirrorColumnDecoder[Short](x =>
|
||||
x match {
|
||||
case s: Short => Some(s); case i: Int => Some(i.toShort); case _ => None
|
||||
}
|
||||
)
|
||||
given intDecoder: Decoder[Int] = mirrorColumnDecoder[Int](x =>
|
||||
x match { case i: Int => Some(i); case _ => None }
|
||||
)
|
||||
given longDecoder: Decoder[Long] = mirrorColumnDecoder[Long](x =>
|
||||
x match {
|
||||
case l: Long => Some(l); case i: Int => Some(i.toLong); case _ => None
|
||||
}
|
||||
)
|
||||
given floatDecoder: Decoder[Float] = mirrorColumnDecoder[Float](x =>
|
||||
x match {
|
||||
case f: Float => Some(f); case d: Double => Some(d.toFloat);
|
||||
case _ => None
|
||||
}
|
||||
)
|
||||
given doubleDecoder: Decoder[Double] = mirrorColumnDecoder[Double](x =>
|
||||
x match {
|
||||
case d: Double => Some(d); case f: Float => Some(f.toDouble);
|
||||
case _ => None
|
||||
}
|
||||
)
|
||||
given byteArrayDecoder: Decoder[Array[Byte]] =
|
||||
mirrorColumnDecoder[Array[Byte]](x =>
|
||||
x match { case ba: Array[Byte] => Some(ba); case _ => None }
|
||||
)
|
||||
given dateDecoder: Decoder[Date] = mirrorColumnDecoder[Date](x =>
|
||||
x match { case d: Date => Some(d); case _ => None }
|
||||
)
|
||||
given localDateDecoder: Decoder[LocalDate] =
|
||||
mirrorColumnDecoder[LocalDate](x =>
|
||||
x match { case ld: LocalDate => Some(ld); case _ => None }
|
||||
)
|
||||
given uuidDecoder: Decoder[UUID] = mirrorColumnDecoder[UUID](x =>
|
||||
x match { case uuid: UUID => Some(uuid); case _ => None }
|
||||
)
|
||||
|
||||
// Implement all required encoders using mirrorEncoder from MirrorCodecs
|
||||
given stringEncoder: Encoder[String] = mirrorEncoder[String]
|
||||
given bigDecimalEncoder: Encoder[BigDecimal] = mirrorEncoder[BigDecimal]
|
||||
given booleanEncoder: Encoder[Boolean] = mirrorEncoder[Boolean]
|
||||
given byteEncoder: Encoder[Byte] = mirrorEncoder[Byte]
|
||||
given shortEncoder: Encoder[Short] = mirrorEncoder[Short]
|
||||
given intEncoder: Encoder[Int] = mirrorEncoder[Int]
|
||||
given longEncoder: Encoder[Long] = mirrorEncoder[Long]
|
||||
given floatEncoder: Encoder[Float] = mirrorEncoder[Float]
|
||||
given doubleEncoder: Encoder[Double] = mirrorEncoder[Double]
|
||||
given byteArrayEncoder: Encoder[Array[Byte]] = mirrorEncoder[Array[Byte]]
|
||||
given dateEncoder: Encoder[Date] = mirrorEncoder[Date]
|
||||
given localDateEncoder: Encoder[LocalDate] = mirrorEncoder[LocalDate]
|
||||
given uuidEncoder: Encoder[UUID] = mirrorEncoder[UUID]
|
||||
}
|
17
src/main/scala/minisql/context/sql/ConcatSupport.scala
Normal file
17
src/main/scala/minisql/context/sql/ConcatSupport.scala
Normal file
|
@ -0,0 +1,17 @@
|
|||
package minisql.context.sql.idiom
|
||||
|
||||
import minisql.util.Messages
|
||||
|
||||
trait ConcatSupport {
|
||||
this: SqlIdiom =>
|
||||
|
||||
override def concatFunction = "UNNEST"
|
||||
}
|
||||
|
||||
trait NoConcatSupport {
|
||||
this: SqlIdiom =>
|
||||
|
||||
override def concatFunction = Messages.fail(
|
||||
s"`concatMap` not supported by ${this.getClass.getSimpleName}"
|
||||
)
|
||||
}
|
24
src/main/scala/minisql/context/sql/MirrorSqlContext.scala
Normal file
24
src/main/scala/minisql/context/sql/MirrorSqlContext.scala
Normal file
|
@ -0,0 +1,24 @@
|
|||
package minisql.context.sql
|
||||
|
||||
import minisql.{NamingStrategy, MirrorContext}
|
||||
import minisql.context.Context
|
||||
import minisql.idiom.Idiom // Changed from minisql.idiom.* to avoid ambiguity with Statement
|
||||
import minisql.context.mirror.MirrorCodecs
|
||||
import minisql.context.ReturningClauseSupported
|
||||
import minisql.context.ReturningCapability
|
||||
|
||||
class MirrorSqlIdiom extends idiom.SqlIdiom {
|
||||
override def concatFunction: String = "CONCAT"
|
||||
override def idiomReturningCapability: ReturningCapability =
|
||||
ReturningClauseSupported
|
||||
|
||||
// Implementations previously provided by MirrorIdiomBase
|
||||
override def prepareForProbing(string: String): String = string
|
||||
override def liftingPlaceholder(index: Int): String = "?"
|
||||
}
|
||||
object MirrorSqlIdiom extends MirrorSqlIdiom
|
||||
|
||||
class MirrorSqlContext[N <: NamingStrategy](naming: N)
|
||||
extends MirrorContext[MirrorSqlIdiom, N](MirrorSqlIdiom, naming)
|
||||
with SqlContext[MirrorSqlIdiom, N]
|
||||
with MirrorCodecs {}
|
52
src/main/scala/minisql/context/sql/MirrorSqlDialect.scala
Normal file
52
src/main/scala/minisql/context/sql/MirrorSqlDialect.scala
Normal file
|
@ -0,0 +1,52 @@
|
|||
package minisql.context.sql
|
||||
|
||||
import minisql.context.{
|
||||
CanReturnClause,
|
||||
CanReturnField,
|
||||
CanReturnMultiField,
|
||||
CannotReturn
|
||||
}
|
||||
import minisql.context.sql.idiom.SqlIdiom
|
||||
import minisql.context.sql.idiom.QuestionMarkBindVariables
|
||||
import minisql.context.sql.idiom.ConcatSupport
|
||||
|
||||
trait MirrorSqlDialect
|
||||
extends SqlIdiom
|
||||
with QuestionMarkBindVariables
|
||||
with ConcatSupport
|
||||
with CanReturnField
|
||||
|
||||
trait MirrorSqlDialectWithReturnMulti
|
||||
extends SqlIdiom
|
||||
with QuestionMarkBindVariables
|
||||
with ConcatSupport
|
||||
with CanReturnMultiField
|
||||
|
||||
trait MirrorSqlDialectWithReturnClause
|
||||
extends SqlIdiom
|
||||
with QuestionMarkBindVariables
|
||||
with ConcatSupport
|
||||
with CanReturnClause
|
||||
|
||||
trait MirrorSqlDialectWithNoReturn
|
||||
extends SqlIdiom
|
||||
with QuestionMarkBindVariables
|
||||
with ConcatSupport
|
||||
with CannotReturn
|
||||
|
||||
object MirrorSqlDialect extends MirrorSqlDialect {
|
||||
override def prepareForProbing(string: String) = string
|
||||
}
|
||||
|
||||
object MirrorSqlDialectWithReturnMulti extends MirrorSqlDialectWithReturnMulti {
|
||||
override def prepareForProbing(string: String) = string
|
||||
}
|
||||
|
||||
object MirrorSqlDialectWithReturnClause
|
||||
extends MirrorSqlDialectWithReturnClause {
|
||||
override def prepareForProbing(string: String) = string
|
||||
}
|
||||
|
||||
object MirrorSqlDialectWithNoReturn extends MirrorSqlDialectWithNoReturn {
|
||||
override def prepareForProbing(string: String) = string
|
||||
}
|
70
src/main/scala/minisql/context/sql/OnConflictSupport.scala
Normal file
70
src/main/scala/minisql/context/sql/OnConflictSupport.scala
Normal file
|
@ -0,0 +1,70 @@
|
|||
package minisql.context.sql.idiom
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.idiom.StatementInterpolator._
|
||||
import minisql.idiom.Token
|
||||
import minisql.NamingStrategy
|
||||
import minisql.util.Messages.fail
|
||||
|
||||
trait OnConflictSupport {
|
||||
self: SqlIdiom =>
|
||||
|
||||
implicit def conflictTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[OnConflict] = {
|
||||
|
||||
val customEntityTokenizer = Tokenizer[Entity] {
|
||||
case Entity.Opinionated(name, _, renameable) =>
|
||||
stmt"INTO ${renameable.fixedOr(name.token)(strategy.table(name).token)} AS t"
|
||||
}
|
||||
|
||||
val customAstTokenizer =
|
||||
Tokenizer.withFallback[Ast](self.astTokenizer(using _, strategy)) {
|
||||
case _: OnConflict.Excluded => stmt"EXCLUDED"
|
||||
case OnConflict.Existing(a) => stmt"${a.token}"
|
||||
case a: Action =>
|
||||
self
|
||||
.actionTokenizer(customEntityTokenizer)(
|
||||
using actionAstTokenizer,
|
||||
strategy
|
||||
)
|
||||
.token(a)
|
||||
}
|
||||
|
||||
import OnConflict._
|
||||
|
||||
def doUpdateStmt(i: Token, t: Token, u: Update) = {
|
||||
val assignments = u.assignments
|
||||
.map(a =>
|
||||
stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(using customAstTokenizer)}"
|
||||
)
|
||||
.mkStmt()
|
||||
|
||||
stmt"$i ON CONFLICT $t DO UPDATE SET $assignments"
|
||||
}
|
||||
|
||||
def doNothingStmt(i: Ast, t: Token) =
|
||||
stmt"${i.token} ON CONFLICT $t DO NOTHING"
|
||||
|
||||
implicit val conflictTargetPropsTokenizer: Tokenizer[Properties] =
|
||||
Tokenizer[Properties] {
|
||||
case OnConflict.Properties(props) =>
|
||||
stmt"(${props.map(n => n.renameable.fixedOr(n.name)(strategy.column(n.name))).mkStmt(",")})"
|
||||
}
|
||||
|
||||
def tokenizer(implicit astTokenizer: Tokenizer[Ast]) =
|
||||
Tokenizer[OnConflict] {
|
||||
case OnConflict(_, NoTarget, _: Update) =>
|
||||
fail("'DO UPDATE' statement requires explicit conflict target")
|
||||
case OnConflict(i, p: Properties, u: Update) =>
|
||||
doUpdateStmt(i.token, p.token, u)
|
||||
|
||||
case OnConflict(i, NoTarget, Ignore) =>
|
||||
stmt"${astTokenizer.token(i)} ON CONFLICT DO NOTHING"
|
||||
case OnConflict(i, p: Properties, Ignore) => doNothingStmt(i, p.token)
|
||||
}
|
||||
|
||||
tokenizer(using customAstTokenizer)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package minisql.context.sql.idiom
|
||||
|
||||
trait PositionalBindVariables { self: SqlIdiom =>
|
||||
|
||||
override def liftingPlaceholder(index: Int): String = s"$$${index + 1}"
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
package minisql.context.sql.idiom
|
||||
|
||||
trait QuestionMarkBindVariables { self: SqlIdiom =>
|
||||
|
||||
override def liftingPlaceholder(index: Int): String = s"?"
|
||||
}
|
44
src/main/scala/minisql/context/sql/SqlContext.scala
Normal file
44
src/main/scala/minisql/context/sql/SqlContext.scala
Normal file
|
@ -0,0 +1,44 @@
|
|||
package minisql.context.sql
|
||||
|
||||
import java.time.LocalDate
|
||||
|
||||
import minisql.idiom.{Idiom => BaseIdiom}
|
||||
import java.util.{Date, UUID}
|
||||
|
||||
import minisql.context.Context
|
||||
import minisql.NamingStrategy
|
||||
|
||||
trait SqlContext[Idiom <: BaseIdiom, Naming <: NamingStrategy]
|
||||
extends Context[Idiom, Naming] {
|
||||
|
||||
given optionDecoder[T](using d: Decoder[T]): Decoder[Option[T]]
|
||||
given optionEncoder[T](using d: Encoder[T]): Encoder[Option[T]]
|
||||
|
||||
given stringDecoder: Decoder[String]
|
||||
given bigDecimalDecoder: Decoder[BigDecimal]
|
||||
given booleanDecoder: Decoder[Boolean]
|
||||
given byteDecoder: Decoder[Byte]
|
||||
given shortDecoder: Decoder[Short]
|
||||
given intDecoder: Decoder[Int]
|
||||
given longDecoder: Decoder[Long]
|
||||
given floatDecoder: Decoder[Float]
|
||||
given doubleDecoder: Decoder[Double]
|
||||
given byteArrayDecoder: Decoder[Array[Byte]]
|
||||
given dateDecoder: Decoder[Date]
|
||||
given localDateDecoder: Decoder[LocalDate]
|
||||
given uuidDecoder: Decoder[UUID]
|
||||
|
||||
given stringEncoder: Encoder[String]
|
||||
given bigDecimalEncoder: Encoder[BigDecimal]
|
||||
given booleanEncoder: Encoder[Boolean]
|
||||
given byteEncoder: Encoder[Byte]
|
||||
given shortEncoder: Encoder[Short]
|
||||
given intEncoder: Encoder[Int]
|
||||
given longEncoder: Encoder[Long]
|
||||
given floatEncoder: Encoder[Float]
|
||||
given doubleEncoder: Encoder[Double]
|
||||
given byteArrayEncoder: Encoder[Array[Byte]]
|
||||
given dateEncoder: Encoder[Date]
|
||||
given localDateEncoder: Encoder[LocalDate]
|
||||
given uuidEncoder: Encoder[UUID]
|
||||
}
|
706
src/main/scala/minisql/context/sql/SqlIdiom.scala
Normal file
706
src/main/scala/minisql/context/sql/SqlIdiom.scala
Normal file
|
@ -0,0 +1,706 @@
|
|||
package minisql.context.sql.idiom
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.ast.BooleanOperator._
|
||||
import minisql.ast.Lift
|
||||
import minisql.context.sql._
|
||||
import minisql.context.sql.norm._
|
||||
import minisql.idiom._
|
||||
import minisql.idiom.StatementInterpolator._
|
||||
import minisql.NamingStrategy
|
||||
import minisql.ast.Renameable.Fixed
|
||||
import minisql.ast.Visibility.Hidden
|
||||
import minisql.context.{ReturningCapability, ReturningClauseSupported}
|
||||
import minisql.util.Interleave
|
||||
import minisql.util.Messages.{fail, trace}
|
||||
import minisql.idiom.Token
|
||||
import minisql.norm.EqualityBehavior
|
||||
import minisql.norm.ConcatBehavior
|
||||
import minisql.norm.ConcatBehavior.AnsiConcat
|
||||
import minisql.norm.EqualityBehavior.AnsiEquality
|
||||
import minisql.norm.ExpandReturning
|
||||
|
||||
trait SqlIdiom extends Idiom {
|
||||
|
||||
override def prepareForProbing(string: String): String
|
||||
|
||||
protected def concatBehavior: ConcatBehavior = AnsiConcat
|
||||
protected def equalityBehavior: EqualityBehavior = AnsiEquality
|
||||
protected def actionAlias: Option[Ident] = None
|
||||
override def format(queryString: String): String = queryString
|
||||
|
||||
def querifyAst(ast: Ast) = SqlQuery(ast)
|
||||
|
||||
private def doTranslate(ast: Ast, cached: Boolean)(using
|
||||
naming: NamingStrategy
|
||||
): (Ast, Statement) = {
|
||||
val normalizedAst =
|
||||
SqlNormalize(ast, concatBehavior, equalityBehavior)
|
||||
|
||||
given Tokenizer[Ast] = defaultTokenizer
|
||||
|
||||
val token =
|
||||
normalizedAst match {
|
||||
case q: Query =>
|
||||
val sql = querifyAst(q)
|
||||
trace("sql")(sql)
|
||||
VerifySqlQuery(sql).map(fail)
|
||||
val expanded = new ExpandNestedQueries(naming)(sql, List())
|
||||
trace("expanded sql")(expanded)
|
||||
val tokenized = expanded.token
|
||||
trace("tokenized sql")(tokenized)
|
||||
tokenized
|
||||
case other =>
|
||||
other.token
|
||||
}
|
||||
|
||||
(normalizedAst, stmt"$token")
|
||||
}
|
||||
|
||||
override def translate(
|
||||
ast: Ast
|
||||
)(implicit naming: NamingStrategy): (Ast, Statement) = {
|
||||
doTranslate(ast, false)
|
||||
}
|
||||
|
||||
def defaultTokenizer(using naming: NamingStrategy): Tokenizer[Ast] =
|
||||
new Tokenizer[Ast] {
|
||||
private val stableTokenizer = astTokenizer(using this, naming)
|
||||
|
||||
extension (v: Ast) {
|
||||
def token = stableTokenizer.token(v)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def astTokenizer(using
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Ast] =
|
||||
Tokenizer[Ast] {
|
||||
case a: Query => SqlQuery(a).token
|
||||
case a: Operation => a.token
|
||||
case a: Infix => a.token
|
||||
case a: Action => a.token
|
||||
case a: Ident => a.token
|
||||
case a: ExternalIdent => a.token
|
||||
case a: Property => a.token
|
||||
case a: Value => a.token
|
||||
case a: If => a.token
|
||||
case a: Lift => a.token
|
||||
case a: Assignment => a.token
|
||||
case a: OptionOperation => a.token
|
||||
case a @ (
|
||||
_: Function | _: FunctionApply | _: Dynamic | _: OptionOperation |
|
||||
_: Block | _: Val | _: Ordering | _: IterableOperation |
|
||||
_: OnConflict.Excluded | _: OnConflict.Existing
|
||||
) =>
|
||||
fail(s"Malformed or unsupported construct: $a.")
|
||||
}
|
||||
|
||||
implicit def ifTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[If] = Tokenizer[If] {
|
||||
case ast: If =>
|
||||
def flatten(ast: Ast): (List[(Ast, Ast)], Ast) =
|
||||
ast match {
|
||||
case If(cond, a, b) =>
|
||||
val (l, e) = flatten(b)
|
||||
((cond, a) +: l, e)
|
||||
case other =>
|
||||
(List(), other)
|
||||
}
|
||||
|
||||
val (l, e) = flatten(ast)
|
||||
val conditions =
|
||||
for ((cond, body) <- l) yield {
|
||||
stmt"WHEN ${cond.token} THEN ${body.token}"
|
||||
}
|
||||
stmt"CASE ${conditions.mkStmt(" ")} ELSE ${e.token} END"
|
||||
}
|
||||
|
||||
def concatFunction: String
|
||||
|
||||
protected def tokenizeGroupBy(
|
||||
values: Ast
|
||||
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Token =
|
||||
values.token
|
||||
|
||||
protected class FlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
) {
|
||||
import q._
|
||||
|
||||
def selectTokenizer =
|
||||
select match {
|
||||
case Nil => stmt"*"
|
||||
case _ => select.token
|
||||
}
|
||||
|
||||
def distinctTokenizer = (
|
||||
distinct match {
|
||||
case DistinctKind.Distinct => stmt"DISTINCT "
|
||||
case DistinctKind.DistinctOn(props) =>
|
||||
stmt"DISTINCT ON (${props.token}) "
|
||||
case DistinctKind.None => stmt""
|
||||
}
|
||||
)
|
||||
|
||||
def withDistinct = stmt"$distinctTokenizer${selectTokenizer}"
|
||||
|
||||
def withFrom =
|
||||
from match {
|
||||
case Nil => withDistinct
|
||||
case head :: tail =>
|
||||
val t = tail.foldLeft(stmt"${head.token}") {
|
||||
case (a, b: FlatJoinContext) =>
|
||||
stmt"$a ${(b: FromContext).token}"
|
||||
case (a, b) =>
|
||||
stmt"$a, ${b.token}"
|
||||
}
|
||||
|
||||
stmt"$withDistinct FROM $t"
|
||||
}
|
||||
|
||||
def withWhere =
|
||||
where match {
|
||||
case None => withFrom
|
||||
case Some(where) => stmt"$withFrom WHERE ${where.token}"
|
||||
}
|
||||
def withGroupBy =
|
||||
groupBy match {
|
||||
case None => withWhere
|
||||
case Some(groupBy) =>
|
||||
stmt"$withWhere GROUP BY ${tokenizeGroupBy(groupBy)}"
|
||||
}
|
||||
def withOrderBy =
|
||||
orderBy match {
|
||||
case Nil => withGroupBy
|
||||
case orderBy => stmt"$withGroupBy ${tokenOrderBy(orderBy)}"
|
||||
}
|
||||
def withLimitOffset = limitOffsetToken(withOrderBy).token((limit, offset))
|
||||
|
||||
def apply = stmt"SELECT $withLimitOffset"
|
||||
}
|
||||
|
||||
implicit def sqlQueryTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] {
|
||||
case q: FlattenSqlQuery =>
|
||||
new FlattenSqlQueryTokenizerHelper(q).apply
|
||||
case SetOperationSqlQuery(a, op, b) =>
|
||||
stmt"(${a.token}) ${op.token} (${b.token})"
|
||||
case UnaryOperationSqlQuery(op, q) =>
|
||||
stmt"SELECT ${op.token} (${q.token})"
|
||||
}
|
||||
|
||||
protected def tokenizeColumn(
|
||||
strategy: NamingStrategy,
|
||||
column: String,
|
||||
renameable: Renameable
|
||||
) =
|
||||
renameable match {
|
||||
case Fixed => column
|
||||
case _ => strategy.column(column)
|
||||
}
|
||||
|
||||
protected def tokenizeTable(
|
||||
strategy: NamingStrategy,
|
||||
table: String,
|
||||
renameable: Renameable
|
||||
) =
|
||||
renameable match {
|
||||
case Fixed => table
|
||||
case _ => strategy.table(table)
|
||||
}
|
||||
|
||||
protected def tokenizeAlias(strategy: NamingStrategy, table: String) =
|
||||
strategy.default(table)
|
||||
|
||||
implicit def selectValueTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[SelectValue] = {
|
||||
|
||||
def tokenizer(implicit astTokenizer: Tokenizer[Ast]) =
|
||||
Tokenizer[SelectValue] {
|
||||
case SelectValue(ast, Some(alias), false) => {
|
||||
stmt"${ast.token} AS ${alias.token}"
|
||||
}
|
||||
case SelectValue(ast, Some(alias), true) =>
|
||||
stmt"${concatFunction.token}(${ast.token}) AS ${alias.token}"
|
||||
case selectValue =>
|
||||
val value =
|
||||
selectValue match {
|
||||
case SelectValue(Ident("?"), _, _) => "?".token
|
||||
case SelectValue(Ident(name), _, _) =>
|
||||
stmt"${strategy.default(name).token}.*"
|
||||
case SelectValue(ast, _, _) => ast.token
|
||||
}
|
||||
selectValue.concat match {
|
||||
case true => stmt"${concatFunction.token}(${value.token})"
|
||||
case false => value
|
||||
}
|
||||
}
|
||||
|
||||
val customAstTokenizer =
|
||||
Tokenizer.withFallback[Ast](
|
||||
SqlIdiom.this.astTokenizer(using _, strategy)
|
||||
) {
|
||||
case Aggregation(op, Ident(_) | Tuple(_)) => stmt"${op.token}(*)"
|
||||
case Aggregation(op, Distinct(ast)) =>
|
||||
stmt"${op.token}(DISTINCT ${ast.token})"
|
||||
case ast @ Aggregation(op, _: Query) => scopedTokenizer(ast)
|
||||
case Aggregation(op, ast) => stmt"${op.token}(${ast.token})"
|
||||
}
|
||||
|
||||
tokenizer(using customAstTokenizer)
|
||||
}
|
||||
|
||||
implicit def operationTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Operation] = Tokenizer[Operation] {
|
||||
case UnaryOperation(op, ast) => stmt"${op.token} (${ast.token})"
|
||||
case BinaryOperation(a, EqualityOperator.`==`, NullValue) =>
|
||||
stmt"${scopedTokenizer(a)} IS NULL"
|
||||
case BinaryOperation(NullValue, EqualityOperator.`==`, b) =>
|
||||
stmt"${scopedTokenizer(b)} IS NULL"
|
||||
case BinaryOperation(a, EqualityOperator.`!=`, NullValue) =>
|
||||
stmt"${scopedTokenizer(a)} IS NOT NULL"
|
||||
case BinaryOperation(NullValue, EqualityOperator.`!=`, b) =>
|
||||
stmt"${scopedTokenizer(b)} IS NOT NULL"
|
||||
case BinaryOperation(a, StringOperator.`startsWith`, b) =>
|
||||
stmt"${scopedTokenizer(a)} LIKE (${(BinaryOperation(b, StringOperator.`concat`, Constant("%")): Ast).token})"
|
||||
case BinaryOperation(a, op @ StringOperator.`split`, b) =>
|
||||
stmt"${op.token}(${scopedTokenizer(a)}, ${scopedTokenizer(b)})"
|
||||
case BinaryOperation(a, op @ SetOperator.`contains`, b) =>
|
||||
SetContainsToken(scopedTokenizer(b), op.token, a.token)
|
||||
case BinaryOperation(a, op @ `&&`, b) =>
|
||||
(a, b) match {
|
||||
case (BinaryOperation(_, `||`, _), BinaryOperation(_, `||`, _)) =>
|
||||
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
|
||||
case (BinaryOperation(_, `||`, _), _) =>
|
||||
stmt"${scopedTokenizer(a)} ${op.token} ${b.token}"
|
||||
case (_, BinaryOperation(_, `||`, _)) =>
|
||||
stmt"${a.token} ${op.token} ${scopedTokenizer(b)}"
|
||||
case _ => stmt"${a.token} ${op.token} ${b.token}"
|
||||
}
|
||||
case BinaryOperation(a, op @ `||`, b) =>
|
||||
stmt"${a.token} ${op.token} ${b.token}"
|
||||
case BinaryOperation(a, op, b) =>
|
||||
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
|
||||
case e: FunctionApply => fail(s"Can't translate the ast to sql: '$e'")
|
||||
}
|
||||
|
||||
implicit def optionOperationTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] {
|
||||
case OptionIsEmpty(ast) => stmt"${ast.token} IS NULL"
|
||||
case OptionNonEmpty(ast) => stmt"${ast.token} IS NOT NULL"
|
||||
case OptionIsDefined(ast) => stmt"${ast.token} IS NOT NULL"
|
||||
case other => fail(s"Malformed or unsupported construct: $other.")
|
||||
}
|
||||
|
||||
implicit val setOperationTokenizer: Tokenizer[SetOperation] =
|
||||
Tokenizer[SetOperation] {
|
||||
case UnionOperation => stmt"UNION"
|
||||
case UnionAllOperation => stmt"UNION ALL"
|
||||
}
|
||||
|
||||
protected def limitOffsetToken(
|
||||
query: Statement
|
||||
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) =
|
||||
Tokenizer[(Option[Ast], Option[Ast])] {
|
||||
case (None, None) => query
|
||||
case (Some(limit), None) => stmt"$query LIMIT ${limit.token}"
|
||||
case (Some(limit), Some(offset)) =>
|
||||
stmt"$query LIMIT ${limit.token} OFFSET ${offset.token}"
|
||||
case (None, Some(offset)) => stmt"$query OFFSET ${offset.token}"
|
||||
}
|
||||
|
||||
protected def tokenOrderBy(
|
||||
criterias: List[OrderByCriteria]
|
||||
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) =
|
||||
stmt"ORDER BY ${criterias.token}"
|
||||
|
||||
implicit def sourceTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[FromContext] = Tokenizer[FromContext] {
|
||||
case TableContext(name, alias) =>
|
||||
stmt"${name.token} ${tokenizeAlias(strategy, alias).token}"
|
||||
case QueryContext(query, alias) =>
|
||||
stmt"(${query.token}) AS ${tokenizeAlias(strategy, alias).token}"
|
||||
case InfixContext(infix, alias) if infix.noParen =>
|
||||
stmt"${(infix: Ast).token} AS ${strategy.default(alias).token}"
|
||||
case InfixContext(infix, alias) =>
|
||||
stmt"(${(infix: Ast).token}) AS ${strategy.default(alias).token}"
|
||||
case JoinContext(t, a, b, on) =>
|
||||
stmt"${a.token} ${t.token} ${b.token} ON ${on.token}"
|
||||
case FlatJoinContext(t, a, on) => stmt"${t.token} ${a.token} ON ${on.token}"
|
||||
}
|
||||
|
||||
implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] {
|
||||
case JoinType.InnerJoin => stmt"INNER JOIN"
|
||||
case JoinType.LeftJoin => stmt"LEFT JOIN"
|
||||
case JoinType.RightJoin => stmt"RIGHT JOIN"
|
||||
case JoinType.FullJoin => stmt"FULL JOIN"
|
||||
}
|
||||
|
||||
implicit def orderByCriteriaTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] {
|
||||
case OrderByCriteria(ast, PropertyOrdering.Asc) =>
|
||||
stmt"${scopedTokenizer(ast)} ASC"
|
||||
case OrderByCriteria(ast, PropertyOrdering.Desc) =>
|
||||
stmt"${scopedTokenizer(ast)} DESC"
|
||||
case OrderByCriteria(ast, PropertyOrdering.AscNullsFirst) =>
|
||||
stmt"${scopedTokenizer(ast)} ASC NULLS FIRST"
|
||||
case OrderByCriteria(ast, PropertyOrdering.DescNullsFirst) =>
|
||||
stmt"${scopedTokenizer(ast)} DESC NULLS FIRST"
|
||||
case OrderByCriteria(ast, PropertyOrdering.AscNullsLast) =>
|
||||
stmt"${scopedTokenizer(ast)} ASC NULLS LAST"
|
||||
case OrderByCriteria(ast, PropertyOrdering.DescNullsLast) =>
|
||||
stmt"${scopedTokenizer(ast)} DESC NULLS LAST"
|
||||
}
|
||||
|
||||
implicit val unaryOperatorTokenizer: Tokenizer[UnaryOperator] =
|
||||
Tokenizer[UnaryOperator] {
|
||||
case NumericOperator.`-` => stmt"-"
|
||||
case BooleanOperator.`!` => stmt"NOT"
|
||||
case StringOperator.`toUpperCase` => stmt"UPPER"
|
||||
case StringOperator.`toLowerCase` => stmt"LOWER"
|
||||
case StringOperator.`toLong` => stmt"" // cast is implicit
|
||||
case StringOperator.`toInt` => stmt"" // cast is implicit
|
||||
case SetOperator.`isEmpty` => stmt"NOT EXISTS"
|
||||
case SetOperator.`nonEmpty` => stmt"EXISTS"
|
||||
}
|
||||
|
||||
implicit val aggregationOperatorTokenizer: Tokenizer[AggregationOperator] =
|
||||
Tokenizer[AggregationOperator] {
|
||||
case AggregationOperator.`min` => stmt"MIN"
|
||||
case AggregationOperator.`max` => stmt"MAX"
|
||||
case AggregationOperator.`avg` => stmt"AVG"
|
||||
case AggregationOperator.`sum` => stmt"SUM"
|
||||
case AggregationOperator.`size` => stmt"COUNT"
|
||||
}
|
||||
|
||||
implicit val binaryOperatorTokenizer: Tokenizer[BinaryOperator] =
|
||||
Tokenizer[BinaryOperator] {
|
||||
case EqualityOperator.`==` => stmt"="
|
||||
case EqualityOperator.`!=` => stmt"<>"
|
||||
case BooleanOperator.`&&` => stmt"AND"
|
||||
case BooleanOperator.`||` => stmt"OR"
|
||||
case StringOperator.`concat` => stmt"||"
|
||||
case StringOperator.`startsWith` =>
|
||||
fail("bug: this code should be unreachable")
|
||||
case StringOperator.`split` => stmt"SPLIT"
|
||||
case NumericOperator.`-` => stmt"-"
|
||||
case NumericOperator.`+` => stmt"+"
|
||||
case NumericOperator.`*` => stmt"*"
|
||||
case NumericOperator.`>` => stmt">"
|
||||
case NumericOperator.`>=` => stmt">="
|
||||
case NumericOperator.`<` => stmt"<"
|
||||
case NumericOperator.`<=` => stmt"<="
|
||||
case NumericOperator.`/` => stmt"/"
|
||||
case NumericOperator.`%` => stmt"%"
|
||||
case SetOperator.`contains` => stmt"IN"
|
||||
}
|
||||
|
||||
implicit def propertyTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Property] = {
|
||||
|
||||
def unnest(ast: Ast): (Ast, List[String]) =
|
||||
ast match {
|
||||
case Property.Opinionated(a, _, _, Hidden) =>
|
||||
unnest(a) match {
|
||||
case (a, nestedName) => (a, nestedName)
|
||||
}
|
||||
// Append the property name. This includes tuple indexes.
|
||||
case Property(a, name) =>
|
||||
unnest(a) match {
|
||||
case (ast, nestedName) =>
|
||||
(ast, nestedName :+ name)
|
||||
}
|
||||
case a => (a, Nil)
|
||||
}
|
||||
|
||||
def tokenizePrefixedProperty(
|
||||
name: String,
|
||||
prefix: List[String],
|
||||
strategy: NamingStrategy,
|
||||
renameable: Renameable
|
||||
) =
|
||||
renameable.fixedOr(
|
||||
(prefix.mkString + name).token
|
||||
)(tokenizeColumn(strategy, prefix.mkString + name, renameable).token)
|
||||
|
||||
Tokenizer[Property] {
|
||||
case Property.Opinionated(
|
||||
ast,
|
||||
name,
|
||||
renameable,
|
||||
_ /* Top level property cannot be invisible */
|
||||
) =>
|
||||
// When we have things like Embedded tables, properties inside of one another needs to be un-nested.
|
||||
// E.g. in `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` the inner
|
||||
// property needs to be unwrapped and the result of this should only be `realTable.realPropertyAlias`
|
||||
// as opposed to `realTable.embeddedTableAlias.realPropertyAlias`.
|
||||
unnest(ast) match {
|
||||
// When using ExternalIdent such as .returning(eid => eid.idColumn) clauses drop the 'eid' since SQL
|
||||
// returning clauses have no alias for the original table. I.e. INSERT [...] RETURNING idColumn there's no
|
||||
// alias you can assign to the INSERT [...] clause that can be used as a prefix to 'idColumn'.
|
||||
// In this case, `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)`
|
||||
// should just be `realPropertyAlias` as opposed to `realTable.realPropertyAlias`.
|
||||
// The exception to this is when a Query inside of a RETURNING clause is used. In that case, assume
|
||||
// that there is an alias for the inserted table (i.e. `INSERT ... as theAlias values ... RETURNING`)
|
||||
// and the instances of ExternalIdent use it.
|
||||
case (ExternalIdent(_), prefix) =>
|
||||
stmt"${actionAlias
|
||||
.map(alias => stmt"${scopedTokenizer(alias)}.")
|
||||
.getOrElse(stmt"")}${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
|
||||
|
||||
// In the rare case that the Ident is invisible, do not show it. See the Ident documentation for more info.
|
||||
case (Ident.Opinionated(_, Hidden), prefix) =>
|
||||
stmt"${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
|
||||
|
||||
// The normal case where `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)`
|
||||
// becomes `realTable.realPropertyAlias`.
|
||||
case (ast, prefix) =>
|
||||
stmt"${scopedTokenizer(ast)}.${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
implicit def valueTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Value] = Tokenizer[Value] {
|
||||
case Constant(v: String) => stmt"'${v.token}'"
|
||||
case Constant(()) => stmt"1"
|
||||
case Constant(v) => stmt"${v.toString.token}"
|
||||
case NullValue => stmt"null"
|
||||
case Tuple(values) => stmt"${values.token}"
|
||||
case CaseClass(values) => stmt"${values.map(_._2).token}"
|
||||
}
|
||||
|
||||
implicit def infixTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Infix] = Tokenizer[Infix] {
|
||||
case Infix(parts, params, _, _) =>
|
||||
val pt = parts.map(_.token)
|
||||
val pr = params.map(_.token)
|
||||
Statement(Interleave(pt, pr))
|
||||
}
|
||||
|
||||
implicit def identTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Ident] =
|
||||
Tokenizer[Ident](e => strategy.default(e.name).token)
|
||||
|
||||
implicit def externalIdentTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[ExternalIdent] =
|
||||
Tokenizer[ExternalIdent](e => strategy.default(e.name).token)
|
||||
|
||||
implicit def assignmentTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Assignment] = Tokenizer[Assignment] {
|
||||
case Assignment(alias, prop, value) =>
|
||||
stmt"${prop.token} = ${scopedTokenizer(value)}"
|
||||
}
|
||||
|
||||
implicit def defaultAstTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Action] = {
|
||||
val insertEntityTokenizer = Tokenizer[Entity] {
|
||||
case Entity.Opinionated(name, _, renameable) =>
|
||||
stmt"INTO ${tokenizeTable(strategy, name, renameable).token}"
|
||||
}
|
||||
actionTokenizer(insertEntityTokenizer)(using actionAstTokenizer, strategy)
|
||||
}
|
||||
|
||||
protected def actionAstTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
) =
|
||||
Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(using _, strategy)) {
|
||||
case q: Query => astTokenizer.token(q)
|
||||
case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") =>
|
||||
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NULL"
|
||||
case Property(
|
||||
Property.Opinionated(_, name, renameable, _),
|
||||
"isDefined"
|
||||
) =>
|
||||
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL"
|
||||
case Property(Property.Opinionated(_, name, renameable, _), "nonEmpty") =>
|
||||
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL"
|
||||
case Property.Opinionated(_, name, renameable, _) =>
|
||||
renameable.fixedOr(name.token)(
|
||||
tokenizeColumn(strategy, name, renameable).token
|
||||
)
|
||||
}
|
||||
|
||||
def returnListTokenizer(implicit
|
||||
tokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[List[Ast]] = {
|
||||
val customAstTokenizer =
|
||||
Tokenizer.withFallback[Ast](
|
||||
SqlIdiom.this.astTokenizer(using _, strategy)
|
||||
) {
|
||||
case sq: Query =>
|
||||
stmt"(${tokenizer.token(sq)})"
|
||||
}
|
||||
|
||||
Tokenizer[List[Ast]] {
|
||||
case list =>
|
||||
list.mkStmt(", ")(using customAstTokenizer)
|
||||
}
|
||||
}
|
||||
|
||||
protected def actionTokenizer(
|
||||
insertEntityTokenizer: Tokenizer[Entity]
|
||||
)(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Action] =
|
||||
Tokenizer[Action] {
|
||||
|
||||
case Insert(entity: Entity, assignments) =>
|
||||
val table = insertEntityTokenizer.token(entity)
|
||||
val columns = assignments.map(_.property.token)
|
||||
val values = assignments.map(_.value)
|
||||
stmt"INSERT $table${actionAlias.map(alias => stmt" AS ${alias.token}").getOrElse(stmt"")} (${columns
|
||||
.mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})"
|
||||
|
||||
case Update(table: Entity, assignments) =>
|
||||
stmt"UPDATE ${table.token}${actionAlias
|
||||
.map(alias => stmt" AS ${alias.token}")
|
||||
.getOrElse(stmt"")} SET ${assignments.token}"
|
||||
|
||||
case Update(Filter(table: Entity, x, where), assignments) =>
|
||||
stmt"UPDATE ${table.token}${actionAlias
|
||||
.map(alias => stmt" AS ${alias.token}")
|
||||
.getOrElse(stmt"")} SET ${assignments.token} WHERE ${where.token}"
|
||||
|
||||
case Delete(Filter(table: Entity, x, where)) =>
|
||||
stmt"DELETE FROM ${table.token} WHERE ${where.token}"
|
||||
|
||||
case Delete(table: Entity) =>
|
||||
stmt"DELETE FROM ${table.token}"
|
||||
|
||||
case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) =>
|
||||
idiomReturningCapability match {
|
||||
// If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as
|
||||
// a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5)
|
||||
case ReturningClauseSupported
|
||||
if (CollectAst.byType[Entity](prop).nonEmpty) =>
|
||||
SqlIdiom.withActionAlias(this, r)
|
||||
case ReturningClauseSupported =>
|
||||
stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer
|
||||
.token(ExpandReturning(r)(this, strategy).map(_._1))}"
|
||||
case other =>
|
||||
stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}"
|
||||
}
|
||||
|
||||
case r @ ReturningAction(action, alias, prop) =>
|
||||
idiomReturningCapability match {
|
||||
// If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as
|
||||
// a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5)
|
||||
case ReturningClauseSupported
|
||||
if (CollectAst.byType[Entity](prop).nonEmpty) =>
|
||||
SqlIdiom.withActionAlias(this, r)
|
||||
case ReturningClauseSupported =>
|
||||
stmt"${action.token} RETURNING ${returnListTokenizer.token(
|
||||
ExpandReturning(r)(this, strategy).map(_._1)
|
||||
)}"
|
||||
case other =>
|
||||
stmt"${action.token}"
|
||||
}
|
||||
|
||||
case other =>
|
||||
fail(s"Action ast can't be translated to sql: '$other'")
|
||||
}
|
||||
|
||||
implicit def entityTokenizer(implicit
|
||||
astTokenizer: Tokenizer[Ast],
|
||||
strategy: NamingStrategy
|
||||
): Tokenizer[Entity] = Tokenizer[Entity] {
|
||||
case Entity.Opinionated(name, _, renameable) =>
|
||||
tokenizeTable(strategy, name, renameable).token
|
||||
}
|
||||
|
||||
protected def scopedTokenizer(ast: Ast)(implicit tokenizer: Tokenizer[Ast]) =
|
||||
ast match {
|
||||
case _: Query => stmt"(${ast.token})"
|
||||
case _: BinaryOperation => stmt"(${ast.token})"
|
||||
case _: Tuple => stmt"(${ast.token})"
|
||||
case _ => ast.token
|
||||
}
|
||||
}
|
||||
|
||||
object SqlIdiom {
|
||||
private[minisql] def copyIdiom(
|
||||
parent: SqlIdiom,
|
||||
newActionAlias: Option[Ident]
|
||||
): SqlIdiom =
|
||||
new SqlIdiom {
|
||||
override protected def actionAlias: Option[Ident] = newActionAlias
|
||||
override def prepareForProbing(string: String): String =
|
||||
parent.prepareForProbing(string)
|
||||
override def concatFunction: String = parent.concatFunction
|
||||
override def liftingPlaceholder(index: Int): String =
|
||||
parent.liftingPlaceholder(index)
|
||||
override def idiomReturningCapability: ReturningCapability =
|
||||
parent.idiomReturningCapability
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new instance of the specified idiom with `newActionAlias`
|
||||
* variable specified so that actions (i.e. insert, and update) will be
|
||||
* rendered with the specified alias. This is needed for RETURNING clauses
|
||||
* that have queries inside. See #1509 for details.
|
||||
*/
|
||||
private[minisql] def withActionAlias(
|
||||
parentIdiom: SqlIdiom,
|
||||
query: ReturningAction
|
||||
)(implicit strategy: NamingStrategy) = {
|
||||
val idiom = copyIdiom(parentIdiom, Some(query.alias))
|
||||
import idiom._
|
||||
|
||||
implicit val stableTokenizer: Tokenizer[Ast] = idiom.astTokenizer(using
|
||||
new Tokenizer[Ast] { self =>
|
||||
extension (v: Ast) {
|
||||
def token = astTokenizer(using self, strategy).token(v)
|
||||
}
|
||||
},
|
||||
strategy
|
||||
)
|
||||
|
||||
query match {
|
||||
case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) =>
|
||||
stmt"INSERT INTO ${table.token} AS ${alias.name.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer
|
||||
.token(ExpandReturning(r)(idiom, strategy).map(_._1))}"
|
||||
case r @ ReturningAction(action, alias, prop) =>
|
||||
stmt"${action.token} RETURNING ${returnListTokenizer.token(
|
||||
ExpandReturning(r)(idiom, strategy).map(_._1)
|
||||
)}"
|
||||
case r =>
|
||||
fail(s"Unsupported Returning construct: $r")
|
||||
}
|
||||
}
|
||||
}
|
326
src/main/scala/minisql/context/sql/SqlQuery.scala
Normal file
326
src/main/scala/minisql/context/sql/SqlQuery.scala
Normal file
|
@ -0,0 +1,326 @@
|
|||
package minisql.context.sql
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.context.sql.norm.FlattenGroupByAggregation
|
||||
import minisql.norm.BetaReduction
|
||||
import minisql.util.Messages.fail
|
||||
import minisql.{Literal, PseudoAst, NamingStrategy}
|
||||
|
||||
case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering)
|
||||
|
||||
sealed trait FromContext
|
||||
case class TableContext(entity: Entity, alias: String) extends FromContext
|
||||
case class QueryContext(query: SqlQuery, alias: String) extends FromContext
|
||||
case class InfixContext(infix: Infix, alias: String) extends FromContext
|
||||
case class JoinContext(t: JoinType, a: FromContext, b: FromContext, on: Ast)
|
||||
extends FromContext
|
||||
case class FlatJoinContext(t: JoinType, a: FromContext, on: Ast)
|
||||
extends FromContext
|
||||
|
||||
sealed trait SqlQuery {
|
||||
override def toString = {
|
||||
import MirrorSqlDialect.*
|
||||
import minisql.idiom.StatementInterpolator.*
|
||||
given Tokenizer[SqlQuery] = sqlQueryTokenizer(using
|
||||
defaultTokenizer(using Literal),
|
||||
Literal
|
||||
)
|
||||
summon[Tokenizer[SqlQuery]].token(this).toString()
|
||||
}
|
||||
}
|
||||
|
||||
sealed trait SetOperation
|
||||
case object UnionOperation extends SetOperation
|
||||
case object UnionAllOperation extends SetOperation
|
||||
|
||||
sealed trait DistinctKind { def isDistinct: Boolean }
|
||||
case object DistinctKind {
|
||||
case object Distinct extends DistinctKind { val isDistinct: Boolean = true }
|
||||
case class DistinctOn(props: List[Ast]) extends DistinctKind {
|
||||
val isDistinct: Boolean = true
|
||||
}
|
||||
case object None extends DistinctKind { val isDistinct: Boolean = false }
|
||||
}
|
||||
|
||||
case class SetOperationSqlQuery(
|
||||
a: SqlQuery,
|
||||
op: SetOperation,
|
||||
b: SqlQuery
|
||||
) extends SqlQuery
|
||||
|
||||
case class UnaryOperationSqlQuery(
|
||||
op: UnaryOperator,
|
||||
q: SqlQuery
|
||||
) extends SqlQuery
|
||||
|
||||
case class SelectValue(
|
||||
ast: Ast,
|
||||
alias: Option[String] = None,
|
||||
concat: Boolean = false
|
||||
) extends PseudoAst {
|
||||
override def toString: String =
|
||||
s"${ast.toString}${alias.map("->" + _).getOrElse("")}"
|
||||
}
|
||||
|
||||
case class FlattenSqlQuery(
|
||||
from: List[FromContext] = List(),
|
||||
where: Option[Ast] = None,
|
||||
groupBy: Option[Ast] = None,
|
||||
orderBy: List[OrderByCriteria] = Nil,
|
||||
limit: Option[Ast] = None,
|
||||
offset: Option[Ast] = None,
|
||||
select: List[SelectValue],
|
||||
distinct: DistinctKind = DistinctKind.None
|
||||
) extends SqlQuery
|
||||
|
||||
object TakeDropFlatten {
|
||||
def unapply(q: Query): Option[(Query, Option[Ast], Option[Ast])] = q match {
|
||||
case Take(q: FlatMap, n) => Some((q, Some(n), None))
|
||||
case Drop(q: FlatMap, n) => Some((q, None, Some(n)))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
object SqlQuery {
|
||||
|
||||
def apply(query: Ast): SqlQuery =
|
||||
query match {
|
||||
case Union(a, b) =>
|
||||
SetOperationSqlQuery(apply(a), UnionOperation, apply(b))
|
||||
case UnionAll(a, b) =>
|
||||
SetOperationSqlQuery(apply(a), UnionAllOperation, apply(b))
|
||||
case UnaryOperation(op, q: Query) => UnaryOperationSqlQuery(op, apply(q))
|
||||
case _: Operation | _: Value =>
|
||||
FlattenSqlQuery(select = List(SelectValue(query)))
|
||||
case Map(q, a, b) if a == b => apply(q)
|
||||
case TakeDropFlatten(q, limit, offset) =>
|
||||
flatten(q, "x").copy(limit = limit, offset = offset)
|
||||
case q: Query => flatten(q, "x")
|
||||
case infix: Infix => flatten(infix, "x")
|
||||
case other =>
|
||||
fail(
|
||||
s"Query not properly normalized. Please open a bug report. Ast: '$other'"
|
||||
)
|
||||
}
|
||||
|
||||
private def flatten(query: Ast, alias: String): FlattenSqlQuery = {
|
||||
val (sources, finalFlatMapBody) = flattenContexts(query)
|
||||
flatten(sources, finalFlatMapBody, alias)
|
||||
}
|
||||
|
||||
private def flattenContexts(query: Ast): (List[FromContext], Ast) =
|
||||
query match {
|
||||
case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Query) =>
|
||||
val source = this.source(q, alias)
|
||||
val (nestedContexts, finalFlatMapBody) = flattenContexts(p)
|
||||
(source +: nestedContexts, finalFlatMapBody)
|
||||
case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Infix) =>
|
||||
fail(s"Infix can't be use as a `flatMap` body. $query")
|
||||
case other =>
|
||||
(List.empty, other)
|
||||
}
|
||||
|
||||
object NestedNest {
|
||||
def unapply(q: Ast): Option[Ast] =
|
||||
q match {
|
||||
case _: Nested => recurse(q)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
private def recurse(q: Ast): Option[Ast] =
|
||||
q match {
|
||||
case Nested(qn) => recurse(qn)
|
||||
case other => Some(other)
|
||||
}
|
||||
}
|
||||
|
||||
private def flatten(
|
||||
sources: List[FromContext],
|
||||
finalFlatMapBody: Ast,
|
||||
alias: String
|
||||
): FlattenSqlQuery = {
|
||||
|
||||
def select(alias: String) = SelectValue(Ident(alias), None) :: Nil
|
||||
|
||||
def base(q: Ast, alias: String) = {
|
||||
def nest(ctx: FromContext) =
|
||||
FlattenSqlQuery(from = sources :+ ctx, select = select(alias))
|
||||
q match {
|
||||
case Map(_: GroupBy, _, _) => nest(source(q, alias))
|
||||
case NestedNest(q) => nest(QueryContext(apply(q), alias))
|
||||
case q: ConcatMap => nest(QueryContext(apply(q), alias))
|
||||
case Join(tpe, a, b, iA, iB, on) =>
|
||||
val ctx = source(q, alias)
|
||||
def aliases(ctx: FromContext): List[String] =
|
||||
ctx match {
|
||||
case TableContext(_, alias) => alias :: Nil
|
||||
case QueryContext(_, alias) => alias :: Nil
|
||||
case InfixContext(_, alias) => alias :: Nil
|
||||
case JoinContext(_, a, b, _) => aliases(a) ::: aliases(b)
|
||||
case FlatJoinContext(_, a, _) => aliases(a)
|
||||
}
|
||||
FlattenSqlQuery(
|
||||
from = ctx :: Nil,
|
||||
select = aliases(ctx).map(a => SelectValue(Ident(a), None))
|
||||
)
|
||||
case q @ (_: Map | _: Filter | _: Entity) => flatten(sources, q, alias)
|
||||
case q if (sources == Nil) => flatten(sources, q, alias)
|
||||
case other => nest(source(q, alias))
|
||||
}
|
||||
}
|
||||
|
||||
finalFlatMapBody match {
|
||||
|
||||
case ConcatMap(q, Ident(alias), p) =>
|
||||
FlattenSqlQuery(
|
||||
from = source(q, alias) :: Nil,
|
||||
select = selectValues(p).map(_.copy(concat = true))
|
||||
)
|
||||
|
||||
case Map(GroupBy(q, x @ Ident(alias), g), a, p) =>
|
||||
val b = base(q, alias)
|
||||
val select = BetaReduction(p, a -> Tuple(List(g, x)))
|
||||
val flattenSelect = FlattenGroupByAggregation(x)(select)
|
||||
b.copy(groupBy = Some(g), select = this.selectValues(flattenSelect))
|
||||
|
||||
case GroupBy(q, Ident(alias), p) =>
|
||||
fail("A `groupBy` clause must be followed by `map`.")
|
||||
|
||||
case Map(q, Ident(alias), p) =>
|
||||
val b = base(q, alias)
|
||||
val agg = b.select.collect {
|
||||
case s @ SelectValue(_: Aggregation, _, _) => s
|
||||
}
|
||||
if (!b.distinct.isDistinct && agg.isEmpty)
|
||||
b.copy(select = selectValues(p))
|
||||
else
|
||||
FlattenSqlQuery(
|
||||
from = QueryContext(apply(q), alias) :: Nil,
|
||||
select = selectValues(p)
|
||||
)
|
||||
|
||||
case Filter(q, Ident(alias), p) =>
|
||||
val b = base(q, alias)
|
||||
if (b.where.isEmpty)
|
||||
b.copy(where = Some(p))
|
||||
else
|
||||
FlattenSqlQuery(
|
||||
from = QueryContext(apply(q), alias) :: Nil,
|
||||
where = Some(p),
|
||||
select = select(alias)
|
||||
)
|
||||
|
||||
case SortBy(q, Ident(alias), p, o) =>
|
||||
val b = base(q, alias)
|
||||
val criterias = orderByCriterias(p, o)
|
||||
if (b.orderBy.isEmpty)
|
||||
b.copy(orderBy = criterias)
|
||||
else
|
||||
FlattenSqlQuery(
|
||||
from = QueryContext(apply(q), alias) :: Nil,
|
||||
orderBy = criterias,
|
||||
select = select(alias)
|
||||
)
|
||||
|
||||
case Aggregation(op, q: Query) =>
|
||||
val b = flatten(q, alias)
|
||||
b.select match {
|
||||
case head :: Nil if !b.distinct.isDistinct =>
|
||||
b.copy(select = List(head.copy(ast = Aggregation(op, head.ast))))
|
||||
case other =>
|
||||
FlattenSqlQuery(
|
||||
from = QueryContext(apply(q), alias) :: Nil,
|
||||
select = List(SelectValue(Aggregation(op, Ident("*"))))
|
||||
)
|
||||
}
|
||||
|
||||
case Take(q, n) =>
|
||||
val b = base(q, alias)
|
||||
if (b.limit.isEmpty)
|
||||
b.copy(limit = Some(n))
|
||||
else
|
||||
FlattenSqlQuery(
|
||||
from = QueryContext(apply(q), alias) :: Nil,
|
||||
limit = Some(n),
|
||||
select = select(alias)
|
||||
)
|
||||
|
||||
case Drop(q, n) =>
|
||||
val b = base(q, alias)
|
||||
if (b.offset.isEmpty && b.limit.isEmpty)
|
||||
b.copy(offset = Some(n))
|
||||
else
|
||||
FlattenSqlQuery(
|
||||
from = QueryContext(apply(q), alias) :: Nil,
|
||||
offset = Some(n),
|
||||
select = select(alias)
|
||||
)
|
||||
|
||||
case Distinct(q: Query) =>
|
||||
val b = base(q, alias)
|
||||
b.copy(distinct = DistinctKind.Distinct)
|
||||
|
||||
case DistinctOn(q, Ident(alias), fields) =>
|
||||
val distinctList =
|
||||
fields match {
|
||||
case Tuple(values) => values
|
||||
case other => List(other)
|
||||
}
|
||||
|
||||
q match {
|
||||
// Ideally we don't need to make an extra sub-query for every single case of
|
||||
// distinct-on but it only works when the parent AST is an entity. That's because DistinctOn
|
||||
// selects from an alias of an outer clause. For example, query[Person].map(p => Name(p.firstName, p.lastName)).distinctOn(_.name)
|
||||
// (Let's say Person(firstName, lastName, age), Name(first, last)) will turn into
|
||||
// SELECT DISTINCT ON (p.name), p.firstName AS first, p.lastName AS last, p.age FROM Person
|
||||
// This doesn't work beause `name` in `p.name` doesn't exist yet. Therefore we have to nest this in a subquery:
|
||||
// SELECT DISTINCT ON (p.name) FROM (SELECT p.firstName AS first, p.lastName AS last, p.age FROM Person p) AS p
|
||||
// The only exception to this is if we are directly selecting from an entity:
|
||||
// query[Person].distinctOn(_.firstName) which should be fine: SELECT (x.firstName), x.firstName, x.lastName, a.age FROM Person x
|
||||
// since all the fields inside the (...) of the DISTINCT ON must be contained in the entity.
|
||||
case _: Entity =>
|
||||
val b = base(q, alias)
|
||||
b.copy(distinct = DistinctKind.DistinctOn(distinctList))
|
||||
case _ =>
|
||||
FlattenSqlQuery(
|
||||
from = QueryContext(apply(q), alias) :: Nil,
|
||||
select = select(alias),
|
||||
distinct = DistinctKind.DistinctOn(distinctList)
|
||||
)
|
||||
}
|
||||
|
||||
case other =>
|
||||
FlattenSqlQuery(
|
||||
from = sources :+ source(other, alias),
|
||||
select = select(alias)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private def selectValues(ast: Ast) =
|
||||
ast match {
|
||||
case Tuple(values) => values.map(SelectValue(_))
|
||||
case other => SelectValue(ast) :: Nil
|
||||
}
|
||||
|
||||
private def source(ast: Ast, alias: String): FromContext =
|
||||
ast match {
|
||||
case entity: Entity => TableContext(entity, alias)
|
||||
case infix: Infix => InfixContext(infix, alias)
|
||||
case Join(t, a, b, ia, ib, on) =>
|
||||
JoinContext(t, source(a, ia.name), source(b, ib.name), on)
|
||||
case FlatJoin(t, a, ia, on) => FlatJoinContext(t, source(a, ia.name), on)
|
||||
case Nested(q) => QueryContext(apply(q), alias)
|
||||
case other => QueryContext(apply(other), alias)
|
||||
}
|
||||
|
||||
private def orderByCriterias(ast: Ast, ordering: Ast): List[OrderByCriteria] =
|
||||
(ast, ordering) match {
|
||||
case (Tuple(properties), ord: PropertyOrdering) =>
|
||||
properties.flatMap(orderByCriterias(_, ord))
|
||||
case (Tuple(properties), TupleOrdering(ord)) =>
|
||||
properties.zip(ord).flatMap { case (a, o) => orderByCriterias(a, o) }
|
||||
case (a, o: PropertyOrdering) => List(OrderByCriteria(a, o))
|
||||
case other => fail(s"Invalid order by criteria $ast")
|
||||
}
|
||||
}
|
122
src/main/scala/minisql/context/sql/VerifySqlQuery.scala
Normal file
122
src/main/scala/minisql/context/sql/VerifySqlQuery.scala
Normal file
|
@ -0,0 +1,122 @@
|
|||
package minisql.context.sql.idiom
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.context.sql._
|
||||
import minisql.norm.FreeVariables
|
||||
|
||||
case class Error(free: List[Ident], ast: Ast)
|
||||
case class InvalidSqlQuery(errors: List[Error]) {
|
||||
override def toString =
|
||||
s"The monad composition can't be expressed using applicative joins. " +
|
||||
errors
|
||||
.map(error =>
|
||||
s"Faulty expression: '${error.ast}'. Free variables: '${error.free}'."
|
||||
)
|
||||
.mkString(", ")
|
||||
}
|
||||
|
||||
object VerifySqlQuery {
|
||||
|
||||
def apply(query: SqlQuery): Option[String] =
|
||||
verify(query).map(_.toString)
|
||||
|
||||
private def verify(query: SqlQuery): Option[InvalidSqlQuery] =
|
||||
query match {
|
||||
case q: FlattenSqlQuery => verify(q)
|
||||
case SetOperationSqlQuery(a, op, b) => verify(a).orElse(verify(b))
|
||||
case UnaryOperationSqlQuery(op, q) => verify(q)
|
||||
}
|
||||
|
||||
private def verifyFlatJoins(q: FlattenSqlQuery) = {
|
||||
|
||||
def loop(l: List[FromContext], available: Set[String]): Set[String] =
|
||||
l.foldLeft(available) {
|
||||
case (av, TableContext(_, alias)) => Set(alias)
|
||||
case (av, InfixContext(_, alias)) => Set(alias)
|
||||
case (av, QueryContext(_, alias)) => Set(alias)
|
||||
case (av, JoinContext(_, a, b, on)) =>
|
||||
av ++ loop(a :: Nil, av) ++ loop(b :: Nil, av)
|
||||
case (av, FlatJoinContext(_, a, on)) =>
|
||||
val nav = av ++ loop(a :: Nil, av)
|
||||
val free = FreeVariables(on).map(_.name)
|
||||
val invalid = free -- nav
|
||||
require(
|
||||
invalid.isEmpty,
|
||||
s"Found an `ON` table reference of a table that is not available: $invalid. " +
|
||||
"The `ON` condition can only use tables defined through explicit joins."
|
||||
)
|
||||
nav
|
||||
}
|
||||
loop(q.from, Set())
|
||||
}
|
||||
|
||||
private def verify(query: FlattenSqlQuery): Option[InvalidSqlQuery] = {
|
||||
|
||||
verifyFlatJoins(query)
|
||||
|
||||
val aliases =
|
||||
query.from.flatMap(this.aliases).map(Ident(_)) :+ Ident("*") :+ Ident("?")
|
||||
|
||||
def verifyAst(ast: Ast) = {
|
||||
val freeVariables =
|
||||
(FreeVariables(ast) -- aliases).toList
|
||||
val freeIdents =
|
||||
(CollectAst(ast) {
|
||||
case ast: Property => None
|
||||
case Aggregation(_, _: Ident) => None
|
||||
case ast: Ident => Some(ast)
|
||||
}).flatten
|
||||
(freeVariables ++ freeIdents) match {
|
||||
case Nil => None
|
||||
case free => Some(Error(free, ast))
|
||||
}
|
||||
}
|
||||
|
||||
// Recursively expand children until values are fully flattened. Identities in all these should
|
||||
// be skipped during verification.
|
||||
def expandSelect(sv: SelectValue): List[SelectValue] =
|
||||
sv.ast match {
|
||||
case Tuple(values) =>
|
||||
values.map(v => SelectValue(v)).flatMap(expandSelect(_))
|
||||
case CaseClass(values) =>
|
||||
values.map(v => SelectValue(v._2)).flatMap(expandSelect(_))
|
||||
case _ => List(sv)
|
||||
}
|
||||
|
||||
val freeVariableErrors: List[Error] =
|
||||
query.where.flatMap(verifyAst).toList ++
|
||||
query.orderBy.map(_.ast).flatMap(verifyAst) ++
|
||||
query.limit.flatMap(verifyAst) ++
|
||||
query.select
|
||||
.flatMap(
|
||||
expandSelect(_)
|
||||
) // Expand tuple select clauses so their top-level identities are skipped
|
||||
.map(_.ast)
|
||||
.filterNot(_.isInstanceOf[Ident])
|
||||
.flatMap(verifyAst) ++
|
||||
query.from.flatMap {
|
||||
case j: JoinContext => verifyAst(j.on)
|
||||
case j: FlatJoinContext => verifyAst(j.on)
|
||||
case _ => Nil
|
||||
}
|
||||
|
||||
val nestedErrors =
|
||||
query.from.collect {
|
||||
case QueryContext(query, alias) => verify(query).map(_.errors)
|
||||
}.flatten.flatten
|
||||
|
||||
(freeVariableErrors ++ nestedErrors) match {
|
||||
case Nil => None
|
||||
case errors => Some(InvalidSqlQuery(errors))
|
||||
}
|
||||
}
|
||||
|
||||
private def aliases(s: FromContext): List[String] =
|
||||
s match {
|
||||
case s: TableContext => List(s.alias)
|
||||
case s: QueryContext => List(s.alias)
|
||||
case s: InfixContext => List(s.alias)
|
||||
case s: JoinContext => aliases(s.a) ++ aliases(s.b)
|
||||
case s: FlatJoinContext => aliases(s.a)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package minisql.context.sql.norm
|
||||
|
||||
import minisql.ast.Constant
|
||||
import minisql.context.sql.{FlattenSqlQuery, SqlQuery, _}
|
||||
|
||||
/**
|
||||
* In SQL Server, `Order By` clauses are only allowed in sub-queries if the
|
||||
* sub-query has a `TOP` or `OFFSET` modifier. Otherwise an exception will be
|
||||
* thrown. This transformation adds a 'dummy' `OFFSET 0` in this scenario (if an
|
||||
* `Offset` clause does not exist already).
|
||||
*/
|
||||
object AddDropToNestedOrderBy {
|
||||
|
||||
def applyInner(q: SqlQuery): SqlQuery =
|
||||
q match {
|
||||
case q: FlattenSqlQuery =>
|
||||
q.copy(
|
||||
offset =
|
||||
if (q.orderBy.nonEmpty) q.offset.orElse(Some(Constant(0)))
|
||||
else q.offset,
|
||||
from = q.from.map(applyInner(_))
|
||||
)
|
||||
|
||||
case SetOperationSqlQuery(a, op, b) =>
|
||||
SetOperationSqlQuery(applyInner(a), op, applyInner(b))
|
||||
case UnaryOperationSqlQuery(op, a) =>
|
||||
UnaryOperationSqlQuery(op, applyInner(a))
|
||||
}
|
||||
|
||||
private def applyInner(f: FromContext): FromContext =
|
||||
f match {
|
||||
case QueryContext(a, alias) => QueryContext(applyInner(a), alias)
|
||||
case JoinContext(t, a, b, on) =>
|
||||
JoinContext(t, applyInner(a), applyInner(b), on)
|
||||
case FlatJoinContext(t, a, on) => FlatJoinContext(t, applyInner(a), on)
|
||||
case other => other
|
||||
}
|
||||
|
||||
def apply(q: SqlQuery): SqlQuery =
|
||||
q match {
|
||||
case q: FlattenSqlQuery => q.copy(from = q.from.map(applyInner(_)))
|
||||
case SetOperationSqlQuery(a, op, b) =>
|
||||
SetOperationSqlQuery(applyInner(a), op, applyInner(b))
|
||||
case UnaryOperationSqlQuery(op, a) =>
|
||||
UnaryOperationSqlQuery(op, applyInner(a))
|
||||
}
|
||||
}
|
68
src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala
Normal file
68
src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala
Normal file
|
@ -0,0 +1,68 @@
|
|||
package minisql.context.sql.norm
|
||||
|
||||
import minisql.ast.Visibility.Hidden
|
||||
import minisql.ast._
|
||||
|
||||
object ExpandDistinct {
|
||||
|
||||
@annotation.tailrec
|
||||
def hasJoin(q: Ast): Boolean = {
|
||||
q match {
|
||||
case _: Join => true
|
||||
case Map(q, _, _) => hasJoin(q)
|
||||
case Filter(q, _, _) => hasJoin(q)
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
def apply(q: Ast): Ast =
|
||||
q match {
|
||||
case Distinct(q) =>
|
||||
Distinct(apply(q))
|
||||
case q =>
|
||||
Transform(q) {
|
||||
case Aggregation(op, Distinct(q)) =>
|
||||
Aggregation(op, Distinct(apply(q)))
|
||||
case Distinct(Map(q, x, cc @ Tuple(values))) =>
|
||||
Map(
|
||||
Distinct(Map(q, x, cc)),
|
||||
x,
|
||||
Tuple(values.zipWithIndex.map {
|
||||
case (_, i) => Property(x, s"_${i + 1}")
|
||||
})
|
||||
)
|
||||
|
||||
// Situations like this:
|
||||
// case class AdHocCaseClass(id: Int, name: String)
|
||||
// val q = quote {
|
||||
// query[SomeTable].map(st => AdHocCaseClass(st.id, st.name)).distinct
|
||||
// }
|
||||
// ... need some special treatment. Otherwise their values will not be correctly expanded.
|
||||
case Distinct(Map(q, x, cc @ CaseClass(values))) =>
|
||||
Map(
|
||||
Distinct(Map(q, x, cc)),
|
||||
x,
|
||||
CaseClass(values.map {
|
||||
case (name, _) => (name, Property(x, name))
|
||||
})
|
||||
)
|
||||
|
||||
// Need some special handling to address issues with distinct returning a single embedded entity i.e:
|
||||
// query[Parent].map(p => p.emb).distinct.map(e => (e.name, e.id))
|
||||
// cannot treat such a case normally or "confused" queries will result e.g:
|
||||
// SELECT p.embname, p.embid FROM (SELECT DISTINCT emb.name /* Where the heck is 'emb' coming from? */ AS embname, emb.id AS embid FROM Parent p) AS p
|
||||
case d @ Distinct(
|
||||
Map(q, x, p @ Property.Opinionated(_, _, _, Hidden))
|
||||
) =>
|
||||
d
|
||||
|
||||
// Problems with distinct were first discovered in #1032. Basically, unless
|
||||
// the distinct is "expanded" adding an outer map, Ident's representing a Table will end up in invalid places
|
||||
// such as "ORDER BY tableIdent" etc...
|
||||
case Distinct(Map(q, x, p)) =>
|
||||
val newMap = Map(q, x, Tuple(List(p)))
|
||||
val newIdent = Ident(x.name)
|
||||
Map(Distinct(newMap), newIdent, Property(newIdent, "_1"))
|
||||
}
|
||||
}
|
||||
}
|
49
src/main/scala/minisql/context/sql/norm/ExpandJoin.scala
Normal file
49
src/main/scala/minisql/context/sql/norm/ExpandJoin.scala
Normal file
|
@ -0,0 +1,49 @@
|
|||
package minisql.context.sql.norm
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.norm.BetaReduction
|
||||
import minisql.norm.Normalize
|
||||
|
||||
object ExpandJoin {
|
||||
|
||||
def apply(q: Ast) = expand(q, None)
|
||||
|
||||
def expand(q: Ast, id: Option[Ident]) =
|
||||
Transform(q) {
|
||||
case q @ Join(_, _, _, Ident(a), Ident(b), _) =>
|
||||
val (qr, tuple) = expandedTuple(q)
|
||||
Map(qr, id.getOrElse(Ident(s"$a$b")), tuple)
|
||||
}
|
||||
|
||||
private def expandedTuple(q: Join): (Join, Tuple) =
|
||||
q match {
|
||||
|
||||
case Join(t, a: Join, b: Join, tA, tB, o) =>
|
||||
val (ar, at) = expandedTuple(a)
|
||||
val (br, bt) = expandedTuple(b)
|
||||
val or = BetaReduction(o, tA -> at, tB -> bt)
|
||||
(Join(t, ar, br, tA, tB, or), Tuple(List(at, bt)))
|
||||
|
||||
case Join(t, a: Join, b, tA, tB, o) =>
|
||||
val (ar, at) = expandedTuple(a)
|
||||
val or = BetaReduction(o, tA -> at)
|
||||
(Join(t, ar, b, tA, tB, or), Tuple(List(at, tB)))
|
||||
|
||||
case Join(t, a, b: Join, tA, tB, o) =>
|
||||
val (br, bt) = expandedTuple(b)
|
||||
val or = BetaReduction(o, tB -> bt)
|
||||
(Join(t, a, br, tA, tB, or), Tuple(List(tA, bt)))
|
||||
|
||||
case q @ Join(t, a, b, tA, tB, on) =>
|
||||
(
|
||||
Join(t, nestedExpand(a, tA), nestedExpand(b, tB), tA, tB, on),
|
||||
Tuple(List(tA, tB))
|
||||
)
|
||||
}
|
||||
|
||||
private def nestedExpand(q: Ast, id: Ident) =
|
||||
Normalize(expand(q, Some(id))) match {
|
||||
case Map(q, _, _) => q
|
||||
case q => q
|
||||
}
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package minisql.context.sql.norm
|
||||
|
||||
import minisql.ast._
|
||||
|
||||
object ExpandMappedInfix {
|
||||
def apply(q: Ast): Ast = {
|
||||
Transform(q) {
|
||||
case Map(Infix("" :: parts, (q: Query) :: params, pure, noParen), x, p) =>
|
||||
Infix("" :: parts, Map(q, x, p) :: params, pure, noParen)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
package minisql.context.sql.norm
|
||||
|
||||
import minisql.NamingStrategy
|
||||
import minisql.ast.Ast
|
||||
import minisql.ast.Ident
|
||||
import minisql.ast._
|
||||
import minisql.ast.StatefulTransformer
|
||||
import minisql.ast.Visibility.Visible
|
||||
import minisql.context.sql._
|
||||
|
||||
import scala.collection.mutable.LinkedHashSet
|
||||
import minisql.util.Interpolator
|
||||
import minisql.util.Messages.TraceType.NestedQueryExpansion
|
||||
import minisql.context.sql.norm.nested.ExpandSelect
|
||||
import minisql.norm.BetaReduction
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
class ExpandNestedQueries(strategy: NamingStrategy) {
|
||||
|
||||
val interp = new Interpolator(3)
|
||||
import interp._
|
||||
|
||||
def apply(q: SqlQuery, references: List[Property]): SqlQuery =
|
||||
apply(q, LinkedHashSet.empty ++ references)
|
||||
|
||||
// Using LinkedHashSet despite the fact that it is mutable because it has better characteristics then ListSet.
|
||||
// Also this collection is strictly internal to ExpandNestedQueries and exposed anywhere else.
|
||||
private def apply(
|
||||
q: SqlQuery,
|
||||
references: LinkedHashSet[Property]
|
||||
): SqlQuery =
|
||||
q match {
|
||||
case q: FlattenSqlQuery =>
|
||||
val expand = expandNested(
|
||||
q.copy(select = ExpandSelect(q.select, references, strategy))
|
||||
)
|
||||
trace"Expanded Nested Query $q into $expand".andLog()
|
||||
expand
|
||||
case SetOperationSqlQuery(a, op, b) =>
|
||||
SetOperationSqlQuery(apply(a, references), op, apply(b, references))
|
||||
case UnaryOperationSqlQuery(op, q) =>
|
||||
UnaryOperationSqlQuery(op, apply(q, references))
|
||||
}
|
||||
|
||||
private def expandNested(q: FlattenSqlQuery): SqlQuery =
|
||||
q match {
|
||||
case FlattenSqlQuery(
|
||||
from,
|
||||
where,
|
||||
groupBy,
|
||||
orderBy,
|
||||
limit,
|
||||
offset,
|
||||
select,
|
||||
distinct
|
||||
) =>
|
||||
val asts = Nil ++ select.map(_.ast) ++ where ++ groupBy ++ orderBy.map(
|
||||
_.ast
|
||||
) ++ limit ++ offset
|
||||
val expansions = q.from.map(expandContext(_, asts))
|
||||
val from = expansions.map(_._1)
|
||||
val references = expansions.flatMap(_._2)
|
||||
|
||||
val replacedRefs = references.map(ref => (ref, unhideAst(ref)))
|
||||
|
||||
// Need to unhide properties that were used during the query
|
||||
def replaceProps(ast: Ast) =
|
||||
BetaReduction(ast, replacedRefs*)
|
||||
def replacePropsOption(ast: Option[Ast]) =
|
||||
ast.map(replaceProps(_))
|
||||
|
||||
val distinctKind =
|
||||
q.distinct match {
|
||||
case DistinctKind.DistinctOn(props) =>
|
||||
DistinctKind.DistinctOn(props.map(p => replaceProps(p)))
|
||||
case other => other
|
||||
}
|
||||
|
||||
q.copy(
|
||||
select = select.map(sv => sv.copy(ast = replaceProps(sv.ast))),
|
||||
from = from,
|
||||
where = replacePropsOption(where),
|
||||
groupBy = replacePropsOption(groupBy),
|
||||
orderBy = orderBy.map(ob => ob.copy(ast = replaceProps(ob.ast))),
|
||||
limit = replacePropsOption(limit),
|
||||
offset = replacePropsOption(offset),
|
||||
distinct = distinctKind
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
def unhideAst(ast: Ast): Ast =
|
||||
Transform(ast) {
|
||||
case Property.Opinionated(a, n, r, v) =>
|
||||
Property.Opinionated(unhideAst(a), n, r, Visible)
|
||||
}
|
||||
|
||||
private def unhideProperties(sv: SelectValue) =
|
||||
sv.copy(ast = unhideAst(sv.ast))
|
||||
|
||||
private def expandContext(
|
||||
s: FromContext,
|
||||
asts: List[Ast]
|
||||
): (FromContext, LinkedHashSet[Property]) =
|
||||
s match {
|
||||
case QueryContext(q, alias) =>
|
||||
val refs = references(alias, asts)
|
||||
(QueryContext(apply(q, refs), alias), refs)
|
||||
case JoinContext(t, a, b, on) =>
|
||||
val (left, leftRefs) = expandContext(a, asts :+ on)
|
||||
val (right, rightRefs) = expandContext(b, asts :+ on)
|
||||
(JoinContext(t, left, right, on), leftRefs ++ rightRefs)
|
||||
case FlatJoinContext(t, a, on) =>
|
||||
val (next, refs) = expandContext(a, asts :+ on)
|
||||
(FlatJoinContext(t, next, on), refs)
|
||||
case _: TableContext | _: InfixContext =>
|
||||
(s, new mutable.LinkedHashSet[Property]())
|
||||
}
|
||||
|
||||
private def references(alias: String, asts: List[Ast]) =
|
||||
LinkedHashSet.empty ++ (References(State(Ident(alias), Nil))(asts)(
|
||||
_.apply
|
||||
)._2.state.references)
|
||||
}
|
||||
|
||||
case class State(ident: Ident, references: List[Property])
|
||||
|
||||
case class References(val state: State) extends StatefulTransformer[State] {
|
||||
|
||||
import state._
|
||||
|
||||
override def apply(a: Ast) =
|
||||
a match {
|
||||
case `reference`(p) => (p, References(State(ident, references :+ p)))
|
||||
case other => super.apply(a)
|
||||
}
|
||||
|
||||
object reference {
|
||||
def unapply(p: Property): Option[Property] =
|
||||
p match {
|
||||
case Property(`ident`, name) => Some(p)
|
||||
case Property(reference(_), name) => Some(p)
|
||||
case other => None
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package minisql.context.sql.norm
|
||||
|
||||
import minisql.ast.Aggregation
|
||||
import minisql.ast.Ast
|
||||
import minisql.ast.Drop
|
||||
import minisql.ast.Filter
|
||||
import minisql.ast.FlatMap
|
||||
import minisql.ast.Ident
|
||||
import minisql.ast.Join
|
||||
import minisql.ast.Map
|
||||
import minisql.ast.Query
|
||||
import minisql.ast.SortBy
|
||||
import minisql.ast.StatelessTransformer
|
||||
import minisql.ast.Take
|
||||
import minisql.ast.Union
|
||||
import minisql.ast.UnionAll
|
||||
import minisql.norm.BetaReduction
|
||||
import minisql.util.Messages.fail
|
||||
import minisql.ast.ConcatMap
|
||||
|
||||
case class FlattenGroupByAggregation(agg: Ident) extends StatelessTransformer {
|
||||
|
||||
override def apply(ast: Ast) =
|
||||
ast match {
|
||||
case q: Query if (isGroupByAggregation(q)) =>
|
||||
q match {
|
||||
case Aggregation(op, Map(`agg`, ident, body)) =>
|
||||
Aggregation(op, BetaReduction(body, ident -> agg))
|
||||
case Map(`agg`, ident, body) =>
|
||||
BetaReduction(body, ident -> agg)
|
||||
case q @ Aggregation(op, `agg`) =>
|
||||
q
|
||||
case other =>
|
||||
fail(s"Invalid group by aggregation: '$other'")
|
||||
}
|
||||
case other =>
|
||||
super.apply(other)
|
||||
}
|
||||
|
||||
private def isGroupByAggregation(ast: Ast): Boolean =
|
||||
ast match {
|
||||
case Aggregation(a, b) => isGroupByAggregation(b)
|
||||
case Map(a, b, c) => isGroupByAggregation(a)
|
||||
case FlatMap(a, b, c) => isGroupByAggregation(a)
|
||||
case ConcatMap(a, b, c) => isGroupByAggregation(a)
|
||||
case Filter(a, b, c) => isGroupByAggregation(a)
|
||||
case SortBy(a, b, c, d) => isGroupByAggregation(a)
|
||||
case Take(a, b) => isGroupByAggregation(a)
|
||||
case Drop(a, b) => isGroupByAggregation(a)
|
||||
case Union(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b)
|
||||
case UnionAll(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b)
|
||||
case Join(t, a, b, ta, tb, on) =>
|
||||
isGroupByAggregation(a) || isGroupByAggregation(b)
|
||||
case `agg` => true
|
||||
case other => false
|
||||
}
|
||||
|
||||
}
|
53
src/main/scala/minisql/context/sql/norm/SqlNormalize.scala
Normal file
53
src/main/scala/minisql/context/sql/norm/SqlNormalize.scala
Normal file
|
@ -0,0 +1,53 @@
|
|||
package minisql.context.sql.norm
|
||||
|
||||
import minisql.norm._
|
||||
import minisql.ast.Ast
|
||||
import minisql.norm.ConcatBehavior.AnsiConcat
|
||||
import minisql.norm.EqualityBehavior.AnsiEquality
|
||||
import minisql.norm.capture.DemarcateExternalAliases
|
||||
import minisql.util.Messages.trace
|
||||
|
||||
object SqlNormalize {
|
||||
def apply(
|
||||
ast: Ast,
|
||||
concatBehavior: ConcatBehavior = AnsiConcat,
|
||||
equalityBehavior: EqualityBehavior = AnsiEquality
|
||||
) =
|
||||
new SqlNormalize(concatBehavior, equalityBehavior)(ast)
|
||||
}
|
||||
|
||||
class SqlNormalize(
|
||||
concatBehavior: ConcatBehavior,
|
||||
equalityBehavior: EqualityBehavior
|
||||
) {
|
||||
|
||||
private val normalize =
|
||||
(identity[Ast])
|
||||
.andThen(trace("original"))
|
||||
.andThen(DemarcateExternalAliases.apply)
|
||||
.andThen(trace("DemarcateReturningAliases"))
|
||||
.andThen(new FlattenOptionOperation(concatBehavior).apply)
|
||||
.andThen(trace("FlattenOptionOperation"))
|
||||
.andThen(new SimplifyNullChecks(equalityBehavior).apply)
|
||||
.andThen(trace("SimplifyNullChecks"))
|
||||
.andThen(Normalize.apply)
|
||||
.andThen(trace("Normalize"))
|
||||
// Need to do RenameProperties before ExpandJoin which normalizes-out all the tuple indexes
|
||||
// on which RenameProperties relies
|
||||
.andThen(RenameProperties.apply)
|
||||
.andThen(trace("RenameProperties"))
|
||||
.andThen(ExpandDistinct.apply)
|
||||
.andThen(trace("ExpandDistinct"))
|
||||
.andThen(NestImpureMappedInfix.apply)
|
||||
.andThen(trace("NestMappedInfix"))
|
||||
.andThen(Normalize.apply)
|
||||
.andThen(trace("Normalize"))
|
||||
.andThen(ExpandJoin.apply)
|
||||
.andThen(trace("ExpandJoin"))
|
||||
.andThen(ExpandMappedInfix.apply)
|
||||
.andThen(trace("ExpandMappedInfix"))
|
||||
.andThen(Normalize.apply)
|
||||
.andThen(trace("Normalize"))
|
||||
|
||||
def apply(ast: Ast) = normalize(ast)
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package minisql.context.sql.norm.nested
|
||||
|
||||
import minisql.PseudoAst
|
||||
import minisql.context.sql.SelectValue
|
||||
|
||||
object Elements {
|
||||
|
||||
/**
|
||||
* In order to be able to reconstruct the original ordering of elements inside
|
||||
* of a select clause, we need to keep track of their order, not only within
|
||||
* the top-level select but also it's order within any possible
|
||||
* tuples/case-classes that in which it is embedded. For example, in the
|
||||
* query: <pre><code> query[Person].map(p => (p.id, (p.name, p.age))).nested
|
||||
* // SELECT p.id, p.name, p.age FROM (SELECT x.id, x.name, x.age FROM person
|
||||
* x) AS p </code></pre> Since the `p.name` and `p.age` elements are selected
|
||||
* inside of a sub-tuple, their "order" is `List(2,1)` and `List(2,2)`
|
||||
* respectively as opposed to `p.id` whose "order" is just `List(1)`.
|
||||
*
|
||||
* This class keeps track of the values needed in order to perform do this.
|
||||
*/
|
||||
case class OrderedSelect(order: List[Int], selectValue: SelectValue)
|
||||
extends PseudoAst {
|
||||
override def toString: String = s"[${order.mkString(",")}]${selectValue}"
|
||||
}
|
||||
object OrderedSelect {
|
||||
def apply(order: Int, selectValue: SelectValue) =
|
||||
new OrderedSelect(List(order), selectValue)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,262 @@
|
|||
package minisql.context.sql.norm.nested
|
||||
|
||||
import minisql.NamingStrategy
|
||||
import minisql.ast.Property
|
||||
import minisql.context.sql.SelectValue
|
||||
import minisql.util.Interpolator
|
||||
import minisql.util.Messages.TraceType.NestedQueryExpansion
|
||||
|
||||
import scala.collection.mutable.LinkedHashSet
|
||||
import minisql.context.sql.norm.nested.Elements._
|
||||
import minisql.ast._
|
||||
import minisql.norm.BetaReduction
|
||||
|
||||
/**
|
||||
* Takes the `SelectValue` elements inside of a sub-query (if a super/sub-query
|
||||
* constrct exists) and flattens them from a nested-hiearchical structure (i.e.
|
||||
* tuples inside case classes inside tuples etc..) into into a single series of
|
||||
* top-level select elements where needed. In cases where a user wants to select
|
||||
* an element that contains an entire tuple (i.e. a sub-tuple of the outer
|
||||
* select clause) we pull out the entire tuple that is being selected and leave
|
||||
* it to the tokenizer to flatten later.
|
||||
*
|
||||
* The part about this operation that is tricky is if there are situations where
|
||||
* there are infix clauses in a sub-query representing an element that has not
|
||||
* been selected by the query-query but in order to ensure the SQL operation has
|
||||
* the same meaning, we need to keep track for it. For example: <pre><code> val
|
||||
* q = quote { query[Person].map(p => (infix"DISTINCT ON (${p.other})".as[Int],
|
||||
* p.name, p.id)).map(t => (t._2, t._3)) } run(q) // SELECT p._2, p._3 FROM
|
||||
* (SELECT DISTINCT ON (p.other), p.name AS _2, p.id AS _3 FROM Person p) AS p
|
||||
* </code></pre> Since `DISTINCT ON` significantly changes the behavior of the
|
||||
* outer query, we need to keep track of it inside of the inner query. In order
|
||||
* to do this, we need to keep track of the location of the infix in the inner
|
||||
* query so that we can reconstruct it. This is why the `OrderedSelect` and
|
||||
* `DoubleOrderedSelect` objects are used. See the notes on these classes for
|
||||
* more detail.
|
||||
*
|
||||
* See issue #1597 for more details and another example.
|
||||
*/
|
||||
private class ExpandSelect(
|
||||
selectValues: List[SelectValue],
|
||||
references: LinkedHashSet[Property],
|
||||
strategy: NamingStrategy
|
||||
) {
|
||||
val interp = new Interpolator(3)
|
||||
import interp._
|
||||
|
||||
object TupleIndex {
|
||||
def unapply(s: String): Option[Int] =
|
||||
if (s.matches("_[0-9]*"))
|
||||
Some(s.drop(1).toInt - 1)
|
||||
else
|
||||
None
|
||||
}
|
||||
|
||||
object MultiTupleIndex {
|
||||
def unapply(s: String): Boolean =
|
||||
if (s.matches("(_[0-9]+)+"))
|
||||
true
|
||||
else
|
||||
false
|
||||
}
|
||||
|
||||
val select =
|
||||
selectValues.zipWithIndex.map {
|
||||
case (value, index) => OrderedSelect(index, value)
|
||||
}
|
||||
|
||||
def expandColumn(name: String, renameable: Renameable): String =
|
||||
renameable.fixedOr(name)(strategy.column(name))
|
||||
|
||||
def apply: List[SelectValue] =
|
||||
trace"Expanding Select values: $selectValues into references: $references" andReturn {
|
||||
|
||||
def expandReference(ref: Property): OrderedSelect =
|
||||
trace"Expanding: $ref from $select" andReturn {
|
||||
|
||||
def expressIfTupleIndex(str: String) =
|
||||
str match {
|
||||
case MultiTupleIndex() => Some(str)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
def concat(alias: Option[String], idx: Int) =
|
||||
Some(s"${alias.getOrElse("")}_${idx + 1}")
|
||||
|
||||
val orderedSelect = ref match {
|
||||
case pp @ Property(ast: Property, TupleIndex(idx)) =>
|
||||
trace"Reference is a sub-property of a tuple index: $idx. Walking inside." `andReturn`
|
||||
expandReference(ast) match {
|
||||
case OrderedSelect(o, SelectValue(Tuple(elems), alias, c)) =>
|
||||
trace"Expressing Element $idx of $elems " `andReturn`
|
||||
OrderedSelect(
|
||||
o :+ idx,
|
||||
SelectValue(elems(idx), concat(alias, idx), c)
|
||||
)
|
||||
case OrderedSelect(o, SelectValue(ast, alias, c)) =>
|
||||
trace"Appending $idx to $alias " `andReturn`
|
||||
OrderedSelect(o, SelectValue(ast, concat(alias, idx), c))
|
||||
}
|
||||
case pp @ Property.Opinionated(
|
||||
ast: Property,
|
||||
name,
|
||||
renameable,
|
||||
visible
|
||||
) =>
|
||||
trace"Reference is a sub-property. Walking inside." `andReturn`
|
||||
expandReference(ast) match {
|
||||
case OrderedSelect(o, SelectValue(ast, nested, c)) =>
|
||||
// Alias is the name of the column after the naming strategy
|
||||
// The clauses in `SqlIdiom` that use `Tokenizer[SelectValue]` select the
|
||||
// alias field when it's value is Some(T).
|
||||
// Technically the aliases of a column should not be using naming strategies
|
||||
// but this is an issue to fix at a later date.
|
||||
|
||||
// In the current implementation, aliases we add nested tuple names to queries e.g.
|
||||
// SELECT foo from
|
||||
// SELECT x, y FROM (SELECT foo, bar, red, orange FROM baz JOIN colors)
|
||||
// Typically becomes SELECT foo _1foo, _1bar, _2red, _2orange when
|
||||
// this kind of query is the result of an applicative join that looks like this:
|
||||
// query[baz].join(query[colors]).nested
|
||||
// this may need to change based on how distinct appends table names instead of just tuple indexes
|
||||
// into the property path.
|
||||
|
||||
trace"...inside walk completed, continuing to return: " `andReturn`
|
||||
OrderedSelect(
|
||||
o,
|
||||
SelectValue(
|
||||
// Note: Pass invisible properties to be tokenized by the idiom, they should be excluded there
|
||||
Property.Opinionated(ast, name, renameable, visible),
|
||||
// Skip concatonation of invisible properties into the alias e.g. so it will be
|
||||
Some(
|
||||
s"${nested.getOrElse("")}${expandColumn(name, renameable)}"
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
case pp @ Property(_, TupleIndex(idx)) =>
|
||||
trace"Reference is a tuple index: $idx from $select." `andReturn`
|
||||
select(idx) match {
|
||||
case OrderedSelect(o, SelectValue(ast, alias, c)) =>
|
||||
OrderedSelect(o, SelectValue(ast, concat(alias, idx), c))
|
||||
}
|
||||
case pp @ Property.Opinionated(_, name, renameable, visible) =>
|
||||
select match {
|
||||
case List(
|
||||
OrderedSelect(o, SelectValue(cc: CaseClass, alias, c))
|
||||
) =>
|
||||
// Currently case class element name is not being appended. Need to change that in order to ensure
|
||||
// path name uniqueness in future.
|
||||
val ((_, ast), index) =
|
||||
cc.values.zipWithIndex.find(_._1._1 == name) match {
|
||||
case Some(v) => v
|
||||
case None =>
|
||||
throw new IllegalArgumentException(
|
||||
s"Cannot find element $name in $cc"
|
||||
)
|
||||
}
|
||||
trace"Reference is a case class member: " `andReturn`
|
||||
OrderedSelect(
|
||||
o :+ index,
|
||||
SelectValue(ast, Some(expandColumn(name, renameable)), c)
|
||||
)
|
||||
case List(OrderedSelect(o, SelectValue(i: Ident, _, c))) =>
|
||||
trace"Reference is an identifier: " `andReturn`
|
||||
OrderedSelect(
|
||||
o,
|
||||
SelectValue(
|
||||
Property.Opinionated(i, name, renameable, visible),
|
||||
Some(name),
|
||||
c
|
||||
)
|
||||
)
|
||||
case other =>
|
||||
trace"Reference is unidentified: $other returning:" `andReturn`
|
||||
OrderedSelect(
|
||||
Integer.MAX_VALUE,
|
||||
SelectValue(
|
||||
Ident.Opinionated(name, visible),
|
||||
Some(expandColumn(name, renameable)),
|
||||
false
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// For certain very large queries where entities are unwrapped and then re-wrapped into CaseClass/Tuple constructs,
|
||||
// the actual row-types can contain Tuple/CaseClass values. For this reason. They need to be beta-reduced again.
|
||||
val normalizedOrderedSelect = orderedSelect.copy(selectValue =
|
||||
orderedSelect.selectValue.copy(ast =
|
||||
BetaReduction(orderedSelect.selectValue.ast)
|
||||
)
|
||||
)
|
||||
|
||||
trace"Expanded $ref into $orderedSelect then Normalized to $normalizedOrderedSelect" `andReturn`
|
||||
normalizedOrderedSelect
|
||||
}
|
||||
|
||||
def deAliasWhenUneeded(os: OrderedSelect) =
|
||||
os match {
|
||||
case OrderedSelect(
|
||||
_,
|
||||
sv @ SelectValue(Property(Ident(_), propName), Some(alias), _)
|
||||
) if (propName == alias) =>
|
||||
trace"Detected select value with un-needed alias: $os removing it:" `andReturn`
|
||||
os.copy(selectValue = sv.copy(alias = None))
|
||||
case _ => os
|
||||
}
|
||||
|
||||
references.toList match {
|
||||
case Nil => select.map(_.selectValue)
|
||||
case refs => {
|
||||
// elements first need to be sorted by their order in the select clause. Since some may map to multiple
|
||||
// properties when expanded, we want to maintain this order of properties as a secondary value.
|
||||
val mappedRefs =
|
||||
refs
|
||||
// Expand all the references to properties that we have selected in the super query
|
||||
.map(expandReference)
|
||||
// Once all the recursive calls of expandReference are done, remove the alias if it is not needed.
|
||||
// We cannot do this because during recursive calls, the aliases of outer clauses are used for inner ones.
|
||||
.map(deAliasWhenUneeded(_))
|
||||
|
||||
trace"Mapped Refs: $mappedRefs".andLog()
|
||||
|
||||
// are there any selects that have infix values which we have not already selected? We need to include
|
||||
// them because they could be doing essential things e.g. RANK ... ORDER BY
|
||||
val remainingSelectsWithInfixes =
|
||||
trace"Searching Selects with Infix:" `andReturn`
|
||||
new FindUnexpressedInfixes(select)(mappedRefs)
|
||||
|
||||
implicit val ordering: scala.math.Ordering[List[Int]] =
|
||||
new scala.math.Ordering[List[Int]] {
|
||||
override def compare(x: List[Int], y: List[Int]): Int =
|
||||
(x, y) match {
|
||||
case (head1 :: tail1, head2 :: tail2) =>
|
||||
val diff = head1 - head2
|
||||
if (diff != 0) diff
|
||||
else compare(tail1, tail2)
|
||||
case (Nil, Nil) => 0 // List(1,2,3) == List(1,2,3)
|
||||
case (head1, Nil) => -1 // List(1,2,3) < List(1,2)
|
||||
case (Nil, head2) => 1 // List(1,2) > List(1,2,3)
|
||||
}
|
||||
}
|
||||
|
||||
val sortedRefs =
|
||||
(mappedRefs ++ remainingSelectsWithInfixes).sortBy(ref =>
|
||||
ref.order
|
||||
) // (ref.order, ref.secondaryOrder)
|
||||
|
||||
sortedRefs.map(_.selectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object ExpandSelect {
|
||||
def apply(
|
||||
selectValues: List[SelectValue],
|
||||
references: LinkedHashSet[Property],
|
||||
strategy: NamingStrategy
|
||||
): List[SelectValue] =
|
||||
new ExpandSelect(selectValues, references, strategy).apply
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
package minisql.context.sql.norm.nested
|
||||
|
||||
import minisql.context.sql.norm.nested.Elements._
|
||||
import minisql.util.Interpolator
|
||||
import minisql.util.Messages.TraceType.NestedQueryExpansion
|
||||
import minisql.ast._
|
||||
import minisql.context.sql.SelectValue
|
||||
|
||||
/**
|
||||
* The challenge with appeneding infixes (that have not been used but are still
|
||||
* needed) back into the query, is that they could be inside of
|
||||
* tuples/case-classes that have already been selected, or inside of sibling
|
||||
* elements which have been selected. Take for instance a query that looks like
|
||||
* this: <pre><code> query[Person].map(p => (p.name, (p.id,
|
||||
* infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2._1)) </code></pre> In
|
||||
* this situation, `p.id` which is the sibling of the non-selected infix has
|
||||
* been selected via `p._2._1` (whose select-order is List(1,0) to represent 1st
|
||||
* element in 2nd tuple. We need to add it's sibling infix.
|
||||
*
|
||||
* Or take the following situation: <pre><code> query[Person].map(p => (p.name,
|
||||
* (p.id, infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2))
|
||||
* </code></pre> In this case, we have selected the entire 2nd element including
|
||||
* the infix. We need to know that `P._2._2` does not need to be selected since
|
||||
* `p._2` was.
|
||||
*
|
||||
* In order to do these things, we use the `order` property from `OrderedSelect`
|
||||
* in order to see which sub-sub-...-element has been selected. If `p._2` (that
|
||||
* has order `List(1)`) has been selected, we know that any infixes inside of it
|
||||
* e.g. `p._2._1` (ordering `List(1,0)`) does not need to be.
|
||||
*/
|
||||
class FindUnexpressedInfixes(select: List[OrderedSelect]) {
|
||||
val interp = new Interpolator(3)
|
||||
import interp._
|
||||
|
||||
def apply(refs: List[OrderedSelect]) = {
|
||||
|
||||
def pathExists(path: List[Int]) =
|
||||
refs.map(_.order).contains(path)
|
||||
|
||||
def containsInfix(ast: Ast) =
|
||||
CollectAst.byType[Infix](ast).length > 0
|
||||
|
||||
// build paths to every infix and see these paths were not selected already
|
||||
def findMissingInfixes(
|
||||
ast: Ast,
|
||||
parentOrder: List[Int]
|
||||
): List[(Ast, List[Int])] = {
|
||||
trace"Searching for infix: $ast in the sub-path $parentOrder".andLog()
|
||||
if (pathExists(parentOrder))
|
||||
trace"No infixes found" `andContinue`
|
||||
List()
|
||||
else
|
||||
ast match {
|
||||
case Tuple(values) =>
|
||||
values.zipWithIndex
|
||||
.filter(v => containsInfix(v._1))
|
||||
.flatMap {
|
||||
case (ast, index) =>
|
||||
findMissingInfixes(ast, parentOrder :+ index)
|
||||
}
|
||||
case CaseClass(values) =>
|
||||
values.zipWithIndex
|
||||
.filter(v => containsInfix(v._1._2))
|
||||
.flatMap {
|
||||
case ((_, ast), index) =>
|
||||
findMissingInfixes(ast, parentOrder :+ index)
|
||||
}
|
||||
case other if (containsInfix(other)) =>
|
||||
trace"Found unexpressed infix inside $other in $parentOrder"
|
||||
.andLog()
|
||||
List((other, parentOrder))
|
||||
case _ =>
|
||||
List()
|
||||
}
|
||||
}
|
||||
|
||||
select.flatMap {
|
||||
case OrderedSelect(o, sv) => findMissingInfixes(sv.ast, o)
|
||||
}.map {
|
||||
case (ast, order) => OrderedSelect(order, SelectValue(ast))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
package minisql.dsl
|
||||
|
||||
import minisql.*
|
||||
import minisql.parsing.*
|
||||
import minisql.ast.{Ast, Entity, Map, Property, Ident, given}
|
||||
import scala.quoted.*
|
||||
import scala.compiletime.*
|
||||
import scala.compiletime.ops.string.*
|
||||
import scala.collection.immutable.{Map => IMap}
|
||||
|
||||
opaque type Quoted <: Ast = Ast
|
||||
|
||||
opaque type Query[E] <: Quoted = Quoted
|
||||
|
||||
opaque type EntityQuery[E] <: Query[E] = Query[E]
|
||||
|
||||
extension [E](inline e: EntityQuery[E]) {
|
||||
inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
|
||||
transform(e)(f)(Map.apply)
|
||||
}
|
||||
}
|
||||
|
||||
private inline def transform[A, B](inline q1: Quoted)(
|
||||
inline f: A => B
|
||||
)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
|
||||
fast(q1, f.param0, f.body)
|
||||
}
|
||||
|
||||
inline def query[E](inline table: String): EntityQuery[E] =
|
||||
Entity(table, Nil)
|
||||
|
||||
inline def compile(inline x: Ast): Option[String] = ${
|
||||
compileImpl('{ x })
|
||||
}
|
||||
|
||||
private def compileImpl(
|
||||
x: Expr[Ast]
|
||||
)(using Quotes): Expr[Option[String]] = {
|
||||
import quotes.reflect.*
|
||||
x.value match {
|
||||
case Some(xv) => '{ Some(${ Expr(xv.toString()) }) }
|
||||
case None => '{ None }
|
||||
}
|
||||
}
|
||||
|
||||
extension [A, B](inline f1: A => B) {
|
||||
private inline def param0 = parsing.parseParamAt(f1, 0)
|
||||
private inline def body = parsing.parseBody(f1)
|
||||
}
|
||||
|
||||
extension [A1, A2, B](inline f1: (A1, A2) => B) {
|
||||
private inline def param0 = parsing.parseParamAt(f1, 0)
|
||||
private inline def param1 = parsing.parseParamAt(f1, 1)
|
||||
private inline def body = parsing.parseBody(f1)
|
||||
}
|
||||
|
||||
case class Foo(id: Int)
|
||||
|
||||
inline def queryFooId = query[Foo]("foo").map(_.id)
|
23
src/main/scala/minisql/idiom/Idiom.scala
Normal file
23
src/main/scala/minisql/idiom/Idiom.scala
Normal file
|
@ -0,0 +1,23 @@
|
|||
package minisql.idiom
|
||||
|
||||
import minisql.NamingStrategy
|
||||
import minisql.ast._
|
||||
import minisql.context.Capabilities
|
||||
|
||||
trait Idiom extends Capabilities {
|
||||
|
||||
def emptySetContainsToken(field: Token): Token = StringToken("FALSE")
|
||||
|
||||
def defaultAutoGeneratedToken(field: Token): Token = StringToken(
|
||||
"DEFAULT VALUES"
|
||||
)
|
||||
|
||||
def liftingPlaceholder(index: Int): String
|
||||
|
||||
def translate(ast: Ast)(using naming: NamingStrategy): (Ast, Statement)
|
||||
|
||||
def format(queryString: String): String = queryString
|
||||
|
||||
def prepareForProbing(string: String): String
|
||||
|
||||
}
|
28
src/main/scala/minisql/idiom/LoadNaming.scala
Normal file
28
src/main/scala/minisql/idiom/LoadNaming.scala
Normal file
|
@ -0,0 +1,28 @@
|
|||
package minisql.idiom
|
||||
|
||||
import scala.util.Try
|
||||
import scala.quoted._
|
||||
import minisql.NamingStrategy
|
||||
import minisql.util.CollectTry
|
||||
import minisql.util.LoadObject
|
||||
import minisql.CompositeNamingStrategy
|
||||
|
||||
object LoadNaming {
|
||||
|
||||
def static[C](using Quotes, Type[C]): Try[NamingStrategy] = CollectTry {
|
||||
strategies[C].map(LoadObject[NamingStrategy](_))
|
||||
}.map(NamingStrategy(_))
|
||||
|
||||
private def strategies[C](using Quotes, Type[C]) = {
|
||||
import quotes.reflect.*
|
||||
val isComposite = TypeRepr.of[C] <:< TypeRepr.of[CompositeNamingStrategy]
|
||||
val ct = TypeRepr.of[C]
|
||||
if (isComposite) {
|
||||
ct.typeArgs.filterNot { t =>
|
||||
t =:= TypeRepr.of[NamingStrategy] && t =:= TypeRepr.of[Nothing]
|
||||
}
|
||||
} else {
|
||||
List(ct)
|
||||
}
|
||||
}
|
||||
}
|
356
src/main/scala/minisql/idiom/MirrorIdiom.scala
Normal file
356
src/main/scala/minisql/idiom/MirrorIdiom.scala
Normal file
|
@ -0,0 +1,356 @@
|
|||
package minisql.idiom
|
||||
|
||||
import minisql.NamingStrategy
|
||||
import minisql.ast.Renameable.{ByStrategy, Fixed}
|
||||
import minisql.ast.Visibility.Hidden
|
||||
import minisql.ast.*
|
||||
import minisql.context.CanReturnClause
|
||||
import minisql.idiom.{Idiom, SetContainsToken, Statement}
|
||||
import minisql.idiom.StatementInterpolator.*
|
||||
import minisql.norm.Normalize
|
||||
import minisql.util.Interleave
|
||||
|
||||
object MirrorIdiom extends MirrorIdiom
|
||||
class MirrorIdiom extends MirrorIdiomBase with CanReturnClause
|
||||
|
||||
object MirrorIdiomPrinting extends MirrorIdiom {
|
||||
override def distinguishHidden: Boolean = true
|
||||
}
|
||||
|
||||
trait MirrorIdiomBase extends Idiom {
|
||||
|
||||
def distinguishHidden: Boolean = false
|
||||
|
||||
override def prepareForProbing(string: String) = string
|
||||
|
||||
override def liftingPlaceholder(index: Int): String = "?"
|
||||
|
||||
override def translate(
|
||||
ast: Ast
|
||||
)(implicit naming: NamingStrategy): (Ast, Statement) = {
|
||||
val normalizedAst = Normalize(ast)
|
||||
(normalizedAst, stmt"${normalizedAst.token}")
|
||||
}
|
||||
|
||||
implicit def astTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Ast] = Tokenizer[Ast] {
|
||||
case ast: Query => ast.token
|
||||
case ast: Function => ast.token
|
||||
case ast: Value => ast.token
|
||||
case ast: Operation => ast.token
|
||||
case ast: Action => ast.token
|
||||
case ast: Ident => ast.token
|
||||
case ast: ExternalIdent => ast.token
|
||||
case ast: Property => ast.token
|
||||
case ast: Infix => ast.token
|
||||
case ast: OptionOperation => ast.token
|
||||
case ast: IterableOperation => ast.token
|
||||
case ast: Dynamic => ast.token
|
||||
case ast: If => ast.token
|
||||
case ast: Block => ast.token
|
||||
case ast: Val => ast.token
|
||||
case ast: Ordering => ast.token
|
||||
case ast: Lift => ast.token
|
||||
case ast: Assignment => ast.token
|
||||
case ast: OnConflict.Excluded => ast.token
|
||||
case ast: OnConflict.Existing => ast.token
|
||||
}
|
||||
|
||||
implicit def ifTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[If] = Tokenizer[If] {
|
||||
case If(a, b, c) => stmt"if(${a.token}) ${b.token} else ${c.token}"
|
||||
}
|
||||
|
||||
implicit val dynamicTokenizer: Tokenizer[Dynamic] = Tokenizer[Dynamic] {
|
||||
case Dynamic(tree) => stmt"${tree.toString.token}"
|
||||
}
|
||||
|
||||
implicit def blockTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Block] = Tokenizer[Block] {
|
||||
case Block(statements) => stmt"{ ${statements.map(_.token).mkStmt("; ")} }"
|
||||
}
|
||||
|
||||
implicit def valTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Val] = Tokenizer[Val] {
|
||||
case Val(name, body) => stmt"val ${name.token} = ${body.token}"
|
||||
}
|
||||
|
||||
implicit def queryTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Query] = Tokenizer[Query] {
|
||||
|
||||
case Entity.Opinionated(name, Nil, renameable) =>
|
||||
stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token})"
|
||||
|
||||
case Entity.Opinionated(name, prop, renameable) =>
|
||||
val properties =
|
||||
prop.map(p => stmt"""_.${p.path.mkStmt(".")} -> "${p.alias.token}"""")
|
||||
stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token}, ${properties.token})"
|
||||
|
||||
case Filter(source, alias, body) =>
|
||||
stmt"${source.token}.filter(${alias.token} => ${body.token})"
|
||||
|
||||
case Map(source, alias, body) =>
|
||||
stmt"${source.token}.map(${alias.token} => ${body.token})"
|
||||
|
||||
case FlatMap(source, alias, body) =>
|
||||
stmt"${source.token}.flatMap(${alias.token} => ${body.token})"
|
||||
|
||||
case ConcatMap(source, alias, body) =>
|
||||
stmt"${source.token}.concatMap(${alias.token} => ${body.token})"
|
||||
|
||||
case SortBy(source, alias, body, ordering) =>
|
||||
stmt"${source.token}.sortBy(${alias.token} => ${body.token})(${ordering.token})"
|
||||
|
||||
case GroupBy(source, alias, body) =>
|
||||
stmt"${source.token}.groupBy(${alias.token} => ${body.token})"
|
||||
|
||||
case Aggregation(op, ast) =>
|
||||
stmt"${scopedTokenizer(ast)}.${op.token}"
|
||||
|
||||
case Take(source, n) =>
|
||||
stmt"${source.token}.take(${n.token})"
|
||||
|
||||
case Drop(source, n) =>
|
||||
stmt"${source.token}.drop(${n.token})"
|
||||
|
||||
case Union(a, b) =>
|
||||
stmt"${a.token}.union(${b.token})"
|
||||
|
||||
case UnionAll(a, b) =>
|
||||
stmt"${a.token}.unionAll(${b.token})"
|
||||
|
||||
case Join(t, a, b, iA, iB, on) =>
|
||||
stmt"${a.token}.${t.token}(${b.token}).on((${iA.token}, ${iB.token}) => ${on.token})"
|
||||
|
||||
case FlatJoin(t, a, iA, on) =>
|
||||
stmt"${a.token}.${t.token}((${iA.token}) => ${on.token})"
|
||||
|
||||
case Distinct(a) =>
|
||||
stmt"${a.token}.distinct"
|
||||
|
||||
case DistinctOn(source, alias, body) =>
|
||||
stmt"${source.token}.distinctOn(${alias.token} => ${body.token})"
|
||||
|
||||
case Nested(a) =>
|
||||
stmt"${a.token}.nested"
|
||||
}
|
||||
|
||||
implicit val orderingTokenizer: Tokenizer[Ordering] = Tokenizer[Ordering] {
|
||||
case TupleOrdering(elems) => stmt"Ord(${elems.token})"
|
||||
case PropertyOrdering.Asc => stmt"Ord.asc"
|
||||
case PropertyOrdering.Desc => stmt"Ord.desc"
|
||||
case PropertyOrdering.AscNullsFirst => stmt"Ord.ascNullsFirst"
|
||||
case PropertyOrdering.DescNullsFirst => stmt"Ord.descNullsFirst"
|
||||
case PropertyOrdering.AscNullsLast => stmt"Ord.ascNullsLast"
|
||||
case PropertyOrdering.DescNullsLast => stmt"Ord.descNullsLast"
|
||||
}
|
||||
|
||||
implicit def optionOperationTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] {
|
||||
case OptionTableFlatMap(ast, alias, body) =>
|
||||
stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
|
||||
case OptionTableMap(ast, alias, body) =>
|
||||
stmt"${ast.token}.map((${alias.token}) => ${body.token})"
|
||||
case OptionTableExists(ast, alias, body) =>
|
||||
stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
|
||||
case OptionTableForall(ast, alias, body) =>
|
||||
stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
|
||||
case OptionFlatten(ast) => stmt"${ast.token}.flatten"
|
||||
case OptionGetOrElse(ast, body) =>
|
||||
stmt"${ast.token}.getOrElse(${body.token})"
|
||||
case OptionFlatMap(ast, alias, body) =>
|
||||
stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
|
||||
case OptionMap(ast, alias, body) =>
|
||||
stmt"${ast.token}.map((${alias.token}) => ${body.token})"
|
||||
case OptionForall(ast, alias, body) =>
|
||||
stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
|
||||
case OptionExists(ast, alias, body) =>
|
||||
stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
|
||||
case OptionContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
|
||||
case OptionIsEmpty(ast) => stmt"${ast.token}.isEmpty"
|
||||
case OptionNonEmpty(ast) => stmt"${ast.token}.nonEmpty"
|
||||
case OptionIsDefined(ast) => stmt"${ast.token}.isDefined"
|
||||
case OptionSome(ast) => stmt"Some(${ast.token})"
|
||||
case OptionApply(ast) => stmt"Option(${ast.token})"
|
||||
case OptionOrNull(ast) => stmt"${ast.token}.orNull"
|
||||
case OptionGetOrNull(ast) => stmt"${ast.token}.getOrNull"
|
||||
case OptionNone => stmt"None"
|
||||
}
|
||||
|
||||
implicit def traversableOperationTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[IterableOperation] = Tokenizer[IterableOperation] {
|
||||
case MapContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
|
||||
case SetContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
|
||||
case ListContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
|
||||
}
|
||||
|
||||
implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] {
|
||||
case JoinType.InnerJoin => stmt"join"
|
||||
case JoinType.LeftJoin => stmt"leftJoin"
|
||||
case JoinType.RightJoin => stmt"rightJoin"
|
||||
case JoinType.FullJoin => stmt"fullJoin"
|
||||
}
|
||||
|
||||
implicit def functionTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Function] = Tokenizer[Function] {
|
||||
case Function(params, body) => stmt"(${params.token}) => ${body.token}"
|
||||
}
|
||||
|
||||
implicit def operationTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Operation] = Tokenizer[Operation] {
|
||||
case UnaryOperation(op: PrefixUnaryOperator, ast) =>
|
||||
stmt"${op.token}${scopedTokenizer(ast)}"
|
||||
case UnaryOperation(op: PostfixUnaryOperator, ast) =>
|
||||
stmt"${scopedTokenizer(ast)}.${op.token}"
|
||||
case BinaryOperation(a, op @ SetOperator.`contains`, b) =>
|
||||
SetContainsToken(scopedTokenizer(b), op.token, a.token)
|
||||
case BinaryOperation(a, op, b) =>
|
||||
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
|
||||
case FunctionApply(function, values) =>
|
||||
stmt"${scopedTokenizer(function)}.apply(${values.token})"
|
||||
}
|
||||
|
||||
implicit def operatorTokenizer[T <: Operator]: Tokenizer[T] = Tokenizer[T] {
|
||||
case o => stmt"${o.toString.token}"
|
||||
}
|
||||
|
||||
def tokenizeName(name: String, renameable: Renameable) =
|
||||
renameable match {
|
||||
case ByStrategy => name
|
||||
case Fixed => s"`${name}`"
|
||||
}
|
||||
|
||||
def bracketIfHidden(name: String, visibility: Visibility) =
|
||||
(distinguishHidden, visibility) match {
|
||||
case (true, Hidden) => s"[$name]"
|
||||
case _ => name
|
||||
}
|
||||
|
||||
implicit def propertyTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Property] = Tokenizer[Property] {
|
||||
case Property.Opinionated(ExternalIdent(_), name, renameable, visibility) =>
|
||||
stmt"${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
|
||||
case Property.Opinionated(ref, name, renameable, visibility) =>
|
||||
stmt"${scopedTokenizer(ref)}.${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
|
||||
}
|
||||
|
||||
implicit val valueTokenizer: Tokenizer[Value] = Tokenizer[Value] {
|
||||
case Constant(v: String) => stmt""""${v.token}""""
|
||||
case Constant(()) => stmt"{}"
|
||||
case Constant(v) => stmt"${v.toString.token}"
|
||||
case NullValue => stmt"null"
|
||||
case Tuple(values) => stmt"(${values.token})"
|
||||
case CaseClass(values) =>
|
||||
stmt"CaseClass(${values.map { case (k, v) => s"${k.token}: ${v.token}" }.mkString(", ").token})"
|
||||
}
|
||||
|
||||
implicit val identTokenizer: Tokenizer[Ident] = Tokenizer[Ident] {
|
||||
case Ident.Opinionated(name, visibility) =>
|
||||
stmt"${bracketIfHidden(name, visibility).token}"
|
||||
}
|
||||
|
||||
implicit val typeTokenizer: Tokenizer[ExternalIdent] =
|
||||
Tokenizer[ExternalIdent] {
|
||||
case e => stmt"${e.name.token}"
|
||||
}
|
||||
|
||||
implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] =
|
||||
Tokenizer[OnConflict.Excluded] {
|
||||
case OnConflict.Excluded(ident) => stmt"${ident.token}"
|
||||
}
|
||||
|
||||
implicit val existingTokenizer: Tokenizer[OnConflict.Existing] =
|
||||
Tokenizer[OnConflict.Existing] {
|
||||
case OnConflict.Existing(ident) => stmt"${ident.token}"
|
||||
}
|
||||
|
||||
implicit def actionTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Action] = Tokenizer[Action] {
|
||||
case Update(query, assignments) =>
|
||||
stmt"${query.token}.update(${assignments.token})"
|
||||
case Insert(query, assignments) =>
|
||||
stmt"${query.token}.insert(${assignments.token})"
|
||||
case Delete(query) => stmt"${query.token}.delete"
|
||||
case Returning(query, alias, body) =>
|
||||
stmt"${query.token}.returning((${alias.token}) => ${body.token})"
|
||||
case ReturningGenerated(query, alias, body) =>
|
||||
stmt"${query.token}.returningGenerated((${alias.token}) => ${body.token})"
|
||||
case Foreach(query, alias, body) =>
|
||||
stmt"${query.token}.foreach((${alias.token}) => ${body.token})"
|
||||
case c: OnConflict => stmt"${c.token}"
|
||||
}
|
||||
|
||||
implicit def conflictTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[OnConflict] = {
|
||||
|
||||
def targetProps(l: List[Property]) = l.map(p =>
|
||||
Transform(p) {
|
||||
case Ident(_) => Ident("_")
|
||||
}
|
||||
)
|
||||
|
||||
implicit val conflictTargetTokenizer: Tokenizer[OnConflict.Target] =
|
||||
Tokenizer[OnConflict.Target] {
|
||||
case OnConflict.NoTarget => stmt""
|
||||
case OnConflict.Properties(props) =>
|
||||
val listTokens = props.token
|
||||
stmt"(${listTokens})"
|
||||
}
|
||||
|
||||
val updateAssignsTokenizer = Tokenizer[Assignment] {
|
||||
case Assignment(i, p, v) =>
|
||||
stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}"
|
||||
}
|
||||
|
||||
Tokenizer[OnConflict] {
|
||||
case OnConflict(i, t, OnConflict.Update(assign)) =>
|
||||
stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})"
|
||||
case OnConflict(i, t, OnConflict.Ignore) =>
|
||||
stmt"${i.token}.onConflictIgnore${t.token}"
|
||||
}
|
||||
}
|
||||
|
||||
implicit def assignmentTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Assignment] = Tokenizer[Assignment] {
|
||||
case Assignment(ident, property, value) =>
|
||||
stmt"${ident.token} => ${property.token} -> ${value.token}"
|
||||
}
|
||||
|
||||
implicit def infixTokenizer(implicit
|
||||
liftTokenizer: Tokenizer[Lift]
|
||||
): Tokenizer[Infix] = Tokenizer[Infix] {
|
||||
case Infix(parts, params, _, _) =>
|
||||
def tokenParam(ast: Ast) =
|
||||
ast match {
|
||||
case ast: Ident => stmt"$$${ast.token}"
|
||||
case other => stmt"$${${ast.token}}"
|
||||
}
|
||||
|
||||
val pt = parts.map(_.token)
|
||||
val pr = params.map(tokenParam)
|
||||
val body = Statement(Interleave(pt, pr))
|
||||
stmt"""infix"${body.token}""""
|
||||
}
|
||||
|
||||
private def scopedTokenizer(
|
||||
ast: Ast
|
||||
)(implicit liftTokenizer: Tokenizer[Lift]) =
|
||||
ast match {
|
||||
case _: Function => stmt"(${ast.token})"
|
||||
case _: BinaryOperation => stmt"(${ast.token})"
|
||||
case other => ast.token
|
||||
}
|
||||
}
|
101
src/main/scala/minisql/idiom/ReifyStatement.scala
Normal file
101
src/main/scala/minisql/idiom/ReifyStatement.scala
Normal file
|
@ -0,0 +1,101 @@
|
|||
package minisql.idiom
|
||||
|
||||
import minisql.ParamEncoder
|
||||
import minisql.ast.*
|
||||
import minisql.util.Interleave
|
||||
import minisql.idiom.StatementInterpolator.*
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.immutable.{Map => SMap}
|
||||
|
||||
object ReifyStatement {
|
||||
|
||||
def apply(
|
||||
liftingPlaceholder: Int => String,
|
||||
emptySetContainsToken: Token => Token,
|
||||
statement: Statement,
|
||||
liftMap: SMap[String, (Any, ParamEncoder[?])]
|
||||
): (String, List[ScalarValueLift]) = {
|
||||
val expanded = expandLiftings(statement, emptySetContainsToken, liftMap)
|
||||
token2string(expanded, liftMap, liftingPlaceholder)
|
||||
}
|
||||
|
||||
private def token2string(
|
||||
token: Token,
|
||||
liftMap: SMap[String, (Any, ParamEncoder[?])],
|
||||
liftingPlaceholder: Int => String
|
||||
): (String, List[ScalarValueLift]) = {
|
||||
|
||||
val liftBuilder = List.newBuilder[ScalarValueLift]
|
||||
val sqlBuilder = StringBuilder()
|
||||
@tailrec
|
||||
def loop(
|
||||
workList: Seq[Token],
|
||||
liftingSize: Int
|
||||
): Unit = workList match {
|
||||
case Seq() => ()
|
||||
case head +: tail =>
|
||||
head match {
|
||||
case StringToken(s2) =>
|
||||
sqlBuilder ++= s2
|
||||
loop(tail, liftingSize)
|
||||
case SetContainsToken(a, op, b) =>
|
||||
loop(
|
||||
stmt"$a $op ($b)" +: tail,
|
||||
liftingSize
|
||||
)
|
||||
case ScalarLiftToken(lift: ScalarValueLift) =>
|
||||
sqlBuilder ++= liftingPlaceholder(liftingSize)
|
||||
liftBuilder += lift.copy(value = liftMap.get(lift.liftId))
|
||||
loop(tail, liftingSize + 1)
|
||||
case ScalarLiftToken(o) =>
|
||||
throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}")
|
||||
case Statement(tokens) =>
|
||||
loop(
|
||||
tokens.foldRight(tail)(_ +: _),
|
||||
liftingSize
|
||||
)
|
||||
}
|
||||
}
|
||||
loop(Vector(token), 0)
|
||||
sqlBuilder.toString() -> liftBuilder.result()
|
||||
}
|
||||
|
||||
private def expandLiftings(
|
||||
statement: Statement,
|
||||
emptySetContainsToken: Token => Token,
|
||||
liftMap: SMap[String, (Any, ParamEncoder[?])]
|
||||
): (Token) = {
|
||||
Statement {
|
||||
val lb = List.newBuilder[Token]
|
||||
statement.tokens.foldLeft(lb) {
|
||||
case (
|
||||
tokens,
|
||||
SetContainsToken(a, op, ScalarLiftToken(lift: ScalarQueryLift))
|
||||
) =>
|
||||
val (lv, le) = liftMap(lift.liftId)
|
||||
lv.asInstanceOf[Iterable[Any]].toVector match {
|
||||
case Vector() => tokens += emptySetContainsToken(a)
|
||||
case values =>
|
||||
val liftings = values.zipWithIndex.map {
|
||||
case (v, i) =>
|
||||
ScalarLiftToken(
|
||||
ScalarValueLift(
|
||||
s"${lift.name}[${i}]",
|
||||
s"${lift.liftId}[${i}]",
|
||||
Some(v -> le)
|
||||
)
|
||||
)
|
||||
}
|
||||
val separators = Vector.fill(liftings.size - 1)(StringToken(", "))
|
||||
(tokens += stmt"$a $op (") ++= Interleave(
|
||||
liftings,
|
||||
separators
|
||||
) += StringToken(")")
|
||||
}
|
||||
case (tokens, token) =>
|
||||
tokens += token
|
||||
}
|
||||
lb.result()
|
||||
}
|
||||
}
|
||||
}
|
47
src/main/scala/minisql/idiom/Statement.scala
Normal file
47
src/main/scala/minisql/idiom/Statement.scala
Normal file
|
@ -0,0 +1,47 @@
|
|||
package minisql.idiom
|
||||
|
||||
import scala.quoted._
|
||||
import minisql.ast._
|
||||
|
||||
sealed trait Token
|
||||
|
||||
case class StringToken(string: String) extends Token {
|
||||
override def toString = string
|
||||
}
|
||||
|
||||
case class ScalarLiftToken(lift: ScalarLift) extends Token {
|
||||
override def toString = s"lift(${lift.name})"
|
||||
}
|
||||
|
||||
case class Statement(tokens: List[Token]) extends Token {
|
||||
override def toString = tokens.mkString
|
||||
}
|
||||
|
||||
case class SetContainsToken(a: Token, op: Token, b: Token) extends Token {
|
||||
override def toString = s"${a.toString} ${op.toString} (${b.toString})"
|
||||
}
|
||||
|
||||
object Statement {
|
||||
given ToExpr[Statement] with {
|
||||
def apply(t: Statement)(using Quotes): Expr[Statement] = {
|
||||
'{ Statement(${ Expr(t.tokens) }) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object Token {
|
||||
|
||||
given ToExpr[Token] with {
|
||||
def apply(t: Token)(using Quotes): Expr[Token] = {
|
||||
t match {
|
||||
case StringToken(s) =>
|
||||
'{ StringToken(${ Expr(s) }) }
|
||||
case ScalarLiftToken(l) =>
|
||||
'{ ScalarLiftToken(${ Expr(l) }) }
|
||||
case SetContainsToken(a, op, b) =>
|
||||
'{ SetContainsToken(${ Expr(a) }, ${ Expr(op) }, ${ Expr(b) }) }
|
||||
case s: Statement => Expr(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
131
src/main/scala/minisql/idiom/StatementInterpolator.scala
Normal file
131
src/main/scala/minisql/idiom/StatementInterpolator.scala
Normal file
|
@ -0,0 +1,131 @@
|
|||
package minisql.idiom
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.util.Interleave
|
||||
import minisql.util.Messages._
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
object StatementInterpolator {
|
||||
|
||||
extension [T](list: List[T]) {
|
||||
private[minisql] def mkStmt(
|
||||
sep: String = ", "
|
||||
)(using tokenize: Tokenizer[T]) = {
|
||||
val l1 = list.map(_.token)
|
||||
val l2 = List.fill(l1.size - 1)(StringToken(sep))
|
||||
Statement(Interleave(l1, l2))
|
||||
}
|
||||
}
|
||||
trait Tokenizer[T] {
|
||||
extension (v: T) {
|
||||
def token: Token
|
||||
}
|
||||
}
|
||||
object Tokenizer {
|
||||
def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] {
|
||||
extension (v: T) {
|
||||
def token: Token = f(v)
|
||||
}
|
||||
}
|
||||
def withFallback[T](
|
||||
fallback: Tokenizer[T] => Tokenizer[T]
|
||||
)(pf: PartialFunction[T, Token]) =
|
||||
new Tokenizer[T] {
|
||||
extension (v: T) {
|
||||
private def stable = fallback(this)
|
||||
override def token = pf.applyOrElse(v, stable.token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension [T](v: T)(using tokenizer: Tokenizer[T]) {
|
||||
def token = tokenizer.token(v)
|
||||
}
|
||||
|
||||
given stringTokenizer: Tokenizer[String] =
|
||||
Tokenizer[String] {
|
||||
case string => StringToken(string)
|
||||
}
|
||||
|
||||
given liftTokenizer: Tokenizer[Lift] =
|
||||
Tokenizer[Lift] {
|
||||
case lift: ScalarLift => ScalarLiftToken(lift)
|
||||
}
|
||||
|
||||
given tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity)
|
||||
given statementTokenizer: Tokenizer[Statement] =
|
||||
Tokenizer[Statement](identity)
|
||||
given stringTokenTokenizer: Tokenizer[StringToken] =
|
||||
Tokenizer[StringToken](identity)
|
||||
given liftingTokenTokenizer: Tokenizer[ScalarLiftToken] =
|
||||
Tokenizer[ScalarLiftToken](identity)
|
||||
|
||||
given listTokenizer[T](using
|
||||
tokenize: Tokenizer[T]
|
||||
): Tokenizer[List[T]] =
|
||||
Tokenizer[List[T]] {
|
||||
case list => list.mkStmt()
|
||||
}
|
||||
|
||||
extension (sc: StringContext) {
|
||||
|
||||
def flatten(tokens: List[Token]): List[Token] = {
|
||||
|
||||
def unestStatements(tokens: List[Token]): List[Token] = {
|
||||
tokens.flatMap {
|
||||
case Statement(innerTokens) => unestStatements(innerTokens)
|
||||
case token => token :: Nil
|
||||
}
|
||||
}
|
||||
|
||||
def mergeStringTokens(tokens: List[Token]): List[Token] = {
|
||||
val (resultBuilder, leftTokens) =
|
||||
tokens.foldLeft((new ListBuffer[Token], new ListBuffer[String])) {
|
||||
case ((builder, acc), stringToken: StringToken) =>
|
||||
val str = stringToken.string
|
||||
if (str.nonEmpty)
|
||||
acc += stringToken.string
|
||||
(builder, acc)
|
||||
case ((builder, prev), b) if prev.isEmpty =>
|
||||
(builder += b.token, prev)
|
||||
case ((builder, prev), b) /* if prev.nonEmpty */ =>
|
||||
builder += StringToken(prev.result().mkString)
|
||||
builder += b.token
|
||||
(builder, new ListBuffer[String])
|
||||
}
|
||||
if (leftTokens.nonEmpty)
|
||||
resultBuilder += StringToken(leftTokens.result().mkString)
|
||||
resultBuilder.result()
|
||||
}
|
||||
|
||||
(unestStatements)
|
||||
.andThen(mergeStringTokens)
|
||||
.apply(tokens)
|
||||
}
|
||||
|
||||
def checkLengths(
|
||||
args: scala.collection.Seq[Any],
|
||||
parts: Seq[String]
|
||||
): Unit =
|
||||
if (parts.length != args.length + 1)
|
||||
throw new IllegalArgumentException(
|
||||
"wrong number of arguments (" + args.length
|
||||
+ ") for interpolated string with " + parts.length + " parts"
|
||||
)
|
||||
|
||||
def stmt(args: Token*): Statement = {
|
||||
checkLengths(args, sc.parts)
|
||||
val partsIterator = sc.parts.iterator
|
||||
val argsIterator = args.iterator
|
||||
val bldr = List.newBuilder[Token]
|
||||
bldr += StringToken(partsIterator.next())
|
||||
while (argsIterator.hasNext) {
|
||||
bldr += argsIterator.next()
|
||||
bldr += StringToken(partsIterator.next())
|
||||
}
|
||||
val tokens = flatten(bldr.result())
|
||||
Statement(tokens)
|
||||
}
|
||||
}
|
||||
}
|
52
src/main/scala/minisql/norm/AdHocReduction.scala
Normal file
52
src/main/scala/minisql/norm/AdHocReduction.scala
Normal file
|
@ -0,0 +1,52 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast.BinaryOperation
|
||||
import minisql.ast.BooleanOperator
|
||||
import minisql.ast.Filter
|
||||
import minisql.ast.FlatMap
|
||||
import minisql.ast.Map
|
||||
import minisql.ast.Query
|
||||
import minisql.ast.Union
|
||||
import minisql.ast.UnionAll
|
||||
|
||||
object AdHocReduction {
|
||||
|
||||
def unapply(q: Query) =
|
||||
q match {
|
||||
|
||||
// ---------------------------
|
||||
// *.filter
|
||||
|
||||
// a.filter(b => c).filter(d => e) =>
|
||||
// a.filter(b => c && e[d := b])
|
||||
case Filter(Filter(a, b, c), d, e) =>
|
||||
val er = BetaReduction(e, d -> b)
|
||||
Some(Filter(a, b, BinaryOperation(c, BooleanOperator.`&&`, er)))
|
||||
|
||||
// ---------------------------
|
||||
// flatMap.*
|
||||
|
||||
// a.flatMap(b => c).map(d => e) =>
|
||||
// a.flatMap(b => c.map(d => e))
|
||||
case Map(FlatMap(a, b, c), d, e) =>
|
||||
Some(FlatMap(a, b, Map(c, d, e)))
|
||||
|
||||
// a.flatMap(b => c).filter(d => e) =>
|
||||
// a.flatMap(b => c.filter(d => e))
|
||||
case Filter(FlatMap(a, b, c), d, e) =>
|
||||
Some(FlatMap(a, b, Filter(c, d, e)))
|
||||
|
||||
// a.flatMap(b => c.union(d))
|
||||
// a.flatMap(b => c).union(a.flatMap(b => d))
|
||||
case FlatMap(a, b, Union(c, d)) =>
|
||||
Some(Union(FlatMap(a, b, c), FlatMap(a, b, d)))
|
||||
|
||||
// a.flatMap(b => c.unionAll(d))
|
||||
// a.flatMap(b => c).unionAll(a.flatMap(b => d))
|
||||
case FlatMap(a, b, UnionAll(c, d)) =>
|
||||
Some(UnionAll(FlatMap(a, b, c), FlatMap(a, b, d)))
|
||||
|
||||
case other => None
|
||||
}
|
||||
|
||||
}
|
160
src/main/scala/minisql/norm/ApplyMap.scala
Normal file
160
src/main/scala/minisql/norm/ApplyMap.scala
Normal file
|
@ -0,0 +1,160 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast._
|
||||
|
||||
object ApplyMap {
|
||||
|
||||
private def isomorphic(e: Ast, c: Ast, alias: Ident) =
|
||||
BetaReduction(e, alias -> c) == c
|
||||
|
||||
object InfixedTailOperation {
|
||||
|
||||
def hasImpureInfix(ast: Ast) =
|
||||
CollectAst(ast) {
|
||||
case i @ Infix(_, _, false, _) => i
|
||||
}.nonEmpty
|
||||
|
||||
def unapply(ast: Ast): Option[Ast] =
|
||||
ast match {
|
||||
case cc: CaseClass if hasImpureInfix(cc) => Some(cc)
|
||||
case tup: Tuple if hasImpureInfix(tup) => Some(tup)
|
||||
case p: Property if hasImpureInfix(p) => Some(p)
|
||||
case b: BinaryOperation if hasImpureInfix(b) => Some(b)
|
||||
case u: UnaryOperation if hasImpureInfix(u) => Some(u)
|
||||
case i @ Infix(_, _, false, _) => Some(i)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
object MapWithoutInfixes {
|
||||
def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
|
||||
ast match {
|
||||
case Map(a, b, InfixedTailOperation(c)) => None
|
||||
case Map(a, b, c) => Some((a, b, c))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
object DetachableMap {
|
||||
def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
|
||||
ast match {
|
||||
case Map(a: GroupBy, b, c) => None
|
||||
case Map(a, b, InfixedTailOperation(c)) => None
|
||||
case Map(a, b, c) => Some((a, b, c))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
def unapply(q: Query): Option[Query] =
|
||||
q match {
|
||||
|
||||
case Map(a: GroupBy, b, c) if (b == c) => None
|
||||
case Map(a: DistinctOn, b, c) => None
|
||||
case Map(a: Nested, b, c) if (b == c) => None
|
||||
case Nested(DetachableMap(a: Join, b, c)) => None
|
||||
|
||||
// map(i => (i.i, i.l)).distinct.map(x => (x._1, x._2)) =>
|
||||
// map(i => (i.i, i.l)).distinct
|
||||
case Map(Distinct(DetachableMap(a, b, c)), d, e) if isomorphic(e, c, d) =>
|
||||
Some(Distinct(Map(a, b, c)))
|
||||
|
||||
// a.map(b => c).map(d => e) =>
|
||||
// a.map(b => e[d := c])
|
||||
case before @ Map(MapWithoutInfixes(a, b, c), d, e) =>
|
||||
val er = BetaReduction(e, d -> c)
|
||||
Some(Map(a, b, er))
|
||||
|
||||
// a.map(b => b) =>
|
||||
// a
|
||||
case Map(a: Query, b, c) if (b == c) =>
|
||||
Some(a)
|
||||
|
||||
// a.map(b => c).flatMap(d => e) =>
|
||||
// a.flatMap(b => e[d := c])
|
||||
case FlatMap(DetachableMap(a, b, c), d, e) =>
|
||||
val er = BetaReduction(e, d -> c)
|
||||
Some(FlatMap(a, b, er))
|
||||
|
||||
// a.map(b => c).filter(d => e) =>
|
||||
// a.filter(b => e[d := c]).map(b => c)
|
||||
case Filter(DetachableMap(a, b, c), d, e) =>
|
||||
val er = BetaReduction(e, d -> c)
|
||||
Some(Map(Filter(a, b, er), b, c))
|
||||
|
||||
// a.map(b => c).sortBy(d => e) =>
|
||||
// a.sortBy(b => e[d := c]).map(b => c)
|
||||
case SortBy(DetachableMap(a, b, c), d, e, f) =>
|
||||
val er = BetaReduction(e, d -> c)
|
||||
Some(Map(SortBy(a, b, er, f), b, c))
|
||||
|
||||
// a.map(b => c).sortBy(d => e).distinct =>
|
||||
// a.sortBy(b => e[d := c]).map(b => c).distinct
|
||||
case SortBy(Distinct(DetachableMap(a, b, c)), d, e, f) =>
|
||||
val er = BetaReduction(e, d -> c)
|
||||
Some(Distinct(Map(SortBy(a, b, er, f), b, c)))
|
||||
|
||||
// a.map(b => c).groupBy(d => e) =>
|
||||
// a.groupBy(b => e[d := c]).map(x => (x._1, x._2.map(b => c)))
|
||||
case GroupBy(DetachableMap(a, b, c), d, e) =>
|
||||
val er = BetaReduction(e, d -> c)
|
||||
val x = Ident("x")
|
||||
val x1 = Property(
|
||||
Ident("x"),
|
||||
"_1"
|
||||
) // These introduced property should not be renamed
|
||||
val x2 = Property(Ident("x"), "_2") // due to any naming convention.
|
||||
val body = Tuple(List(x1, Map(x2, b, c)))
|
||||
Some(Map(GroupBy(a, b, er), x, body))
|
||||
|
||||
// a.map(b => c).drop(d) =>
|
||||
// a.drop(d).map(b => c)
|
||||
case Drop(DetachableMap(a, b, c), d) =>
|
||||
Some(Map(Drop(a, d), b, c))
|
||||
|
||||
// a.map(b => c).take(d) =>
|
||||
// a.drop(d).map(b => c)
|
||||
case Take(DetachableMap(a, b, c), d) =>
|
||||
Some(Map(Take(a, d), b, c))
|
||||
|
||||
// a.map(b => c).nested =>
|
||||
// a.nested.map(b => c)
|
||||
case Nested(DetachableMap(a, b, c)) =>
|
||||
Some(Map(Nested(a), b, c))
|
||||
|
||||
// a.map(b => c).*join(d.map(e => f)).on((iA, iB) => on)
|
||||
// a.*join(d).on((b, e) => on[iA := c, iB := f]).map(t => (c[b := t._1], f[e := t._2]))
|
||||
case Join(
|
||||
tpe,
|
||||
DetachableMap(a, b, c),
|
||||
DetachableMap(d, e, f),
|
||||
iA,
|
||||
iB,
|
||||
on
|
||||
) =>
|
||||
val onr = BetaReduction(on, iA -> c, iB -> f)
|
||||
val t = Ident("t")
|
||||
val t1 = BetaReduction(c, b -> Property(t, "_1"))
|
||||
val t2 = BetaReduction(f, e -> Property(t, "_2"))
|
||||
Some(Map(Join(tpe, a, d, b, e, onr), t, Tuple(List(t1, t2))))
|
||||
|
||||
// a.*join(b.map(c => d)).on((iA, iB) => on)
|
||||
// a.*join(b).on((iA, c) => on[iB := d]).map(t => (t._1, d[c := t._2]))
|
||||
case Join(tpe, a, DetachableMap(b, c, d), iA, iB, on) =>
|
||||
val onr = BetaReduction(on, iB -> d)
|
||||
val t = Ident("t")
|
||||
val t1 = Property(t, "_1")
|
||||
val t2 = BetaReduction(d, c -> Property(t, "_2"))
|
||||
Some(Map(Join(tpe, a, b, iA, c, onr), t, Tuple(List(t1, t2))))
|
||||
|
||||
// a.map(b => c).*join(d).on((iA, iB) => on)
|
||||
// a.*join(d).on((b, iB) => on[iA := c]).map(t => (c[b := t._1], t._2))
|
||||
case Join(tpe, DetachableMap(a, b, c), d, iA, iB, on) =>
|
||||
val onr = BetaReduction(on, iA -> c)
|
||||
val t = Ident("t")
|
||||
val t1 = BetaReduction(c, b -> Property(t, "_1"))
|
||||
val t2 = Property(t, "_2")
|
||||
Some(Map(Join(tpe, a, d, b, iB, onr), t, Tuple(List(t1, t2))))
|
||||
|
||||
case other => None
|
||||
}
|
||||
}
|
48
src/main/scala/minisql/norm/AttachToEntity.scala
Normal file
48
src/main/scala/minisql/norm/AttachToEntity.scala
Normal file
|
@ -0,0 +1,48 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.util.Messages.fail
|
||||
import minisql.ast._
|
||||
|
||||
object AttachToEntity {
|
||||
|
||||
private object IsEntity {
|
||||
def unapply(q: Ast): Option[Ast] =
|
||||
q match {
|
||||
case q: Entity => Some(q)
|
||||
case q: Infix => Some(q)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)(
|
||||
q: Ast
|
||||
): Ast =
|
||||
q match {
|
||||
|
||||
case Map(IsEntity(a), b, c) => Map(f(a, b), b, c)
|
||||
case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c)
|
||||
case ConcatMap(IsEntity(a), b, c) => ConcatMap(f(a, b), b, c)
|
||||
case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c)
|
||||
case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d)
|
||||
case DistinctOn(IsEntity(a), b, c) => DistinctOn(f(a, b), b, c)
|
||||
|
||||
case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join |
|
||||
_: FlatJoin =>
|
||||
f(q, alias.getOrElse(Ident("x")))
|
||||
|
||||
case Map(a: Query, b, c) => Map(apply(f, Some(b))(a), b, c)
|
||||
case FlatMap(a: Query, b, c) => FlatMap(apply(f, Some(b))(a), b, c)
|
||||
case ConcatMap(a: Query, b, c) => ConcatMap(apply(f, Some(b))(a), b, c)
|
||||
case Filter(a: Query, b, c) => Filter(apply(f, Some(b))(a), b, c)
|
||||
case SortBy(a: Query, b, c, d) => SortBy(apply(f, Some(b))(a), b, c, d)
|
||||
case Take(a: Query, b) => Take(apply(f, alias)(a), b)
|
||||
case Drop(a: Query, b) => Drop(apply(f, alias)(a), b)
|
||||
case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a))
|
||||
case Distinct(a: Query) => Distinct(apply(f, alias)(a))
|
||||
case DistinctOn(a: Query, b, c) => DistinctOn(apply(f, Some(b))(a), b, c)
|
||||
|
||||
case IsEntity(q) => f(q, alias.getOrElse(Ident("x")))
|
||||
|
||||
case other => fail(s"Can't find an 'Entity' in '$q'")
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
package minisql.util
|
||||
package minisql.norm
|
||||
|
||||
import minisql.ast.*
|
||||
import minisql.ast._
|
||||
import scala.collection.immutable.{Map => IMap}
|
||||
|
||||
case class BetaReduction(replacements: Replacements)
|
7
src/main/scala/minisql/norm/ConcatBehavior.scala
Normal file
7
src/main/scala/minisql/norm/ConcatBehavior.scala
Normal file
|
@ -0,0 +1,7 @@
|
|||
package minisql.norm
|
||||
|
||||
trait ConcatBehavior
|
||||
object ConcatBehavior {
|
||||
case object AnsiConcat extends ConcatBehavior
|
||||
case object NonAnsiConcat extends ConcatBehavior
|
||||
}
|
7
src/main/scala/minisql/norm/EqualityBehavior.scala
Normal file
7
src/main/scala/minisql/norm/EqualityBehavior.scala
Normal file
|
@ -0,0 +1,7 @@
|
|||
package minisql.norm
|
||||
|
||||
trait EqualityBehavior
|
||||
object EqualityBehavior {
|
||||
case object AnsiEquality extends EqualityBehavior
|
||||
case object NonAnsiEquality extends EqualityBehavior
|
||||
}
|
74
src/main/scala/minisql/norm/ExpandReturning.scala
Normal file
74
src/main/scala/minisql/norm/ExpandReturning.scala
Normal file
|
@ -0,0 +1,74 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ReturnAction.ReturnColumns
|
||||
import minisql.{NamingStrategy, ReturnAction}
|
||||
import minisql.ast._
|
||||
import minisql.context.{
|
||||
ReturningClauseSupported,
|
||||
ReturningMultipleFieldSupported,
|
||||
ReturningNotSupported,
|
||||
ReturningSingleFieldSupported
|
||||
}
|
||||
import minisql.idiom.{Idiom, Statement}
|
||||
|
||||
/**
|
||||
* Take the `.returning` part in a query that contains it and return the array
|
||||
* of columns representing of the returning seccovtion with any other operations
|
||||
* etc... that they might contain.
|
||||
*/
|
||||
object ExpandReturning {
|
||||
|
||||
def applyMap(
|
||||
returning: ReturningAction
|
||||
)(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = {
|
||||
val initialExpand = ExpandReturning.apply(returning)(idiom, naming)
|
||||
|
||||
idiom.idiomReturningCapability match {
|
||||
case ReturningClauseSupported =>
|
||||
ReturnAction.ReturnRecord
|
||||
case ReturningMultipleFieldSupported =>
|
||||
ReturnColumns(initialExpand.map {
|
||||
case (ast, statement) => f(ast, statement)
|
||||
})
|
||||
case ReturningSingleFieldSupported =>
|
||||
if (initialExpand.length == 1)
|
||||
ReturnColumns(initialExpand.map {
|
||||
case (ast, statement) => f(ast, statement)
|
||||
})
|
||||
else
|
||||
throw new IllegalArgumentException(
|
||||
s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified."
|
||||
)
|
||||
case ReturningNotSupported =>
|
||||
throw new IllegalArgumentException(
|
||||
s"RETURNING columns are not allowed in the ${idiom} dialect."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def apply(
|
||||
returning: ReturningAction
|
||||
)(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = {
|
||||
val ReturningAction(_, alias, properties) = returning: @unchecked
|
||||
|
||||
// Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1))))
|
||||
// => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
|
||||
val dePropertized =
|
||||
Transform(properties) {
|
||||
case `alias` => ExternalIdent(alias.name)
|
||||
}
|
||||
|
||||
val aliasName = alias.name
|
||||
|
||||
// Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
|
||||
// => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))
|
||||
val deTuplified = dePropertized match {
|
||||
case Tuple(values) => values
|
||||
case CaseClass(values) => values.map(_._2)
|
||||
case other => List(other)
|
||||
}
|
||||
|
||||
implicit val namingStrategy: NamingStrategy = naming
|
||||
deTuplified.map(v => idiom.translate(v))
|
||||
}
|
||||
}
|
108
src/main/scala/minisql/norm/FlattenOptionOperation.scala
Normal file
108
src/main/scala/minisql/norm/FlattenOptionOperation.scala
Normal file
|
@ -0,0 +1,108 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast.*
|
||||
import minisql.ast.Implicits.*
|
||||
import minisql.norm.ConcatBehavior.NonAnsiConcat
|
||||
|
||||
class FlattenOptionOperation(concatBehavior: ConcatBehavior)
|
||||
extends StatelessTransformer {
|
||||
|
||||
private def emptyOrNot(b: Boolean, ast: Ast) =
|
||||
if (b) OptionIsEmpty(ast) else OptionNonEmpty(ast)
|
||||
|
||||
def uncheckedReduction(ast: Ast, alias: Ident, body: Ast) =
|
||||
apply(BetaReduction(body, alias -> ast))
|
||||
|
||||
def uncheckedForall(ast: Ast, alias: Ident, body: Ast) = {
|
||||
val reduced = BetaReduction(body, alias -> ast)
|
||||
apply((IsNullCheck(ast) +||+ reduced): Ast)
|
||||
}
|
||||
|
||||
def containsNonFallthroughElement(ast: Ast) =
|
||||
CollectAst(ast) {
|
||||
case If(_, _, _) => true
|
||||
case Infix(_, _, _, _) => true
|
||||
case BinaryOperation(_, StringOperator.`concat`, _)
|
||||
if (concatBehavior == NonAnsiConcat) =>
|
||||
true
|
||||
}.nonEmpty
|
||||
|
||||
override def apply(ast: Ast): Ast =
|
||||
ast match {
|
||||
|
||||
case OptionTableFlatMap(ast, alias, body) =>
|
||||
uncheckedReduction(ast, alias, body)
|
||||
|
||||
case OptionTableMap(ast, alias, body) =>
|
||||
uncheckedReduction(ast, alias, body)
|
||||
|
||||
case OptionTableExists(ast, alias, body) =>
|
||||
uncheckedReduction(ast, alias, body)
|
||||
|
||||
case OptionTableForall(ast, alias, body) =>
|
||||
uncheckedForall(ast, alias, body)
|
||||
|
||||
case OptionFlatten(ast) =>
|
||||
apply(ast)
|
||||
|
||||
case OptionSome(ast) =>
|
||||
apply(ast)
|
||||
|
||||
case OptionApply(ast) =>
|
||||
apply(ast)
|
||||
|
||||
case OptionOrNull(ast) =>
|
||||
apply(ast)
|
||||
|
||||
case OptionGetOrNull(ast) =>
|
||||
apply(ast)
|
||||
|
||||
case OptionNone => NullValue
|
||||
|
||||
case OptionGetOrElse(OptionMap(ast, alias, body), Constant(b: Boolean)) =>
|
||||
apply((BetaReduction(body, alias -> ast) +||+ emptyOrNot(b, ast)): Ast)
|
||||
|
||||
case OptionGetOrElse(ast, body) =>
|
||||
apply(If(IsNotNullCheck(ast), ast, body))
|
||||
|
||||
case OptionFlatMap(ast, alias, body) =>
|
||||
if (containsNonFallthroughElement(body)) {
|
||||
val reduced = BetaReduction(body, alias -> ast)
|
||||
apply(IfExistElseNull(ast, reduced))
|
||||
} else {
|
||||
uncheckedReduction(ast, alias, body)
|
||||
}
|
||||
|
||||
case OptionMap(ast, alias, body) =>
|
||||
if (containsNonFallthroughElement(body)) {
|
||||
val reduced = BetaReduction(body, alias -> ast)
|
||||
apply(IfExistElseNull(ast, reduced))
|
||||
} else {
|
||||
uncheckedReduction(ast, alias, body)
|
||||
}
|
||||
|
||||
case OptionForall(ast, alias, body) =>
|
||||
if (containsNonFallthroughElement(body)) {
|
||||
val reduction = BetaReduction(body, alias -> ast)
|
||||
apply(
|
||||
(IsNullCheck(ast) +||+ (IsNotNullCheck(ast) +&&+ reduction)): Ast
|
||||
)
|
||||
} else {
|
||||
uncheckedForall(ast, alias, body)
|
||||
}
|
||||
|
||||
case OptionExists(ast, alias, body) =>
|
||||
if (containsNonFallthroughElement(body)) {
|
||||
val reduction = BetaReduction(body, alias -> ast)
|
||||
apply((IsNotNullCheck(ast) +&&+ reduction): Ast)
|
||||
} else {
|
||||
uncheckedReduction(ast, alias, body)
|
||||
}
|
||||
|
||||
case OptionContains(ast, body) =>
|
||||
apply((ast +==+ body): Ast)
|
||||
|
||||
case other =>
|
||||
super.apply(other)
|
||||
}
|
||||
}
|
120
src/main/scala/minisql/norm/FreeVariables.scala
Normal file
120
src/main/scala/minisql/norm/FreeVariables.scala
Normal file
|
@ -0,0 +1,120 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast.*
|
||||
import collection.immutable.Set
|
||||
|
||||
case class State(seen: Set[Ident], free: Set[Ident])
|
||||
|
||||
case class FreeVariables(state: State) extends StatefulTransformer[State] {
|
||||
|
||||
override def apply(ast: Ast): (Ast, StatefulTransformer[State]) =
|
||||
ast match {
|
||||
case ident: Ident if (!state.seen.contains(ident)) =>
|
||||
(ident, FreeVariables(State(state.seen, state.free + ident)))
|
||||
case f @ Function(params, body) =>
|
||||
val (_, t) =
|
||||
FreeVariables(State(state.seen ++ params, state.free))(body)
|
||||
(f, FreeVariables(State(state.seen, state.free ++ t.state.free)))
|
||||
case q @ Foreach(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case other =>
|
||||
super.apply(other)
|
||||
}
|
||||
|
||||
override def apply(
|
||||
o: OptionOperation
|
||||
): (OptionOperation, StatefulTransformer[State]) =
|
||||
o match {
|
||||
case q @ OptionTableFlatMap(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ OptionTableMap(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ OptionTableExists(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ OptionTableForall(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ OptionFlatMap(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ OptionMap(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ OptionForall(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ OptionExists(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case other =>
|
||||
super.apply(other)
|
||||
}
|
||||
|
||||
override def apply(e: Assignment): (Assignment, StatefulTransformer[State]) =
|
||||
e match {
|
||||
case Assignment(a, b, c) =>
|
||||
val t = FreeVariables(State(state.seen + a, state.free))
|
||||
val (bt, btt) = t(b)
|
||||
val (ct, ctt) = t(c)
|
||||
(
|
||||
Assignment(a, bt, ct),
|
||||
FreeVariables(
|
||||
State(state.seen, state.free ++ btt.state.free ++ ctt.state.free)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
override def apply(action: Action): (Action, StatefulTransformer[State]) =
|
||||
action match {
|
||||
case q @ Returning(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case q @ ReturningGenerated(a, b, c) =>
|
||||
(q, free(a, b, c))
|
||||
case other =>
|
||||
super.apply(other)
|
||||
}
|
||||
|
||||
override def apply(
|
||||
e: OnConflict.Target
|
||||
): (OnConflict.Target, StatefulTransformer[State]) = (e, this)
|
||||
|
||||
override def apply(query: Query): (Query, StatefulTransformer[State]) =
|
||||
query match {
|
||||
case q @ Filter(a, b, c) => (q, free(a, b, c))
|
||||
case q @ Map(a, b, c) => (q, free(a, b, c))
|
||||
case q @ DistinctOn(a, b, c) => (q, free(a, b, c))
|
||||
case q @ FlatMap(a, b, c) => (q, free(a, b, c))
|
||||
case q @ ConcatMap(a, b, c) => (q, free(a, b, c))
|
||||
case q @ SortBy(a, b, c, d) => (q, free(a, b, c))
|
||||
case q @ GroupBy(a, b, c) => (q, free(a, b, c))
|
||||
case q @ FlatJoin(t, a, b, c) => (q, free(a, b, c))
|
||||
case q @ Join(t, a, b, iA, iB, on) =>
|
||||
val (_, freeA) = apply(a)
|
||||
val (_, freeB) = apply(b)
|
||||
val (_, freeOn) =
|
||||
FreeVariables(State(state.seen + iA + iB, Set.empty))(on)
|
||||
(
|
||||
q,
|
||||
FreeVariables(
|
||||
State(
|
||||
state.seen,
|
||||
state.free ++ freeA.state.free ++ freeB.state.free ++ freeOn.state.free
|
||||
)
|
||||
)
|
||||
)
|
||||
case _: Entity | _: Take | _: Drop | _: Union | _: UnionAll |
|
||||
_: Aggregation | _: Distinct | _: Nested =>
|
||||
super.apply(query)
|
||||
}
|
||||
|
||||
private def free(a: Ast, ident: Ident, c: Ast) = {
|
||||
val (_, ta) = apply(a)
|
||||
val (_, tc) = FreeVariables(State(state.seen + ident, state.free))(c)
|
||||
FreeVariables(
|
||||
State(state.seen, state.free ++ ta.state.free ++ tc.state.free)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
object FreeVariables {
|
||||
def apply(ast: Ast): Set[Ident] =
|
||||
new FreeVariables(State(Set.empty, Set.empty))(ast) match {
|
||||
case (_, transformer) =>
|
||||
transformer.state.free
|
||||
}
|
||||
}
|
76
src/main/scala/minisql/norm/NestImpureMappedInfix.scala
Normal file
76
src/main/scala/minisql/norm/NestImpureMappedInfix.scala
Normal file
|
@ -0,0 +1,76 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast._
|
||||
|
||||
/**
|
||||
* A problem occurred in the original way infixes were done in that it was
|
||||
* assumed that infix clauses represented pure functions. While this is true of
|
||||
* many UDFs (e.g. `CONCAT`, `GETDATE`) it is certainly not true of many others
|
||||
* e.g. `RAND()`, and most importantly `RANK()`. For this reason, the operations
|
||||
* that are done in `ApplyMap` on standard AST `Map` clauses cannot be done
|
||||
* therefore additional safety checks were introduced there in order to assure
|
||||
* this does not happen. In addition to this however, it is necessary to add
|
||||
* this normalization step which inserts `Nested` AST elements in every map that
|
||||
* contains impure infix. See more information and examples in #1534.
|
||||
*/
|
||||
object NestImpureMappedInfix extends StatelessTransformer {
|
||||
|
||||
// Are there any impure infixes that exist inside the specified ASTs
|
||||
def hasInfix(asts: Ast*): Boolean =
|
||||
asts.exists(ast =>
|
||||
CollectAst(ast) {
|
||||
case i @ Infix(_, _, false, _) => i
|
||||
}.nonEmpty
|
||||
)
|
||||
|
||||
// Continue exploring into the Map to see if there are additional impure infix clauses inside.
|
||||
private def applyInside(m: Map) =
|
||||
Map(apply(m.query), m.alias, m.body)
|
||||
|
||||
override def apply(ast: Ast): Ast =
|
||||
ast match {
|
||||
// If there is already a nested clause inside the map, there is no reason to insert another one
|
||||
case Nested(Map(inner, a, b)) =>
|
||||
Nested(Map(apply(inner), a, b))
|
||||
|
||||
case m @ Map(_, x, cc @ CaseClass(values)) if hasInfix(cc) => // Nested(m)
|
||||
Map(
|
||||
Nested(applyInside(m)),
|
||||
x,
|
||||
CaseClass(values.map {
|
||||
case (name, _) =>
|
||||
(
|
||||
name,
|
||||
Property(x, name)
|
||||
) // mappings of nested-query case class properties should not be renamed
|
||||
})
|
||||
)
|
||||
|
||||
case m @ Map(_, x, tup @ Tuple(values)) if hasInfix(tup) =>
|
||||
Map(
|
||||
Nested(applyInside(m)),
|
||||
x,
|
||||
Tuple(values.zipWithIndex.map {
|
||||
case (_, i) =>
|
||||
Property(
|
||||
x,
|
||||
s"_${i + 1}"
|
||||
) // mappings of nested-query tuple properties should not be renamed
|
||||
})
|
||||
)
|
||||
|
||||
case m @ Map(_, x, i @ Infix(_, _, false, _)) =>
|
||||
Map(Nested(applyInside(m)), x, Property(x, "_1"))
|
||||
|
||||
case m @ Map(_, x, Property(prop, _)) if hasInfix(prop) =>
|
||||
Map(Nested(applyInside(m)), x, Property(x, "_1"))
|
||||
|
||||
case m @ Map(_, x, BinaryOperation(a, _, b)) if hasInfix(a, b) =>
|
||||
Map(Nested(applyInside(m)), x, Property(x, "_1"))
|
||||
|
||||
case m @ Map(_, x, UnaryOperation(_, a)) if hasInfix(a) =>
|
||||
Map(Nested(applyInside(m)), x, Property(x, "_1"))
|
||||
|
||||
case other => super.apply(other)
|
||||
}
|
||||
}
|
51
src/main/scala/minisql/norm/Normalize.scala
Normal file
51
src/main/scala/minisql/norm/Normalize.scala
Normal file
|
@ -0,0 +1,51 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast.Ast
|
||||
import minisql.ast.Query
|
||||
import minisql.ast.StatelessTransformer
|
||||
import minisql.norm.capture.AvoidCapture
|
||||
import minisql.ast.Action
|
||||
import minisql.util.Messages.trace
|
||||
import minisql.util.Messages.TraceType.Normalizations
|
||||
|
||||
import scala.annotation.tailrec
|
||||
|
||||
object Normalize extends StatelessTransformer {
|
||||
|
||||
override def apply(q: Ast): Ast =
|
||||
super.apply(BetaReduction(q))
|
||||
|
||||
override def apply(q: Action): Action =
|
||||
NormalizeReturning(super.apply(q))
|
||||
|
||||
override def apply(q: Query): Query =
|
||||
norm(AvoidCapture(q))
|
||||
|
||||
private def traceNorm[T](label: String) =
|
||||
trace[T](s"${label} (Normalize)", 1, Normalizations)
|
||||
|
||||
@tailrec
|
||||
private def norm(q: Query): Query =
|
||||
q match {
|
||||
case NormalizeNestedStructures(query) =>
|
||||
traceNorm("NormalizeNestedStructures")(query)
|
||||
norm(query)
|
||||
case ApplyMap(query) =>
|
||||
traceNorm("ApplyMap")(query)
|
||||
norm(query)
|
||||
case SymbolicReduction(query) =>
|
||||
traceNorm("SymbolicReduction")(query)
|
||||
norm(query)
|
||||
case AdHocReduction(query) =>
|
||||
traceNorm("AdHocReduction")(query)
|
||||
norm(query)
|
||||
case OrderTerms(query) =>
|
||||
traceNorm("OrderTerms")(query)
|
||||
norm(query)
|
||||
case NormalizeAggregationIdent(query) =>
|
||||
traceNorm("NormalizeAggregationIdent")(query)
|
||||
norm(query)
|
||||
case other =>
|
||||
other
|
||||
}
|
||||
}
|
29
src/main/scala/minisql/norm/NormalizeAggregationIdent.scala
Normal file
29
src/main/scala/minisql/norm/NormalizeAggregationIdent.scala
Normal file
|
@ -0,0 +1,29 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast._
|
||||
|
||||
object NormalizeAggregationIdent {
|
||||
|
||||
def unapply(q: Query) =
|
||||
q match {
|
||||
|
||||
// a => a.b.map(x => x.c).agg =>
|
||||
// a => a.b.map(a => a.c).agg
|
||||
case Aggregation(
|
||||
op,
|
||||
Map(
|
||||
p @ Property(i: Ident, _),
|
||||
mi,
|
||||
Property.Opinionated(_: Ident, n, renameable, visibility)
|
||||
)
|
||||
) if i != mi =>
|
||||
Some(
|
||||
Aggregation(
|
||||
op,
|
||||
Map(p, i, Property.Opinionated(i, n, renameable, visibility))
|
||||
)
|
||||
) // in example aove, if c in x.c is fixed c in a.c should also be
|
||||
|
||||
case _ => None
|
||||
}
|
||||
}
|
47
src/main/scala/minisql/norm/NormalizeNestedStructures.scala
Normal file
47
src/main/scala/minisql/norm/NormalizeNestedStructures.scala
Normal file
|
@ -0,0 +1,47 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast._
|
||||
|
||||
object NormalizeNestedStructures {
|
||||
|
||||
def unapply(q: Query): Option[Query] =
|
||||
q match {
|
||||
case e: Entity => None
|
||||
case Map(a, b, c) => apply(a, c)(Map(_, b, _))
|
||||
case FlatMap(a, b, c) => apply(a, c)(FlatMap(_, b, _))
|
||||
case ConcatMap(a, b, c) => apply(a, c)(ConcatMap(_, b, _))
|
||||
case Filter(a, b, c) => apply(a, c)(Filter(_, b, _))
|
||||
case SortBy(a, b, c, d) => apply(a, c)(SortBy(_, b, _, d))
|
||||
case GroupBy(a, b, c) => apply(a, c)(GroupBy(_, b, _))
|
||||
case Aggregation(a, b) => apply(b)(Aggregation(a, _))
|
||||
case Take(a, b) => apply(a, b)(Take.apply)
|
||||
case Drop(a, b) => apply(a, b)(Drop.apply)
|
||||
case Union(a, b) => apply(a, b)(Union.apply)
|
||||
case UnionAll(a, b) => apply(a, b)(UnionAll.apply)
|
||||
case Distinct(a) => apply(a)(Distinct.apply)
|
||||
case DistinctOn(a, b, c) => apply(a, c)(DistinctOn(_, b, _))
|
||||
case Nested(a) => apply(a)(Nested.apply)
|
||||
case FlatJoin(t, a, iA, on) =>
|
||||
(Normalize(a), Normalize(on)) match {
|
||||
case (`a`, `on`) => None
|
||||
case (a, on) => Some(FlatJoin(t, a, iA, on))
|
||||
}
|
||||
case Join(t, a, b, iA, iB, on) =>
|
||||
(Normalize(a), Normalize(b), Normalize(on)) match {
|
||||
case (`a`, `b`, `on`) => None
|
||||
case (a, b, on) => Some(Join(t, a, b, iA, iB, on))
|
||||
}
|
||||
}
|
||||
|
||||
private def apply(a: Ast)(f: Ast => Query) =
|
||||
(Normalize(a)) match {
|
||||
case (`a`) => None
|
||||
case (a) => Some(f(a))
|
||||
}
|
||||
|
||||
private def apply(a: Ast, b: Ast)(f: (Ast, Ast) => Query) =
|
||||
(Normalize(a), Normalize(b)) match {
|
||||
case (`a`, `b`) => None
|
||||
case (a, b) => Some(f(a, b))
|
||||
}
|
||||
}
|
154
src/main/scala/minisql/norm/NormalizeReturning.scala
Normal file
154
src/main/scala/minisql/norm/NormalizeReturning.scala
Normal file
|
@ -0,0 +1,154 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.norm.capture.AvoidAliasConflict
|
||||
|
||||
/**
|
||||
* When actions are used with a `.returning` clause, remove the columns used in
|
||||
* the returning clause from the action. E.g. for `insert(Person(id,
|
||||
* name)).returning(_.id)` remove the `id` column from the original insert.
|
||||
*/
|
||||
object NormalizeReturning {
|
||||
|
||||
def apply(e: Action): Action = {
|
||||
e match {
|
||||
case ReturningGenerated(a: Action, alias, body) =>
|
||||
// De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action.
|
||||
// This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres.
|
||||
// For example:
|
||||
// query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
|
||||
// Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally
|
||||
// be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case.
|
||||
// In order to fix this, we need to change `t` into a different alias.
|
||||
val newBody = dealiasBody(body, alias)
|
||||
ReturningGenerated(apply(a, newBody, alias), alias, newBody)
|
||||
|
||||
// For a regular return clause, do not need to exclude assignments from insertion however, we still
|
||||
// need to de-alias the Action body in case conflicts result. For example the following query:
|
||||
// query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
|
||||
// would incorrectly be interpreted as:
|
||||
// INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t
|
||||
// whereas it should be:
|
||||
// INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1)
|
||||
case Returning(a: Action, alias, body) =>
|
||||
val newBody = dealiasBody(body, alias)
|
||||
Returning(a, alias, newBody)
|
||||
|
||||
case _ => e
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* In some situations, a query can exist inside of a `returning` clause. In
|
||||
* this case, we need to rename if the aliases used in that query override the
|
||||
* alias used in the `returning` clause otherwise they will be treated as
|
||||
* returning-clause aliases ExpandReturning (i.e. they will become
|
||||
* ExternalAlias instances) and later be tokenized incorrectly.
|
||||
*/
|
||||
private def dealiasBody(body: Ast, alias: Ident): Ast =
|
||||
Transform(body) {
|
||||
case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias))
|
||||
}
|
||||
|
||||
private def apply(e: Action, body: Ast, returningIdent: Ident): Action =
|
||||
e match {
|
||||
case Insert(query, assignments) =>
|
||||
Insert(query, filterReturnedColumn(assignments, body, returningIdent))
|
||||
case Update(query, assignments) =>
|
||||
Update(query, filterReturnedColumn(assignments, body, returningIdent))
|
||||
case OnConflict(a: Action, target, act) =>
|
||||
OnConflict(apply(a, body, returningIdent), target, act)
|
||||
case _ => e
|
||||
}
|
||||
|
||||
private def filterReturnedColumn(
|
||||
assignments: List[Assignment],
|
||||
column: Ast,
|
||||
returningIdent: Ident
|
||||
): List[Assignment] =
|
||||
assignments.flatMap(filterReturnedColumn(_, column, returningIdent))
|
||||
|
||||
/**
|
||||
* In situations like Property(Property(ident, foo), bar) pull out the
|
||||
* inner-most ident
|
||||
*/
|
||||
object NestedProperty {
|
||||
def unapply(ast: Property): Option[Ast] = {
|
||||
ast match {
|
||||
case p @ Property(subAst, _) => Some(innerMost(subAst))
|
||||
}
|
||||
}
|
||||
|
||||
private def innerMost(ast: Ast): Ast = ast match {
|
||||
case Property(inner, _) => innerMost(inner)
|
||||
case other => other
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove the specified column from the assignment. For example, in a query
|
||||
* like `insert(Person(id, name)).returning(r => r.id)` we need to remove the
|
||||
* `id` column from the insertion. The value of the `column:Ast` in this case
|
||||
* will be `Property(Ident(r), id)` and the values fo the assignment `p1`
|
||||
* property will typically be `v.id` and `v.name` (the `v` variable is a
|
||||
* default used for `insert` queries).
|
||||
*/
|
||||
private def filterReturnedColumn(
|
||||
assignment: Assignment,
|
||||
body: Ast,
|
||||
returningIdent: Ident
|
||||
): Option[Assignment] =
|
||||
assignment match {
|
||||
case Assignment(_, p1: Property, _) => {
|
||||
// Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but
|
||||
// if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2))
|
||||
// or it can even be a query since queries are allowed to be in return sections e.g:
|
||||
// query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max))
|
||||
// In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment
|
||||
// in order to know what to exclude.
|
||||
val matchedProps =
|
||||
CollectAst(body) {
|
||||
// case prop @ NestedProperty(`returningIdent`) => prop
|
||||
case prop @ NestedProperty(Ident(name))
|
||||
if (name == returningIdent.name) =>
|
||||
prop
|
||||
case prop @ NestedProperty(ExternalIdent(name))
|
||||
if (name == returningIdent.name) =>
|
||||
prop
|
||||
}
|
||||
|
||||
if (
|
||||
matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp))
|
||||
)
|
||||
None
|
||||
else
|
||||
Some(assignment)
|
||||
}
|
||||
case assignment => Some(assignment)
|
||||
}
|
||||
|
||||
object SomeIdent {
|
||||
def unapply(ast: Ast): Option[Ast] =
|
||||
ast match {
|
||||
case id: Ident => Some(id)
|
||||
case id: ExternalIdent => Some(id)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Is it the same property (but possibly of a different identity). E.g.
|
||||
* `p.foo.bar` and `v.foo.bar`
|
||||
*/
|
||||
private def isSameProperties(p1: Property, p2: Property): Boolean =
|
||||
(p1.ast, p2.ast) match {
|
||||
case (SomeIdent(_), SomeIdent(_)) =>
|
||||
p1.name == p2.name
|
||||
// If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the
|
||||
// outer properties are the same before moving on to the inner ones.
|
||||
case (pp1: Property, pp2: Property) if (p1.name == p2.name) =>
|
||||
isSameProperties(pp1, pp2)
|
||||
case _ =>
|
||||
false
|
||||
}
|
||||
}
|
29
src/main/scala/minisql/norm/OrderTerms.scala
Normal file
29
src/main/scala/minisql/norm/OrderTerms.scala
Normal file
|
@ -0,0 +1,29 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast._
|
||||
|
||||
object OrderTerms {
|
||||
|
||||
def unapply(q: Query) =
|
||||
q match {
|
||||
|
||||
case Take(Map(a: GroupBy, b, c), d) => None
|
||||
|
||||
// a.sortBy(b => c).filter(d => e) =>
|
||||
// a.filter(d => e).sortBy(b => c)
|
||||
case Filter(SortBy(a, b, c, d), e, f) =>
|
||||
Some(SortBy(Filter(a, e, f), b, c, d))
|
||||
|
||||
// a.flatMap(b => c).take(n).map(d => e) =>
|
||||
// a.flatMap(b => c).map(d => e).take(n)
|
||||
case Map(Take(fm: FlatMap, n), ma, mb) =>
|
||||
Some(Take(Map(fm, ma, mb), n))
|
||||
|
||||
// a.flatMap(b => c).drop(n).map(d => e) =>
|
||||
// a.flatMap(b => c).map(d => e).drop(n)
|
||||
case Map(Drop(fm: FlatMap, n), ma, mb) =>
|
||||
Some(Drop(Map(fm, ma, mb), n))
|
||||
|
||||
case other => None
|
||||
}
|
||||
}
|
491
src/main/scala/minisql/norm/RenameProperties.scala
Normal file
491
src/main/scala/minisql/norm/RenameProperties.scala
Normal file
|
@ -0,0 +1,491 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast.Renameable.Fixed
|
||||
import minisql.ast.Visibility.Visible
|
||||
import minisql.ast._
|
||||
import minisql.util.Interpolator
|
||||
|
||||
object RenameProperties extends StatelessTransformer {
|
||||
val interp = new Interpolator(3)
|
||||
import interp._
|
||||
def traceDifferent(one: Any, two: Any) =
|
||||
if (one != two)
|
||||
trace"Replaced $one with $two".andLog()
|
||||
else
|
||||
trace"Replacements did not match".andLog()
|
||||
|
||||
override def apply(q: Query): Query =
|
||||
applySchemaOnly(q)
|
||||
|
||||
override def apply(q: Action): Action =
|
||||
applySchema(q) match {
|
||||
case (q, schema) => q
|
||||
}
|
||||
|
||||
override def apply(e: Operation): Operation =
|
||||
e match {
|
||||
case UnaryOperation(o, c: Query) =>
|
||||
UnaryOperation(o, applySchemaOnly(apply(c)))
|
||||
case _ => super.apply(e)
|
||||
}
|
||||
|
||||
private def applySchema(q: Action): (Action, Schema) =
|
||||
q match {
|
||||
case Insert(q: Query, assignments) =>
|
||||
applySchema(q, assignments, Insert.apply)
|
||||
case Update(q: Query, assignments) =>
|
||||
applySchema(q, assignments, Update.apply)
|
||||
case Delete(q: Query) =>
|
||||
applySchema(q) match {
|
||||
case (q, schema) => (Delete(q), schema)
|
||||
}
|
||||
case Returning(action: Action, alias, body) =>
|
||||
applySchema(action) match {
|
||||
case (action, schema) =>
|
||||
val replace =
|
||||
trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
|
||||
replacements(alias, schema)
|
||||
val bodyr = BetaReduction(body, replace*)
|
||||
traceDifferent(body, bodyr)
|
||||
(Returning(action, alias, bodyr), schema)
|
||||
}
|
||||
case ReturningGenerated(action: Action, alias, body) =>
|
||||
applySchema(action) match {
|
||||
case (action, schema) =>
|
||||
val replace =
|
||||
trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
|
||||
replacements(alias, schema)
|
||||
val bodyr = BetaReduction(body, replace*)
|
||||
traceDifferent(body, bodyr)
|
||||
(ReturningGenerated(action, alias, bodyr), schema)
|
||||
}
|
||||
case OnConflict(a: Action, target, act) =>
|
||||
applySchema(a) match {
|
||||
case (action, schema) =>
|
||||
val targetR = target match {
|
||||
case OnConflict.Properties(props) =>
|
||||
val propsR = props.map { prop =>
|
||||
val replace =
|
||||
trace"Finding Replacements for $props inside ${prop.ast} using schema $schema:" `andReturn`
|
||||
replacements(
|
||||
prop.ast,
|
||||
schema
|
||||
) // A BetaReduction on a Property will always give back a Property
|
||||
BetaReduction(prop, replace*).asInstanceOf[Property]
|
||||
}
|
||||
traceDifferent(props, propsR)
|
||||
OnConflict.Properties(propsR)
|
||||
case OnConflict.NoTarget => target
|
||||
}
|
||||
val actR = act match {
|
||||
case OnConflict.Update(assignments) =>
|
||||
OnConflict.Update(replaceAssignments(assignments, schema))
|
||||
case _ => act
|
||||
}
|
||||
(OnConflict(action, targetR, actR), schema)
|
||||
}
|
||||
case q => (q, TupleSchema.empty)
|
||||
}
|
||||
|
||||
private def replaceAssignments(
|
||||
a: List[Assignment],
|
||||
schema: Schema
|
||||
): List[Assignment] =
|
||||
a.map {
|
||||
case Assignment(alias, prop, value) =>
|
||||
val replace =
|
||||
trace"Finding Replacements for $prop inside $alias using schema $schema:" `andReturn`
|
||||
replacements(alias, schema)
|
||||
val propR = BetaReduction(prop, replace*)
|
||||
traceDifferent(prop, propR)
|
||||
val valueR = BetaReduction(value, replace*)
|
||||
traceDifferent(value, valueR)
|
||||
Assignment(alias, propR, valueR)
|
||||
}
|
||||
|
||||
private def applySchema(
|
||||
q: Query,
|
||||
a: List[Assignment],
|
||||
f: (Query, List[Assignment]) => Action
|
||||
): (Action, Schema) =
|
||||
applySchema(q) match {
|
||||
case (q, schema) =>
|
||||
(f(q, replaceAssignments(a, schema)), schema)
|
||||
}
|
||||
|
||||
private def applySchemaOnly(q: Query): Query =
|
||||
applySchema(q) match {
|
||||
case (q, _) => q
|
||||
}
|
||||
|
||||
object TupleIndex {
|
||||
def unapply(s: String): Option[Int] =
|
||||
if (s.matches("_[0-9]*"))
|
||||
Some(s.drop(1).toInt - 1)
|
||||
else
|
||||
None
|
||||
}
|
||||
|
||||
sealed trait Schema {
|
||||
def lookup(property: List[String]): Option[Schema] =
|
||||
(property, this) match {
|
||||
case (Nil, schema) =>
|
||||
trace"Nil at $property returning " `andReturn`
|
||||
Some(schema)
|
||||
case (path, e @ EntitySchema(_)) =>
|
||||
trace"Entity at $path returning " `andReturn`
|
||||
Some(e.subSchemaOrEmpty(path))
|
||||
case (head :: tail, CaseClassSchema(props)) if (props.contains(head)) =>
|
||||
trace"Case class at $property returning " `andReturn`
|
||||
props(head).lookup(tail)
|
||||
case (TupleIndex(idx) :: tail, TupleSchema(values))
|
||||
if values.contains(idx) =>
|
||||
trace"Tuple at at $property returning " `andReturn`
|
||||
values(idx).lookup(tail)
|
||||
case _ =>
|
||||
trace"Nothing found at $property returning " `andReturn`
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Represents a nested property path to an identity i.e. Property(Property(... Ident(), ...))
|
||||
object PropertyMatroshka {
|
||||
|
||||
def traverse(initial: Property): Option[(Ident, List[String])] =
|
||||
initial match {
|
||||
// If it's a nested-property walk inside and append the name to the result (if something is returned)
|
||||
case Property(inner: Property, name) =>
|
||||
traverse(inner).map { case (id, list) => (id, list :+ name) }
|
||||
// If it's a property with ident in the core, return that
|
||||
case Property(id: Ident, name) =>
|
||||
Some((id, List(name)))
|
||||
// Otherwise an ident property is not inside so don't return anything
|
||||
case _ =>
|
||||
None
|
||||
}
|
||||
|
||||
def unapply(ast: Ast): Option[(Ident, List[String])] =
|
||||
ast match {
|
||||
case p: Property => traverse(p)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def protractSchema(
|
||||
body: Ast,
|
||||
ident: Ident,
|
||||
schema: Schema
|
||||
): Option[Schema] = {
|
||||
|
||||
def protractSchemaRecurse(body: Ast, schema: Schema): Option[Schema] =
|
||||
body match {
|
||||
// if any values yield a sub-schema which is not an entity, recurse into that
|
||||
case cc @ CaseClass(values) =>
|
||||
trace"Protracting CaseClass $cc into new schema:" `andReturn`
|
||||
CaseClassSchema(
|
||||
values.collect {
|
||||
case (name, innerBody @ HierarchicalAstEntity()) =>
|
||||
(name, protractSchemaRecurse(innerBody, schema))
|
||||
// pass the schema into a recursive call an extract from it when we non tuple/caseclass element
|
||||
case (name, innerBody @ PropertyMatroshka(`ident`, path)) =>
|
||||
(name, protractSchemaRecurse(innerBody, schema))
|
||||
// we have reached an ident i.e. recurse to pass the current schema into the case class
|
||||
case (name, `ident`) =>
|
||||
(name, protractSchemaRecurse(ident, schema))
|
||||
}.collect {
|
||||
case (name, Some(subSchema)) => (name, subSchema)
|
||||
}
|
||||
).notEmpty
|
||||
case tup @ Tuple(values) =>
|
||||
trace"Protracting Tuple $tup into new schema:" `andReturn`
|
||||
TupleSchema
|
||||
.fromIndexes(
|
||||
values.zipWithIndex.collect {
|
||||
case (innerBody @ HierarchicalAstEntity(), index) =>
|
||||
(index, protractSchemaRecurse(innerBody, schema))
|
||||
// pass the schema into a recursive call an extract from it when we non tuple/caseclass element
|
||||
case (innerBody @ PropertyMatroshka(`ident`, path), index) =>
|
||||
(index, protractSchemaRecurse(innerBody, schema))
|
||||
// we have reached an ident i.e. recurse to pass the current schema into the tuple
|
||||
case (`ident`, index) =>
|
||||
(index, protractSchemaRecurse(ident, schema))
|
||||
}.collect {
|
||||
case (index, Some(subSchema)) => (index, subSchema)
|
||||
}
|
||||
)
|
||||
.notEmpty
|
||||
|
||||
case prop @ PropertyMatroshka(`ident`, path) =>
|
||||
trace"Protraction completed schema path $prop at the schema $schema pointing to:" `andReturn`
|
||||
schema match {
|
||||
// case e: EntitySchema => Some(e)
|
||||
case _ => schema.lookup(path)
|
||||
}
|
||||
case `ident` =>
|
||||
trace"Protraction completed with the mapping identity $ident at the schema:" `andReturn`
|
||||
Some(schema)
|
||||
case other =>
|
||||
trace"Protraction DID NOT find a sub schema, it completed with $other at the schema:" `andReturn`
|
||||
Some(schema)
|
||||
}
|
||||
|
||||
protractSchemaRecurse(body, schema)
|
||||
}
|
||||
|
||||
case object EmptySchema extends Schema
|
||||
case class EntitySchema(e: Entity) extends Schema {
|
||||
def noAliases = e.properties.isEmpty
|
||||
|
||||
private def subSchema(path: List[String]) =
|
||||
EntitySchema(
|
||||
Entity(
|
||||
s"sub-${e.name}",
|
||||
e.properties.flatMap {
|
||||
case PropertyAlias(aliasPath, alias) =>
|
||||
if (aliasPath == path)
|
||||
List(PropertyAlias(aliasPath, alias))
|
||||
else if (aliasPath.startsWith(path))
|
||||
List(PropertyAlias(aliasPath.diff(path), alias))
|
||||
else
|
||||
List()
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def subSchemaOrEmpty(path: List[String]): Schema =
|
||||
trace"Creating sub-schema for entity $e at path $path will be" andReturn {
|
||||
val sub = subSchema(path)
|
||||
if (sub.noAliases) EmptySchema else sub
|
||||
}
|
||||
|
||||
}
|
||||
case class TupleSchema(m: collection.Map[Int, Schema] /* Zero Indexed */ )
|
||||
extends Schema {
|
||||
def list = m.toList.sortBy(_._1)
|
||||
def notEmpty =
|
||||
if (this.m.nonEmpty) Some(this) else None
|
||||
}
|
||||
case class CaseClassSchema(m: collection.Map[String, Schema]) extends Schema {
|
||||
def list = m.toList
|
||||
def notEmpty =
|
||||
if (this.m.nonEmpty) Some(this) else None
|
||||
}
|
||||
object CaseClassSchema {
|
||||
def apply(property: String, value: Schema): CaseClassSchema =
|
||||
CaseClassSchema(collection.Map(property -> value))
|
||||
def apply(list: List[(String, Schema)]): CaseClassSchema =
|
||||
CaseClassSchema(list.toMap)
|
||||
}
|
||||
|
||||
object TupleSchema {
|
||||
def fromIndexes(schemas: List[(Int, Schema)]): TupleSchema =
|
||||
TupleSchema(schemas.toMap)
|
||||
|
||||
def apply(schemas: List[Schema]): TupleSchema =
|
||||
TupleSchema(schemas.zipWithIndex.map(_.swap).toMap)
|
||||
|
||||
def apply(index: Int, schema: Schema): TupleSchema =
|
||||
TupleSchema(collection.Map(index -> schema))
|
||||
|
||||
def empty: TupleSchema = TupleSchema(List.empty)
|
||||
}
|
||||
|
||||
object HierarchicalAstEntity {
|
||||
def unapply(ast: Ast): Boolean =
|
||||
ast match {
|
||||
case cc: CaseClass => true
|
||||
case tup: Tuple => true
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
private def applySchema(q: Query): (Query, Schema) = {
|
||||
q match {
|
||||
|
||||
// Don't understand why this is needed....
|
||||
case Map(q: Query, x, p) =>
|
||||
applySchema(q) match {
|
||||
case (q, subSchema) =>
|
||||
val replace =
|
||||
trace"Looking for possible replacements for $p inside $x using schema $subSchema:" `andReturn`
|
||||
replacements(x, subSchema)
|
||||
val pr = BetaReduction(p, replace*)
|
||||
traceDifferent(p, pr)
|
||||
val prr = apply(pr)
|
||||
traceDifferent(pr, prr)
|
||||
|
||||
val schema =
|
||||
trace"Protracting Hierarchical Entity $prr into sub-schema: $subSchema" `andReturn` {
|
||||
protractSchema(prr, x, subSchema)
|
||||
}.getOrElse(EmptySchema)
|
||||
|
||||
(Map(q, x, prr), schema)
|
||||
}
|
||||
|
||||
case e: Entity => (e, EntitySchema(e))
|
||||
case Filter(q: Query, x, p) => applySchema(q, x, p, Filter.apply)
|
||||
case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o))
|
||||
case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy.apply)
|
||||
case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _))
|
||||
case Take(q: Query, n) => applySchema(q, Take(_, n))
|
||||
case Drop(q: Query, n) => applySchema(q, Drop(_, n))
|
||||
case Nested(q: Query) => applySchema(q, Nested.apply)
|
||||
case Distinct(q: Query) => applySchema(q, Distinct.apply)
|
||||
case DistinctOn(q: Query, iA, on) => applySchema(q, DistinctOn(_, iA, on))
|
||||
|
||||
case FlatMap(q: Query, x, p) =>
|
||||
applySchema(q, x, p, FlatMap.apply) match {
|
||||
case (FlatMap(q, x, p: Query), oldSchema) =>
|
||||
val (pr, newSchema) = applySchema(p)
|
||||
(FlatMap(q, x, pr), newSchema)
|
||||
case (flatMap, oldSchema) =>
|
||||
(flatMap, TupleSchema.empty)
|
||||
}
|
||||
|
||||
case ConcatMap(q: Query, x, p) =>
|
||||
applySchema(q, x, p, ConcatMap.apply) match {
|
||||
case (ConcatMap(q, x, p: Query), oldSchema) =>
|
||||
val (pr, newSchema) = applySchema(p)
|
||||
(ConcatMap(q, x, pr), newSchema)
|
||||
case (concatMap, oldSchema) =>
|
||||
(concatMap, TupleSchema.empty)
|
||||
}
|
||||
|
||||
case Join(typ, a: Query, b: Query, iA, iB, on) =>
|
||||
(applySchema(a), applySchema(b)) match {
|
||||
case ((a, schemaA), (b, schemaB)) =>
|
||||
val combinedReplacements =
|
||||
trace"Finding Replacements for $on inside ${(iA, iB)} using schemas ${(schemaA, schemaB)}:" andReturn {
|
||||
val replaceA = replacements(iA, schemaA)
|
||||
val replaceB = replacements(iB, schemaB)
|
||||
replaceA ++ replaceB
|
||||
}
|
||||
val onr = BetaReduction(on, combinedReplacements*)
|
||||
traceDifferent(on, onr)
|
||||
(Join(typ, a, b, iA, iB, onr), TupleSchema(List(schemaA, schemaB)))
|
||||
}
|
||||
|
||||
case FlatJoin(typ, a: Query, iA, on) =>
|
||||
applySchema(a) match {
|
||||
case (a, schemaA) =>
|
||||
val replaceA =
|
||||
trace"Finding Replacements for $on inside $iA using schema $schemaA:" `andReturn`
|
||||
replacements(iA, schemaA)
|
||||
val onr = BetaReduction(on, replaceA*)
|
||||
traceDifferent(on, onr)
|
||||
(FlatJoin(typ, a, iA, onr), schemaA)
|
||||
}
|
||||
|
||||
case Map(q: Operation, x, p) if x == p =>
|
||||
(Map(apply(q), x, p), TupleSchema.empty)
|
||||
|
||||
case Map(Infix(parts, params, pure, paren), x, p) =>
|
||||
val transformed =
|
||||
params.map {
|
||||
case q: Query =>
|
||||
val (qr, schema) = applySchema(q)
|
||||
traceDifferent(q, qr)
|
||||
(qr, Some(schema))
|
||||
case q =>
|
||||
(q, None)
|
||||
}
|
||||
|
||||
val schema =
|
||||
transformed.collect {
|
||||
case (_, Some(schema)) => schema
|
||||
} match {
|
||||
case e :: Nil => e
|
||||
case ls => TupleSchema(ls)
|
||||
}
|
||||
val replace =
|
||||
trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
|
||||
replacements(x, schema)
|
||||
val pr = BetaReduction(p, replace*)
|
||||
traceDifferent(p, pr)
|
||||
val prr = apply(pr)
|
||||
traceDifferent(pr, prr)
|
||||
|
||||
(Map(Infix(parts, transformed.map(_._1), pure, paren), x, prr), schema)
|
||||
|
||||
case q =>
|
||||
(q, TupleSchema.empty)
|
||||
}
|
||||
}
|
||||
|
||||
private def applySchema(ast: Query, f: Ast => Query): (Query, Schema) =
|
||||
applySchema(ast) match {
|
||||
case (ast, schema) =>
|
||||
(f(ast), schema)
|
||||
}
|
||||
|
||||
private def applySchema[T](
|
||||
q: Query,
|
||||
x: Ident,
|
||||
p: Ast,
|
||||
f: (Ast, Ident, Ast) => T
|
||||
): (T, Schema) =
|
||||
applySchema(q) match {
|
||||
case (q, schema) =>
|
||||
val replace =
|
||||
trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
|
||||
replacements(x, schema)
|
||||
val pr = BetaReduction(p, replace*)
|
||||
traceDifferent(p, pr)
|
||||
val prr = apply(pr)
|
||||
traceDifferent(pr, prr)
|
||||
(f(q, x, prr), schema)
|
||||
}
|
||||
|
||||
private def replacements(base: Ast, schema: Schema): Seq[(Ast, Ast)] =
|
||||
schema match {
|
||||
// The entity renameable property should already have been marked as Fixed
|
||||
case EntitySchema(Entity(entity, properties)) =>
|
||||
// trace"%4 Entity Schema: " andReturn
|
||||
properties.flatMap {
|
||||
// A property alias means that there was either a querySchema(tableName, _.propertyName -> PropertyAlias)
|
||||
// or a schemaMeta (which ultimately gets turned into a querySchema) which is the same thing but implicit.
|
||||
// In this case, we want to rename the properties based on the property aliases as well as mark
|
||||
// them Fixed since they should not be renamed based on
|
||||
// the naming strategy wherever they are tokenized (e.g. in SqlIdiom)
|
||||
case PropertyAlias(path, alias) =>
|
||||
def apply(base: Ast, path: List[String]): Ast =
|
||||
path match {
|
||||
case Nil => base
|
||||
case head :: tail => apply(Property(base, head), tail)
|
||||
}
|
||||
List(
|
||||
apply(base, path) -> Property.Opinionated(
|
||||
base,
|
||||
alias,
|
||||
Fixed,
|
||||
Visible
|
||||
) // Hidden properties cannot be renamed
|
||||
)
|
||||
}
|
||||
case tup: TupleSchema =>
|
||||
// trace"%4 Tuple Schema: " andReturn
|
||||
tup.list.flatMap {
|
||||
case (idx, value) =>
|
||||
replacements(
|
||||
// Should not matter whether property is fixed or variable here
|
||||
// since beta reduction ignores that
|
||||
Property(base, s"_${idx + 1}"),
|
||||
value
|
||||
)
|
||||
}
|
||||
case cc: CaseClassSchema =>
|
||||
// trace"%4 CaseClass Schema: " andReturn
|
||||
cc.list.flatMap {
|
||||
case (property, value) =>
|
||||
replacements(
|
||||
// Should not matter whether property is fixed or variable here
|
||||
// since beta reduction ignores that
|
||||
Property(base, property),
|
||||
value
|
||||
)
|
||||
}
|
||||
// Do nothing if it is an empty schema
|
||||
case EmptySchema => List()
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package minisql.util
|
||||
package minisql.norm
|
||||
|
||||
import minisql.ast.Ast
|
||||
import scala.collection.immutable.Map
|
124
src/main/scala/minisql/norm/SimplifyNullChecks.scala
Normal file
124
src/main/scala/minisql/norm/SimplifyNullChecks.scala
Normal file
|
@ -0,0 +1,124 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast.*
|
||||
import minisql.norm.EqualityBehavior.AnsiEquality
|
||||
|
||||
/**
|
||||
* Due to the introduction of null checks in `map`, `flatMap`, and `exists`, in
|
||||
* `FlattenOptionOperation` in order to resolve #1053, as well as to support
|
||||
* non-ansi compliant string concatenation as outlined in #1295, large
|
||||
* conditional composites became common. For example: <code><pre> case class
|
||||
* Holder(value:Option[String])
|
||||
*
|
||||
* // The following statement query[Holder].map(h => h.value.map(_ + "foo")) //
|
||||
* Will yield the following result SELECT CASE WHEN h.value IS NOT NULL THEN
|
||||
* h.value || 'foo' ELSE null END FROM Holder h </pre></code> Now, let's add a
|
||||
* <code>getOrElse</code> statement to the clause that requires an additional
|
||||
* wrapped null check. We cannot rely on there being a <code>map</code> call
|
||||
* beforehand since we could be reading <code>value</code> as a nullable field
|
||||
* directly from the database). <code><pre> // The following statement
|
||||
* query[Holder].map(h => h.value.map(_ + "foo").getOrElse("bar")) // Yields the
|
||||
* following result: SELECT CASE WHEN CASE WHEN h.value IS NOT NULL THEN h.value
|
||||
* \|| 'foo' ELSE null END IS NOT NULL THEN CASE WHEN h.value IS NOT NULL THEN
|
||||
* h.value || 'foo' ELSE null END ELSE 'bar' END FROM Holder h </pre></code>
|
||||
* This of course is highly redundant and can be reduced to simply: <code><pre>
|
||||
* SELECT CASE WHEN h.value IS NOT NULL AND (h.value || 'foo') IS NOT NULL THEN
|
||||
* h.value || 'foo' ELSE 'bar' END FROM Holder h </pre></code> This reduction is
|
||||
* done by the "Center Rule." There are some other simplification rules as well.
|
||||
* Note how we are force to null-check both `h.value` as well as `(h.value ||
|
||||
* 'foo')` because a user may use `Option[T].flatMap` and explicitly transform a
|
||||
* particular value to `null`.
|
||||
*/
|
||||
class SimplifyNullChecks(equalityBehavior: EqualityBehavior)
|
||||
extends StatelessTransformer {
|
||||
|
||||
override def apply(ast: Ast): Ast = {
|
||||
import minisql.ast.Implicits.*
|
||||
ast match {
|
||||
// Center rule
|
||||
case IfExist(
|
||||
IfExistElseNull(condA, thenA),
|
||||
IfExistElseNull(condB, thenB),
|
||||
otherwise
|
||||
) if (condA == condB && thenA == thenB) =>
|
||||
apply(
|
||||
If(IsNotNullCheck(condA) +&&+ IsNotNullCheck(thenA), thenA, otherwise)
|
||||
)
|
||||
|
||||
// Left hand rule
|
||||
case IfExist(IfExistElseNull(check, affirm), value, otherwise) =>
|
||||
apply(
|
||||
If(
|
||||
IsNotNullCheck(check) +&&+ IsNotNullCheck(affirm),
|
||||
value,
|
||||
otherwise
|
||||
)
|
||||
)
|
||||
|
||||
// Right hand rule
|
||||
case IfExistElseNull(cond, IfExistElseNull(innerCond, innerThen)) =>
|
||||
apply(
|
||||
If(
|
||||
IsNotNullCheck(cond) +&&+ IsNotNullCheck(innerCond),
|
||||
innerThen,
|
||||
NullValue
|
||||
)
|
||||
)
|
||||
|
||||
case OptionIsDefined(Optional(a)) +&&+ OptionIsDefined(
|
||||
Optional(b)
|
||||
) +&&+ (exp @ (Optional(a1) `== or !=` Optional(b1)))
|
||||
if (a == a1 && b == b1 && equalityBehavior == AnsiEquality) =>
|
||||
apply(exp)
|
||||
|
||||
case OptionIsDefined(Optional(a)) +&&+ (exp @ (Optional(
|
||||
a1
|
||||
) `== or !=` Optional(_)))
|
||||
if (a == a1 && equalityBehavior == AnsiEquality) =>
|
||||
apply(exp)
|
||||
case OptionIsDefined(Optional(b)) +&&+ (exp @ (Optional(
|
||||
_
|
||||
) `== or !=` Optional(b1)))
|
||||
if (b == b1 && equalityBehavior == AnsiEquality) =>
|
||||
apply(exp)
|
||||
|
||||
case (left +&&+ OptionIsEmpty(Optional(Constant(_)))) +||+ other =>
|
||||
apply(other)
|
||||
case (OptionIsEmpty(Optional(Constant(_))) +&&+ right) +||+ other =>
|
||||
apply(other)
|
||||
case other +||+ (left +&&+ OptionIsEmpty(Optional(Constant(_)))) =>
|
||||
apply(other)
|
||||
case other +||+ (OptionIsEmpty(Optional(Constant(_))) +&&+ right) =>
|
||||
apply(other)
|
||||
|
||||
case (left +&&+ OptionIsDefined(Optional(Constant(_)))) => apply(left)
|
||||
case (OptionIsDefined(Optional(Constant(_))) +&&+ right) => apply(right)
|
||||
case (left +||+ OptionIsEmpty(Optional(Constant(_)))) => apply(left)
|
||||
case (OptionIsEmpty(OptionSome(Optional(_))) +||+ right) => apply(right)
|
||||
|
||||
case other =>
|
||||
super.apply(other)
|
||||
}
|
||||
}
|
||||
|
||||
object `== or !=` {
|
||||
def unapply(ast: Ast): Option[(Ast, Ast)] = ast match {
|
||||
case a +==+ b => Some((a, b))
|
||||
case a +!=+ b => Some((a, b))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple extractor that looks inside of an optional values to see if the
|
||||
* thing inside can be pulled out. If not, it just returns whatever element it
|
||||
* can find.
|
||||
*/
|
||||
object Optional {
|
||||
def unapply(a: Ast): Option[Ast] = a match {
|
||||
case OptionApply(value) => Some(value)
|
||||
case OptionSome(value) => Some(value)
|
||||
case value => Some(value)
|
||||
}
|
||||
}
|
||||
}
|
38
src/main/scala/minisql/norm/SymbolicReduction.scala
Normal file
38
src/main/scala/minisql/norm/SymbolicReduction.scala
Normal file
|
@ -0,0 +1,38 @@
|
|||
package minisql.norm
|
||||
|
||||
import minisql.ast.Filter
|
||||
import minisql.ast.FlatMap
|
||||
import minisql.ast.Query
|
||||
import minisql.ast.Union
|
||||
import minisql.ast.UnionAll
|
||||
|
||||
object SymbolicReduction {
|
||||
|
||||
def unapply(q: Query) =
|
||||
q match {
|
||||
|
||||
// a.filter(b => c).flatMap(d => e.$) =>
|
||||
// a.flatMap(d => e.filter(_ => c[b := d]).$)
|
||||
case FlatMap(Filter(a, b, c), d, e: Query) =>
|
||||
val cr = BetaReduction(c, b -> d)
|
||||
val er = AttachToEntity(Filter(_, _, cr))(e)
|
||||
Some(FlatMap(a, d, er))
|
||||
|
||||
// a.flatMap(b => c).flatMap(d => e) =>
|
||||
// a.flatMap(b => c.flatMap(d => e))
|
||||
case FlatMap(FlatMap(a, b, c), d, e) =>
|
||||
Some(FlatMap(a, b, FlatMap(c, d, e)))
|
||||
|
||||
// a.union(b).flatMap(c => d)
|
||||
// a.flatMap(c => d).union(b.flatMap(c => d))
|
||||
case FlatMap(Union(a, b), c, d) =>
|
||||
Some(Union(FlatMap(a, c, d), FlatMap(b, c, d)))
|
||||
|
||||
// a.unionAll(b).flatMap(c => d)
|
||||
// a.flatMap(c => d).unionAll(b.flatMap(c => d))
|
||||
case FlatMap(UnionAll(a, b), c, d) =>
|
||||
Some(UnionAll(FlatMap(a, c, d), FlatMap(b, c, d)))
|
||||
|
||||
case other => None
|
||||
}
|
||||
}
|
174
src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala
Normal file
174
src/main/scala/minisql/norm/capture/AvoidAliasConflict.scala
Normal file
|
@ -0,0 +1,174 @@
|
|||
package minisql.norm.capture
|
||||
|
||||
import minisql.ast.{
|
||||
Entity,
|
||||
Filter,
|
||||
FlatJoin,
|
||||
FlatMap,
|
||||
GroupBy,
|
||||
Ident,
|
||||
Join,
|
||||
Map,
|
||||
Query,
|
||||
SortBy,
|
||||
StatefulTransformer,
|
||||
_
|
||||
}
|
||||
import minisql.norm.{BetaReduction, Normalize}
|
||||
import scala.collection.immutable.Set
|
||||
|
||||
private[minisql] case class AvoidAliasConflict(state: Set[Ident])
|
||||
extends StatefulTransformer[Set[Ident]] {
|
||||
|
||||
object Unaliased {
|
||||
|
||||
private def isUnaliased(q: Ast): Boolean =
|
||||
q match {
|
||||
case Nested(q: Query) => isUnaliased(q)
|
||||
case Take(q: Query, _) => isUnaliased(q)
|
||||
case Drop(q: Query, _) => isUnaliased(q)
|
||||
case Aggregation(_, q: Query) => isUnaliased(q)
|
||||
case Distinct(q: Query) => isUnaliased(q)
|
||||
case _: Entity | _: Infix => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def unapply(q: Ast): Option[Ast] =
|
||||
q match {
|
||||
case q if (isUnaliased(q)) => Some(q)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
override def apply(q: Query): (Query, StatefulTransformer[Set[Ident]]) =
|
||||
q match {
|
||||
|
||||
case FlatMap(Unaliased(q), x, p) =>
|
||||
apply(x, p)(FlatMap(q, _, _))
|
||||
|
||||
case ConcatMap(Unaliased(q), x, p) =>
|
||||
apply(x, p)(ConcatMap(q, _, _))
|
||||
|
||||
case Map(Unaliased(q), x, p) =>
|
||||
apply(x, p)(Map(q, _, _))
|
||||
|
||||
case Filter(Unaliased(q), x, p) =>
|
||||
apply(x, p)(Filter(q, _, _))
|
||||
|
||||
case SortBy(Unaliased(q), x, p, o) =>
|
||||
apply(x, p)(SortBy(q, _, _, o))
|
||||
|
||||
case GroupBy(Unaliased(q), x, p) =>
|
||||
apply(x, p)(GroupBy(q, _, _))
|
||||
|
||||
case DistinctOn(Unaliased(q), x, p) =>
|
||||
apply(x, p)(DistinctOn(q, _, _))
|
||||
|
||||
case Join(t, a, b, iA, iB, o) =>
|
||||
val (ar, art) = apply(a)
|
||||
val (br, brt) = art.apply(b)
|
||||
val freshA = freshIdent(iA, brt.state)
|
||||
val freshB = freshIdent(iB, brt.state + freshA)
|
||||
val or = BetaReduction(o, iA -> freshA, iB -> freshB)
|
||||
val (orr, orrt) = AvoidAliasConflict(brt.state + freshA + freshB)(or)
|
||||
(Join(t, ar, br, freshA, freshB, orr), orrt)
|
||||
|
||||
case FlatJoin(t, a, iA, o) =>
|
||||
val (ar, art) = apply(a)
|
||||
val freshA = freshIdent(iA)
|
||||
val or = BetaReduction(o, iA -> freshA)
|
||||
val (orr, orrt) = AvoidAliasConflict(art.state + freshA)(or)
|
||||
(FlatJoin(t, ar, freshA, orr), orrt)
|
||||
|
||||
case _: Entity | _: FlatMap | _: ConcatMap | _: Map | _: Filter |
|
||||
_: SortBy | _: GroupBy | _: Aggregation | _: Take | _: Drop |
|
||||
_: Union | _: UnionAll | _: Distinct | _: DistinctOn | _: Nested =>
|
||||
super.apply(q)
|
||||
}
|
||||
|
||||
private def apply(x: Ident, p: Ast)(
|
||||
f: (Ident, Ast) => Query
|
||||
): (Query, StatefulTransformer[Set[Ident]]) = {
|
||||
val fresh = freshIdent(x)
|
||||
val pr = BetaReduction(p, x -> fresh)
|
||||
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
|
||||
(f(fresh, prr), t)
|
||||
}
|
||||
|
||||
private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = {
|
||||
def loop(x: Ident, n: Int): Ident = {
|
||||
val fresh = Ident(s"${x.name}$n")
|
||||
if (!state.contains(fresh))
|
||||
fresh
|
||||
else
|
||||
loop(x, n + 1)
|
||||
}
|
||||
if (!state.contains(x))
|
||||
x
|
||||
else
|
||||
loop(x, 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* Sometimes we need to change the variables in a function because they will
|
||||
* might conflict with some variable further up in the macro. Right now, this
|
||||
* only happens when you do something like this: <code> val q = quote { (v:
|
||||
* Foo) => query[Foo].insert(v) } run(q(lift(v))) </code> Since 'v' is used by
|
||||
* actionMeta in order to map keys to values for insertion, using it as a
|
||||
* function argument messes up the output SQL like so: <code> INSERT INTO
|
||||
* MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?) </code>
|
||||
* Therefore, we need to have a method to remove such conflicting variables
|
||||
* from Function ASTs
|
||||
*/
|
||||
private def applyFunction(f: Function): Function = {
|
||||
val (newBody, _, newParams) =
|
||||
f.params.foldLeft((f.body, state, List[Ident]())) {
|
||||
case ((body, state, newParams), param) => {
|
||||
val fresh = freshIdent(param)
|
||||
val pr = BetaReduction(body, param -> fresh)
|
||||
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
|
||||
(prr, t.state, newParams :+ fresh)
|
||||
}
|
||||
}
|
||||
Function(newParams, newBody)
|
||||
}
|
||||
|
||||
private def applyForeach(f: Foreach): Foreach = {
|
||||
val fresh = freshIdent(f.alias)
|
||||
val pr = BetaReduction(f.body, f.alias -> fresh)
|
||||
val (prr, _) = AvoidAliasConflict(state + fresh)(pr)
|
||||
Foreach(f.query, fresh, prr)
|
||||
}
|
||||
}
|
||||
|
||||
private[minisql] object AvoidAliasConflict {
|
||||
|
||||
def apply(q: Query): Query =
|
||||
AvoidAliasConflict(Set[Ident]())(q) match {
|
||||
case (q, _) => q
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure query parameters do not collide with paramters of a AST function.
|
||||
* Do this by walkning through the function's subtree and transforming and
|
||||
* queries encountered.
|
||||
*/
|
||||
def sanitizeVariables(
|
||||
f: Function,
|
||||
dangerousVariables: Set[Ident]
|
||||
): Function = {
|
||||
AvoidAliasConflict(dangerousVariables).applyFunction(f)
|
||||
}
|
||||
|
||||
/** Same is `sanitizeVariables` but for Foreach * */
|
||||
def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = {
|
||||
AvoidAliasConflict(dangerousVariables).applyForeach(f)
|
||||
}
|
||||
|
||||
def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = {
|
||||
AvoidAliasConflict(dangerousVariables).apply(q) match {
|
||||
// Propagate aliasing changes to the rest of the query
|
||||
case (q, _) => Normalize(q)
|
||||
}
|
||||
}
|
||||
}
|
9
src/main/scala/minisql/norm/capture/AvoidCapture.scala
Normal file
9
src/main/scala/minisql/norm/capture/AvoidCapture.scala
Normal file
|
@ -0,0 +1,9 @@
|
|||
package minisql.norm.capture
|
||||
|
||||
import minisql.ast.Query
|
||||
|
||||
object AvoidCapture {
|
||||
|
||||
def apply(q: Query): Query =
|
||||
Dealias(AvoidAliasConflict(q))
|
||||
}
|
72
src/main/scala/minisql/norm/capture/Dealias.scala
Normal file
72
src/main/scala/minisql/norm/capture/Dealias.scala
Normal file
|
@ -0,0 +1,72 @@
|
|||
package minisql.norm.capture
|
||||
|
||||
import minisql.ast._
|
||||
import minisql.norm.BetaReduction
|
||||
|
||||
case class Dealias(state: Option[Ident])
|
||||
extends StatefulTransformer[Option[Ident]] {
|
||||
|
||||
override def apply(q: Query): (Query, StatefulTransformer[Option[Ident]]) =
|
||||
q match {
|
||||
case FlatMap(a, b, c) =>
|
||||
dealias(a, b, c)(FlatMap.apply) match {
|
||||
case (FlatMap(a, b, c), _) =>
|
||||
val (cn, cnt) = apply(c)
|
||||
(FlatMap(a, b, cn), cnt)
|
||||
}
|
||||
case ConcatMap(a, b, c) =>
|
||||
dealias(a, b, c)(ConcatMap.apply) match {
|
||||
case (ConcatMap(a, b, c), _) =>
|
||||
val (cn, cnt) = apply(c)
|
||||
(ConcatMap(a, b, cn), cnt)
|
||||
}
|
||||
case Map(a, b, c) =>
|
||||
dealias(a, b, c)(Map.apply)
|
||||
case Filter(a, b, c) =>
|
||||
dealias(a, b, c)(Filter.apply)
|
||||
case SortBy(a, b, c, d) =>
|
||||
dealias(a, b, c)(SortBy(_, _, _, d))
|
||||
case GroupBy(a, b, c) =>
|
||||
dealias(a, b, c)(GroupBy.apply)
|
||||
case DistinctOn(a, b, c) =>
|
||||
dealias(a, b, c)(DistinctOn.apply)
|
||||
case Take(a, b) =>
|
||||
val (an, ant) = apply(a)
|
||||
(Take(an, b), ant)
|
||||
case Drop(a, b) =>
|
||||
val (an, ant) = apply(a)
|
||||
(Drop(an, b), ant)
|
||||
case Union(a, b) =>
|
||||
val (an, _) = apply(a)
|
||||
val (bn, _) = apply(b)
|
||||
(Union(an, bn), Dealias(None))
|
||||
case UnionAll(a, b) =>
|
||||
val (an, _) = apply(a)
|
||||
val (bn, _) = apply(b)
|
||||
(UnionAll(an, bn), Dealias(None))
|
||||
case Join(t, a, b, iA, iB, o) =>
|
||||
val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _))
|
||||
val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _))
|
||||
(Join(t, an, bn, iAn, iBn, onn), Dealias(None))
|
||||
case FlatJoin(t, a, iA, o) =>
|
||||
val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _))
|
||||
(FlatJoin(t, an, iAn, on), Dealias(Some(iA)))
|
||||
case _: Entity | _: Distinct | _: Aggregation | _: Nested =>
|
||||
(q, Dealias(None))
|
||||
}
|
||||
|
||||
private def dealias[T](a: Ast, b: Ident, c: Ast)(f: (Ast, Ident, Ast) => T) =
|
||||
apply(a) match {
|
||||
case (an, t @ Dealias(Some(alias))) =>
|
||||
(f(an, alias, BetaReduction(c, b -> alias)), t)
|
||||
case other =>
|
||||
(f(a, b, c), Dealias(Some(b)))
|
||||
}
|
||||
}
|
||||
|
||||
object Dealias {
|
||||
def apply(query: Query) =
|
||||
new Dealias(None)(query) match {
|
||||
case (q, _) => q
|
||||
}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package minisql.norm.capture
|
||||
|
||||
import minisql.ast.*
|
||||
|
||||
/**
|
||||
* Walk through any Queries that a returning clause has and replace Ident of the
|
||||
* returning variable with ExternalIdent so that in later steps involving filter
|
||||
* simplification, it will not be mistakenly dealiased with a potential shadow.
|
||||
* Take this query for instance: <pre> query[TestEntity]
|
||||
* .insert(lift(TestEntity("s", 0, 1L, None))) .returningGenerated( r =>
|
||||
* (query[Dummy].filter(r => r.i == r.i).filter(d => d.i == r.i).max) ) </pre>
|
||||
* The returning clause has an alias `Ident("r")` as well as the first filter
|
||||
* clause. These two filters will be combined into one at which point the
|
||||
* meaning of `r.i` in the 2nd filter will be confused for the first filter's
|
||||
* alias (i.e. the `r` in `filter(r => ...)`. Therefore, we need to change this
|
||||
* vunerable `r.i` in the second filter clause to an `ExternalIdent` before any
|
||||
* of the simplifications are done.
|
||||
*
|
||||
* Note that we only want to do this for Queries inside of a `Returning` clause
|
||||
* body. Other places where this needs to be done (e.g. in a Tuple that
|
||||
* `Returning` returns) are done in `ExpandReturning`.
|
||||
*/
|
||||
private[minisql] case class DemarcateExternalAliases(externalIdent: Ident)
|
||||
extends StatelessTransformer {
|
||||
|
||||
def applyNonOverride(idents: Ident*)(ast: Ast) =
|
||||
if (idents.forall(_ != externalIdent)) apply(ast)
|
||||
else ast
|
||||
|
||||
override def apply(ast: Ast): Ast = ast match {
|
||||
|
||||
case FlatMap(q, i, b) =>
|
||||
FlatMap(apply(q), i, applyNonOverride(i)(b))
|
||||
|
||||
case ConcatMap(q, i, b) =>
|
||||
ConcatMap(apply(q), i, applyNonOverride(i)(b))
|
||||
|
||||
case Map(q, i, b) =>
|
||||
Map(apply(q), i, applyNonOverride(i)(b))
|
||||
|
||||
case Filter(q, i, b) =>
|
||||
Filter(apply(q), i, applyNonOverride(i)(b))
|
||||
|
||||
case SortBy(q, i, p, o) =>
|
||||
SortBy(apply(q), i, applyNonOverride(i)(p), o)
|
||||
|
||||
case GroupBy(q, i, b) =>
|
||||
GroupBy(apply(q), i, applyNonOverride(i)(b))
|
||||
|
||||
case DistinctOn(q, i, b) =>
|
||||
DistinctOn(apply(q), i, applyNonOverride(i)(b))
|
||||
|
||||
case Join(t, a, b, iA, iB, o) =>
|
||||
Join(t, a, b, iA, iB, applyNonOverride(iA, iB)(o))
|
||||
|
||||
case FlatJoin(t, a, iA, o) =>
|
||||
FlatJoin(t, a, iA, applyNonOverride(iA)(o))
|
||||
|
||||
case p @ Property.Opinionated(
|
||||
id @ Ident(_),
|
||||
value,
|
||||
renameable,
|
||||
visibility
|
||||
) =>
|
||||
if (id == externalIdent)
|
||||
Property.Opinionated(
|
||||
ExternalIdent(externalIdent.name),
|
||||
value,
|
||||
renameable,
|
||||
visibility
|
||||
)
|
||||
else
|
||||
p
|
||||
|
||||
case other =>
|
||||
super.apply(other)
|
||||
}
|
||||
}
|
||||
|
||||
object DemarcateExternalAliases {
|
||||
|
||||
private def demarcateQueriesInBody(id: Ident, body: Ast) =
|
||||
Transform(body) {
|
||||
// Apply to the AST defined apply method about, not to the superclass method that takes Query
|
||||
case q: Query =>
|
||||
new DemarcateExternalAliases(id).apply(q.asInstanceOf[Ast])
|
||||
}
|
||||
|
||||
def apply(ast: Ast): Ast = ast match {
|
||||
case Returning(a, id, body) =>
|
||||
Returning(a, id, demarcateQueriesInBody(id, body))
|
||||
case ReturningGenerated(a, id, body) =>
|
||||
val d = demarcateQueriesInBody(id, body)
|
||||
ReturningGenerated(a, id, demarcateQueriesInBody(id, body))
|
||||
case other =>
|
||||
other
|
||||
}
|
||||
}
|
66
src/main/scala/minisql/parsing/BlockParsing.scala
Normal file
66
src/main/scala/minisql/parsing/BlockParsing.scala
Normal file
|
@ -0,0 +1,66 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import scala.quoted.*
|
||||
|
||||
type SParser[X] =
|
||||
(q: Quotes) ?=> PartialFunction[q.reflect.Statement, Expr[X]]
|
||||
|
||||
private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using
|
||||
Quotes
|
||||
): SParser[ast.Ast] = {
|
||||
|
||||
import quotes.reflect.*
|
||||
|
||||
@annotation.nowarn
|
||||
lazy val valDefParser: SParser[ast.Val] = {
|
||||
case ValDef(n, _, Some(b)) =>
|
||||
val body = astParser(b.asExpr)
|
||||
'{ ast.Val(ast.Ident(${ Expr(n) }), $body) }
|
||||
|
||||
}
|
||||
valDefParser
|
||||
}
|
||||
|
||||
private[parsing] def parseBlockList(
|
||||
astParser: => Parser[ast.Ast],
|
||||
e: Expr[Any]
|
||||
)(using Quotes): List[Expr[ast.Ast]] = {
|
||||
import quotes.reflect.*
|
||||
|
||||
lazy val statementParser = statementParsing(astParser)
|
||||
|
||||
e.asTerm match {
|
||||
case Block(st, t) =>
|
||||
(st :+ t).map {
|
||||
case e if e.isExpr => astParser(e.asExpr)
|
||||
case `statementParser`(x) => x
|
||||
case o =>
|
||||
report.errorAndAbort(s"Cannot parse statement: ${o.show}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[parsing] def blockParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.Ast] = {
|
||||
|
||||
import quotes.reflect.*
|
||||
|
||||
lazy val statementParser = statementParsing(astParser)
|
||||
|
||||
termParser {
|
||||
case Block(Nil, t) => astParser(t.asExpr)
|
||||
case b @ Block(st, t) =>
|
||||
val asts = (st :+ t).map {
|
||||
case e if e.isExpr => astParser(e.asExpr)
|
||||
case `statementParser`(x) => x
|
||||
case o =>
|
||||
report.errorAndAbort(s"Cannot parse statement: ${o.show}")
|
||||
}
|
||||
if (asts.size > 1) {
|
||||
'{ ast.Block(${ Expr.ofList(asts) }) }
|
||||
} else asts(0)
|
||||
|
||||
}
|
||||
}
|
31
src/main/scala/minisql/parsing/BoxingParsing.scala
Normal file
31
src/main/scala/minisql/parsing/BoxingParsing.scala
Normal file
|
@ -0,0 +1,31 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import scala.quoted.*
|
||||
|
||||
private[parsing] def boxingParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.Ast] = {
|
||||
case '{ BigDecimal.int2bigDecimal($v) } => astParser(v)
|
||||
case '{ BigDecimal.long2bigDecimal($v) } => astParser(v)
|
||||
case '{ BigDecimal.double2bigDecimal($v) } => astParser(v)
|
||||
case '{ BigDecimal.javaBigDecimal2bigDecimal($v) } => astParser(v)
|
||||
case '{ Predef.byte2Byte($v) } => astParser(v)
|
||||
case '{ Predef.short2Short($v) } => astParser(v)
|
||||
case '{ Predef.char2Character($v) } => astParser(v)
|
||||
case '{ Predef.int2Integer($v) } => astParser(v)
|
||||
case '{ Predef.long2Long($v) } => astParser(v)
|
||||
case '{ Predef.float2Float($v) } => astParser(v)
|
||||
case '{ Predef.double2Double($v) } => astParser(v)
|
||||
case '{ Predef.boolean2Boolean($v) } => astParser(v)
|
||||
case '{ Predef.augmentString($v) } => astParser(v)
|
||||
case '{ Predef.Byte2byte($v) } => astParser(v)
|
||||
case '{ Predef.Short2short($v) } => astParser(v)
|
||||
case '{ Predef.Character2char($v) } => astParser(v)
|
||||
case '{ Predef.Integer2int($v) } => astParser(v)
|
||||
case '{ Predef.Long2long($v) } => astParser(v)
|
||||
case '{ Predef.Float2float($v) } => astParser(v)
|
||||
case '{ Predef.Double2double($v) } => astParser(v)
|
||||
case '{ Predef.Boolean2boolean($v) } => astParser(v)
|
||||
|
||||
}
|
29
src/main/scala/minisql/parsing/InfixParsing.scala
Normal file
29
src/main/scala/minisql/parsing/InfixParsing.scala
Normal file
|
@ -0,0 +1,29 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import scala.quoted.*
|
||||
|
||||
private[parsing] def infixParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.Infix] = {
|
||||
import quotes.reflect.*
|
||||
{
|
||||
case '{ ($x: minisql.InfixValue).as[t] } => infixParsing(astParser)(x)
|
||||
case '{
|
||||
minisql.infix(StringContext(${ Varargs(partsExprs) }*))(${
|
||||
Varargs(argsExprs)
|
||||
}*)
|
||||
} =>
|
||||
val parts = partsExprs.map { p =>
|
||||
p.value.getOrElse(
|
||||
report.errorAndAbort(
|
||||
s"Expected a string literal in StringContext parts, but got: ${p.show}"
|
||||
)
|
||||
)
|
||||
}.toList
|
||||
|
||||
val params = argsExprs.map(arg => astParser(arg)).toList
|
||||
|
||||
'{ ast.Infix(${ Expr(parts) }, ${ Expr.ofList(params) }, true, false) }
|
||||
}
|
||||
}
|
17
src/main/scala/minisql/parsing/LiftParsing.scala
Normal file
17
src/main/scala/minisql/parsing/LiftParsing.scala
Normal file
|
@ -0,0 +1,17 @@
|
|||
package minisql.parsing
|
||||
|
||||
import scala.quoted.*
|
||||
import minisql.ParamEncoder
|
||||
import minisql.ast
|
||||
import minisql.*
|
||||
import minisql.util.*
|
||||
|
||||
private[parsing] def liftParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.Lift] = {
|
||||
case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } =>
|
||||
import quotes.reflect.*
|
||||
val name = x.show
|
||||
val liftId = liftIdOfExpr(x)
|
||||
'{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) }
|
||||
}
|
113
src/main/scala/minisql/parsing/OperationParsing.scala
Normal file
113
src/main/scala/minisql/parsing/OperationParsing.scala
Normal file
|
@ -0,0 +1,113 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import minisql.ast.{
|
||||
EqualityOperator,
|
||||
StringOperator,
|
||||
NumericOperator,
|
||||
BooleanOperator
|
||||
}
|
||||
import minisql.*
|
||||
import scala.quoted._
|
||||
|
||||
private[parsing] def operationParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.Operation] = {
|
||||
import quotes.reflect.*
|
||||
|
||||
def isNumeric(t: TypeRepr) = {
|
||||
t <:< TypeRepr.of[Int]
|
||||
|| t <:< TypeRepr.of[Long]
|
||||
|| t <:< TypeRepr.of[Byte]
|
||||
|| t <:< TypeRepr.of[Float]
|
||||
|| t <:< TypeRepr.of[Double]
|
||||
|| t <:< TypeRepr.of[java.math.BigDecimal]
|
||||
|| t <:< TypeRepr.of[scala.math.BigDecimal]
|
||||
}
|
||||
|
||||
def parseBinary(
|
||||
left: Expr[Any],
|
||||
right: Expr[Any],
|
||||
op: Expr[ast.BinaryOperator]
|
||||
) = {
|
||||
val leftE = astParser(left)
|
||||
val rightE = astParser(right)
|
||||
'{ ast.BinaryOperation(${ leftE }, ${ op }, ${ rightE }) }
|
||||
}
|
||||
|
||||
def parseUnary(expr: Expr[Any], op: Expr[ast.UnaryOperator]) = {
|
||||
val base = astParser(expr)
|
||||
'{ ast.UnaryOperation($op, $base) }
|
||||
|
||||
}
|
||||
|
||||
val universalOpParser: Parser[ast.BinaryOperation] = termParser {
|
||||
case Apply(Select(leftT, UniversalOp(op)), List(rightT)) =>
|
||||
parseBinary(leftT.asExpr, rightT.asExpr, op)
|
||||
}
|
||||
|
||||
val stringOpParser: Parser[ast.Operation] = {
|
||||
case '{ ($x: String) + ($y: String) } =>
|
||||
parseBinary(x, y, '{ StringOperator.concat })
|
||||
case '{ ($x: String).startsWith($y) } =>
|
||||
parseBinary(x, y, '{ StringOperator.startsWith })
|
||||
case '{ ($x: String).split($y) } =>
|
||||
parseBinary(x, y, '{ StringOperator.split })
|
||||
case '{ ($x: String).toUpperCase } =>
|
||||
parseUnary(x, '{ StringOperator.toUpperCase })
|
||||
case '{ ($x: String).toLowerCase } =>
|
||||
parseUnary(x, '{ StringOperator.toLowerCase })
|
||||
case '{ ($x: String).toLong } =>
|
||||
parseUnary(x, '{ StringOperator.toLong })
|
||||
case '{ ($x: String).toInt } =>
|
||||
parseUnary(x, '{ StringOperator.toInt })
|
||||
}
|
||||
|
||||
val numericOpParser = termParser {
|
||||
case (Apply(Select(lt, NumericOp(op)), List(rt))) if isNumeric(lt.tpe) =>
|
||||
parseBinary(lt.asExpr, rt.asExpr, op)
|
||||
case Select(leftTerm, "unary_-") if isNumeric(leftTerm.tpe) =>
|
||||
val leftExpr = astParser(leftTerm.asExpr)
|
||||
'{ ast.UnaryOperation(NumericOperator.-, ${ leftExpr }) }
|
||||
|
||||
}
|
||||
|
||||
val booleanOpParser: Parser[ast.Operation] = {
|
||||
case '{ ($x: Boolean) && $y } =>
|
||||
parseBinary(x, y, '{ BooleanOperator.&& })
|
||||
case '{ ($x: Boolean) || $y } =>
|
||||
parseBinary(x, y, '{ BooleanOperator.|| })
|
||||
case '{ !($x: Boolean) } =>
|
||||
parseUnary(x, '{ BooleanOperator.! })
|
||||
}
|
||||
|
||||
universalOpParser
|
||||
.orElse(stringOpParser)
|
||||
.orElse(numericOpParser)
|
||||
.orElse(booleanOpParser)
|
||||
}
|
||||
|
||||
private object UniversalOp {
|
||||
def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] =
|
||||
op match {
|
||||
case "==" | "equals" => Some('{ EqualityOperator.== })
|
||||
case "!=" => Some('{ EqualityOperator.!= })
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
private object NumericOp {
|
||||
def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] =
|
||||
op match {
|
||||
case "+" => Some('{ NumericOperator.+ })
|
||||
case "-" => Some('{ NumericOperator.- })
|
||||
case "*" => Some('{ NumericOperator.* })
|
||||
case "/" => Some('{ NumericOperator./ })
|
||||
case ">" => Some('{ NumericOperator.> })
|
||||
case ">=" => Some('{ NumericOperator.>= })
|
||||
case "<" => Some('{ NumericOperator.< })
|
||||
case "<=" => Some('{ NumericOperator.<= })
|
||||
case "%" => Some('{ NumericOperator.% })
|
||||
case _ => None
|
||||
}
|
||||
}
|
48
src/main/scala/minisql/parsing/Parser.scala
Normal file
48
src/main/scala/minisql/parsing/Parser.scala
Normal file
|
@ -0,0 +1,48 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import minisql.ast.Ast
|
||||
import scala.quoted.*
|
||||
import minisql.util.*
|
||||
|
||||
private[minisql] inline def parseParamAt[F](
|
||||
inline f: F,
|
||||
inline n: Int
|
||||
): ast.Ident = ${
|
||||
parseParamAt('f, 'n)
|
||||
}
|
||||
|
||||
private[minisql] inline def parseBody[X](
|
||||
inline f: X
|
||||
): ast.Ast = ${
|
||||
parseBody('f)
|
||||
}
|
||||
|
||||
private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
|
||||
Quotes
|
||||
): Expr[ast.Ident] = {
|
||||
|
||||
import quotes.reflect.*
|
||||
|
||||
val pIdx = n.value.getOrElse(
|
||||
report.errorAndAbort(s"Param index ${n.show} is not know")
|
||||
)
|
||||
extractTerm(f.asTerm) match {
|
||||
case Lambda(vals, _) =>
|
||||
vals(pIdx) match {
|
||||
case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[minisql] def parseBody[X](
|
||||
x: Expr[X]
|
||||
)(using Quotes): Expr[Ast] = {
|
||||
import quotes.reflect.*
|
||||
extractTerm(x.asTerm) match {
|
||||
case Lambda(vals, body) =>
|
||||
Parsing.parseExpr(body.asExpr)
|
||||
case o =>
|
||||
report.errorAndAbort(s"Can only parse function")
|
||||
}
|
||||
}
|
127
src/main/scala/minisql/parsing/Parsing.scala
Normal file
127
src/main/scala/minisql/parsing/Parsing.scala
Normal file
|
@ -0,0 +1,127 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import minisql.context.{ReturningMultipleFieldSupported, _}
|
||||
import minisql.norm.BetaReduction
|
||||
import minisql.norm.capture.AvoidAliasConflict
|
||||
import minisql.idiom.Idiom
|
||||
import scala.annotation.tailrec
|
||||
import minisql.ast.Implicits._
|
||||
import minisql.ast.Renameable.Fixed
|
||||
import minisql.ast.Visibility.{Hidden, Visible}
|
||||
import minisql.util.{Interleave, extractTerm}
|
||||
import scala.quoted.*
|
||||
|
||||
type Parser[A] = PartialFunction[Expr[Any], Expr[A]]
|
||||
|
||||
private def termParser[A](using q: Quotes)(
|
||||
pf: PartialFunction[q.reflect.Term, Expr[A]]
|
||||
): Parser[A] = {
|
||||
import quotes.reflect._
|
||||
{
|
||||
case e if pf.isDefinedAt(e.asTerm) => pf(e.asTerm)
|
||||
}
|
||||
}
|
||||
|
||||
private def parser[A](
|
||||
f: PartialFunction[Expr[Any], Expr[A]]
|
||||
)(using Quotes): Parser[A] = {
|
||||
case e if f.isDefinedAt(e) => f(e)
|
||||
}
|
||||
|
||||
private[minisql] object Parsing {
|
||||
|
||||
def parseExpr(
|
||||
expr: Expr[?]
|
||||
)(using q: Quotes): Expr[ast.Ast] = {
|
||||
|
||||
import q.reflect._
|
||||
|
||||
def unwrapped(
|
||||
f: Parser[ast.Ast]
|
||||
): Parser[ast.Ast] = {
|
||||
case expr =>
|
||||
val t = extractTerm(expr.asTerm)
|
||||
if (t.isExpr)
|
||||
f(t.asExpr)
|
||||
else f(expr)
|
||||
}
|
||||
|
||||
lazy val astParser: Parser[ast.Ast] =
|
||||
unwrapped {
|
||||
typedParser
|
||||
.orElse(propertyParser)
|
||||
.orElse(liftParser)
|
||||
.orElse(infixParser)
|
||||
.orElse(identParser)
|
||||
.orElse(valueParser)
|
||||
.orElse(operationParser)
|
||||
.orElse(constantParser)
|
||||
.orElse(blockParser)
|
||||
.orElse(boxingParser)
|
||||
.orElse(ifParser)
|
||||
.orElse(traversableOperationParser)
|
||||
.orElse(patMatchParser)
|
||||
.orElse {
|
||||
case o =>
|
||||
val str = scala.util.Try(o.show).getOrElse("")
|
||||
report.errorAndAbort(
|
||||
s"cannot parse ${str}",
|
||||
o.asTerm.pos
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
lazy val typedParser: Parser[ast.Ast] = termParser {
|
||||
case (Typed(t, _)) =>
|
||||
astParser(t.asExpr)
|
||||
}
|
||||
|
||||
lazy val blockParser: Parser[ast.Ast] = blockParsing(astParser)
|
||||
|
||||
lazy val valueParser: Parser[ast.Value] = valueParsing(astParser)
|
||||
|
||||
lazy val liftParser: Parser[ast.Lift] = liftParsing(astParser)
|
||||
|
||||
lazy val constantParser: Parser[ast.Constant] = termParser {
|
||||
case Literal(x) =>
|
||||
'{ ast.Constant(${ Literal(x).asExpr }) }
|
||||
}
|
||||
|
||||
lazy val identParser: Parser[ast.Ident] = termParser {
|
||||
case x @ Ident(n) =>
|
||||
'{ ast.Ident(${ Expr(n) }) }
|
||||
}
|
||||
|
||||
lazy val propertyParser: Parser[ast.Property] = propertyParsing(astParser)
|
||||
|
||||
lazy val operationParser: Parser[ast.Operation] = operationParsing(
|
||||
astParser
|
||||
)
|
||||
|
||||
lazy val boxingParser: Parser[ast.Ast] = boxingParsing(astParser)
|
||||
|
||||
lazy val ifParser: Parser[ast.If] = {
|
||||
case '{ if ($a) $b else $c } =>
|
||||
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
|
||||
|
||||
}
|
||||
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
|
||||
|
||||
lazy val traversableOperationParser: Parser[ast.IterableOperation] =
|
||||
traversableOperationParsing(astParser)
|
||||
|
||||
lazy val infixParser: Parser[ast.Infix] = infixParsing(
|
||||
astParser
|
||||
)
|
||||
|
||||
astParser(expr)
|
||||
}
|
||||
|
||||
private[minisql] inline def parse[A](
|
||||
inline a: A
|
||||
): ast.Ast = ${
|
||||
parseExpr('a)
|
||||
}
|
||||
|
||||
}
|
50
src/main/scala/minisql/parsing/PatMatchParsing.scala
Normal file
50
src/main/scala/minisql/parsing/PatMatchParsing.scala
Normal file
|
@ -0,0 +1,50 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import scala.quoted.*
|
||||
|
||||
private[parsing] def patMatchParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.Ast] = {
|
||||
|
||||
import quotes.reflect.*
|
||||
|
||||
termParser {
|
||||
// Val defs that showd pattern variables will cause error
|
||||
case e @ Match(
|
||||
Ident(t),
|
||||
List(CaseDef(IsTupleUnapply(binds), None, body))
|
||||
) =>
|
||||
val bindStmts = binds.map {
|
||||
case Bind(bn, _) =>
|
||||
'{
|
||||
ast.Val(
|
||||
ast.Ident(${ Expr(bn) }),
|
||||
ast.Property(ast.Ident(${ Expr(t) }), "_1")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
val allStmts = bindStmts ++ parseBlockList(astParser, body.asExpr)
|
||||
'{ ast.Block(${ Expr.ofList(allStmts.toList) }) }
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object IsTupleUnapply {
|
||||
|
||||
def unapply(using
|
||||
Quotes
|
||||
)(t: quotes.reflect.Tree): Option[List[quotes.reflect.Tree]] = {
|
||||
import quotes.reflect.*
|
||||
def isTupleNUnapply(x: Term) = {
|
||||
val fn = x.symbol.fullName
|
||||
fn.startsWith("scala.Tuple") && fn.endsWith("$.unapply")
|
||||
}
|
||||
t match {
|
||||
case Unapply(m, _, binds) if isTupleNUnapply(m) =>
|
||||
Some(binds)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
29
src/main/scala/minisql/parsing/PropertyParsing.scala
Normal file
29
src/main/scala/minisql/parsing/PropertyParsing.scala
Normal file
|
@ -0,0 +1,29 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import scala.quoted._
|
||||
|
||||
private[parsing] def propertyParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.Property] = {
|
||||
import quotes.reflect.*
|
||||
|
||||
def isAccessor(s: Select) = {
|
||||
s.qualifier.tpe.typeSymbol.caseFields.exists(cf => cf.name == s.name)
|
||||
}
|
||||
|
||||
val parseApply: Parser[ast.Property] = termParser {
|
||||
case m @ Select(base, n) if isAccessor(m) =>
|
||||
val obj = astParser(base.asExpr)
|
||||
'{ ast.Property($obj, ${ Expr(n) }) }
|
||||
}
|
||||
|
||||
val parseOptionGet: Parser[ast.Property] = {
|
||||
case '{ ($e: Option[t]).get } =>
|
||||
report.errorAndAbort(
|
||||
"Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead."
|
||||
)
|
||||
}
|
||||
parseApply.orElse(parseOptionGet)
|
||||
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast
|
||||
import scala.quoted.*
|
||||
|
||||
private def traversableOperationParsing(
|
||||
astParser: => Parser[ast.Ast]
|
||||
)(using Quotes): Parser[ast.IterableOperation] = {
|
||||
case '{ type k; type v; (${ m }: Map[`k`, `v`]).contains($key) } =>
|
||||
'{ ast.MapContains(${ astParser(m) }, ${ astParser(key) }) }
|
||||
case '{ ($s: Set[e]).contains($i) } =>
|
||||
'{ ast.SetContains(${ astParser(s) }, ${ astParser(i) }) }
|
||||
case '{ ($s: Seq[e]).contains($i) } =>
|
||||
'{ ast.ListContains(${ astParser(s) }, ${ astParser(i) }) }
|
||||
|
||||
}
|
72
src/main/scala/minisql/parsing/ValueParsing.scala
Normal file
72
src/main/scala/minisql/parsing/ValueParsing.scala
Normal file
|
@ -0,0 +1,72 @@
|
|||
package minisql
|
||||
package parsing
|
||||
|
||||
import scala.quoted._
|
||||
import minisql.util.*
|
||||
|
||||
private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using
|
||||
Quotes
|
||||
): Parser[ast.Value] = {
|
||||
|
||||
import quotes.reflect.*
|
||||
|
||||
val parseTupleApply: Parser[ast.Tuple] = {
|
||||
case IsTupleApply(args) =>
|
||||
val t = args.map(astParser)
|
||||
'{ ast.Tuple(${ Expr.ofList(t) }) }
|
||||
}
|
||||
|
||||
val parseAssocTuple: Parser[ast.Tuple] = {
|
||||
case '{ ($x: tx) -> ($y: ty) } =>
|
||||
'{ ast.Tuple(List(${ astParser(x) }, ${ astParser(y) })) }
|
||||
}
|
||||
|
||||
parseTupleApply.orElse(parseAssocTuple)
|
||||
}
|
||||
|
||||
private[minisql] object IsTupleApply {
|
||||
|
||||
def unapply(e: Expr[Any])(using Quotes): Option[Seq[Expr[Any]]] = {
|
||||
|
||||
import quotes.reflect.*
|
||||
|
||||
def isTupleNApply(t: Term) = {
|
||||
val fn = t.symbol.fullName
|
||||
fn.startsWith("scala.Tuple") && fn.endsWith("$.apply")
|
||||
}
|
||||
|
||||
def isTupleXXLApply(t: Term) = {
|
||||
t.symbol.fullName == "scala.runtime.TupleXXL$.apply"
|
||||
}
|
||||
|
||||
extractTerm(e.asTerm) match {
|
||||
// TupleN(0-22).apply
|
||||
case Apply(b, args) if isTupleNApply(b) =>
|
||||
Some(args.map(_.asExpr))
|
||||
case TypeApply(Select(t, "$asInstanceOf$"), tt) if isTupleXXLApply(t) =>
|
||||
t.asExpr match {
|
||||
case '{ scala.runtime.TupleXXL.apply(${ Varargs(args) }*) } =>
|
||||
Some(args)
|
||||
}
|
||||
case o =>
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[parsing] object IsTuple2 {
|
||||
|
||||
def unapply(using
|
||||
Quotes
|
||||
)(t: Expr[Any]): Option[(Expr[Any], Expr[Any])] =
|
||||
t match {
|
||||
case '{ scala.Tuple2.apply($x1, $x2) } => Some((x1, x2))
|
||||
case '{ ($x1: t1) -> ($x2: t2) } => Some((x1, x2))
|
||||
case _ => None
|
||||
}
|
||||
|
||||
def unapply(using
|
||||
Quotes
|
||||
)(t: quotes.reflect.Term): Option[(Expr[Any], Expr[Any])] =
|
||||
unapply(t.asExpr)
|
||||
}
|
|
@ -1,8 +1,26 @@
|
|||
package minisql.util
|
||||
|
||||
import scala.util.Try
|
||||
import scala.util.*
|
||||
|
||||
object CollectTry {
|
||||
extension [A](xs: Iterable[A]) {
|
||||
private[minisql] def traverse[B](f: A => Try[B]): Try[IArray[B]] = {
|
||||
val out = IArray.newBuilder[Any]
|
||||
var left: Option[Throwable] = None
|
||||
xs.foreach { (v) =>
|
||||
if (!left.isDefined) {
|
||||
f(v) match {
|
||||
case Failure(e) =>
|
||||
left = Some(e)
|
||||
case Success(r) =>
|
||||
out += r
|
||||
}
|
||||
}
|
||||
}
|
||||
left.toLeft(out.result().asInstanceOf).toTry
|
||||
}
|
||||
}
|
||||
|
||||
private[minisql] object CollectTry {
|
||||
def apply[T](list: List[Try[T]]): Try[List[T]] =
|
||||
list.foldLeft(Try(List.empty[T])) {
|
||||
case (list, t) =>
|
||||
|
|
|
@ -11,20 +11,25 @@ import scala.util.matching.Regex
|
|||
class Interpolator(
|
||||
defaultIndent: Int = 0,
|
||||
qprint: AstPrinter = AstPrinter(),
|
||||
out: PrintStream = System.out,
|
||||
out: PrintStream = System.out
|
||||
) {
|
||||
|
||||
extension (sc: StringContext) {
|
||||
def trace(elements: Any*) = new Traceable(sc, elements)
|
||||
}
|
||||
|
||||
class Traceable(sc: StringContext, elementsSeq: Seq[Any]) {
|
||||
|
||||
private val elementPrefix = "| "
|
||||
|
||||
private sealed trait PrintElement
|
||||
private case class Str(str: String, first: Boolean) extends PrintElement
|
||||
private case class Elem(value: String) extends PrintElement
|
||||
private case object Separator extends PrintElement
|
||||
private case class Elem(value: String) extends PrintElement
|
||||
private case object Separator extends PrintElement
|
||||
|
||||
private def generateStringForCommand(value: Any, indent: Int) = {
|
||||
val objectString = qprint(value)
|
||||
val oneLine = objectString.fitsOnOneLine
|
||||
val oneLine = objectString.fitsOnOneLine
|
||||
oneLine match {
|
||||
case true => s"${indent.prefix}> ${objectString}"
|
||||
case false =>
|
||||
|
@ -42,7 +47,7 @@ class Interpolator(
|
|||
private def readBuffers() = {
|
||||
def orZero(i: Int): Int = if (i < 0) 0 else i
|
||||
|
||||
val parts = sc.parts.iterator.toList
|
||||
val parts = sc.parts.iterator.toList
|
||||
val elements = elementsSeq.toList.map(qprint(_))
|
||||
|
||||
val (firstStr, explicitIndent) = readFirst(parts.head)
|
||||
|
|
|
@ -5,10 +5,15 @@ import scala.util.Try
|
|||
|
||||
object LoadObject {
|
||||
|
||||
def apply[T](using Quotes, Type[T]): Try[T] = {
|
||||
import quotes.reflect.*
|
||||
apply(TypeRepr.of[T])
|
||||
}
|
||||
|
||||
def apply[T](using Quotes)(ot: quotes.reflect.TypeRepr): Try[T] = Try {
|
||||
import quotes.reflect.*
|
||||
val moduleClsName = ot.typeSymbol.companionModule.moduleClass.fullName
|
||||
val moduleCls = Class.forName(moduleClsName)
|
||||
val moduleCls = Class.forName(moduleClsName)
|
||||
val field = moduleCls
|
||||
.getFields()
|
||||
.find { f =>
|
||||
|
|
75
src/main/scala/minisql/util/Message.scala
Normal file
75
src/main/scala/minisql/util/Message.scala
Normal file
|
@ -0,0 +1,75 @@
|
|||
package minisql.util
|
||||
|
||||
import minisql.AstPrinter
|
||||
import minisql.idiom.Idiom
|
||||
import minisql.util.IndentUtil._
|
||||
|
||||
object Messages {
|
||||
|
||||
private def variable(propName: String, envName: String, default: String) =
|
||||
Option(System.getProperty(propName))
|
||||
.orElse(sys.env.get(envName))
|
||||
.getOrElse(default)
|
||||
|
||||
private[util] val prettyPrint =
|
||||
variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean
|
||||
private[util] val debugEnabled =
|
||||
variable("quill.macro.log", "quill_macro_log", "true").toBoolean
|
||||
private[util] val traceEnabled =
|
||||
variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean
|
||||
private[util] val traceColors =
|
||||
variable("quill.trace.color", "quill_trace_color,", "false").toBoolean
|
||||
private[util] val traceOpinions =
|
||||
variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean
|
||||
private[util] val traceAstSimple = variable(
|
||||
"quill.trace.ast.simple",
|
||||
"quill_trace_ast_simple",
|
||||
"false"
|
||||
).toBoolean
|
||||
private[minisql] val cacheDynamicQueries = variable(
|
||||
"quill.query.cacheDaynamic",
|
||||
"query_query_cacheDaynamic",
|
||||
"true"
|
||||
).toBoolean
|
||||
private[util] val traces: List[TraceType] =
|
||||
variable("quill.trace.types", "quill_trace_types", "standard")
|
||||
.split(",")
|
||||
.toList
|
||||
.map(_.trim)
|
||||
.flatMap(trace =>
|
||||
TraceType.values.filter(traceType => trace == traceType.value)
|
||||
)
|
||||
|
||||
def tracesEnabled(tt: TraceType) =
|
||||
traceEnabled && traces.contains(tt)
|
||||
|
||||
sealed trait TraceType { def value: String }
|
||||
object TraceType {
|
||||
case object Normalizations extends TraceType { val value = "norm" }
|
||||
case object Standard extends TraceType { val value = "standard" }
|
||||
case object NestedQueryExpansion extends TraceType { val value = "nest" }
|
||||
|
||||
def values: List[TraceType] =
|
||||
List(Standard, Normalizations, NestedQueryExpansion)
|
||||
}
|
||||
|
||||
val qprint = AstPrinter()
|
||||
|
||||
def fail(msg: String) =
|
||||
throw new IllegalStateException(msg)
|
||||
|
||||
def trace[T](
|
||||
label: String,
|
||||
numIndent: Int = 0,
|
||||
traceType: TraceType = TraceType.Standard
|
||||
) =
|
||||
(v: T) => {
|
||||
val indent = (0 to numIndent).map(_ => "").mkString(" ")
|
||||
if (tracesEnabled(traceType))
|
||||
println(s"$indent$label\n${{
|
||||
if (traceColors) qprint.apply(v)
|
||||
else qprint.apply(v)
|
||||
}.split("\n").map(s"$indent " + _).mkString("\n")}")
|
||||
v
|
||||
}
|
||||
}
|
41
src/main/scala/minisql/util/QuotesHelper.scala
Normal file
41
src/main/scala/minisql/util/QuotesHelper.scala
Normal file
|
@ -0,0 +1,41 @@
|
|||
package minisql.util
|
||||
|
||||
import scala.quoted.*
|
||||
|
||||
private[minisql] def splicePkgPath(using Quotes) = {
|
||||
import quotes.reflect.*
|
||||
def recurse(sym: Symbol): String =
|
||||
sym match {
|
||||
case s if s.isPackageDef => s.fullName
|
||||
case s if s.isNoSymbol => ""
|
||||
case _ =>
|
||||
recurse(sym.maybeOwner)
|
||||
}
|
||||
recurse(Symbol.spliceOwner)
|
||||
}
|
||||
|
||||
private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = {
|
||||
import quotes.reflect.*
|
||||
val name = x.asTerm.show
|
||||
val packageName = splicePkgPath
|
||||
val pos = x.asTerm.pos
|
||||
val fileName = pos.sourceFile.name
|
||||
s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}"
|
||||
}
|
||||
|
||||
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
|
||||
import quotes.reflect.*
|
||||
def unwrapTerm(t: Term): Term = t match {
|
||||
case Inlined(_, _, o) => unwrapTerm(o)
|
||||
case Block(Nil, last) => last
|
||||
case Typed(t, _) =>
|
||||
unwrapTerm(t)
|
||||
case Select(t, "$asInstanceOf$") =>
|
||||
unwrapTerm(t)
|
||||
case TypeApply(t, _) =>
|
||||
unwrapTerm(t)
|
||||
case o => o
|
||||
}
|
||||
val o = unwrapTerm(x)
|
||||
o
|
||||
}
|
|
@ -1,21 +1,20 @@
|
|||
package minisql.util
|
||||
|
||||
object Show {
|
||||
trait Show[T] {
|
||||
def show(v: T): String
|
||||
trait Show[T] {
|
||||
extension (v: T) {
|
||||
def show: String
|
||||
}
|
||||
}
|
||||
|
||||
object Show {
|
||||
def apply[T](f: T => String) = new Show[T] {
|
||||
def show(v: T) = f(v)
|
||||
object Show {
|
||||
|
||||
def apply[T](f: T => String) = new Show[T] {
|
||||
extension (v: T) {
|
||||
def show: String = f(v)
|
||||
}
|
||||
}
|
||||
|
||||
implicit class Shower[T](v: T)(implicit shower: Show[T]) {
|
||||
def show = shower.show(v)
|
||||
}
|
||||
|
||||
implicit def listShow[T](implicit shower: Show[T]): Show[List[T]] =
|
||||
given listShow[T](using shower: Show[T]): Show[List[T]] =
|
||||
Show[List[T]] {
|
||||
case list => list.map(_.show).mkString(", ")
|
||||
}
|
||||
|
|
188
src/test/scala/minisql/ast/FromExprsSuite.scala
Normal file
188
src/test/scala/minisql/ast/FromExprsSuite.scala
Normal file
|
@ -0,0 +1,188 @@
|
|||
package minisql.ast
|
||||
|
||||
import munit.FunSuite
|
||||
import minisql.ast.*
|
||||
import scala.quoted.*
|
||||
|
||||
class FromExprsSuite extends FunSuite {
|
||||
|
||||
// Helper to test both compile-time and runtime extraction
|
||||
inline def testFor[A <: Ast](label: String)(inline ast: A) = {
|
||||
test(label) {
|
||||
// Test compile-time extraction
|
||||
val compileTimeResult = minisql.compileTimeAst(ast)
|
||||
assert(compileTimeResult.contains(ast.toString))
|
||||
}
|
||||
}
|
||||
|
||||
testFor("Ident") {
|
||||
Ident("test")
|
||||
}
|
||||
|
||||
testFor("Ident with visibility") {
|
||||
Ident.Opinionated("test", Visibility.Hidden)
|
||||
}
|
||||
|
||||
testFor("Property") {
|
||||
Property(Ident("a"), "b")
|
||||
}
|
||||
|
||||
testFor("Property with opinions") {
|
||||
Property.Opinionated(Ident("a"), "b", Renameable.Fixed, Visibility.Visible)
|
||||
}
|
||||
|
||||
testFor("BinaryOperation") {
|
||||
BinaryOperation(Ident("a"), EqualityOperator.==, Ident("b"))
|
||||
}
|
||||
|
||||
testFor("UnaryOperation") {
|
||||
UnaryOperation(BooleanOperator.!, Ident("flag"))
|
||||
}
|
||||
|
||||
testFor("ScalarValueLift") {
|
||||
ScalarValueLift("name", "id", None)
|
||||
}
|
||||
|
||||
testFor("Ordering") {
|
||||
PropertyOrdering.Asc
|
||||
}
|
||||
|
||||
testFor("TupleOrdering") {
|
||||
TupleOrdering(List(PropertyOrdering.Asc, PropertyOrdering.Desc))
|
||||
}
|
||||
|
||||
testFor("Entity") {
|
||||
Entity("people", Nil)
|
||||
}
|
||||
|
||||
testFor("Entity with properties") {
|
||||
Entity("people", List(PropertyAlias(List("name"), "full_name")))
|
||||
}
|
||||
|
||||
testFor("Action - Insert") {
|
||||
Insert(
|
||||
Ident("table"),
|
||||
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
|
||||
)
|
||||
}
|
||||
|
||||
testFor("Action - Update") {
|
||||
Update(
|
||||
Ident("table"),
|
||||
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
|
||||
)
|
||||
}
|
||||
|
||||
testFor("Action - Returning") {
|
||||
Returning(
|
||||
Insert(
|
||||
Ident("table"),
|
||||
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
|
||||
),
|
||||
Ident("x"),
|
||||
Property(Ident("x"), "id")
|
||||
)
|
||||
}
|
||||
|
||||
testFor("Action - ReturningGenerated") {
|
||||
ReturningGenerated(
|
||||
Insert(
|
||||
Ident("table"),
|
||||
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
|
||||
),
|
||||
Ident("x"),
|
||||
Property(Ident("x"), "generatedId")
|
||||
)
|
||||
}
|
||||
|
||||
testFor("Action - Val outside") {
|
||||
val p1 = Update(
|
||||
Ident("table"),
|
||||
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
|
||||
)
|
||||
val p2 = Ident("x")
|
||||
val p3 = Property(Ident("x"), "id")
|
||||
Returning(p1, p2, p3)
|
||||
}
|
||||
|
||||
testFor("Action - ReturningGenerated with Update") {
|
||||
val p1 = Update(
|
||||
Ident("table"),
|
||||
List(Assignment(Ident("x"), Ident("col"), Ident("val")))
|
||||
)
|
||||
val p2 = Ident("x")
|
||||
val p3 = Property(Ident("x"), "id")
|
||||
ReturningGenerated(p1, p2, p3)
|
||||
}
|
||||
|
||||
testFor("If expression") {
|
||||
If(Ident("cond"), Ident("then"), Ident("else"))
|
||||
}
|
||||
|
||||
testFor("Infix") {
|
||||
Infix(
|
||||
List("func(", ")"),
|
||||
List(Ident("param")),
|
||||
pure = true,
|
||||
noParen = false
|
||||
)
|
||||
}
|
||||
|
||||
testFor("Infix with different parameters") {
|
||||
Infix(
|
||||
List("?", " + ", "?"),
|
||||
List(Constant(1), Constant(2)),
|
||||
pure = true,
|
||||
noParen = true
|
||||
)
|
||||
}
|
||||
|
||||
testFor("OptionOperation - OptionMap") {
|
||||
OptionMap(Ident("opt"), Ident("x"), Ident("x"))
|
||||
}
|
||||
|
||||
testFor("OptionOperation - OptionFlatMap") {
|
||||
OptionFlatMap(Ident("opt"), Ident("x"), Ident("x"))
|
||||
}
|
||||
|
||||
testFor("OptionOperation - OptionGetOrElse") {
|
||||
OptionGetOrElse(Ident("opt"), Ident("default"))
|
||||
}
|
||||
|
||||
testFor("Join") {
|
||||
Join(
|
||||
JoinType.InnerJoin,
|
||||
Ident("a"),
|
||||
Ident("b"),
|
||||
Ident("a1"),
|
||||
Ident("b1"),
|
||||
BinaryOperation(Ident("a1.id"), EqualityOperator.==, Ident("b1.id"))
|
||||
)
|
||||
}
|
||||
|
||||
testFor("Distinct") {
|
||||
Distinct(Ident("query"))
|
||||
}
|
||||
|
||||
testFor("GroupBy") {
|
||||
GroupBy(Ident("query"), Ident("alias"), Ident("body"))
|
||||
}
|
||||
|
||||
testFor("Aggregation") {
|
||||
Aggregation(AggregationOperator.avg, Ident("field"))
|
||||
}
|
||||
|
||||
testFor("CaseClass") {
|
||||
CaseClass(List(("name", Ident("value"))))
|
||||
}
|
||||
|
||||
testFor("Block") { // Also tested Val
|
||||
Block(
|
||||
List(
|
||||
Val(Ident("x"), Constant(1)),
|
||||
Val(Ident("y"), Constant(2)),
|
||||
BinaryOperation(Ident("x"), NumericOperator.+, Ident("y"))
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package minisql.context.sql
|
||||
|
||||
import minisql.*
|
||||
import minisql.ast.*
|
||||
import minisql.idiom.*
|
||||
import minisql.NamingStrategy
|
||||
import minisql.MirrorContext
|
||||
import minisql.context.mirror.{*, given}
|
||||
|
||||
class MirrorSqlContextSuite extends munit.FunSuite {
|
||||
|
||||
case class Foo(id: Long, name: String)
|
||||
|
||||
inline def Foos = query[Foo]("foo")
|
||||
|
||||
import testContext.given
|
||||
|
||||
test("SimpleQuery") {
|
||||
val o = testContext.io(
|
||||
query[Foo](
|
||||
"foo",
|
||||
alias("id", "id1")
|
||||
).filter(x => x.id > 0)
|
||||
)
|
||||
assertEquals(o.sql, "SELECT x.id1, x.name FROM foo x WHERE x.id1 > 0")
|
||||
}
|
||||
|
||||
test("Insert") {
|
||||
val v: Foo = Foo(0L, "foo")
|
||||
|
||||
val o = testContext.io(Foos.insert(v))
|
||||
assertEquals(
|
||||
o.sql,
|
||||
"INSERT INTO foo (id,name) VALUES (?, ?)"
|
||||
)
|
||||
}
|
||||
|
||||
test("InsertReturningGenerated") {
|
||||
val v: Foo = Foo(0L, "foo")
|
||||
|
||||
val o = testContext.io(Foos.insert(v).returningGenerated(_.id))
|
||||
assertEquals(
|
||||
o.sql,
|
||||
"INSERT INTO foo (name) VALUES (?) RETURNING id"
|
||||
)
|
||||
}
|
||||
|
||||
test("LeftJoin") {
|
||||
val o = testContext
|
||||
.io(Foos.join(Foos).on((f1, f2) => f1.id == f2.id).map {
|
||||
case (f1, f2) => (f1.id, f2.id)
|
||||
})
|
||||
|
||||
println(o)
|
||||
}
|
||||
|
||||
test("Infix string interpolation") {
|
||||
val o = testContext.io(
|
||||
Foos.map(f => infix"CONCAT(${f.name}, ' ', ${f.id})".as[String])
|
||||
)
|
||||
assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f")
|
||||
}
|
||||
}
|
5
src/test/scala/minisql/context/sql/context.scala
Normal file
5
src/test/scala/minisql/context/sql/context.scala
Normal file
|
@ -0,0 +1,5 @@
|
|||
package minisql.context.sql
|
||||
|
||||
import minisql.*
|
||||
|
||||
val testContext = new MirrorSqlContext(Literal)
|
38
src/test/scala/minisql/parsing/ParsingSuite.scala
Normal file
38
src/test/scala/minisql/parsing/ParsingSuite.scala
Normal file
|
@ -0,0 +1,38 @@
|
|||
package minisql.parsing
|
||||
|
||||
import minisql.ast.*
|
||||
|
||||
class ParsingSuite extends munit.FunSuite {
|
||||
|
||||
test("Ident") {
|
||||
val x = 1
|
||||
assertEquals(Parsing.parse(x), Ident("x"))
|
||||
}
|
||||
|
||||
test("NumericOperator.+") {
|
||||
val a = 1
|
||||
val b = 2
|
||||
assertEquals(
|
||||
Parsing.parse(a + b),
|
||||
BinaryOperation(Ident("a"), NumericOperator.+, Ident("b"))
|
||||
)
|
||||
}
|
||||
|
||||
test("NumericOperator.-") {
|
||||
val a = 1
|
||||
val b = 2
|
||||
assertEquals(
|
||||
Parsing.parse(a - b),
|
||||
BinaryOperation(Ident("a"), NumericOperator.-, Ident("b"))
|
||||
)
|
||||
}
|
||||
|
||||
test("NumericOperator.*") {
|
||||
val a = 1
|
||||
val b = 2
|
||||
assertEquals(
|
||||
Parsing.parse(a * b),
|
||||
BinaryOperation(Ident("a"), NumericOperator.*, Ident("b"))
|
||||
)
|
||||
}
|
||||
}
|
1
src/test/scala/minisql/parsing/QuerySuite.scala
Normal file
1
src/test/scala/minisql/parsing/QuerySuite.scala
Normal file
|
@ -0,0 +1 @@
|
|||
|
Loading…
Add table
Add a link
Reference in a new issue