Compare commits

...

8 commits

Author SHA1 Message Date
jilen
47cf808e8f simplify typeclass 2024-12-29 20:19:07 +08:00
jilen
7f5092c396 add mirror context 2024-12-19 12:36:44 +08:00
jilen
87f1b70b27 try implement context 2024-12-19 11:42:23 +08:00
jilen
6be96aba2c save 2024-12-18 19:28:43 +08:00
jilen
59f969a232 test simple quoted ast 2024-12-18 16:09:08 +08:00
jilen
2e7e7df4a3 move package 2024-12-17 19:51:19 +08:00
jilen
a0ceea91a9 add one test case 2024-12-15 21:11:14 +08:00
jilen
8103d45178 try parsing function body 2024-12-15 20:51:38 +08:00
62 changed files with 3772 additions and 215 deletions

View file

@ -5,7 +5,7 @@
大部分场景不用在 `macro` 对 Ast 进行复杂模式匹配来分析代码。
## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作
## 核心思路 使用 inline 和 `FromExpr` 代替部分 parsing 工作
`FromExpr``scala3` 内置的 typeclass用来获取编译期值 。

View file

@ -1,8 +1,9 @@
name := "minisql"
scalaVersion := "3.6.2"
scalaVersion := "3.5.2"
libraryDependencies ++= Seq(
"org.scalameta" %% "munit" % "1.0.3" % Test
)
scalacOptions ++= Seq("-experimental", "-language:experimental.namedTuples")

View file

@ -0,0 +1,3 @@
package minisql
type QueryMeta

View file

@ -1,8 +1,23 @@
package minisql
import scala.util.Try
trait ParamEncoder[E] {
type Stmt
def setParam(s: Stmt, idx: Int, v: E): Unit
}
trait ColumnDecoder[X] {
type DBRow
def decode(row: DBRow, idx: Int): Try[X]
}
object ColumnDecoder {
type Aux[R, X] = ColumnDecoder[X] {
type DBRow = R
}
}

View file

@ -1,127 +0,0 @@
package minisql.parsing
import minisql.ast
import minisql.ast.Ast
import scala.quoted.*
type Parser[O <: Ast] = PartialFunction[Expr[?], Expr[O]]
private[minisql] inline def parseParamAt[F](
inline f: F,
inline n: Int
): ast.Ident = ${
parseParamAt('f, 'n)
}
private[minisql] inline def parseBody[X](
inline f: X
): ast.Ast = ${
parseBody('f)
}
private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
Quotes
): Expr[ast.Ident] = {
import quotes.reflect.*
val pIdx = n.value.getOrElse(
report.errorAndAbort(s"Param index ${n.show} is not know")
)
extractTerm(f.asTerm) match {
case Lambda(vals, _) =>
vals(pIdx) match {
case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
}
}
}
private[minisql] def parseBody[X](
x: Expr[X]
)(using Quotes): Expr[Ast] = {
import quotes.reflect.*
extractTerm(x.asTerm) match {
case Lambda(vals, body) =>
astParser(body.asExpr)
case o =>
report.errorAndAbort(s"Can only parse function")
}
}
private def isNumeric(x: Expr[?])(using Quotes): Boolean = {
import quotes.reflect.*
x.asTerm.tpe match {
case t if t <:< TypeRepr.of[Int] => true
case t if t <:< TypeRepr.of[Long] => true
case t if t <:< TypeRepr.of[Float] => true
case t if t <:< TypeRepr.of[Double] => true
case t if t <:< TypeRepr.of[BigDecimal] => true
case t if t <:< TypeRepr.of[BigInt] => true
case t if t <:< TypeRepr.of[java.lang.Integer] => true
case t if t <:< TypeRepr.of[java.lang.Long] => true
case t if t <:< TypeRepr.of[java.lang.Float] => true
case t if t <:< TypeRepr.of[java.lang.Double] => true
case t if t <:< TypeRepr.of[java.math.BigDecimal] => true
case _ => false
}
}
private def identParser(using Quotes): Parser[ast.Ident] = {
import quotes.reflect.*
{ (x: Expr[?]) =>
extractTerm(x.asTerm) match {
case Ident(n) => Some('{ ast.Ident(${ Expr(n) }) })
case _ => None
}
}.unlift
}
private lazy val astParser: Quotes ?=> Parser[Ast] = {
identParser.orElse(propertyParser(astParser))
}
private object IsPropertySelect {
def unapply(x: Expr[?])(using Quotes): Option[(Expr[?], String)] = {
import quotes.reflect.*
x.asTerm match {
case Select(x, n) => Some(x.asExpr, n)
case _ => None
}
}
}
def propertyParser(
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.Property] = {
case IsPropertySelect(expr, n) =>
'{ ast.Property(${ astParser(expr) }, ${ Expr(n) }) }
}
def optionOperationParser(
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.OptionOperation] = {
case '{ ($x: Option[t]).isEmpty } =>
'{ ast.OptionIsEmpty(${ astParser(x) }) }
}
def binaryOperationParser(
astParser: => Parser[Ast]
)(using Quotes): Parser[ast.BinaryOperation] = {
???
}
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
unwrapTerm(x)
}

View file

@ -0,0 +1,101 @@
package minisql
import minisql.*
import minisql.idiom.*
import minisql.parsing.*
import minisql.util.*
import minisql.ast.{Ast, Entity, Map, Property, Ident, Filter, given}
import scala.quoted.*
import scala.compiletime.*
import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap}
opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted
opaque type EntityQuery[E] <: Query[E] = Query[E]
object EntityQuery {
extension [E](inline e: EntityQuery[E]) {
inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
transform(e)(f)(Map.apply)
}
inline def filter(inline f: E => Boolean): EntityQuery[E] = {
transform(e)(f)(Filter.apply)
}
}
}
private inline def transform[A, B](inline q1: Quoted)(
inline f: A => B
)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
fast(q1, f.param0, f.body)
}
inline def query[E](inline table: String): EntityQuery[E] =
Entity(table, Nil)
extension [A, B](inline f1: A => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def body = parsing.parseBody(f1)
}
extension [A1, A2, B](inline f1: (A1, A2) => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def param1 = parsing.parseParamAt(f1, 1)
private inline def body = parsing.parseBody(f1)
}
def lift[X](x: X)(using e: ParamEncoder[X]): X = throw NonQuotedException()
class NonQuotedException extends Exception("Cannot be used at runtime")
private[minisql] inline def compileTimeAst(inline q: Quoted): Option[String] =
${
compileTimeAstImpl('q)
}
private def compileTimeAstImpl(e: Expr[Quoted])(using
Quotes
): Expr[Option[String]] = {
import quotes.reflect.*
e.value match {
case Some(v) => '{ Some(${ Expr(v.toString()) }) }
case None => '{ None }
}
}
private[minisql] inline def compile[I <: Idiom, N <: NamingStrategy](
inline q: Quoted,
inline idiom: I,
inline naming: N
): Statement = ${ compileImpl[I, N]('q, 'idiom, 'naming) }
private def compileImpl[I <: Idiom, N <: NamingStrategy](
q: Expr[Quoted],
idiom: Expr[I],
n: Expr[N]
)(using Quotes, Type[I], Type[N]): Expr[Statement] = {
import quotes.reflect.*
q.value match {
case Some(ast) =>
val idiom = LoadObject[I].getOrElse(
report.errorAndAbort(s"Idiom not known at compile")
)
val naming = LoadNaming
.static[N]
.getOrElse(report.errorAndAbort(s"NamingStrategy not known at compile"))
val stmt = idiom.translate(ast)(using naming)
Expr(stmt._2)
case None =>
report.info("Dynamic Query")
'{
$idiom.translate($q)(using $n)._2
}
}
}

View file

@ -0,0 +1,7 @@
package minisql
enum ReturnAction {
case ReturnNothing
case ReturnColumns(columns: List[String])
case ReturnRecord
}

View file

@ -1,6 +1,7 @@
package minisql.ast
import minisql.NamingStrategy
import minisql.ParamEncoder
import scala.quoted.*
@ -378,14 +379,21 @@ sealed trait ScalarLift extends Lift
case class ScalarValueLift(
name: String,
liftId: String
liftId: String,
value: Option[(Any, ParamEncoder[?])]
) extends ScalarLift
case class ScalarQueryLift(
name: String,
liftId: String,
value: Option[(Seq[Any], ParamEncoder[?])]
) extends ScalarLift
object ScalarLift {
given ToExpr[ScalarLift] with {
def apply(l: ScalarLift)(using Quotes) = l match {
case ScalarValueLift(n, id) =>
'{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) }
case ScalarValueLift(n, id, v) =>
'{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, None) }
}
}
}

View file

@ -0,0 +1,94 @@
package minisql.ast
object Implicits {
implicit class AstOps(body: Ast) {
private[minisql] def +||+(other: Ast) =
BinaryOperation(body, BooleanOperator.`||`, other)
private[minisql] def +&&+(other: Ast) =
BinaryOperation(body, BooleanOperator.`&&`, other)
private[minisql] def +==+(other: Ast) =
BinaryOperation(body, EqualityOperator.`==`, other)
private[minisql] def +!=+(other: Ast) =
BinaryOperation(body, EqualityOperator.`!=`, other)
}
}
object +||+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, BooleanOperator.`||`, two) => Some((one, two))
case _ => None
}
}
}
object +&&+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, BooleanOperator.`&&`, two) => Some((one, two))
case _ => None
}
}
}
val EqOp = EqualityOperator.`==`
val NeqOp = EqualityOperator.`!=`
object +==+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, EqOp, two) => Some((one, two))
case _ => None
}
}
}
object +!=+ {
def unapply(a: Ast): Option[(Ast, Ast)] = {
a match {
case BinaryOperation(one, NeqOp, two) => Some((one, two))
case _ => None
}
}
}
object IsNotNullCheck {
def apply(ast: Ast) = BinaryOperation(ast, EqualityOperator.`!=`, NullValue)
def unapply(ast: Ast): Option[Ast] = {
ast match {
case BinaryOperation(cond, NeqOp, NullValue) => Some(cond)
case _ => None
}
}
}
object IsNullCheck {
def apply(ast: Ast) = BinaryOperation(ast, EqOp, NullValue)
def unapply(ast: Ast): Option[Ast] = {
ast match {
case BinaryOperation(cond, EqOp, NullValue) => Some(cond)
case _ => None
}
}
}
object IfExistElseNull {
def apply(exists: Ast, `then`: Ast) =
If(IsNotNullCheck(exists), `then`, NullValue)
def unapply(ast: Ast) = ast match {
case If(IsNotNullCheck(exists), t, NullValue) => Some((exists, t))
case _ => None
}
}
object IfExist {
def apply(exists: Ast, `then`: Ast, otherwise: Ast) =
If(IsNotNullCheck(exists), `then`, otherwise)
def unapply(ast: Ast) = ast match {
case If(IsNotNullCheck(exists), t, e) => Some((exists, t, e))
case _ => None
}
}

View file

@ -45,8 +45,9 @@ private given FromExpr[Infix] with {
private given FromExpr[ScalarValueLift] with {
def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] =
x match {
case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }) } =>
Some(ScalarValueLift(n, id))
case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } =>
// don't cared about value here, a little tricky
Some(ScalarValueLift(n, id, null))
}
}
@ -122,6 +123,13 @@ private given FromExpr[Query] with {
Some(FlatMap(b, id, body))
case '{ ConcatMap(${ Expr(b) }, ${ Expr(id) }, ${ Expr(body) }) } =>
Some(ConcatMap(b, id, body))
case '{
val x: Ast = ${ Expr(b) }
val y: Ident = ${ Expr(id) }
val z: Ast = ${ Expr(body) }
ConcatMap(x, y, z)
} =>
Some(ConcatMap(b, id, body))
case '{ Drop(${ Expr(b) }, ${ Expr(n) }) } =>
Some(Drop(b, n))
case '{ Take(${ Expr(b) }, ${ Expr[Ast](n) }) } =>
@ -129,7 +137,7 @@ private given FromExpr[Query] with {
case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } =>
Some(SortBy(b, p, s, o))
case o =>
println(s"Cannot extract ${o.show}")
println(s"Cannot extract ${o}")
None
}
}
@ -145,6 +153,7 @@ private given FromExpr[BinaryOperator] with {
case '{ NumericOperator.- } => Some(NumericOperator.-)
case '{ NumericOperator.* } => Some(NumericOperator.*)
case '{ NumericOperator./ } => Some(NumericOperator./)
case '{ NumericOperator.> } => Some(NumericOperator.>)
case '{ StringOperator.split } => Some(StringOperator.split)
case '{ StringOperator.startsWith } => Some(StringOperator.startsWith)
case '{ StringOperator.concat } => Some(StringOperator.concat)

View file

@ -0,0 +1,90 @@
package minisql.context
import scala.deriving.*
import scala.compiletime.*
import scala.util.Try
import minisql.util.*
import minisql.idiom.{Idiom, Statement, ReifyStatement}
import minisql.{NamingStrategy, ParamEncoder}
import minisql.ColumnDecoder
import minisql.ast.{Ast, ScalarValueLift, CollectAst}
trait Context[I <: Idiom, N <: NamingStrategy] { selft =>
val idiom: I
val naming: NamingStrategy
type DBStatement
type DBRow
type DBResultSet
trait RowExtract[A] {
def extract(row: DBRow): Try[A]
}
object RowExtract {
private class ExtractorImpl[A](
decoders: IArray[Any],
m: Mirror.ProductOf[A]
) extends RowExtract[A] {
def extract(row: DBRow): Try[A] = {
val decodedFields = decoders.zipWithIndex.traverse {
case (d, i) =>
d.asInstanceOf[Decoder[?]].decode(row, i)
}
decodedFields.map { vs =>
m.fromProduct(Tuple.fromIArray(vs))
}
}
}
inline given [P <: Product](using m: Mirror.ProductOf[P]): RowExtract[P] = {
val decoders = summonAll[Tuple.Map[m.MirroredElemTypes, Decoder]]
ExtractorImpl(decoders.toIArray.asInstanceOf, m)
}
}
type Encoder[X] = ParamEncoder[X] {
type Stmt = DBStatement
}
type Decoder[X] = ColumnDecoder.Aux[DBRow, X]
type DBIO[E] = (
sql: String,
params: List[(Any, Encoder[?])],
mapper: Iterable[DBRow] => Try[E]
)
extension (ast: Ast) {
private def liftMap = {
val lifts = CollectAst.byType[ScalarValueLift](ast)
lifts.map(l => l.liftId -> l.value.get).toMap
}
}
extension (stmt: Statement) {
def expand(liftMap: Map[String, (Any, ParamEncoder[?])]) =
ReifyStatement(
idiom.liftingPlaceholder,
idiom.emptySetContainsToken,
stmt,
liftMap
)
}
inline def io[E](
inline q: minisql.Query[E]
)(using r: RowExtract[E]): DBIO[IArray[E]] = {
val lifts = q.liftMap
val stmt = minisql.compile(q, idiom, naming)
val (sql, params) = stmt.expand(lifts)
(
sql = sql,
params = params.map(_.value.get.asInstanceOf),
mapper = (rows) => rows.traverse(r.extract)
)
}
}

View file

@ -0,0 +1,15 @@
package minisql
import minisql.context.mirror.*
class MirrorContext[Idiom <: idiom.Idiom, Naming <: NamingStrategy](
val idiom: Idiom,
val naming: Naming
) extends context.Context[Idiom, Naming] {
type DBRow = Row
type DBResultSet = Iterable[DBRow]
type DBStatement = IArray[Any]
}

View file

@ -0,0 +1,64 @@
package minisql.context
sealed trait ReturningCapability
/**
* Data cannot be returned Insert/Update/etc... clauses in the target database.
*/
sealed trait ReturningNotSupported extends ReturningCapability
/**
* Returning a single field from Insert/Update/etc... clauses is supported. This
* is the most common databases e.g. MySQL, Sqlite, and H2 (although as of
* h2database/h2database#1972 this may change. See #1496 regarding this.
* Typically this needs to be setup in the JDBC
* `connection.prepareStatement(sql, Array("returnColumn"))`.
*/
sealed trait ReturningSingleFieldSupported extends ReturningCapability
/**
* Returning multiple columns from Insert/Update/etc... clauses is supported.
* This generally means that columns besides auto-incrementing ones can be
* returned. This is supported by Oracle. In JDBC, the following is done:
* `connection.prepareStatement(sql, Array("column1, column2, ..."))`.
*/
sealed trait ReturningMultipleFieldSupported extends ReturningCapability
/**
* An actual `RETURNING` clause is supported in the SQL dialect of the specified
* database e.g. Postgres. this typically means that columns returned from
* Insert/Update/etc... clauses can have other database operations done on them
* such as arithmetic `RETURNING id + 1`, UDFs `RETURNING udf(id)` or others. In
* JDBC, the following is done: `connection.prepareStatement(sql,
* Statement.RETURN_GENERATED_KEYS))`.
*/
sealed trait ReturningClauseSupported extends ReturningCapability
object ReturningNotSupported extends ReturningNotSupported
object ReturningSingleFieldSupported extends ReturningSingleFieldSupported
object ReturningMultipleFieldSupported extends ReturningMultipleFieldSupported
object ReturningClauseSupported extends ReturningClauseSupported
trait Capabilities {
def idiomReturningCapability: ReturningCapability
}
trait CanReturnClause extends Capabilities {
override def idiomReturningCapability: ReturningClauseSupported =
ReturningClauseSupported
}
trait CanReturnField extends Capabilities {
override def idiomReturningCapability: ReturningSingleFieldSupported =
ReturningSingleFieldSupported
}
trait CanReturnMultiField extends Capabilities {
override def idiomReturningCapability: ReturningMultipleFieldSupported =
ReturningMultipleFieldSupported
}
trait CannotReturn extends Capabilities {
override def idiomReturningCapability: ReturningNotSupported =
ReturningNotSupported
}

View file

@ -0,0 +1,35 @@
package minisql.context.mirror
import minisql.{MirrorContext, NamingStrategy}
import minisql.idiom.Idiom
import minisql.util.Messages.fail
import scala.reflect.ClassTag
/**
* No extra class defined
*/
opaque type Row = IArray[Any] *: EmptyTuple
extension (r: Row) {
def data: IArray[Any] = r._1
def add(value: Any): Row = (r.data :+ value) *: EmptyTuple
def apply[T](idx: Int)(using t: ClassTag[T]): T = {
r.data(idx) match {
case v: T => v
case other =>
fail(
s"Invalid column type. Expected '${t.runtimeClass}', but got '$other'"
)
}
}
}
trait MirrorCodecs[I <: Idiom, N <: NamingStrategy] {
this: MirrorContext[I, N] =>
given byteEncoder: Encoder[Byte]
}

View file

@ -1,59 +0,0 @@
package minisql.dsl
import minisql.*
import minisql.parsing.*
import minisql.ast.{Ast, Entity, Map, Property, Ident, given}
import scala.quoted.*
import scala.compiletime.*
import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap}
opaque type Quoted <: Ast = Ast
opaque type Query[E] <: Quoted = Quoted
opaque type EntityQuery[E] <: Query[E] = Query[E]
extension [E](inline e: EntityQuery[E]) {
inline def map[E1](inline f: E => E1): EntityQuery[E1] = {
transform(e)(f)(Map.apply)
}
}
private inline def transform[A, B](inline q1: Quoted)(
inline f: A => B
)(inline fast: (Ast, Ident, Ast) => Ast): Quoted = {
fast(q1, f.param0, f.body)
}
inline def query[E](inline table: String): EntityQuery[E] =
Entity(table, Nil)
inline def compile(inline x: Ast): Option[String] = ${
compileImpl('{ x })
}
private def compileImpl(
x: Expr[Ast]
)(using Quotes): Expr[Option[String]] = {
import quotes.reflect.*
x.value match {
case Some(xv) => '{ Some(${ Expr(xv.toString()) }) }
case None => '{ None }
}
}
extension [A, B](inline f1: A => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def body = parsing.parseBody(f1)
}
extension [A1, A2, B](inline f1: (A1, A2) => B) {
private inline def param0 = parsing.parseParamAt(f1, 0)
private inline def param1 = parsing.parseParamAt(f1, 1)
private inline def body = parsing.parseBody(f1)
}
case class Foo(id: Int)
inline def queryFooId = query[Foo]("foo").map(_.id)

View file

@ -0,0 +1,23 @@
package minisql.idiom
import minisql.NamingStrategy
import minisql.ast._
import minisql.context.Capabilities
trait Idiom extends Capabilities {
def emptySetContainsToken(field: Token): Token = StringToken("FALSE")
def defaultAutoGeneratedToken(field: Token): Token = StringToken(
"DEFAULT VALUES"
)
def liftingPlaceholder(index: Int): String
def translate(ast: Ast)(using naming: NamingStrategy): (Ast, Statement)
def format(queryString: String): String = queryString
def prepareForProbing(string: String): String
}

View file

@ -0,0 +1,28 @@
package minisql.idiom
import scala.util.Try
import scala.quoted._
import minisql.NamingStrategy
import minisql.util.CollectTry
import minisql.util.LoadObject
import minisql.CompositeNamingStrategy
object LoadNaming {
def static[C](using Quotes, Type[C]): Try[NamingStrategy] = CollectTry {
strategies[C].map(LoadObject[NamingStrategy](_))
}.map(NamingStrategy(_))
private def strategies[C](using Quotes, Type[C]) = {
import quotes.reflect.*
val isComposite = TypeRepr.of[C] <:< TypeRepr.of[CompositeNamingStrategy]
val ct = TypeRepr.of[C]
if (isComposite) {
ct.typeArgs.filterNot { t =>
t =:= TypeRepr.of[NamingStrategy] && t =:= TypeRepr.of[Nothing]
}
} else {
List(ct)
}
}
}

View file

@ -0,0 +1,355 @@
package minisql
import minisql.ast.Renameable.{ByStrategy, Fixed}
import minisql.ast.Visibility.Hidden
import minisql.ast._
import minisql.context.CanReturnClause
import minisql.idiom.{Idiom, SetContainsToken, Statement}
import minisql.idiom.StatementInterpolator.*
import minisql.norm.Normalize
import minisql.util.Interleave
object MirrorIdiom extends MirrorIdiom
class MirrorIdiom extends MirrorIdiomBase with CanReturnClause
object MirrorIdiomPrinting extends MirrorIdiom {
override def distinguishHidden: Boolean = true
}
trait MirrorIdiomBase extends Idiom {
def distinguishHidden: Boolean = false
override def prepareForProbing(string: String) = string
override def liftingPlaceholder(index: Int): String = "?"
override def translate(
ast: Ast
)(implicit naming: NamingStrategy): (Ast, Statement) = {
val normalizedAst = Normalize(ast)
(normalizedAst, stmt"${normalizedAst.token}")
}
implicit def astTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Ast] = Tokenizer[Ast] {
case ast: Query => ast.token
case ast: Function => ast.token
case ast: Value => ast.token
case ast: Operation => ast.token
case ast: Action => ast.token
case ast: Ident => ast.token
case ast: ExternalIdent => ast.token
case ast: Property => ast.token
case ast: Infix => ast.token
case ast: OptionOperation => ast.token
case ast: IterableOperation => ast.token
case ast: Dynamic => ast.token
case ast: If => ast.token
case ast: Block => ast.token
case ast: Val => ast.token
case ast: Ordering => ast.token
case ast: Lift => ast.token
case ast: Assignment => ast.token
case ast: OnConflict.Excluded => ast.token
case ast: OnConflict.Existing => ast.token
}
implicit def ifTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[If] = Tokenizer[If] {
case If(a, b, c) => stmt"if(${a.token}) ${b.token} else ${c.token}"
}
implicit val dynamicTokenizer: Tokenizer[Dynamic] = Tokenizer[Dynamic] {
case Dynamic(tree) => stmt"${tree.toString.token}"
}
implicit def blockTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Block] = Tokenizer[Block] {
case Block(statements) => stmt"{ ${statements.map(_.token).mkStmt("; ")} }"
}
implicit def valTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Val] = Tokenizer[Val] {
case Val(name, body) => stmt"val ${name.token} = ${body.token}"
}
implicit def queryTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Query] = Tokenizer[Query] {
case Entity.Opinionated(name, Nil, renameable) =>
stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token})"
case Entity.Opinionated(name, prop, renameable) =>
val properties =
prop.map(p => stmt"""_.${p.path.mkStmt(".")} -> "${p.alias.token}"""")
stmt"${tokenizeName("querySchema", renameable).token}(${s""""$name"""".token}, ${properties.token})"
case Filter(source, alias, body) =>
stmt"${source.token}.filter(${alias.token} => ${body.token})"
case Map(source, alias, body) =>
stmt"${source.token}.map(${alias.token} => ${body.token})"
case FlatMap(source, alias, body) =>
stmt"${source.token}.flatMap(${alias.token} => ${body.token})"
case ConcatMap(source, alias, body) =>
stmt"${source.token}.concatMap(${alias.token} => ${body.token})"
case SortBy(source, alias, body, ordering) =>
stmt"${source.token}.sortBy(${alias.token} => ${body.token})(${ordering.token})"
case GroupBy(source, alias, body) =>
stmt"${source.token}.groupBy(${alias.token} => ${body.token})"
case Aggregation(op, ast) =>
stmt"${scopedTokenizer(ast)}.${op.token}"
case Take(source, n) =>
stmt"${source.token}.take(${n.token})"
case Drop(source, n) =>
stmt"${source.token}.drop(${n.token})"
case Union(a, b) =>
stmt"${a.token}.union(${b.token})"
case UnionAll(a, b) =>
stmt"${a.token}.unionAll(${b.token})"
case Join(t, a, b, iA, iB, on) =>
stmt"${a.token}.${t.token}(${b.token}).on((${iA.token}, ${iB.token}) => ${on.token})"
case FlatJoin(t, a, iA, on) =>
stmt"${a.token}.${t.token}((${iA.token}) => ${on.token})"
case Distinct(a) =>
stmt"${a.token}.distinct"
case DistinctOn(source, alias, body) =>
stmt"${source.token}.distinctOn(${alias.token} => ${body.token})"
case Nested(a) =>
stmt"${a.token}.nested"
}
implicit val orderingTokenizer: Tokenizer[Ordering] = Tokenizer[Ordering] {
case TupleOrdering(elems) => stmt"Ord(${elems.token})"
case Asc => stmt"Ord.asc"
case Desc => stmt"Ord.desc"
case AscNullsFirst => stmt"Ord.ascNullsFirst"
case DescNullsFirst => stmt"Ord.descNullsFirst"
case AscNullsLast => stmt"Ord.ascNullsLast"
case DescNullsLast => stmt"Ord.descNullsLast"
}
implicit def optionOperationTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] {
case OptionTableFlatMap(ast, alias, body) =>
stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
case OptionTableMap(ast, alias, body) =>
stmt"${ast.token}.map((${alias.token}) => ${body.token})"
case OptionTableExists(ast, alias, body) =>
stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
case OptionTableForall(ast, alias, body) =>
stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
case OptionFlatten(ast) => stmt"${ast.token}.flatten"
case OptionGetOrElse(ast, body) =>
stmt"${ast.token}.getOrElse(${body.token})"
case OptionFlatMap(ast, alias, body) =>
stmt"${ast.token}.flatMap((${alias.token}) => ${body.token})"
case OptionMap(ast, alias, body) =>
stmt"${ast.token}.map((${alias.token}) => ${body.token})"
case OptionForall(ast, alias, body) =>
stmt"${ast.token}.forall((${alias.token}) => ${body.token})"
case OptionExists(ast, alias, body) =>
stmt"${ast.token}.exists((${alias.token}) => ${body.token})"
case OptionContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
case OptionIsEmpty(ast) => stmt"${ast.token}.isEmpty"
case OptionNonEmpty(ast) => stmt"${ast.token}.nonEmpty"
case OptionIsDefined(ast) => stmt"${ast.token}.isDefined"
case OptionSome(ast) => stmt"Some(${ast.token})"
case OptionApply(ast) => stmt"Option(${ast.token})"
case OptionOrNull(ast) => stmt"${ast.token}.orNull"
case OptionGetOrNull(ast) => stmt"${ast.token}.getOrNull"
case OptionNone => stmt"None"
}
implicit def traversableOperationTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[IterableOperation] = Tokenizer[IterableOperation] {
case MapContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
case SetContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
case ListContains(ast, body) => stmt"${ast.token}.contains(${body.token})"
}
implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] {
case InnerJoin => stmt"join"
case LeftJoin => stmt"leftJoin"
case RightJoin => stmt"rightJoin"
case FullJoin => stmt"fullJoin"
}
implicit def functionTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Function] = Tokenizer[Function] {
case Function(params, body) => stmt"(${params.token}) => ${body.token}"
}
implicit def operationTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Operation] = Tokenizer[Operation] {
case UnaryOperation(op: PrefixUnaryOperator, ast) =>
stmt"${op.token}${scopedTokenizer(ast)}"
case UnaryOperation(op: PostfixUnaryOperator, ast) =>
stmt"${scopedTokenizer(ast)}.${op.token}"
case BinaryOperation(a, op @ SetOperator.`contains`, b) =>
SetContainsToken(scopedTokenizer(b), op.token, a.token)
case BinaryOperation(a, op, b) =>
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
case FunctionApply(function, values) =>
stmt"${scopedTokenizer(function)}.apply(${values.token})"
}
implicit def operatorTokenizer[T <: Operator]: Tokenizer[T] = Tokenizer[T] {
case o => stmt"${o.toString.token}"
}
def tokenizeName(name: String, renameable: Renameable) =
renameable match {
case ByStrategy => name
case Fixed => s"`${name}`"
}
def bracketIfHidden(name: String, visibility: Visibility) =
(distinguishHidden, visibility) match {
case (true, Hidden) => s"[$name]"
case _ => name
}
implicit def propertyTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Property] = Tokenizer[Property] {
case Property.Opinionated(ExternalIdent(_), name, renameable, visibility) =>
stmt"${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
case Property.Opinionated(ref, name, renameable, visibility) =>
stmt"${scopedTokenizer(ref)}.${bracketIfHidden(tokenizeName(name, renameable), visibility).token}"
}
implicit val valueTokenizer: Tokenizer[Value] = Tokenizer[Value] {
case Constant(v: String) => stmt""""${v.token}""""
case Constant(()) => stmt"{}"
case Constant(v) => stmt"${v.toString.token}"
case NullValue => stmt"null"
case Tuple(values) => stmt"(${values.token})"
case CaseClass(values) =>
stmt"CaseClass(${values.map { case (k, v) => s"${k.token}: ${v.token}" }.mkString(", ").token})"
}
implicit val identTokenizer: Tokenizer[Ident] = Tokenizer[Ident] {
case Ident.Opinionated(name, visibility) =>
stmt"${bracketIfHidden(name, visibility).token}"
}
implicit val typeTokenizer: Tokenizer[ExternalIdent] =
Tokenizer[ExternalIdent] {
case e => stmt"${e.name.token}"
}
implicit val excludedTokenizer: Tokenizer[OnConflict.Excluded] =
Tokenizer[OnConflict.Excluded] {
case OnConflict.Excluded(ident) => stmt"${ident.token}"
}
implicit val existingTokenizer: Tokenizer[OnConflict.Existing] =
Tokenizer[OnConflict.Existing] {
case OnConflict.Existing(ident) => stmt"${ident.token}"
}
implicit def actionTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Action] = Tokenizer[Action] {
case Update(query, assignments) =>
stmt"${query.token}.update(${assignments.token})"
case Insert(query, assignments) =>
stmt"${query.token}.insert(${assignments.token})"
case Delete(query) => stmt"${query.token}.delete"
case Returning(query, alias, body) =>
stmt"${query.token}.returning((${alias.token}) => ${body.token})"
case ReturningGenerated(query, alias, body) =>
stmt"${query.token}.returningGenerated((${alias.token}) => ${body.token})"
case Foreach(query, alias, body) =>
stmt"${query.token}.foreach((${alias.token}) => ${body.token})"
case c: OnConflict => stmt"${c.token}"
}
implicit def conflictTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[OnConflict] = {
def targetProps(l: List[Property]) = l.map(p =>
Transform(p) {
case Ident(_) => Ident("_")
}
)
implicit val conflictTargetTokenizer: Tokenizer[OnConflict.Target] =
Tokenizer[OnConflict.Target] {
case OnConflict.NoTarget => stmt""
case OnConflict.Properties(props) =>
val listTokens = listTokenizer(astTokenizer).token(props)
stmt"(${listTokens})"
}
val updateAssignsTokenizer = Tokenizer[Assignment] {
case Assignment(i, p, v) =>
stmt"(${i.token}, e) => ${p.token} -> ${scopedTokenizer(v)}"
}
Tokenizer[OnConflict] {
case OnConflict(i, t, OnConflict.Update(assign)) =>
stmt"${i.token}.onConflictUpdate${t.token}(${assign.map(updateAssignsTokenizer.token).mkStmt()})"
case OnConflict(i, t, OnConflict.Ignore) =>
stmt"${i.token}.onConflictIgnore${t.token}"
}
}
implicit def assignmentTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Assignment] = Tokenizer[Assignment] {
case Assignment(ident, property, value) =>
stmt"${ident.token} => ${property.token} -> ${value.token}"
}
implicit def infixTokenizer(implicit
liftTokenizer: Tokenizer[Lift]
): Tokenizer[Infix] = Tokenizer[Infix] {
case Infix(parts, params, _, _) =>
def tokenParam(ast: Ast) =
ast match {
case ast: Ident => stmt"$$${ast.token}"
case other => stmt"$${${ast.token}}"
}
val pt = parts.map(_.token)
val pr = params.map(tokenParam)
val body = Statement(Interleave(pt, pr))
stmt"""infix"${body.token}""""
}
private def scopedTokenizer(
ast: Ast
)(implicit liftTokenizer: Tokenizer[Lift]) =
ast match {
case _: Function => stmt"(${ast.token})"
case _: BinaryOperation => stmt"(${ast.token})"
case other => ast.token
}
}

View file

@ -0,0 +1,100 @@
package minisql.idiom
import minisql.ParamEncoder
import minisql.ast.*
import minisql.util.Interleave
import minisql.idiom.StatementInterpolator.*
import scala.annotation.tailrec
import scala.collection.immutable.{Map => SMap}
object ReifyStatement {
def apply(
liftingPlaceholder: Int => String,
emptySetContainsToken: Token => Token,
statement: Statement,
liftMap: SMap[String, (Any, ParamEncoder[?])]
): (String, List[ScalarValueLift]) = {
val expanded = expandLiftings(statement, emptySetContainsToken, liftMap)
token2string(expanded, liftingPlaceholder)
}
private def token2string(
token: Token,
liftingPlaceholder: Int => String
): (String, List[ScalarValueLift]) = {
val liftBuilder = List.newBuilder[ScalarValueLift]
val sqlBuilder = StringBuilder()
@tailrec
def loop(
workList: Seq[Token],
liftingSize: Int
): Unit = workList match {
case Seq() => ()
case head +: tail =>
head match {
case StringToken(s2) =>
sqlBuilder ++= s2
loop(tail, liftingSize)
case SetContainsToken(a, op, b) =>
loop(
stmt"$a $op ($b)" +: tail,
liftingSize
)
case ScalarLiftToken(lift: ScalarValueLift) =>
sqlBuilder ++= liftingPlaceholder(liftingSize)
liftBuilder += lift
loop(tail, liftingSize + 1)
case ScalarLiftToken(o) =>
throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}")
case Statement(tokens) =>
loop(
tokens.foldRight(tail)(_ +: _),
liftingSize
)
}
}
loop(Vector(token), 0)
sqlBuilder.toString() -> liftBuilder.result()
}
private def expandLiftings(
statement: Statement,
emptySetContainsToken: Token => Token,
liftMap: SMap[String, (Any, ParamEncoder[?])]
): (Token) = {
Statement {
val lb = List.newBuilder[Token]
statement.tokens.foldLeft(lb) {
case (
tokens,
SetContainsToken(a, op, ScalarLiftToken(lift: ScalarQueryLift))
) =>
val (lv, le) = liftMap(lift.liftId)
lv.asInstanceOf[Iterable[Any]].toVector match {
case Vector() => tokens += emptySetContainsToken(a)
case values =>
val liftings = values.zipWithIndex.map {
case (v, i) =>
ScalarLiftToken(
ScalarValueLift(
s"${lift.name}[${i}]",
s"${lift.liftId}[${i}]",
Some(v -> le)
)
)
}
val separators = Vector.fill(liftings.size - 1)(StringToken(", "))
(tokens += stmt"$a $op (") ++= Interleave(
liftings,
separators
) += StringToken(")")
}
case (tokens, token) =>
tokens += token
}
lb.result()
}
}
}

View file

@ -0,0 +1,47 @@
package minisql.idiom
import scala.quoted._
import minisql.ast._
sealed trait Token
case class StringToken(string: String) extends Token {
override def toString = string
}
case class ScalarLiftToken(lift: ScalarLift) extends Token {
override def toString = s"lift(${lift.name})"
}
case class Statement(tokens: List[Token]) extends Token {
override def toString = tokens.mkString
}
case class SetContainsToken(a: Token, op: Token, b: Token) extends Token {
override def toString = s"${a.toString} ${op.toString} (${b.toString})"
}
object Statement {
given ToExpr[Statement] with {
def apply(t: Statement)(using Quotes): Expr[Statement] = {
'{ Statement(${ Expr(t.tokens) }) }
}
}
}
object Token {
given ToExpr[Token] with {
def apply(t: Token)(using Quotes): Expr[Token] = {
t match {
case StringToken(s) =>
'{ StringToken(${ Expr(s) }) }
case ScalarLiftToken(l) =>
'{ ScalarLiftToken(${ Expr(l) }) }
case SetContainsToken(a, op, b) =>
'{ SetContainsToken(${ Expr(a) }, ${ Expr(op) }, ${ Expr(b) }) }
case s: Statement => Expr(s)
}
}
}
}

View file

@ -0,0 +1,152 @@
package minisql.idiom
import minisql.ast._
import minisql.util.Interleave
import minisql.util.Messages._
import scala.collection.mutable.ListBuffer
object StatementInterpolator {
trait Tokenizer[T] {
extension (v: T) {
def token: Token
}
}
object Tokenizer {
def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] {
extension (v: T) {
def token: Token = f(v)
}
}
def withFallback[T](
fallback: Tokenizer[T] => Tokenizer[T]
)(pf: PartialFunction[T, Token]) =
new Tokenizer[T] {
extension (v: T) {
private def stable = fallback(this)
override def token = pf.applyOrElse(v, stable.token)
}
}
}
implicit class TokenImplicit[T](v: T)(implicit tokenizer: Tokenizer[T]) {
def token = tokenizer.token(v)
}
implicit def stringTokenizer: Tokenizer[String] =
Tokenizer[String] {
case string => StringToken(string)
}
implicit def liftTokenizer: Tokenizer[Lift] =
Tokenizer[Lift] {
case lift: ScalarLift => ScalarLiftToken(lift)
case lift =>
fail(
s"Can't tokenize a non-scalar lifting. ${lift.name}\n" +
s"\n" +
s"This might happen because:\n" +
s"* You are trying to insert or update an `Option[A]` field, but Scala infers the type\n" +
s" to `Some[A]` or `None.type`. For example:\n" +
s" run(query[Users].update(_.optionalField -> lift(Some(value))))" +
s" In that case, make sure the type is `Option`:\n" +
s" run(query[Users].update(_.optionalField -> lift(Some(value): Option[Int])))\n" +
s" or\n" +
s" run(query[Users].update(_.optionalField -> lift(Option(value))))\n" +
s"\n" +
s"* You are trying to insert or update whole Embedded case class. For example:\n" +
s" run(query[Users].update(_.embeddedCaseClass -> lift(someInstance)))\n" +
s" In that case, make sure you are updating individual columns, for example:\n" +
s" run(query[Users].update(\n" +
s" _.embeddedCaseClass.a -> lift(someInstance.a),\n" +
s" _.embeddedCaseClass.b -> lift(someInstance.b)\n" +
s" ))"
)
}
implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity)
implicit def statementTokenizer: Tokenizer[Statement] =
Tokenizer[Statement](identity)
implicit def stringTokenTokenizer: Tokenizer[StringToken] =
Tokenizer[StringToken](identity)
implicit def liftingTokenTokenizer: Tokenizer[ScalarLiftToken] =
Tokenizer[ScalarLiftToken](identity)
extension [T](list: List[T]) {
def mkStmt(sep: String = ", ")(implicit tokenize: Tokenizer[T]) = {
val l1 = list.map(_.token)
val l2 = List.fill(l1.size - 1)(StringToken(sep))
Statement(Interleave(l1, l2))
}
}
implicit def listTokenizer[T](implicit
tokenize: Tokenizer[T]
): Tokenizer[List[T]] =
Tokenizer[List[T]] {
case list => list.mkStmt()
}
extension (sc: StringContext) {
def flatten(tokens: List[Token]): List[Token] = {
def unestStatements(tokens: List[Token]): List[Token] = {
tokens.flatMap {
case Statement(innerTokens) => unestStatements(innerTokens)
case token => token :: Nil
}
}
def mergeStringTokens(tokens: List[Token]): List[Token] = {
val (resultBuilder, leftTokens) =
tokens.foldLeft((new ListBuffer[Token], new ListBuffer[String])) {
case ((builder, acc), stringToken: StringToken) =>
val str = stringToken.string
if (str.nonEmpty)
acc += stringToken.string
(builder, acc)
case ((builder, prev), b) if prev.isEmpty =>
(builder += b.token, prev)
case ((builder, prev), b) /* if prev.nonEmpty */ =>
builder += StringToken(prev.result().mkString)
builder += b.token
(builder, new ListBuffer[String])
}
if (leftTokens.nonEmpty)
resultBuilder += StringToken(leftTokens.result().mkString)
resultBuilder.result()
}
(unestStatements)
.andThen(mergeStringTokens)
.apply(tokens)
}
def checkLengths(
args: scala.collection.Seq[Any],
parts: Seq[String]
): Unit =
if (parts.length != args.length + 1)
throw new IllegalArgumentException(
"wrong number of arguments (" + args.length
+ ") for interpolated string with " + parts.length + " parts"
)
def stmt(args: Token*): Statement = {
checkLengths(args, sc.parts)
val partsIterator = sc.parts.iterator
val argsIterator = args.iterator
val bldr = List.newBuilder[Token]
bldr += StringToken(partsIterator.next())
while (argsIterator.hasNext) {
bldr += argsIterator.next()
bldr += StringToken(partsIterator.next())
}
val tokens = flatten(bldr.result())
Statement(tokens)
}
}
}

View file

@ -0,0 +1,52 @@
package minisql.norm
import minisql.ast.BinaryOperation
import minisql.ast.BooleanOperator
import minisql.ast.Filter
import minisql.ast.FlatMap
import minisql.ast.Map
import minisql.ast.Query
import minisql.ast.Union
import minisql.ast.UnionAll
object AdHocReduction {
def unapply(q: Query) =
q match {
// ---------------------------
// *.filter
// a.filter(b => c).filter(d => e) =>
// a.filter(b => c && e[d := b])
case Filter(Filter(a, b, c), d, e) =>
val er = BetaReduction(e, d -> b)
Some(Filter(a, b, BinaryOperation(c, BooleanOperator.`&&`, er)))
// ---------------------------
// flatMap.*
// a.flatMap(b => c).map(d => e) =>
// a.flatMap(b => c.map(d => e))
case Map(FlatMap(a, b, c), d, e) =>
Some(FlatMap(a, b, Map(c, d, e)))
// a.flatMap(b => c).filter(d => e) =>
// a.flatMap(b => c.filter(d => e))
case Filter(FlatMap(a, b, c), d, e) =>
Some(FlatMap(a, b, Filter(c, d, e)))
// a.flatMap(b => c.union(d))
// a.flatMap(b => c).union(a.flatMap(b => d))
case FlatMap(a, b, Union(c, d)) =>
Some(Union(FlatMap(a, b, c), FlatMap(a, b, d)))
// a.flatMap(b => c.unionAll(d))
// a.flatMap(b => c).unionAll(a.flatMap(b => d))
case FlatMap(a, b, UnionAll(c, d)) =>
Some(UnionAll(FlatMap(a, b, c), FlatMap(a, b, d)))
case other => None
}
}

View file

@ -0,0 +1,160 @@
package minisql.norm
import minisql.ast._
object ApplyMap {
private def isomorphic(e: Ast, c: Ast, alias: Ident) =
BetaReduction(e, alias -> c) == c
object InfixedTailOperation {
def hasImpureInfix(ast: Ast) =
CollectAst(ast) {
case i @ Infix(_, _, false, _) => i
}.nonEmpty
def unapply(ast: Ast): Option[Ast] =
ast match {
case cc: CaseClass if hasImpureInfix(cc) => Some(cc)
case tup: Tuple if hasImpureInfix(tup) => Some(tup)
case p: Property if hasImpureInfix(p) => Some(p)
case b: BinaryOperation if hasImpureInfix(b) => Some(b)
case u: UnaryOperation if hasImpureInfix(u) => Some(u)
case i @ Infix(_, _, false, _) => Some(i)
case _ => None
}
}
object MapWithoutInfixes {
def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
ast match {
case Map(a, b, InfixedTailOperation(c)) => None
case Map(a, b, c) => Some((a, b, c))
case _ => None
}
}
object DetachableMap {
def unapply(ast: Ast): Option[(Ast, Ident, Ast)] =
ast match {
case Map(a: GroupBy, b, c) => None
case Map(a, b, InfixedTailOperation(c)) => None
case Map(a, b, c) => Some((a, b, c))
case _ => None
}
}
def unapply(q: Query): Option[Query] =
q match {
case Map(a: GroupBy, b, c) if (b == c) => None
case Map(a: DistinctOn, b, c) => None
case Map(a: Nested, b, c) if (b == c) => None
case Nested(DetachableMap(a: Join, b, c)) => None
// map(i => (i.i, i.l)).distinct.map(x => (x._1, x._2)) =>
// map(i => (i.i, i.l)).distinct
case Map(Distinct(DetachableMap(a, b, c)), d, e) if isomorphic(e, c, d) =>
Some(Distinct(Map(a, b, c)))
// a.map(b => c).map(d => e) =>
// a.map(b => e[d := c])
case before @ Map(MapWithoutInfixes(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
Some(Map(a, b, er))
// a.map(b => b) =>
// a
case Map(a: Query, b, c) if (b == c) =>
Some(a)
// a.map(b => c).flatMap(d => e) =>
// a.flatMap(b => e[d := c])
case FlatMap(DetachableMap(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
Some(FlatMap(a, b, er))
// a.map(b => c).filter(d => e) =>
// a.filter(b => e[d := c]).map(b => c)
case Filter(DetachableMap(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
Some(Map(Filter(a, b, er), b, c))
// a.map(b => c).sortBy(d => e) =>
// a.sortBy(b => e[d := c]).map(b => c)
case SortBy(DetachableMap(a, b, c), d, e, f) =>
val er = BetaReduction(e, d -> c)
Some(Map(SortBy(a, b, er, f), b, c))
// a.map(b => c).sortBy(d => e).distinct =>
// a.sortBy(b => e[d := c]).map(b => c).distinct
case SortBy(Distinct(DetachableMap(a, b, c)), d, e, f) =>
val er = BetaReduction(e, d -> c)
Some(Distinct(Map(SortBy(a, b, er, f), b, c)))
// a.map(b => c).groupBy(d => e) =>
// a.groupBy(b => e[d := c]).map(x => (x._1, x._2.map(b => c)))
case GroupBy(DetachableMap(a, b, c), d, e) =>
val er = BetaReduction(e, d -> c)
val x = Ident("x")
val x1 = Property(
Ident("x"),
"_1"
) // These introduced property should not be renamed
val x2 = Property(Ident("x"), "_2") // due to any naming convention.
val body = Tuple(List(x1, Map(x2, b, c)))
Some(Map(GroupBy(a, b, er), x, body))
// a.map(b => c).drop(d) =>
// a.drop(d).map(b => c)
case Drop(DetachableMap(a, b, c), d) =>
Some(Map(Drop(a, d), b, c))
// a.map(b => c).take(d) =>
// a.drop(d).map(b => c)
case Take(DetachableMap(a, b, c), d) =>
Some(Map(Take(a, d), b, c))
// a.map(b => c).nested =>
// a.nested.map(b => c)
case Nested(DetachableMap(a, b, c)) =>
Some(Map(Nested(a), b, c))
// a.map(b => c).*join(d.map(e => f)).on((iA, iB) => on)
// a.*join(d).on((b, e) => on[iA := c, iB := f]).map(t => (c[b := t._1], f[e := t._2]))
case Join(
tpe,
DetachableMap(a, b, c),
DetachableMap(d, e, f),
iA,
iB,
on
) =>
val onr = BetaReduction(on, iA -> c, iB -> f)
val t = Ident("t")
val t1 = BetaReduction(c, b -> Property(t, "_1"))
val t2 = BetaReduction(f, e -> Property(t, "_2"))
Some(Map(Join(tpe, a, d, b, e, onr), t, Tuple(List(t1, t2))))
// a.*join(b.map(c => d)).on((iA, iB) => on)
// a.*join(b).on((iA, c) => on[iB := d]).map(t => (t._1, d[c := t._2]))
case Join(tpe, a, DetachableMap(b, c, d), iA, iB, on) =>
val onr = BetaReduction(on, iB -> d)
val t = Ident("t")
val t1 = Property(t, "_1")
val t2 = BetaReduction(d, c -> Property(t, "_2"))
Some(Map(Join(tpe, a, b, iA, c, onr), t, Tuple(List(t1, t2))))
// a.map(b => c).*join(d).on((iA, iB) => on)
// a.*join(d).on((b, iB) => on[iA := c]).map(t => (c[b := t._1], t._2))
case Join(tpe, DetachableMap(a, b, c), d, iA, iB, on) =>
val onr = BetaReduction(on, iA -> c)
val t = Ident("t")
val t1 = BetaReduction(c, b -> Property(t, "_1"))
val t2 = Property(t, "_2")
Some(Map(Join(tpe, a, d, b, iB, onr), t, Tuple(List(t1, t2))))
case other => None
}
}

View file

@ -0,0 +1,48 @@
package minisql.norm
import minisql.util.Messages.fail
import minisql.ast._
object AttachToEntity {
private object IsEntity {
def unapply(q: Ast): Option[Ast] =
q match {
case q: Entity => Some(q)
case q: Infix => Some(q)
case _ => None
}
}
def apply(f: (Ast, Ident) => Query, alias: Option[Ident] = None)(
q: Ast
): Ast =
q match {
case Map(IsEntity(a), b, c) => Map(f(a, b), b, c)
case FlatMap(IsEntity(a), b, c) => FlatMap(f(a, b), b, c)
case ConcatMap(IsEntity(a), b, c) => ConcatMap(f(a, b), b, c)
case Filter(IsEntity(a), b, c) => Filter(f(a, b), b, c)
case SortBy(IsEntity(a), b, c, d) => SortBy(f(a, b), b, c, d)
case DistinctOn(IsEntity(a), b, c) => DistinctOn(f(a, b), b, c)
case Map(_: GroupBy, _, _) | _: Union | _: UnionAll | _: Join |
_: FlatJoin =>
f(q, alias.getOrElse(Ident("x")))
case Map(a: Query, b, c) => Map(apply(f, Some(b))(a), b, c)
case FlatMap(a: Query, b, c) => FlatMap(apply(f, Some(b))(a), b, c)
case ConcatMap(a: Query, b, c) => ConcatMap(apply(f, Some(b))(a), b, c)
case Filter(a: Query, b, c) => Filter(apply(f, Some(b))(a), b, c)
case SortBy(a: Query, b, c, d) => SortBy(apply(f, Some(b))(a), b, c, d)
case Take(a: Query, b) => Take(apply(f, alias)(a), b)
case Drop(a: Query, b) => Drop(apply(f, alias)(a), b)
case Aggregation(op, a: Query) => Aggregation(op, apply(f, alias)(a))
case Distinct(a: Query) => Distinct(apply(f, alias)(a))
case DistinctOn(a: Query, b, c) => DistinctOn(apply(f, Some(b))(a), b, c)
case IsEntity(q) => f(q, alias.getOrElse(Ident("x")))
case other => fail(s"Can't find an 'Entity' in '$q'")
}
}

View file

@ -1,6 +1,6 @@
package minisql.util
package minisql.norm
import minisql.ast.*
import minisql.ast._
import scala.collection.immutable.{Map => IMap}
case class BetaReduction(replacements: Replacements)

View file

@ -0,0 +1,7 @@
package minisql.norm
trait ConcatBehavior
object ConcatBehavior {
case object AnsiConcat extends ConcatBehavior
case object NonAnsiConcat extends ConcatBehavior
}

View file

@ -0,0 +1,7 @@
package minisql.norm
trait EqualityBehavior
object EqualityBehavior {
case object AnsiEquality extends EqualityBehavior
case object NonAnsiEquality extends EqualityBehavior
}

View file

@ -0,0 +1,74 @@
package minisql.norm
import minisql.ReturnAction.ReturnColumns
import minisql.{NamingStrategy, ReturnAction}
import minisql.ast._
import minisql.context.{
ReturningClauseSupported,
ReturningMultipleFieldSupported,
ReturningNotSupported,
ReturningSingleFieldSupported
}
import minisql.idiom.{Idiom, Statement}
/**
* Take the `.returning` part in a query that contains it and return the array
* of columns representing of the returning seccovtion with any other operations
* etc... that they might contain.
*/
object ExpandReturning {
def applyMap(
returning: ReturningAction
)(f: (Ast, Statement) => String)(idiom: Idiom, naming: NamingStrategy) = {
val initialExpand = ExpandReturning.apply(returning)(idiom, naming)
idiom.idiomReturningCapability match {
case ReturningClauseSupported =>
ReturnAction.ReturnRecord
case ReturningMultipleFieldSupported =>
ReturnColumns(initialExpand.map {
case (ast, statement) => f(ast, statement)
})
case ReturningSingleFieldSupported =>
if (initialExpand.length == 1)
ReturnColumns(initialExpand.map {
case (ast, statement) => f(ast, statement)
})
else
throw new IllegalArgumentException(
s"Only one RETURNING column is allowed in the ${idiom} dialect but ${initialExpand.length} were specified."
)
case ReturningNotSupported =>
throw new IllegalArgumentException(
s"RETURNING columns are not allowed in the ${idiom} dialect."
)
}
}
def apply(
returning: ReturningAction
)(idiom: Idiom, naming: NamingStrategy): List[(Ast, Statement)] = {
val ReturningAction(_, alias, properties) = returning: @unchecked
// Ident("j"), Tuple(List(Property(Ident("j"), "name"), BinaryOperation(Property(Ident("j"), "age"), +, Constant(1))))
// => Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
val dePropertized =
Transform(properties) {
case `alias` => ExternalIdent(alias.name)
}
val aliasName = alias.name
// Tuple(List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1))))
// => List(ExternalIdent("name"), BinaryOperation(ExternalIdent("age"), +, Constant(1)))
val deTuplified = dePropertized match {
case Tuple(values) => values
case CaseClass(values) => values.map(_._2)
case other => List(other)
}
implicit val namingStrategy: NamingStrategy = naming
deTuplified.map(v => idiom.translate(v))
}
}

View file

@ -0,0 +1,108 @@
package minisql.norm
import minisql.ast.*
import minisql.ast.Implicits.*
import minisql.norm.ConcatBehavior.NonAnsiConcat
class FlattenOptionOperation(concatBehavior: ConcatBehavior)
extends StatelessTransformer {
private def emptyOrNot(b: Boolean, ast: Ast) =
if (b) OptionIsEmpty(ast) else OptionNonEmpty(ast)
def uncheckedReduction(ast: Ast, alias: Ident, body: Ast) =
apply(BetaReduction(body, alias -> ast))
def uncheckedForall(ast: Ast, alias: Ident, body: Ast) = {
val reduced = BetaReduction(body, alias -> ast)
apply((IsNullCheck(ast) +||+ reduced): Ast)
}
def containsNonFallthroughElement(ast: Ast) =
CollectAst(ast) {
case If(_, _, _) => true
case Infix(_, _, _, _) => true
case BinaryOperation(_, StringOperator.`concat`, _)
if (concatBehavior == NonAnsiConcat) =>
true
}.nonEmpty
override def apply(ast: Ast): Ast =
ast match {
case OptionTableFlatMap(ast, alias, body) =>
uncheckedReduction(ast, alias, body)
case OptionTableMap(ast, alias, body) =>
uncheckedReduction(ast, alias, body)
case OptionTableExists(ast, alias, body) =>
uncheckedReduction(ast, alias, body)
case OptionTableForall(ast, alias, body) =>
uncheckedForall(ast, alias, body)
case OptionFlatten(ast) =>
apply(ast)
case OptionSome(ast) =>
apply(ast)
case OptionApply(ast) =>
apply(ast)
case OptionOrNull(ast) =>
apply(ast)
case OptionGetOrNull(ast) =>
apply(ast)
case OptionNone => NullValue
case OptionGetOrElse(OptionMap(ast, alias, body), Constant(b: Boolean)) =>
apply((BetaReduction(body, alias -> ast) +||+ emptyOrNot(b, ast)): Ast)
case OptionGetOrElse(ast, body) =>
apply(If(IsNotNullCheck(ast), ast, body))
case OptionFlatMap(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduced = BetaReduction(body, alias -> ast)
apply(IfExistElseNull(ast, reduced))
} else {
uncheckedReduction(ast, alias, body)
}
case OptionMap(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduced = BetaReduction(body, alias -> ast)
apply(IfExistElseNull(ast, reduced))
} else {
uncheckedReduction(ast, alias, body)
}
case OptionForall(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduction = BetaReduction(body, alias -> ast)
apply(
(IsNullCheck(ast) +||+ (IsNotNullCheck(ast) +&&+ reduction)): Ast
)
} else {
uncheckedForall(ast, alias, body)
}
case OptionExists(ast, alias, body) =>
if (containsNonFallthroughElement(body)) {
val reduction = BetaReduction(body, alias -> ast)
apply((IsNotNullCheck(ast) +&&+ reduction): Ast)
} else {
uncheckedReduction(ast, alias, body)
}
case OptionContains(ast, body) =>
apply((ast +==+ body): Ast)
case other =>
super.apply(other)
}
}

View file

@ -0,0 +1,76 @@
package minisql.norm
import minisql.ast._
/**
* A problem occurred in the original way infixes were done in that it was
* assumed that infix clauses represented pure functions. While this is true of
* many UDFs (e.g. `CONCAT`, `GETDATE`) it is certainly not true of many others
* e.g. `RAND()`, and most importantly `RANK()`. For this reason, the operations
* that are done in `ApplyMap` on standard AST `Map` clauses cannot be done
* therefore additional safety checks were introduced there in order to assure
* this does not happen. In addition to this however, it is necessary to add
* this normalization step which inserts `Nested` AST elements in every map that
* contains impure infix. See more information and examples in #1534.
*/
object NestImpureMappedInfix extends StatelessTransformer {
// Are there any impure infixes that exist inside the specified ASTs
def hasInfix(asts: Ast*): Boolean =
asts.exists(ast =>
CollectAst(ast) {
case i @ Infix(_, _, false, _) => i
}.nonEmpty
)
// Continue exploring into the Map to see if there are additional impure infix clauses inside.
private def applyInside(m: Map) =
Map(apply(m.query), m.alias, m.body)
override def apply(ast: Ast): Ast =
ast match {
// If there is already a nested clause inside the map, there is no reason to insert another one
case Nested(Map(inner, a, b)) =>
Nested(Map(apply(inner), a, b))
case m @ Map(_, x, cc @ CaseClass(values)) if hasInfix(cc) => // Nested(m)
Map(
Nested(applyInside(m)),
x,
CaseClass(values.map {
case (name, _) =>
(
name,
Property(x, name)
) // mappings of nested-query case class properties should not be renamed
})
)
case m @ Map(_, x, tup @ Tuple(values)) if hasInfix(tup) =>
Map(
Nested(applyInside(m)),
x,
Tuple(values.zipWithIndex.map {
case (_, i) =>
Property(
x,
s"_${i + 1}"
) // mappings of nested-query tuple properties should not be renamed
})
)
case m @ Map(_, x, i @ Infix(_, _, false, _)) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case m @ Map(_, x, Property(prop, _)) if hasInfix(prop) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case m @ Map(_, x, BinaryOperation(a, _, b)) if hasInfix(a, b) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case m @ Map(_, x, UnaryOperation(_, a)) if hasInfix(a) =>
Map(Nested(applyInside(m)), x, Property(x, "_1"))
case other => super.apply(other)
}
}

View file

@ -0,0 +1,51 @@
package minisql.norm
import minisql.ast.Ast
import minisql.ast.Query
import minisql.ast.StatelessTransformer
import minisql.norm.capture.AvoidCapture
import minisql.ast.Action
import minisql.util.Messages.trace
import minisql.util.Messages.TraceType.Normalizations
import scala.annotation.tailrec
object Normalize extends StatelessTransformer {
override def apply(q: Ast): Ast =
super.apply(BetaReduction(q))
override def apply(q: Action): Action =
NormalizeReturning(super.apply(q))
override def apply(q: Query): Query =
norm(AvoidCapture(q))
private def traceNorm[T](label: String) =
trace[T](s"${label} (Normalize)", 1, Normalizations)
@tailrec
private def norm(q: Query): Query =
q match {
case NormalizeNestedStructures(query) =>
traceNorm("NormalizeNestedStructures")(query)
norm(query)
case ApplyMap(query) =>
traceNorm("ApplyMap")(query)
norm(query)
case SymbolicReduction(query) =>
traceNorm("SymbolicReduction")(query)
norm(query)
case AdHocReduction(query) =>
traceNorm("AdHocReduction")(query)
norm(query)
case OrderTerms(query) =>
traceNorm("OrderTerms")(query)
norm(query)
case NormalizeAggregationIdent(query) =>
traceNorm("NormalizeAggregationIdent")(query)
norm(query)
case other =>
other
}
}

View file

@ -0,0 +1,29 @@
package minisql.norm
import minisql.ast._
object NormalizeAggregationIdent {
def unapply(q: Query) =
q match {
// a => a.b.map(x => x.c).agg =>
// a => a.b.map(a => a.c).agg
case Aggregation(
op,
Map(
p @ Property(i: Ident, _),
mi,
Property.Opinionated(_: Ident, n, renameable, visibility)
)
) if i != mi =>
Some(
Aggregation(
op,
Map(p, i, Property.Opinionated(i, n, renameable, visibility))
)
) // in example aove, if c in x.c is fixed c in a.c should also be
case _ => None
}
}

View file

@ -0,0 +1,47 @@
package minisql.norm
import minisql.ast._
object NormalizeNestedStructures {
def unapply(q: Query): Option[Query] =
q match {
case e: Entity => None
case Map(a, b, c) => apply(a, c)(Map(_, b, _))
case FlatMap(a, b, c) => apply(a, c)(FlatMap(_, b, _))
case ConcatMap(a, b, c) => apply(a, c)(ConcatMap(_, b, _))
case Filter(a, b, c) => apply(a, c)(Filter(_, b, _))
case SortBy(a, b, c, d) => apply(a, c)(SortBy(_, b, _, d))
case GroupBy(a, b, c) => apply(a, c)(GroupBy(_, b, _))
case Aggregation(a, b) => apply(b)(Aggregation(a, _))
case Take(a, b) => apply(a, b)(Take.apply)
case Drop(a, b) => apply(a, b)(Drop.apply)
case Union(a, b) => apply(a, b)(Union.apply)
case UnionAll(a, b) => apply(a, b)(UnionAll.apply)
case Distinct(a) => apply(a)(Distinct.apply)
case DistinctOn(a, b, c) => apply(a, c)(DistinctOn(_, b, _))
case Nested(a) => apply(a)(Nested.apply)
case FlatJoin(t, a, iA, on) =>
(Normalize(a), Normalize(on)) match {
case (`a`, `on`) => None
case (a, on) => Some(FlatJoin(t, a, iA, on))
}
case Join(t, a, b, iA, iB, on) =>
(Normalize(a), Normalize(b), Normalize(on)) match {
case (`a`, `b`, `on`) => None
case (a, b, on) => Some(Join(t, a, b, iA, iB, on))
}
}
private def apply(a: Ast)(f: Ast => Query) =
(Normalize(a)) match {
case (`a`) => None
case (a) => Some(f(a))
}
private def apply(a: Ast, b: Ast)(f: (Ast, Ast) => Query) =
(Normalize(a), Normalize(b)) match {
case (`a`, `b`) => None
case (a, b) => Some(f(a, b))
}
}

View file

@ -0,0 +1,154 @@
package minisql.norm
import minisql.ast._
import minisql.norm.capture.AvoidAliasConflict
/**
* When actions are used with a `.returning` clause, remove the columns used in
* the returning clause from the action. E.g. for `insert(Person(id,
* name)).returning(_.id)` remove the `id` column from the original insert.
*/
object NormalizeReturning {
def apply(e: Action): Action = {
e match {
case ReturningGenerated(a: Action, alias, body) =>
// De-alias the body first so variable shadows won't accidentally be interpreted as columns to remove from the insert/update action.
// This typically occurs in advanced cases where actual queries are used in the return clauses which is only supported in Postgres.
// For example:
// query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
// Since the property `t.id` is used both for the `returning` clause and the query inside, it can accidentally
// be seen as a variable used in `returning` hence excluded from insertion which is clearly not the case.
// In order to fix this, we need to change `t` into a different alias.
val newBody = dealiasBody(body, alias)
ReturningGenerated(apply(a, newBody, alias), alias, newBody)
// For a regular return clause, do not need to exclude assignments from insertion however, we still
// need to de-alias the Action body in case conflicts result. For example the following query:
// query[Entity].insert(lift(Person(id, name))).returning(t => (query[Dummy].map(t => t.id).max))
// would incorrectly be interpreted as:
// INSERT INTO Person (id, name) VALUES (1, 'Joe') RETURNING (SELECT MAX(id) FROM Dummy t) -- Note the 'id' in max which is coming from the inserted table instead of t
// whereas it should be:
// INSERT INTO Entity (id) VALUES (1) RETURNING (SELECT MAX(t.id) FROM Dummy t1)
case Returning(a: Action, alias, body) =>
val newBody = dealiasBody(body, alias)
Returning(a, alias, newBody)
case _ => e
}
}
/**
* In some situations, a query can exist inside of a `returning` clause. In
* this case, we need to rename if the aliases used in that query override the
* alias used in the `returning` clause otherwise they will be treated as
* returning-clause aliases ExpandReturning (i.e. they will become
* ExternalAlias instances) and later be tokenized incorrectly.
*/
private def dealiasBody(body: Ast, alias: Ident): Ast =
Transform(body) {
case q: Query => AvoidAliasConflict.sanitizeQuery(q, Set(alias))
}
private def apply(e: Action, body: Ast, returningIdent: Ident): Action =
e match {
case Insert(query, assignments) =>
Insert(query, filterReturnedColumn(assignments, body, returningIdent))
case Update(query, assignments) =>
Update(query, filterReturnedColumn(assignments, body, returningIdent))
case OnConflict(a: Action, target, act) =>
OnConflict(apply(a, body, returningIdent), target, act)
case _ => e
}
private def filterReturnedColumn(
assignments: List[Assignment],
column: Ast,
returningIdent: Ident
): List[Assignment] =
assignments.flatMap(filterReturnedColumn(_, column, returningIdent))
/**
* In situations like Property(Property(ident, foo), bar) pull out the
* inner-most ident
*/
object NestedProperty {
def unapply(ast: Property): Option[Ast] = {
ast match {
case p @ Property(subAst, _) => Some(innerMost(subAst))
}
}
private def innerMost(ast: Ast): Ast = ast match {
case Property(inner, _) => innerMost(inner)
case other => other
}
}
/**
* Remove the specified column from the assignment. For example, in a query
* like `insert(Person(id, name)).returning(r => r.id)` we need to remove the
* `id` column from the insertion. The value of the `column:Ast` in this case
* will be `Property(Ident(r), id)` and the values fo the assignment `p1`
* property will typically be `v.id` and `v.name` (the `v` variable is a
* default used for `insert` queries).
*/
private def filterReturnedColumn(
assignment: Assignment,
body: Ast,
returningIdent: Ident
): Option[Assignment] =
assignment match {
case Assignment(_, p1: Property, _) => {
// Pull out instance of the column usage. The `column` ast will typically be Property(table, field) but
// if the user wants to return multiple things it can also be a tuple Tuple(List(Property(table, field1), Property(table, field2))
// or it can even be a query since queries are allowed to be in return sections e.g:
// query[Entity].insert(lift(Person(id, name))).returning(r => (query[Dummy].filter(t => t.id == r.id).max))
// In all of these cases, we need to pull out the Property (e.g. t.id) in order to compare it to the assignment
// in order to know what to exclude.
val matchedProps =
CollectAst(body) {
// case prop @ NestedProperty(`returningIdent`) => prop
case prop @ NestedProperty(Ident(name))
if (name == returningIdent.name) =>
prop
case prop @ NestedProperty(ExternalIdent(name))
if (name == returningIdent.name) =>
prop
}
if (
matchedProps.exists(matchedProp => isSameProperties(p1, matchedProp))
)
None
else
Some(assignment)
}
case assignment => Some(assignment)
}
object SomeIdent {
def unapply(ast: Ast): Option[Ast] =
ast match {
case id: Ident => Some(id)
case id: ExternalIdent => Some(id)
case _ => None
}
}
/**
* Is it the same property (but possibly of a different identity). E.g.
* `p.foo.bar` and `v.foo.bar`
*/
private def isSameProperties(p1: Property, p2: Property): Boolean =
(p1.ast, p2.ast) match {
case (SomeIdent(_), SomeIdent(_)) =>
p1.name == p2.name
// If it's Property(Property(Id), name) == Property(Property(Id), name) we need to check that the
// outer properties are the same before moving on to the inner ones.
case (pp1: Property, pp2: Property) if (p1.name == p2.name) =>
isSameProperties(pp1, pp2)
case _ =>
false
}
}

View file

@ -0,0 +1,29 @@
package minisql.norm
import minisql.ast._
object OrderTerms {
def unapply(q: Query) =
q match {
case Take(Map(a: GroupBy, b, c), d) => None
// a.sortBy(b => c).filter(d => e) =>
// a.filter(d => e).sortBy(b => c)
case Filter(SortBy(a, b, c, d), e, f) =>
Some(SortBy(Filter(a, e, f), b, c, d))
// a.flatMap(b => c).take(n).map(d => e) =>
// a.flatMap(b => c).map(d => e).take(n)
case Map(Take(fm: FlatMap, n), ma, mb) =>
Some(Take(Map(fm, ma, mb), n))
// a.flatMap(b => c).drop(n).map(d => e) =>
// a.flatMap(b => c).map(d => e).drop(n)
case Map(Drop(fm: FlatMap, n), ma, mb) =>
Some(Drop(Map(fm, ma, mb), n))
case other => None
}
}

View file

@ -0,0 +1,491 @@
package minisql.norm
import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.Visible
import minisql.ast._
import minisql.util.Interpolator
object RenameProperties extends StatelessTransformer {
val interp = new Interpolator(3)
import interp._
def traceDifferent(one: Any, two: Any) =
if (one != two)
trace"Replaced $one with $two".andLog()
else
trace"Replacements did not match".andLog()
override def apply(q: Query): Query =
applySchemaOnly(q)
override def apply(q: Action): Action =
applySchema(q) match {
case (q, schema) => q
}
override def apply(e: Operation): Operation =
e match {
case UnaryOperation(o, c: Query) =>
UnaryOperation(o, applySchemaOnly(apply(c)))
case _ => super.apply(e)
}
private def applySchema(q: Action): (Action, Schema) =
q match {
case Insert(q: Query, assignments) =>
applySchema(q, assignments, Insert.apply)
case Update(q: Query, assignments) =>
applySchema(q, assignments, Update.apply)
case Delete(q: Query) =>
applySchema(q) match {
case (q, schema) => (Delete(q), schema)
}
case Returning(action: Action, alias, body) =>
applySchema(action) match {
case (action, schema) =>
val replace =
trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
replacements(alias, schema)
val bodyr = BetaReduction(body, replace*)
traceDifferent(body, bodyr)
(Returning(action, alias, bodyr), schema)
}
case ReturningGenerated(action: Action, alias, body) =>
applySchema(action) match {
case (action, schema) =>
val replace =
trace"Finding Replacements for $body inside $alias using schema $schema:" `andReturn`
replacements(alias, schema)
val bodyr = BetaReduction(body, replace*)
traceDifferent(body, bodyr)
(ReturningGenerated(action, alias, bodyr), schema)
}
case OnConflict(a: Action, target, act) =>
applySchema(a) match {
case (action, schema) =>
val targetR = target match {
case OnConflict.Properties(props) =>
val propsR = props.map { prop =>
val replace =
trace"Finding Replacements for $props inside ${prop.ast} using schema $schema:" `andReturn`
replacements(
prop.ast,
schema
) // A BetaReduction on a Property will always give back a Property
BetaReduction(prop, replace*).asInstanceOf[Property]
}
traceDifferent(props, propsR)
OnConflict.Properties(propsR)
case OnConflict.NoTarget => target
}
val actR = act match {
case OnConflict.Update(assignments) =>
OnConflict.Update(replaceAssignments(assignments, schema))
case _ => act
}
(OnConflict(action, targetR, actR), schema)
}
case q => (q, TupleSchema.empty)
}
private def replaceAssignments(
a: List[Assignment],
schema: Schema
): List[Assignment] =
a.map {
case Assignment(alias, prop, value) =>
val replace =
trace"Finding Replacements for $prop inside $alias using schema $schema:" `andReturn`
replacements(alias, schema)
val propR = BetaReduction(prop, replace*)
traceDifferent(prop, propR)
val valueR = BetaReduction(value, replace*)
traceDifferent(value, valueR)
Assignment(alias, propR, valueR)
}
private def applySchema(
q: Query,
a: List[Assignment],
f: (Query, List[Assignment]) => Action
): (Action, Schema) =
applySchema(q) match {
case (q, schema) =>
(f(q, replaceAssignments(a, schema)), schema)
}
private def applySchemaOnly(q: Query): Query =
applySchema(q) match {
case (q, _) => q
}
object TupleIndex {
def unapply(s: String): Option[Int] =
if (s.matches("_[0-9]*"))
Some(s.drop(1).toInt - 1)
else
None
}
sealed trait Schema {
def lookup(property: List[String]): Option[Schema] =
(property, this) match {
case (Nil, schema) =>
trace"Nil at $property returning " `andReturn`
Some(schema)
case (path, e @ EntitySchema(_)) =>
trace"Entity at $path returning " `andReturn`
Some(e.subSchemaOrEmpty(path))
case (head :: tail, CaseClassSchema(props)) if (props.contains(head)) =>
trace"Case class at $property returning " `andReturn`
props(head).lookup(tail)
case (TupleIndex(idx) :: tail, TupleSchema(values))
if values.contains(idx) =>
trace"Tuple at at $property returning " `andReturn`
values(idx).lookup(tail)
case _ =>
trace"Nothing found at $property returning " `andReturn`
None
}
}
// Represents a nested property path to an identity i.e. Property(Property(... Ident(), ...))
object PropertyMatroshka {
def traverse(initial: Property): Option[(Ident, List[String])] =
initial match {
// If it's a nested-property walk inside and append the name to the result (if something is returned)
case Property(inner: Property, name) =>
traverse(inner).map { case (id, list) => (id, list :+ name) }
// If it's a property with ident in the core, return that
case Property(id: Ident, name) =>
Some((id, List(name)))
// Otherwise an ident property is not inside so don't return anything
case _ =>
None
}
def unapply(ast: Ast): Option[(Ident, List[String])] =
ast match {
case p: Property => traverse(p)
case _ => None
}
}
def protractSchema(
body: Ast,
ident: Ident,
schema: Schema
): Option[Schema] = {
def protractSchemaRecurse(body: Ast, schema: Schema): Option[Schema] =
body match {
// if any values yield a sub-schema which is not an entity, recurse into that
case cc @ CaseClass(values) =>
trace"Protracting CaseClass $cc into new schema:" `andReturn`
CaseClassSchema(
values.collect {
case (name, innerBody @ HierarchicalAstEntity()) =>
(name, protractSchemaRecurse(innerBody, schema))
// pass the schema into a recursive call an extract from it when we non tuple/caseclass element
case (name, innerBody @ PropertyMatroshka(`ident`, path)) =>
(name, protractSchemaRecurse(innerBody, schema))
// we have reached an ident i.e. recurse to pass the current schema into the case class
case (name, `ident`) =>
(name, protractSchemaRecurse(ident, schema))
}.collect {
case (name, Some(subSchema)) => (name, subSchema)
}
).notEmpty
case tup @ Tuple(values) =>
trace"Protracting Tuple $tup into new schema:" `andReturn`
TupleSchema
.fromIndexes(
values.zipWithIndex.collect {
case (innerBody @ HierarchicalAstEntity(), index) =>
(index, protractSchemaRecurse(innerBody, schema))
// pass the schema into a recursive call an extract from it when we non tuple/caseclass element
case (innerBody @ PropertyMatroshka(`ident`, path), index) =>
(index, protractSchemaRecurse(innerBody, schema))
// we have reached an ident i.e. recurse to pass the current schema into the tuple
case (`ident`, index) =>
(index, protractSchemaRecurse(ident, schema))
}.collect {
case (index, Some(subSchema)) => (index, subSchema)
}
)
.notEmpty
case prop @ PropertyMatroshka(`ident`, path) =>
trace"Protraction completed schema path $prop at the schema $schema pointing to:" `andReturn`
schema match {
// case e: EntitySchema => Some(e)
case _ => schema.lookup(path)
}
case `ident` =>
trace"Protraction completed with the mapping identity $ident at the schema:" `andReturn`
Some(schema)
case other =>
trace"Protraction DID NOT find a sub schema, it completed with $other at the schema:" `andReturn`
Some(schema)
}
protractSchemaRecurse(body, schema)
}
case object EmptySchema extends Schema
case class EntitySchema(e: Entity) extends Schema {
def noAliases = e.properties.isEmpty
private def subSchema(path: List[String]) =
EntitySchema(
Entity(
s"sub-${e.name}",
e.properties.flatMap {
case PropertyAlias(aliasPath, alias) =>
if (aliasPath == path)
List(PropertyAlias(aliasPath, alias))
else if (aliasPath.startsWith(path))
List(PropertyAlias(aliasPath.diff(path), alias))
else
List()
}
)
)
def subSchemaOrEmpty(path: List[String]): Schema =
trace"Creating sub-schema for entity $e at path $path will be" andReturn {
val sub = subSchema(path)
if (sub.noAliases) EmptySchema else sub
}
}
case class TupleSchema(m: collection.Map[Int, Schema] /* Zero Indexed */ )
extends Schema {
def list = m.toList.sortBy(_._1)
def notEmpty =
if (this.m.nonEmpty) Some(this) else None
}
case class CaseClassSchema(m: collection.Map[String, Schema]) extends Schema {
def list = m.toList
def notEmpty =
if (this.m.nonEmpty) Some(this) else None
}
object CaseClassSchema {
def apply(property: String, value: Schema): CaseClassSchema =
CaseClassSchema(collection.Map(property -> value))
def apply(list: List[(String, Schema)]): CaseClassSchema =
CaseClassSchema(list.toMap)
}
object TupleSchema {
def fromIndexes(schemas: List[(Int, Schema)]): TupleSchema =
TupleSchema(schemas.toMap)
def apply(schemas: List[Schema]): TupleSchema =
TupleSchema(schemas.zipWithIndex.map(_.swap).toMap)
def apply(index: Int, schema: Schema): TupleSchema =
TupleSchema(collection.Map(index -> schema))
def empty: TupleSchema = TupleSchema(List.empty)
}
object HierarchicalAstEntity {
def unapply(ast: Ast): Boolean =
ast match {
case cc: CaseClass => true
case tup: Tuple => true
case _ => false
}
}
private def applySchema(q: Query): (Query, Schema) = {
q match {
// Don't understand why this is needed....
case Map(q: Query, x, p) =>
applySchema(q) match {
case (q, subSchema) =>
val replace =
trace"Looking for possible replacements for $p inside $x using schema $subSchema:" `andReturn`
replacements(x, subSchema)
val pr = BetaReduction(p, replace*)
traceDifferent(p, pr)
val prr = apply(pr)
traceDifferent(pr, prr)
val schema =
trace"Protracting Hierarchical Entity $prr into sub-schema: $subSchema" `andReturn` {
protractSchema(prr, x, subSchema)
}.getOrElse(EmptySchema)
(Map(q, x, prr), schema)
}
case e: Entity => (e, EntitySchema(e))
case Filter(q: Query, x, p) => applySchema(q, x, p, Filter.apply)
case SortBy(q: Query, x, p, o) => applySchema(q, x, p, SortBy(_, _, _, o))
case GroupBy(q: Query, x, p) => applySchema(q, x, p, GroupBy.apply)
case Aggregation(op, q: Query) => applySchema(q, Aggregation(op, _))
case Take(q: Query, n) => applySchema(q, Take(_, n))
case Drop(q: Query, n) => applySchema(q, Drop(_, n))
case Nested(q: Query) => applySchema(q, Nested.apply)
case Distinct(q: Query) => applySchema(q, Distinct.apply)
case DistinctOn(q: Query, iA, on) => applySchema(q, DistinctOn(_, iA, on))
case FlatMap(q: Query, x, p) =>
applySchema(q, x, p, FlatMap.apply) match {
case (FlatMap(q, x, p: Query), oldSchema) =>
val (pr, newSchema) = applySchema(p)
(FlatMap(q, x, pr), newSchema)
case (flatMap, oldSchema) =>
(flatMap, TupleSchema.empty)
}
case ConcatMap(q: Query, x, p) =>
applySchema(q, x, p, ConcatMap.apply) match {
case (ConcatMap(q, x, p: Query), oldSchema) =>
val (pr, newSchema) = applySchema(p)
(ConcatMap(q, x, pr), newSchema)
case (concatMap, oldSchema) =>
(concatMap, TupleSchema.empty)
}
case Join(typ, a: Query, b: Query, iA, iB, on) =>
(applySchema(a), applySchema(b)) match {
case ((a, schemaA), (b, schemaB)) =>
val combinedReplacements =
trace"Finding Replacements for $on inside ${(iA, iB)} using schemas ${(schemaA, schemaB)}:" andReturn {
val replaceA = replacements(iA, schemaA)
val replaceB = replacements(iB, schemaB)
replaceA ++ replaceB
}
val onr = BetaReduction(on, combinedReplacements*)
traceDifferent(on, onr)
(Join(typ, a, b, iA, iB, onr), TupleSchema(List(schemaA, schemaB)))
}
case FlatJoin(typ, a: Query, iA, on) =>
applySchema(a) match {
case (a, schemaA) =>
val replaceA =
trace"Finding Replacements for $on inside $iA using schema $schemaA:" `andReturn`
replacements(iA, schemaA)
val onr = BetaReduction(on, replaceA*)
traceDifferent(on, onr)
(FlatJoin(typ, a, iA, onr), schemaA)
}
case Map(q: Operation, x, p) if x == p =>
(Map(apply(q), x, p), TupleSchema.empty)
case Map(Infix(parts, params, pure, paren), x, p) =>
val transformed =
params.map {
case q: Query =>
val (qr, schema) = applySchema(q)
traceDifferent(q, qr)
(qr, Some(schema))
case q =>
(q, None)
}
val schema =
transformed.collect {
case (_, Some(schema)) => schema
} match {
case e :: Nil => e
case ls => TupleSchema(ls)
}
val replace =
trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
replacements(x, schema)
val pr = BetaReduction(p, replace*)
traceDifferent(p, pr)
val prr = apply(pr)
traceDifferent(pr, prr)
(Map(Infix(parts, transformed.map(_._1), pure, paren), x, prr), schema)
case q =>
(q, TupleSchema.empty)
}
}
private def applySchema(ast: Query, f: Ast => Query): (Query, Schema) =
applySchema(ast) match {
case (ast, schema) =>
(f(ast), schema)
}
private def applySchema[T](
q: Query,
x: Ident,
p: Ast,
f: (Ast, Ident, Ast) => T
): (T, Schema) =
applySchema(q) match {
case (q, schema) =>
val replace =
trace"Finding Replacements for $p inside $x using schema $schema:" `andReturn`
replacements(x, schema)
val pr = BetaReduction(p, replace*)
traceDifferent(p, pr)
val prr = apply(pr)
traceDifferent(pr, prr)
(f(q, x, prr), schema)
}
private def replacements(base: Ast, schema: Schema): Seq[(Ast, Ast)] =
schema match {
// The entity renameable property should already have been marked as Fixed
case EntitySchema(Entity(entity, properties)) =>
// trace"%4 Entity Schema: " andReturn
properties.flatMap {
// A property alias means that there was either a querySchema(tableName, _.propertyName -> PropertyAlias)
// or a schemaMeta (which ultimately gets turned into a querySchema) which is the same thing but implicit.
// In this case, we want to rename the properties based on the property aliases as well as mark
// them Fixed since they should not be renamed based on
// the naming strategy wherever they are tokenized (e.g. in SqlIdiom)
case PropertyAlias(path, alias) =>
def apply(base: Ast, path: List[String]): Ast =
path match {
case Nil => base
case head :: tail => apply(Property(base, head), tail)
}
List(
apply(base, path) -> Property.Opinionated(
base,
alias,
Fixed,
Visible
) // Hidden properties cannot be renamed
)
}
case tup: TupleSchema =>
// trace"%4 Tuple Schema: " andReturn
tup.list.flatMap {
case (idx, value) =>
replacements(
// Should not matter whether property is fixed or variable here
// since beta reduction ignores that
Property(base, s"_${idx + 1}"),
value
)
}
case cc: CaseClassSchema =>
// trace"%4 CaseClass Schema: " andReturn
cc.list.flatMap {
case (property, value) =>
replacements(
// Should not matter whether property is fixed or variable here
// since beta reduction ignores that
Property(base, property),
value
)
}
// Do nothing if it is an empty schema
case EmptySchema => List()
}
}

View file

@ -1,4 +1,4 @@
package minisql.util
package minisql.norm
import minisql.ast.Ast
import scala.collection.immutable.Map

View file

@ -0,0 +1,124 @@
package minisql.norm
import minisql.ast.*
import minisql.norm.EqualityBehavior.AnsiEquality
/**
* Due to the introduction of null checks in `map`, `flatMap`, and `exists`, in
* `FlattenOptionOperation` in order to resolve #1053, as well as to support
* non-ansi compliant string concatenation as outlined in #1295, large
* conditional composites became common. For example: <code><pre> case class
* Holder(value:Option[String])
*
* // The following statement query[Holder].map(h => h.value.map(_ + "foo")) //
* Will yield the following result SELECT CASE WHEN h.value IS NOT NULL THEN
* h.value || 'foo' ELSE null END FROM Holder h </pre></code> Now, let's add a
* <code>getOrElse</code> statement to the clause that requires an additional
* wrapped null check. We cannot rely on there being a <code>map</code> call
* beforehand since we could be reading <code>value</code> as a nullable field
* directly from the database). <code><pre> // The following statement
* query[Holder].map(h => h.value.map(_ + "foo").getOrElse("bar")) // Yields the
* following result: SELECT CASE WHEN CASE WHEN h.value IS NOT NULL THEN h.value
* \|| 'foo' ELSE null END IS NOT NULL THEN CASE WHEN h.value IS NOT NULL THEN
* h.value || 'foo' ELSE null END ELSE 'bar' END FROM Holder h </pre></code>
* This of course is highly redundant and can be reduced to simply: <code><pre>
* SELECT CASE WHEN h.value IS NOT NULL AND (h.value || 'foo') IS NOT NULL THEN
* h.value || 'foo' ELSE 'bar' END FROM Holder h </pre></code> This reduction is
* done by the "Center Rule." There are some other simplification rules as well.
* Note how we are force to null-check both `h.value` as well as `(h.value ||
* 'foo')` because a user may use `Option[T].flatMap` and explicitly transform a
* particular value to `null`.
*/
class SimplifyNullChecks(equalityBehavior: EqualityBehavior)
extends StatelessTransformer {
override def apply(ast: Ast): Ast = {
import minisql.ast.Implicits.*
ast match {
// Center rule
case IfExist(
IfExistElseNull(condA, thenA),
IfExistElseNull(condB, thenB),
otherwise
) if (condA == condB && thenA == thenB) =>
apply(
If(IsNotNullCheck(condA) +&&+ IsNotNullCheck(thenA), thenA, otherwise)
)
// Left hand rule
case IfExist(IfExistElseNull(check, affirm), value, otherwise) =>
apply(
If(
IsNotNullCheck(check) +&&+ IsNotNullCheck(affirm),
value,
otherwise
)
)
// Right hand rule
case IfExistElseNull(cond, IfExistElseNull(innerCond, innerThen)) =>
apply(
If(
IsNotNullCheck(cond) +&&+ IsNotNullCheck(innerCond),
innerThen,
NullValue
)
)
case OptionIsDefined(Optional(a)) +&&+ OptionIsDefined(
Optional(b)
) +&&+ (exp @ (Optional(a1) `== or !=` Optional(b1)))
if (a == a1 && b == b1 && equalityBehavior == AnsiEquality) =>
apply(exp)
case OptionIsDefined(Optional(a)) +&&+ (exp @ (Optional(
a1
) `== or !=` Optional(_)))
if (a == a1 && equalityBehavior == AnsiEquality) =>
apply(exp)
case OptionIsDefined(Optional(b)) +&&+ (exp @ (Optional(
_
) `== or !=` Optional(b1)))
if (b == b1 && equalityBehavior == AnsiEquality) =>
apply(exp)
case (left +&&+ OptionIsEmpty(Optional(Constant(_)))) +||+ other =>
apply(other)
case (OptionIsEmpty(Optional(Constant(_))) +&&+ right) +||+ other =>
apply(other)
case other +||+ (left +&&+ OptionIsEmpty(Optional(Constant(_)))) =>
apply(other)
case other +||+ (OptionIsEmpty(Optional(Constant(_))) +&&+ right) =>
apply(other)
case (left +&&+ OptionIsDefined(Optional(Constant(_)))) => apply(left)
case (OptionIsDefined(Optional(Constant(_))) +&&+ right) => apply(right)
case (left +||+ OptionIsEmpty(Optional(Constant(_)))) => apply(left)
case (OptionIsEmpty(OptionSome(Optional(_))) +||+ right) => apply(right)
case other =>
super.apply(other)
}
}
object `== or !=` {
def unapply(ast: Ast): Option[(Ast, Ast)] = ast match {
case a +==+ b => Some((a, b))
case a +!=+ b => Some((a, b))
case _ => None
}
}
/**
* Simple extractor that looks inside of an optional values to see if the
* thing inside can be pulled out. If not, it just returns whatever element it
* can find.
*/
object Optional {
def unapply(a: Ast): Option[Ast] = a match {
case OptionApply(value) => Some(value)
case OptionSome(value) => Some(value)
case value => Some(value)
}
}
}

View file

@ -0,0 +1,38 @@
package minisql.norm
import minisql.ast.Filter
import minisql.ast.FlatMap
import minisql.ast.Query
import minisql.ast.Union
import minisql.ast.UnionAll
object SymbolicReduction {
def unapply(q: Query) =
q match {
// a.filter(b => c).flatMap(d => e.$) =>
// a.flatMap(d => e.filter(_ => c[b := d]).$)
case FlatMap(Filter(a, b, c), d, e: Query) =>
val cr = BetaReduction(c, b -> d)
val er = AttachToEntity(Filter(_, _, cr))(e)
Some(FlatMap(a, d, er))
// a.flatMap(b => c).flatMap(d => e) =>
// a.flatMap(b => c.flatMap(d => e))
case FlatMap(FlatMap(a, b, c), d, e) =>
Some(FlatMap(a, b, FlatMap(c, d, e)))
// a.union(b).flatMap(c => d)
// a.flatMap(c => d).union(b.flatMap(c => d))
case FlatMap(Union(a, b), c, d) =>
Some(Union(FlatMap(a, c, d), FlatMap(b, c, d)))
// a.unionAll(b).flatMap(c => d)
// a.flatMap(c => d).unionAll(b.flatMap(c => d))
case FlatMap(UnionAll(a, b), c, d) =>
Some(UnionAll(FlatMap(a, c, d), FlatMap(b, c, d)))
case other => None
}
}

View file

@ -0,0 +1,174 @@
package minisql.norm.capture
import minisql.ast.{
Entity,
Filter,
FlatJoin,
FlatMap,
GroupBy,
Ident,
Join,
Map,
Query,
SortBy,
StatefulTransformer,
_
}
import minisql.norm.{BetaReduction, Normalize}
import scala.collection.immutable.Set
private[minisql] case class AvoidAliasConflict(state: Set[Ident])
extends StatefulTransformer[Set[Ident]] {
object Unaliased {
private def isUnaliased(q: Ast): Boolean =
q match {
case Nested(q: Query) => isUnaliased(q)
case Take(q: Query, _) => isUnaliased(q)
case Drop(q: Query, _) => isUnaliased(q)
case Aggregation(_, q: Query) => isUnaliased(q)
case Distinct(q: Query) => isUnaliased(q)
case _: Entity | _: Infix => true
case _ => false
}
def unapply(q: Ast): Option[Ast] =
q match {
case q if (isUnaliased(q)) => Some(q)
case _ => None
}
}
override def apply(q: Query): (Query, StatefulTransformer[Set[Ident]]) =
q match {
case FlatMap(Unaliased(q), x, p) =>
apply(x, p)(FlatMap(q, _, _))
case ConcatMap(Unaliased(q), x, p) =>
apply(x, p)(ConcatMap(q, _, _))
case Map(Unaliased(q), x, p) =>
apply(x, p)(Map(q, _, _))
case Filter(Unaliased(q), x, p) =>
apply(x, p)(Filter(q, _, _))
case SortBy(Unaliased(q), x, p, o) =>
apply(x, p)(SortBy(q, _, _, o))
case GroupBy(Unaliased(q), x, p) =>
apply(x, p)(GroupBy(q, _, _))
case DistinctOn(Unaliased(q), x, p) =>
apply(x, p)(DistinctOn(q, _, _))
case Join(t, a, b, iA, iB, o) =>
val (ar, art) = apply(a)
val (br, brt) = art.apply(b)
val freshA = freshIdent(iA, brt.state)
val freshB = freshIdent(iB, brt.state + freshA)
val or = BetaReduction(o, iA -> freshA, iB -> freshB)
val (orr, orrt) = AvoidAliasConflict(brt.state + freshA + freshB)(or)
(Join(t, ar, br, freshA, freshB, orr), orrt)
case FlatJoin(t, a, iA, o) =>
val (ar, art) = apply(a)
val freshA = freshIdent(iA)
val or = BetaReduction(o, iA -> freshA)
val (orr, orrt) = AvoidAliasConflict(art.state + freshA)(or)
(FlatJoin(t, ar, freshA, orr), orrt)
case _: Entity | _: FlatMap | _: ConcatMap | _: Map | _: Filter |
_: SortBy | _: GroupBy | _: Aggregation | _: Take | _: Drop |
_: Union | _: UnionAll | _: Distinct | _: DistinctOn | _: Nested =>
super.apply(q)
}
private def apply(x: Ident, p: Ast)(
f: (Ident, Ast) => Query
): (Query, StatefulTransformer[Set[Ident]]) = {
val fresh = freshIdent(x)
val pr = BetaReduction(p, x -> fresh)
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
(f(fresh, prr), t)
}
private def freshIdent(x: Ident, state: Set[Ident] = state): Ident = {
def loop(x: Ident, n: Int): Ident = {
val fresh = Ident(s"${x.name}$n")
if (!state.contains(fresh))
fresh
else
loop(x, n + 1)
}
if (!state.contains(x))
x
else
loop(x, 1)
}
/**
* Sometimes we need to change the variables in a function because they will
* might conflict with some variable further up in the macro. Right now, this
* only happens when you do something like this: <code> val q = quote { (v:
* Foo) => query[Foo].insert(v) } run(q(lift(v))) </code> Since 'v' is used by
* actionMeta in order to map keys to values for insertion, using it as a
* function argument messes up the output SQL like so: <code> INSERT INTO
* MyTestEntity (s,i,l,o) VALUES (s,i,l,o) instead of (?,?,?,?) </code>
* Therefore, we need to have a method to remove such conflicting variables
* from Function ASTs
*/
private def applyFunction(f: Function): Function = {
val (newBody, _, newParams) =
f.params.foldLeft((f.body, state, List[Ident]())) {
case ((body, state, newParams), param) => {
val fresh = freshIdent(param)
val pr = BetaReduction(body, param -> fresh)
val (prr, t) = AvoidAliasConflict(state + fresh)(pr)
(prr, t.state, newParams :+ fresh)
}
}
Function(newParams, newBody)
}
private def applyForeach(f: Foreach): Foreach = {
val fresh = freshIdent(f.alias)
val pr = BetaReduction(f.body, f.alias -> fresh)
val (prr, _) = AvoidAliasConflict(state + fresh)(pr)
Foreach(f.query, fresh, prr)
}
}
private[minisql] object AvoidAliasConflict {
def apply(q: Query): Query =
AvoidAliasConflict(Set[Ident]())(q) match {
case (q, _) => q
}
/**
* Make sure query parameters do not collide with paramters of a AST function.
* Do this by walkning through the function's subtree and transforming and
* queries encountered.
*/
def sanitizeVariables(
f: Function,
dangerousVariables: Set[Ident]
): Function = {
AvoidAliasConflict(dangerousVariables).applyFunction(f)
}
/** Same is `sanitizeVariables` but for Foreach * */
def sanitizeVariables(f: Foreach, dangerousVariables: Set[Ident]): Foreach = {
AvoidAliasConflict(dangerousVariables).applyForeach(f)
}
def sanitizeQuery(q: Query, dangerousVariables: Set[Ident]): Query = {
AvoidAliasConflict(dangerousVariables).apply(q) match {
// Propagate aliasing changes to the rest of the query
case (q, _) => Normalize(q)
}
}
}

View file

@ -0,0 +1,9 @@
package minisql.norm.capture
import minisql.ast.Query
object AvoidCapture {
def apply(q: Query): Query =
Dealias(AvoidAliasConflict(q))
}

View file

@ -0,0 +1,72 @@
package minisql.norm.capture
import minisql.ast._
import minisql.norm.BetaReduction
case class Dealias(state: Option[Ident])
extends StatefulTransformer[Option[Ident]] {
override def apply(q: Query): (Query, StatefulTransformer[Option[Ident]]) =
q match {
case FlatMap(a, b, c) =>
dealias(a, b, c)(FlatMap.apply) match {
case (FlatMap(a, b, c), _) =>
val (cn, cnt) = apply(c)
(FlatMap(a, b, cn), cnt)
}
case ConcatMap(a, b, c) =>
dealias(a, b, c)(ConcatMap.apply) match {
case (ConcatMap(a, b, c), _) =>
val (cn, cnt) = apply(c)
(ConcatMap(a, b, cn), cnt)
}
case Map(a, b, c) =>
dealias(a, b, c)(Map.apply)
case Filter(a, b, c) =>
dealias(a, b, c)(Filter.apply)
case SortBy(a, b, c, d) =>
dealias(a, b, c)(SortBy(_, _, _, d))
case GroupBy(a, b, c) =>
dealias(a, b, c)(GroupBy.apply)
case DistinctOn(a, b, c) =>
dealias(a, b, c)(DistinctOn.apply)
case Take(a, b) =>
val (an, ant) = apply(a)
(Take(an, b), ant)
case Drop(a, b) =>
val (an, ant) = apply(a)
(Drop(an, b), ant)
case Union(a, b) =>
val (an, _) = apply(a)
val (bn, _) = apply(b)
(Union(an, bn), Dealias(None))
case UnionAll(a, b) =>
val (an, _) = apply(a)
val (bn, _) = apply(b)
(UnionAll(an, bn), Dealias(None))
case Join(t, a, b, iA, iB, o) =>
val ((an, iAn, on), _) = dealias(a, iA, o)((_, _, _))
val ((bn, iBn, onn), _) = dealias(b, iB, on)((_, _, _))
(Join(t, an, bn, iAn, iBn, onn), Dealias(None))
case FlatJoin(t, a, iA, o) =>
val ((an, iAn, on), ont) = dealias(a, iA, o)((_, _, _))
(FlatJoin(t, an, iAn, on), Dealias(Some(iA)))
case _: Entity | _: Distinct | _: Aggregation | _: Nested =>
(q, Dealias(None))
}
private def dealias[T](a: Ast, b: Ident, c: Ast)(f: (Ast, Ident, Ast) => T) =
apply(a) match {
case (an, t @ Dealias(Some(alias))) =>
(f(an, alias, BetaReduction(c, b -> alias)), t)
case other =>
(f(a, b, c), Dealias(Some(b)))
}
}
object Dealias {
def apply(query: Query) =
new Dealias(None)(query) match {
case (q, _) => q
}
}

View file

@ -0,0 +1,98 @@
package minisql.norm.capture
import minisql.ast.*
/**
* Walk through any Queries that a returning clause has and replace Ident of the
* returning variable with ExternalIdent so that in later steps involving filter
* simplification, it will not be mistakenly dealiased with a potential shadow.
* Take this query for instance: <pre> query[TestEntity]
* .insert(lift(TestEntity("s", 0, 1L, None))) .returningGenerated( r =>
* (query[Dummy].filter(r => r.i == r.i).filter(d => d.i == r.i).max) ) </pre>
* The returning clause has an alias `Ident("r")` as well as the first filter
* clause. These two filters will be combined into one at which point the
* meaning of `r.i` in the 2nd filter will be confused for the first filter's
* alias (i.e. the `r` in `filter(r => ...)`. Therefore, we need to change this
* vunerable `r.i` in the second filter clause to an `ExternalIdent` before any
* of the simplifications are done.
*
* Note that we only want to do this for Queries inside of a `Returning` clause
* body. Other places where this needs to be done (e.g. in a Tuple that
* `Returning` returns) are done in `ExpandReturning`.
*/
private[minisql] case class DemarcateExternalAliases(externalIdent: Ident)
extends StatelessTransformer {
def applyNonOverride(idents: Ident*)(ast: Ast) =
if (idents.forall(_ != externalIdent)) apply(ast)
else ast
override def apply(ast: Ast): Ast = ast match {
case FlatMap(q, i, b) =>
FlatMap(apply(q), i, applyNonOverride(i)(b))
case ConcatMap(q, i, b) =>
ConcatMap(apply(q), i, applyNonOverride(i)(b))
case Map(q, i, b) =>
Map(apply(q), i, applyNonOverride(i)(b))
case Filter(q, i, b) =>
Filter(apply(q), i, applyNonOverride(i)(b))
case SortBy(q, i, p, o) =>
SortBy(apply(q), i, applyNonOverride(i)(p), o)
case GroupBy(q, i, b) =>
GroupBy(apply(q), i, applyNonOverride(i)(b))
case DistinctOn(q, i, b) =>
DistinctOn(apply(q), i, applyNonOverride(i)(b))
case Join(t, a, b, iA, iB, o) =>
Join(t, a, b, iA, iB, applyNonOverride(iA, iB)(o))
case FlatJoin(t, a, iA, o) =>
FlatJoin(t, a, iA, applyNonOverride(iA)(o))
case p @ Property.Opinionated(
id @ Ident(_),
value,
renameable,
visibility
) =>
if (id == externalIdent)
Property.Opinionated(
ExternalIdent(externalIdent.name),
value,
renameable,
visibility
)
else
p
case other =>
super.apply(other)
}
}
object DemarcateExternalAliases {
private def demarcateQueriesInBody(id: Ident, body: Ast) =
Transform(body) {
// Apply to the AST defined apply method about, not to the superclass method that takes Query
case q: Query =>
new DemarcateExternalAliases(id).apply(q.asInstanceOf[Ast])
}
def apply(ast: Ast): Ast = ast match {
case Returning(a, id, body) =>
Returning(a, id, demarcateQueriesInBody(id, body))
case ReturningGenerated(a, id, body) =>
val d = demarcateQueriesInBody(id, body)
ReturningGenerated(a, id, demarcateQueriesInBody(id, body))
case other =>
other
}
}

View file

@ -0,0 +1,47 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
type SParser[X] =
(q: Quotes) ?=> PartialFunction[q.reflect.Statement, Expr[X]]
private[parsing] def statementParsing(astParser: => Parser[ast.Ast])(using
Quotes
): SParser[ast.Ast] = {
import quotes.reflect.*
@annotation.nowarn
lazy val valDefParser: SParser[ast.Val] = {
case ValDef(n, _, Some(b)) =>
val body = astParser(b.asExpr)
'{ ast.Val(ast.Ident(${ Expr(n) }), $body) }
}
valDefParser
}
private[parsing] def blockParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Ast] = {
import quotes.reflect.*
lazy val statementParser = statementParsing(astParser)
termParser {
case Block(Nil, t) => astParser(t.asExpr)
case b @ Block(st, t) =>
val asts = (st :+ t).map {
case e if e.isExpr => astParser(e.asExpr)
case `statementParser`(x) => x
case o =>
report.errorAndAbort(s"Cannot parse statement: ${o.show}")
}
if (asts.size > 1) {
'{ ast.Block(${ Expr.ofList(asts) }) }
} else asts(0)
}
}

View file

@ -0,0 +1,31 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private[parsing] def boxingParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Ast] = {
case '{ BigDecimal.int2bigDecimal($v) } => astParser(v)
case '{ BigDecimal.long2bigDecimal($v) } => astParser(v)
case '{ BigDecimal.double2bigDecimal($v) } => astParser(v)
case '{ BigDecimal.javaBigDecimal2bigDecimal($v) } => astParser(v)
case '{ Predef.byte2Byte($v) } => astParser(v)
case '{ Predef.short2Short($v) } => astParser(v)
case '{ Predef.char2Character($v) } => astParser(v)
case '{ Predef.int2Integer($v) } => astParser(v)
case '{ Predef.long2Long($v) } => astParser(v)
case '{ Predef.float2Float($v) } => astParser(v)
case '{ Predef.double2Double($v) } => astParser(v)
case '{ Predef.boolean2Boolean($v) } => astParser(v)
case '{ Predef.augmentString($v) } => astParser(v)
case '{ Predef.Byte2byte($v) } => astParser(v)
case '{ Predef.Short2short($v) } => astParser(v)
case '{ Predef.Character2char($v) } => astParser(v)
case '{ Predef.Integer2int($v) } => astParser(v)
case '{ Predef.Long2long($v) } => astParser(v)
case '{ Predef.Float2float($v) } => astParser(v)
case '{ Predef.Double2double($v) } => astParser(v)
case '{ Predef.Boolean2boolean($v) } => astParser(v)
}

View file

@ -0,0 +1,13 @@
package minisql.parsing
import minisql.ast
import minisql.dsl.*
import scala.quoted.*
private[parsing] def infixParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Infix] = {
import quotes.reflect.*
???
}

View file

@ -0,0 +1,16 @@
package minisql.parsing
import scala.quoted.*
import minisql.ParamEncoder
import minisql.ast
import minisql.*
private[parsing] def liftParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Lift] = {
case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } =>
import quotes.reflect.*
val name = x.asTerm.symbol.fullName
val liftId = x.asTerm.symbol.owner.fullName + "@" + name
'{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) }
}

View file

@ -0,0 +1,113 @@
package minisql.parsing
import minisql.ast
import minisql.ast.{
EqualityOperator,
StringOperator,
NumericOperator,
BooleanOperator
}
import minisql.*
import scala.quoted._
private[parsing] def operationParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Operation] = {
import quotes.reflect.*
def isNumeric(t: TypeRepr) = {
t <:< TypeRepr.of[Int]
|| t <:< TypeRepr.of[Long]
|| t <:< TypeRepr.of[Byte]
|| t <:< TypeRepr.of[Float]
|| t <:< TypeRepr.of[Double]
|| t <:< TypeRepr.of[java.math.BigDecimal]
|| t <:< TypeRepr.of[scala.math.BigDecimal]
}
def parseBinary(
left: Expr[Any],
right: Expr[Any],
op: Expr[ast.BinaryOperator]
) = {
val leftE = astParser(left)
val rightE = astParser(right)
'{ ast.BinaryOperation(${ leftE }, ${ op }, ${ rightE }) }
}
def parseUnary(expr: Expr[Any], op: Expr[ast.UnaryOperator]) = {
val base = astParser(expr)
'{ ast.UnaryOperation($op, $base) }
}
val universalOpParser: Parser[ast.BinaryOperation] = termParser {
case Apply(Select(leftT, UniversalOp(op)), List(rightT)) =>
parseBinary(leftT.asExpr, rightT.asExpr, op)
}
val stringOpParser: Parser[ast.Operation] = {
case '{ ($x: String) + ($y: String) } =>
parseBinary(x, y, '{ StringOperator.concat })
case '{ ($x: String).startsWith($y) } =>
parseBinary(x, y, '{ StringOperator.startsWith })
case '{ ($x: String).split($y) } =>
parseBinary(x, y, '{ StringOperator.split })
case '{ ($x: String).toUpperCase } =>
parseUnary(x, '{ StringOperator.toUpperCase })
case '{ ($x: String).toLowerCase } =>
parseUnary(x, '{ StringOperator.toLowerCase })
case '{ ($x: String).toLong } =>
parseUnary(x, '{ StringOperator.toLong })
case '{ ($x: String).toInt } =>
parseUnary(x, '{ StringOperator.toInt })
}
val numericOpParser = termParser {
case (Apply(Select(lt, NumericOp(op)), List(rt))) if isNumeric(lt.tpe) =>
parseBinary(lt.asExpr, rt.asExpr, op)
case Select(leftTerm, "unary_-") if isNumeric(leftTerm.tpe) =>
val leftExpr = astParser(leftTerm.asExpr)
'{ ast.UnaryOperation(NumericOperator.-, ${ leftExpr }) }
}
val booleanOpParser: Parser[ast.Operation] = {
case '{ ($x: Boolean) && $y } =>
parseBinary(x, y, '{ BooleanOperator.&& })
case '{ ($x: Boolean) || $y } =>
parseBinary(x, y, '{ BooleanOperator.|| })
case '{ !($x: Boolean) } =>
parseUnary(x, '{ BooleanOperator.! })
}
universalOpParser
.orElse(stringOpParser)
.orElse(numericOpParser)
.orElse(booleanOpParser)
}
private object UniversalOp {
def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] =
op match {
case "==" | "equals" => Some('{ EqualityOperator.== })
case "!=" => Some('{ EqualityOperator.!= })
case _ => None
}
}
private object NumericOp {
def unapply(op: String)(using Quotes): Option[Expr[ast.BinaryOperator]] =
op match {
case "+" => Some('{ NumericOperator.+ })
case "-" => Some('{ NumericOperator.- })
case "*" => Some('{ NumericOperator.* })
case "/" => Some('{ NumericOperator./ })
case ">" => Some('{ NumericOperator.> })
case ">=" => Some('{ NumericOperator.>= })
case "<" => Some('{ NumericOperator.< })
case "<=" => Some('{ NumericOperator.<= })
case "%" => Some('{ NumericOperator.% })
case _ => None
}
}

View file

@ -0,0 +1,47 @@
package minisql.parsing
import minisql.ast
import minisql.ast.Ast
import scala.quoted.*
private[minisql] inline def parseParamAt[F](
inline f: F,
inline n: Int
): ast.Ident = ${
parseParamAt('f, 'n)
}
private[minisql] inline def parseBody[X](
inline f: X
): ast.Ast = ${
parseBody('f)
}
private[minisql] def parseParamAt(f: Expr[?], n: Expr[Int])(using
Quotes
): Expr[ast.Ident] = {
import quotes.reflect.*
val pIdx = n.value.getOrElse(
report.errorAndAbort(s"Param index ${n.show} is not know")
)
extractTerm(f.asTerm) match {
case Lambda(vals, _) =>
vals(pIdx) match {
case ValDef(n, _, _) => '{ ast.Ident(${ Expr(n) }) }
}
}
}
private[minisql] def parseBody[X](
x: Expr[X]
)(using Quotes): Expr[Ast] = {
import quotes.reflect.*
extractTerm(x.asTerm) match {
case Lambda(vals, body) =>
Parsing.parseExpr(body.asExpr)
case o =>
report.errorAndAbort(s"Can only parse function")
}
}

View file

@ -0,0 +1,139 @@
package minisql.parsing
import minisql.ast
import minisql.context.{ReturningMultipleFieldSupported, _}
import minisql.norm.BetaReduction
import minisql.norm.capture.AvoidAliasConflict
import minisql.idiom.Idiom
import scala.annotation.tailrec
import minisql.ast.Implicits._
import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.{Hidden, Visible}
import minisql.util.Interleave
import scala.quoted.*
type Parser[A] = PartialFunction[Expr[Any], Expr[A]]
private def termParser[A](using q: Quotes)(
pf: PartialFunction[q.reflect.Term, Expr[A]]
): Parser[A] = {
import quotes.reflect._
{
case e if pf.isDefinedAt(e.asTerm) => pf(e.asTerm)
}
}
private def parser[A](
f: PartialFunction[Expr[Any], Expr[A]]
)(using Quotes): Parser[A] = {
case e if f.isDefinedAt(e) => f(e)
}
private[minisql] def extractTerm(using Quotes)(x: quotes.reflect.Term) = {
import quotes.reflect.*
def unwrapTerm(t: Term): Term = t match {
case Inlined(_, _, o) => unwrapTerm(o)
case Block(Nil, last) => last
case Typed(t, _) =>
unwrapTerm(t)
case Select(t, "$asInstanceOf$") =>
unwrapTerm(t)
case TypeApply(t, _) =>
unwrapTerm(t)
case o => o
}
unwrapTerm(x)
}
private[minisql] object Parsing {
def parseExpr(
expr: Expr[?]
)(using q: Quotes): Expr[ast.Ast] = {
import q.reflect._
def unwrapped(
f: Parser[ast.Ast]
): Parser[ast.Ast] = {
case expr =>
val t = expr.asTerm
f(extractTerm(t).asExpr)
}
lazy val astParser: Parser[ast.Ast] =
unwrapped {
typedParser
.orElse(propertyParser)
.orElse(liftParser)
.orElse(identParser)
.orElse(valueParser)
.orElse(operationParser)
.orElse(constantParser)
.orElse(blockParser)
.orElse(boxingParser)
.orElse(ifParser)
.orElse(traversableOperationParser)
.orElse(patMatchParser)
// .orElse(infixParser)
.orElse {
case o =>
val str = scala.util.Try(o.show).getOrElse("")
report.errorAndAbort(
s"cannot parse ${str}",
o.asTerm.pos
)
}
}
lazy val typedParser: Parser[ast.Ast] = termParser {
case (Typed(t, _)) =>
astParser(t.asExpr)
}
lazy val blockParser: Parser[ast.Ast] = blockParsing(astParser)
lazy val valueParser: Parser[ast.Value] = valueParsing(astParser)
lazy val liftParser: Parser[ast.Lift] = liftParsing(astParser)
lazy val constantParser: Parser[ast.Constant] = termParser {
case Literal(x) =>
'{ ast.Constant(${ Literal(x).asExpr }) }
}
lazy val identParser: Parser[ast.Ident] = termParser {
case x @ Ident(n) if x.symbol.isValDef =>
'{ ast.Ident(${ Expr(n) }) }
}
lazy val propertyParser: Parser[ast.Property] = propertyParsing(astParser)
lazy val operationParser: Parser[ast.Operation] = operationParsing(
astParser
)
lazy val boxingParser: Parser[ast.Ast] = boxingParsing(astParser)
lazy val ifParser: Parser[ast.If] = {
case '{ if ($a) $b else $c } =>
'{ ast.If(${ astParser(a) }, ${ astParser(b) }, ${ astParser(c) }) }
}
lazy val patMatchParser: Parser[ast.Ast] = patMatchParsing(astParser)
// lazy val infixParser: Parser[ast.Infix] = infixParsing(astParser)
lazy val traversableOperationParser: Parser[ast.IterableOperation] =
traversableOperationParsing(astParser)
astParser(expr)
}
private[minisql] inline def parse[A](
inline a: A
): ast.Ast = ${
parseExpr('a)
}
}

View file

@ -0,0 +1,49 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private[parsing] def patMatchParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Ast] = {
import quotes.reflect.*
termParser {
// Val defs that showd pattern variables will cause error
case e @ Match(t, List(CaseDef(IsTupleUnapply(binds), None, body))) =>
val bm = binds.zipWithIndex.map {
case (Bind(n, ident), idx) =>
n -> Select.unique(t, s"_${idx + 1}")
}.toMap
val tm = new TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term = {
tree match {
case Ident(n) => bm(n)
case o => super.transformTerm(o)(owner)
}
}
}
val newBody = tm.transformTree(body)(e.symbol)
astParser(newBody.asExpr)
}
}
object IsTupleUnapply {
def unapply(using
Quotes
)(t: quotes.reflect.Tree): Option[List[quotes.reflect.Tree]] = {
import quotes.reflect.*
def isTupleNUnapply(x: Term) = {
val fn = x.symbol.fullName
fn.startsWith("scala.Tuple") && fn.endsWith("$.unapply")
}
t match {
case Unapply(m, _, binds) if isTupleNUnapply(m) =>
Some(binds)
case _ => None
}
}
}

View file

@ -0,0 +1,30 @@
package minisql.parsing
import minisql.ast
import minisql.dsl.*
import scala.quoted._
private[parsing] def propertyParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.Property] = {
import quotes.reflect.*
def isAccessor(s: Select) = {
s.qualifier.tpe.typeSymbol.caseFields.exists(cf => cf.name == s.name)
}
val parseApply: Parser[ast.Property] = termParser {
case m @ Select(base, n) if isAccessor(m) =>
val obj = astParser(base.asExpr)
'{ ast.Property($obj, ${ Expr(n) }) }
}
val parseOptionGet: Parser[ast.Property] = {
case '{ ($e: Option[t]).get } =>
report.errorAndAbort(
"Option.get is not supported since it's an unsafe operation. Use `forall` or `exists` instead."
)
}
parseApply.orElse(parseOptionGet)
}

View file

@ -0,0 +1,16 @@
package minisql.parsing
import minisql.ast
import scala.quoted.*
private def traversableOperationParsing(
astParser: => Parser[ast.Ast]
)(using Quotes): Parser[ast.IterableOperation] = {
case '{ type k; type v; (${ m }: Map[`k`, `v`]).contains($key) } =>
'{ ast.MapContains(${ astParser(m) }, ${ astParser(key) }) }
case '{ ($s: Set[e]).contains($i) } =>
'{ ast.SetContains(${ astParser(s) }, ${ astParser(i) }) }
case '{ ($s: Seq[e]).contains($i) } =>
'{ ast.ListContains(${ astParser(s) }, ${ astParser(i) }) }
}

View file

@ -0,0 +1,71 @@
package minisql
package parsing
import scala.quoted._
private[parsing] def valueParsing(astParser: => Parser[ast.Ast])(using
Quotes
): Parser[ast.Value] = {
import quotes.reflect.*
val parseTupleApply: Parser[ast.Tuple] = {
case IsTupleApply(args) =>
val t = args.map(astParser)
'{ ast.Tuple(${ Expr.ofList(t) }) }
}
val parseAssocTuple: Parser[ast.Tuple] = {
case '{ ($x: tx) -> ($y: ty) } =>
'{ ast.Tuple(List(${ astParser(x) }, ${ astParser(y) })) }
}
parseTupleApply.orElse(parseAssocTuple)
}
private[minisql] object IsTupleApply {
def unapply(e: Expr[Any])(using Quotes): Option[Seq[Expr[Any]]] = {
import quotes.reflect.*
def isTupleNApply(t: Term) = {
val fn = t.symbol.fullName
fn.startsWith("scala.Tuple") && fn.endsWith("$.apply")
}
def isTupleXXLApply(t: Term) = {
t.symbol.fullName == "scala.runtime.TupleXXL$.apply"
}
extractTerm(e.asTerm) match {
// TupleN(0-22).apply
case Apply(b, args) if isTupleNApply(b) =>
Some(args.map(_.asExpr))
case TypeApply(Select(t, "$asInstanceOf$"), tt) if isTupleXXLApply(t) =>
t.asExpr match {
case '{ scala.runtime.TupleXXL.apply(${ Varargs(args) }*) } =>
Some(args)
}
case o =>
None
}
}
}
private[parsing] object IsTuple2 {
def unapply(using
Quotes
)(t: Expr[Any]): Option[(Expr[Any], Expr[Any])] =
t match {
case '{ scala.Tuple2.apply($x1, $x2) } => Some((x1, x2))
case '{ ($x1: t1) -> ($x2: t2) } => Some((x1, x2))
case _ => None
}
def unapply(using
Quotes
)(t: quotes.reflect.Term): Option[(Expr[Any], Expr[Any])] =
unapply(t.asExpr)
}

View file

@ -1,6 +1,24 @@
package minisql.util
import scala.util.Try
import scala.util.*
extension [A](xs: Iterable[A]) {
private[minisql] def traverse[B](f: A => Try[B]): Try[IArray[B]] = {
val out = IArray.newBuilder[Any]
var left: Option[Throwable] = None
xs.foreach { (v) =>
if (!left.isDefined) {
f(v) match {
case Failure(e) =>
left = Some(e)
case Success(r) =>
out += r
}
}
}
left.toLeft(out.result().asInstanceOf).toTry
}
}
object CollectTry {
def apply[T](list: List[Try[T]]): Try[List[T]] =

View file

@ -11,20 +11,25 @@ import scala.util.matching.Regex
class Interpolator(
defaultIndent: Int = 0,
qprint: AstPrinter = AstPrinter(),
out: PrintStream = System.out,
out: PrintStream = System.out
) {
extension (sc: StringContext) {
def trace(elements: Any*) = new Traceable(sc, elements)
}
class Traceable(sc: StringContext, elementsSeq: Seq[Any]) {
private val elementPrefix = "| "
private sealed trait PrintElement
private case class Str(str: String, first: Boolean) extends PrintElement
private case class Elem(value: String) extends PrintElement
private case object Separator extends PrintElement
private case class Elem(value: String) extends PrintElement
private case object Separator extends PrintElement
private def generateStringForCommand(value: Any, indent: Int) = {
val objectString = qprint(value)
val oneLine = objectString.fitsOnOneLine
val oneLine = objectString.fitsOnOneLine
oneLine match {
case true => s"${indent.prefix}> ${objectString}"
case false =>
@ -42,7 +47,7 @@ class Interpolator(
private def readBuffers() = {
def orZero(i: Int): Int = if (i < 0) 0 else i
val parts = sc.parts.iterator.toList
val parts = sc.parts.iterator.toList
val elements = elementsSeq.toList.map(qprint(_))
val (firstStr, explicitIndent) = readFirst(parts.head)

View file

@ -5,10 +5,15 @@ import scala.util.Try
object LoadObject {
def apply[T](using Quotes, Type[T]): Try[T] = {
import quotes.reflect.*
apply(TypeRepr.of[T])
}
def apply[T](using Quotes)(ot: quotes.reflect.TypeRepr): Try[T] = Try {
import quotes.reflect.*
val moduleClsName = ot.typeSymbol.companionModule.moduleClass.fullName
val moduleCls = Class.forName(moduleClsName)
val moduleCls = Class.forName(moduleClsName)
val field = moduleCls
.getFields()
.find { f =>

View file

@ -0,0 +1,75 @@
package minisql.util
import minisql.AstPrinter
import minisql.idiom.Idiom
import minisql.util.IndentUtil._
object Messages {
private def variable(propName: String, envName: String, default: String) =
Option(System.getProperty(propName))
.orElse(sys.env.get(envName))
.getOrElse(default)
private[util] val prettyPrint =
variable("quill.macro.log.pretty", "quill_macro_log", "false").toBoolean
private[util] val debugEnabled =
variable("quill.macro.log", "quill_macro_log", "true").toBoolean
private[util] val traceEnabled =
variable("quill.trace.enabled", "quill_trace_enabled", "false").toBoolean
private[util] val traceColors =
variable("quill.trace.color", "quill_trace_color,", "false").toBoolean
private[util] val traceOpinions =
variable("quill.trace.opinion", "quill_trace_opinion", "false").toBoolean
private[util] val traceAstSimple = variable(
"quill.trace.ast.simple",
"quill_trace_ast_simple",
"false"
).toBoolean
private[minisql] val cacheDynamicQueries = variable(
"quill.query.cacheDaynamic",
"query_query_cacheDaynamic",
"true"
).toBoolean
private[util] val traces: List[TraceType] =
variable("quill.trace.types", "quill_trace_types", "standard")
.split(",")
.toList
.map(_.trim)
.flatMap(trace =>
TraceType.values.filter(traceType => trace == traceType.value)
)
def tracesEnabled(tt: TraceType) =
traceEnabled && traces.contains(tt)
sealed trait TraceType { def value: String }
object TraceType {
case object Normalizations extends TraceType { val value = "norm" }
case object Standard extends TraceType { val value = "standard" }
case object NestedQueryExpansion extends TraceType { val value = "nest" }
def values: List[TraceType] =
List(Standard, Normalizations, NestedQueryExpansion)
}
val qprint = AstPrinter()
def fail(msg: String) =
throw new IllegalStateException(msg)
def trace[T](
label: String,
numIndent: Int = 0,
traceType: TraceType = TraceType.Standard
) =
(v: T) => {
val indent = (0 to numIndent).map(_ => "").mkString(" ")
if (tracesEnabled(traceType))
println(s"$indent$label\n${{
if (traceColors) qprint.apply(v)
else qprint.apply(v)
}.split("\n").map(s"$indent " + _).mkString("\n")}")
v
}
}

View file

@ -1,21 +1,20 @@
package minisql.util
object Show {
trait Show[T] {
def show(v: T): String
trait Show[T] {
extension (v: T) {
def show: String
}
}
object Show {
def apply[T](f: T => String) = new Show[T] {
def show(v: T) = f(v)
object Show {
def apply[T](f: T => String) = new Show[T] {
extension (v: T) {
def show: String = f(v)
}
}
implicit class Shower[T](v: T)(implicit shower: Show[T]) {
def show = shower.show(v)
}
implicit def listShow[T](implicit shower: Show[T]): Show[List[T]] =
given listShow[T](using shower: Show[T]): Show[List[T]] =
Show[List[T]] {
case list => list.map(_.show).mkString(", ")
}

View file

@ -0,0 +1,38 @@
package minisql.parsing
import minisql.ast.*
class ParsingSuite extends munit.FunSuite {
test("Ident") {
val x = 1
assertEquals(Parsing.parse(x), Ident("x"))
}
test("NumericOperator.+") {
val a = 1
val b = 2
assertEquals(
Parsing.parse(a + b),
BinaryOperation(Ident("a"), NumericOperator.+, Ident("b"))
)
}
test("NumericOperator.-") {
val a = 1
val b = 2
assertEquals(
Parsing.parse(a - b),
BinaryOperation(Ident("a"), NumericOperator.-, Ident("b"))
)
}
test("NumericOperator.*") {
val a = 1
val b = 2
assertEquals(
Parsing.parse(a * b),
BinaryOperation(Ident("a"), NumericOperator.*, Ident("b"))
)
}
}

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,35 @@
package minisql
import minisql.ast.*
class QuotedSuite extends munit.FunSuite {
private inline def testQuoted(label: String)(
inline x: Quoted,
expect: Ast
) = test(label) {
assertEquals(compileTimeAst(x), Some(expect.toString()))
}
case class Foo(id: Long)
inline def Foos = query[Foo]("foo")
val entityFoo = Entity("foo", Nil)
val idx = Ident("x")
testQuoted("EntityQuery")(Foos, entityFoo)
testQuoted("Query/filter")(
Foos.filter(x => x.id > 0),
Filter(
entityFoo,
idx,
BinaryOperation(Property(idx, "id"), NumericOperator.>, Constant(0))
)
)
testQuoted("Query/map")(
Foos.map(x => x.id),
Map(entityFoo, idx, Property(idx, "id"))
)
}