add sql idiom

This commit is contained in:
jilen 2025-06-18 16:59:06 +08:00
parent 63a9a0cad3
commit 1bc6baad68
19 changed files with 2227 additions and 0 deletions

View file

@ -0,0 +1,52 @@
package minisql
import minisql.context.{
CanReturnClause,
CanReturnField,
CanReturnMultiField,
CannotReturn
}
import minisql.context.sql.idiom.SqlIdiom
import minisql.context.sql.idiom.QuestionMarkBindVariables
import minisql.context.sql.idiom.ConcatSupport
trait MirrorSqlDialect
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CanReturnField
trait MirrorSqlDialectWithReturnMulti
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CanReturnMultiField
trait MirrorSqlDialectWithReturnClause
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CanReturnClause
trait MirrorSqlDialectWithNoReturn
extends SqlIdiom
with QuestionMarkBindVariables
with ConcatSupport
with CannotReturn
object MirrorSqlDialect extends MirrorSqlDialect {
override def prepareForProbing(string: String) = string
}
object MirrorSqlDialectWithReturnMulti extends MirrorSqlDialectWithReturnMulti {
override def prepareForProbing(string: String) = string
}
object MirrorSqlDialectWithReturnClause
extends MirrorSqlDialectWithReturnClause {
override def prepareForProbing(string: String) = string
}
object MirrorSqlDialectWithNoReturn extends MirrorSqlDialectWithNoReturn {
override def prepareForProbing(string: String) = string
}

View file

@ -0,0 +1,17 @@
package minisql.context.sql.idiom
import minisql.util.Messages
trait ConcatSupport {
this: SqlIdiom =>
override def concatFunction = "UNNEST"
}
trait NoConcatSupport {
this: SqlIdiom =>
override def concatFunction = Messages.fail(
s"`concatMap` not supported by ${this.getClass.getSimpleName}"
)
}

View file

@ -0,0 +1,70 @@
package minisql.context.sql.idiom
import minisql.ast._
import minisql.idiom.StatementInterpolator._
import minisql.idiom.Token
import minisql.NamingStrategy
import minisql.util.Messages.fail
trait OnConflictSupport {
self: SqlIdiom =>
implicit def conflictTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[OnConflict] = {
val customEntityTokenizer = Tokenizer[Entity] {
case Entity.Opinionated(name, _, renameable) =>
stmt"INTO ${renameable.fixedOr(name.token)(strategy.table(name).token)} AS t"
}
val customAstTokenizer =
Tokenizer.withFallback[Ast](self.astTokenizer(_, strategy)) {
case _: OnConflict.Excluded => stmt"EXCLUDED"
case OnConflict.Existing(a) => stmt"${a.token}"
case a: Action =>
self
.actionTokenizer(customEntityTokenizer)(
actionAstTokenizer,
strategy
)
.token(a)
}
import OnConflict._
def doUpdateStmt(i: Token, t: Token, u: Update) = {
val assignments = u.assignments
.map(a =>
stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(customAstTokenizer)}"
)
.mkStmt()
stmt"$i ON CONFLICT $t DO UPDATE SET $assignments"
}
def doNothingStmt(i: Ast, t: Token) =
stmt"${i.token} ON CONFLICT $t DO NOTHING"
implicit val conflictTargetPropsTokenizer: Tokenizer[Properties] =
Tokenizer[Properties] {
case OnConflict.Properties(props) =>
stmt"(${props.map(n => n.renameable.fixedOr(n.name)(strategy.column(n.name))).mkStmt(",")})"
}
def tokenizer(implicit astTokenizer: Tokenizer[Ast]) =
Tokenizer[OnConflict] {
case OnConflict(_, NoTarget, _: Update) =>
fail("'DO UPDATE' statement requires explicit conflict target")
case OnConflict(i, p: Properties, u: Update) =>
doUpdateStmt(i.token, p.token, u)
case OnConflict(i, NoTarget, Ignore) =>
stmt"${astTokenizer.token(i)} ON CONFLICT DO NOTHING"
case OnConflict(i, p: Properties, Ignore) => doNothingStmt(i, p.token)
}
tokenizer(customAstTokenizer)
}
}

View file

@ -0,0 +1,6 @@
package minisql.context.sql.idiom
trait PositionalBindVariables { self: SqlIdiom =>
override def liftingPlaceholder(index: Int): String = s"$$${index + 1}"
}

View file

@ -0,0 +1,6 @@
package minisql.context.sql.idiom
trait QuestionMarkBindVariables { self: SqlIdiom =>
override def liftingPlaceholder(index: Int): String = s"?"
}

View file

@ -0,0 +1,700 @@
package minisql.context.sql.idiom
import minisql.ast._
import minisql.ast.BooleanOperator._
import minisql.ast.Lift
import minisql.context.sql._
import minisql.context.sql.norm._
import minisql.idiom._
import minisql.idiom.StatementInterpolator._
import minisql.NamingStrategy
import minisql.ast.Renameable.Fixed
import minisql.ast.Visibility.Hidden
import minisql.context.{ReturningCapability, ReturningClauseSupported}
import minisql.util.Interleave
import minisql.util.Messages.{fail, trace}
import minisql.idiom.Token
import minisql.norm.EqualityBehavior
import minisql.norm.ConcatBehavior
import minisql.norm.ConcatBehavior.AnsiConcat
import minisql.norm.EqualityBehavior.AnsiEquality
import minisql.norm.ExpandReturning
trait SqlIdiom extends Idiom {
override def prepareForProbing(string: String): String
protected def concatBehavior: ConcatBehavior = AnsiConcat
protected def equalityBehavior: EqualityBehavior = AnsiEquality
protected def actionAlias: Option[Ident] = None
override def format(queryString: String): String = queryString
def querifyAst(ast: Ast) = SqlQuery(ast)
private def doTranslate(ast: Ast, cached: Boolean)(implicit
naming: NamingStrategy
): (Ast, Statement) = {
val normalizedAst =
SqlNormalize(ast, concatBehavior, equalityBehavior)
implicit val tokernizer: Tokenizer[Ast] = defaultTokenizer
val token =
normalizedAst match {
case q: Query =>
val sql = querifyAst(q)
trace("sql")(sql)
VerifySqlQuery(sql).map(fail)
val expanded = new ExpandNestedQueries(naming)(sql, List())
trace("expanded sql")(expanded)
val tokenized = expanded.token
trace("tokenized sql")(tokenized)
tokenized
case other =>
other.token
}
(normalizedAst, stmt"$token")
}
override def translate(
ast: Ast
)(implicit naming: NamingStrategy): (Ast, Statement) = {
doTranslate(ast, false)
}
def defaultTokenizer(implicit naming: NamingStrategy): Tokenizer[Ast] =
new Tokenizer[Ast] {
private val stableTokenizer = astTokenizer(this, naming)
extension (v: Ast) {
def token = stableTokenizer.token(v)
}
}
def astTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Ast] =
Tokenizer[Ast] {
case a: Query => SqlQuery(a).token
case a: Operation => a.token
case a: Infix => a.token
case a: Action => a.token
case a: Ident => a.token
case a: ExternalIdent => a.token
case a: Property => a.token
case a: Value => a.token
case a: If => a.token
case a: Lift => a.token
case a: Assignment => a.token
case a: OptionOperation => a.token
case a @ (
_: Function | _: FunctionApply | _: Dynamic | _: OptionOperation |
_: Block | _: Val | _: Ordering | _: IterableOperation |
_: OnConflict.Excluded | _: OnConflict.Existing
) =>
fail(s"Malformed or unsupported construct: $a.")
}
implicit def ifTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[If] = Tokenizer[If] {
case ast: If =>
def flatten(ast: Ast): (List[(Ast, Ast)], Ast) =
ast match {
case If(cond, a, b) =>
val (l, e) = flatten(b)
((cond, a) +: l, e)
case other =>
(List(), other)
}
val (l, e) = flatten(ast)
val conditions =
for ((cond, body) <- l) yield {
stmt"WHEN ${cond.token} THEN ${body.token}"
}
stmt"CASE ${conditions.mkStmt(" ")} ELSE ${e.token} END"
}
def concatFunction: String
protected def tokenizeGroupBy(
values: Ast
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Token =
values.token
protected class FlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
) {
import q._
def selectTokenizer =
select match {
case Nil => stmt"*"
case _ => select.token
}
def distinctTokenizer = (
distinct match {
case DistinctKind.Distinct => stmt"DISTINCT "
case DistinctKind.DistinctOn(props) =>
stmt"DISTINCT ON (${props.token}) "
case DistinctKind.None => stmt""
}
)
def withDistinct = stmt"$distinctTokenizer${selectTokenizer}"
def withFrom =
from match {
case Nil => withDistinct
case head :: tail =>
val t = tail.foldLeft(stmt"${head.token}") {
case (a, b: FlatJoinContext) =>
stmt"$a ${(b: FromContext).token}"
case (a, b) =>
stmt"$a, ${b.token}"
}
stmt"$withDistinct FROM $t"
}
def withWhere =
where match {
case None => withFrom
case Some(where) => stmt"$withFrom WHERE ${where.token}"
}
def withGroupBy =
groupBy match {
case None => withWhere
case Some(groupBy) =>
stmt"$withWhere GROUP BY ${tokenizeGroupBy(groupBy)}"
}
def withOrderBy =
orderBy match {
case Nil => withGroupBy
case orderBy => stmt"$withGroupBy ${tokenOrderBy(orderBy)}"
}
def withLimitOffset = limitOffsetToken(withOrderBy).token((limit, offset))
def apply = stmt"SELECT $withLimitOffset"
}
implicit def sqlQueryTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] {
case q: FlattenSqlQuery =>
new FlattenSqlQueryTokenizerHelper(q).apply
case SetOperationSqlQuery(a, op, b) =>
stmt"(${a.token}) ${op.token} (${b.token})"
case UnaryOperationSqlQuery(op, q) =>
stmt"SELECT ${op.token} (${q.token})"
}
protected def tokenizeColumn(
strategy: NamingStrategy,
column: String,
renameable: Renameable
) =
renameable match {
case Fixed => column
case _ => strategy.column(column)
}
protected def tokenizeTable(
strategy: NamingStrategy,
table: String,
renameable: Renameable
) =
renameable match {
case Fixed => table
case _ => strategy.table(table)
}
protected def tokenizeAlias(strategy: NamingStrategy, table: String) =
strategy.default(table)
implicit def selectValueTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[SelectValue] = {
def tokenizer(implicit astTokenizer: Tokenizer[Ast]) =
Tokenizer[SelectValue] {
case SelectValue(ast, Some(alias), false) => {
stmt"${ast.token} AS ${alias.token}"
}
case SelectValue(ast, Some(alias), true) =>
stmt"${concatFunction.token}(${ast.token}) AS ${alias.token}"
case selectValue =>
val value =
selectValue match {
case SelectValue(Ident("?"), _, _) => "?".token
case SelectValue(Ident(name), _, _) =>
stmt"${strategy.default(name).token}.*"
case SelectValue(ast, _, _) => ast.token
}
selectValue.concat match {
case true => stmt"${concatFunction.token}(${value.token})"
case false => value
}
}
val customAstTokenizer =
Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) {
case Aggregation(op, Ident(_) | Tuple(_)) => stmt"${op.token}(*)"
case Aggregation(op, Distinct(ast)) =>
stmt"${op.token}(DISTINCT ${ast.token})"
case ast @ Aggregation(op, _: Query) => scopedTokenizer(ast)
case Aggregation(op, ast) => stmt"${op.token}(${ast.token})"
}
tokenizer(customAstTokenizer)
}
implicit def operationTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Operation] = Tokenizer[Operation] {
case UnaryOperation(op, ast) => stmt"${op.token} (${ast.token})"
case BinaryOperation(a, EqualityOperator.`==`, NullValue) =>
stmt"${scopedTokenizer(a)} IS NULL"
case BinaryOperation(NullValue, EqualityOperator.`==`, b) =>
stmt"${scopedTokenizer(b)} IS NULL"
case BinaryOperation(a, EqualityOperator.`!=`, NullValue) =>
stmt"${scopedTokenizer(a)} IS NOT NULL"
case BinaryOperation(NullValue, EqualityOperator.`!=`, b) =>
stmt"${scopedTokenizer(b)} IS NOT NULL"
case BinaryOperation(a, StringOperator.`startsWith`, b) =>
stmt"${scopedTokenizer(a)} LIKE (${(BinaryOperation(b, StringOperator.`concat`, Constant("%")): Ast).token})"
case BinaryOperation(a, op @ StringOperator.`split`, b) =>
stmt"${op.token}(${scopedTokenizer(a)}, ${scopedTokenizer(b)})"
case BinaryOperation(a, op @ SetOperator.`contains`, b) =>
SetContainsToken(scopedTokenizer(b), op.token, a.token)
case BinaryOperation(a, op @ `&&`, b) =>
(a, b) match {
case (BinaryOperation(_, `||`, _), BinaryOperation(_, `||`, _)) =>
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
case (BinaryOperation(_, `||`, _), _) =>
stmt"${scopedTokenizer(a)} ${op.token} ${b.token}"
case (_, BinaryOperation(_, `||`, _)) =>
stmt"${a.token} ${op.token} ${scopedTokenizer(b)}"
case _ => stmt"${a.token} ${op.token} ${b.token}"
}
case BinaryOperation(a, op @ `||`, b) =>
stmt"${a.token} ${op.token} ${b.token}"
case BinaryOperation(a, op, b) =>
stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}"
case e: FunctionApply => fail(s"Can't translate the ast to sql: '$e'")
}
implicit def optionOperationTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] {
case OptionIsEmpty(ast) => stmt"${ast.token} IS NULL"
case OptionNonEmpty(ast) => stmt"${ast.token} IS NOT NULL"
case OptionIsDefined(ast) => stmt"${ast.token} IS NOT NULL"
case other => fail(s"Malformed or unsupported construct: $other.")
}
implicit val setOperationTokenizer: Tokenizer[SetOperation] =
Tokenizer[SetOperation] {
case UnionOperation => stmt"UNION"
case UnionAllOperation => stmt"UNION ALL"
}
protected def limitOffsetToken(
query: Statement
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) =
Tokenizer[(Option[Ast], Option[Ast])] {
case (None, None) => query
case (Some(limit), None) => stmt"$query LIMIT ${limit.token}"
case (Some(limit), Some(offset)) =>
stmt"$query LIMIT ${limit.token} OFFSET ${offset.token}"
case (None, Some(offset)) => stmt"$query OFFSET ${offset.token}"
}
protected def tokenOrderBy(
criterias: List[OrderByCriteria]
)(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) =
stmt"ORDER BY ${criterias.token}"
implicit def sourceTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[FromContext] = Tokenizer[FromContext] {
case TableContext(name, alias) =>
stmt"${name.token} ${tokenizeAlias(strategy, alias).token}"
case QueryContext(query, alias) =>
stmt"(${query.token}) AS ${tokenizeAlias(strategy, alias).token}"
case InfixContext(infix, alias) if infix.noParen =>
stmt"${(infix: Ast).token} AS ${strategy.default(alias).token}"
case InfixContext(infix, alias) =>
stmt"(${(infix: Ast).token}) AS ${strategy.default(alias).token}"
case JoinContext(t, a, b, on) =>
stmt"${a.token} ${t.token} ${b.token} ON ${on.token}"
case FlatJoinContext(t, a, on) => stmt"${t.token} ${a.token} ON ${on.token}"
}
implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] {
case InnerJoin => stmt"INNER JOIN"
case LeftJoin => stmt"LEFT JOIN"
case RightJoin => stmt"RIGHT JOIN"
case FullJoin => stmt"FULL JOIN"
}
implicit def orderByCriteriaTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] {
case OrderByCriteria(ast, Asc) => stmt"${scopedTokenizer(ast)} ASC"
case OrderByCriteria(ast, Desc) => stmt"${scopedTokenizer(ast)} DESC"
case OrderByCriteria(ast, AscNullsFirst) =>
stmt"${scopedTokenizer(ast)} ASC NULLS FIRST"
case OrderByCriteria(ast, DescNullsFirst) =>
stmt"${scopedTokenizer(ast)} DESC NULLS FIRST"
case OrderByCriteria(ast, AscNullsLast) =>
stmt"${scopedTokenizer(ast)} ASC NULLS LAST"
case OrderByCriteria(ast, DescNullsLast) =>
stmt"${scopedTokenizer(ast)} DESC NULLS LAST"
}
implicit val unaryOperatorTokenizer: Tokenizer[UnaryOperator] =
Tokenizer[UnaryOperator] {
case NumericOperator.`-` => stmt"-"
case BooleanOperator.`!` => stmt"NOT"
case StringOperator.`toUpperCase` => stmt"UPPER"
case StringOperator.`toLowerCase` => stmt"LOWER"
case StringOperator.`toLong` => stmt"" // cast is implicit
case StringOperator.`toInt` => stmt"" // cast is implicit
case SetOperator.`isEmpty` => stmt"NOT EXISTS"
case SetOperator.`nonEmpty` => stmt"EXISTS"
}
implicit val aggregationOperatorTokenizer: Tokenizer[AggregationOperator] =
Tokenizer[AggregationOperator] {
case AggregationOperator.`min` => stmt"MIN"
case AggregationOperator.`max` => stmt"MAX"
case AggregationOperator.`avg` => stmt"AVG"
case AggregationOperator.`sum` => stmt"SUM"
case AggregationOperator.`size` => stmt"COUNT"
}
implicit val binaryOperatorTokenizer: Tokenizer[BinaryOperator] =
Tokenizer[BinaryOperator] {
case EqualityOperator.`==` => stmt"="
case EqualityOperator.`!=` => stmt"<>"
case BooleanOperator.`&&` => stmt"AND"
case BooleanOperator.`||` => stmt"OR"
case StringOperator.`concat` => stmt"||"
case StringOperator.`startsWith` =>
fail("bug: this code should be unreachable")
case StringOperator.`split` => stmt"SPLIT"
case NumericOperator.`-` => stmt"-"
case NumericOperator.`+` => stmt"+"
case NumericOperator.`*` => stmt"*"
case NumericOperator.`>` => stmt">"
case NumericOperator.`>=` => stmt">="
case NumericOperator.`<` => stmt"<"
case NumericOperator.`<=` => stmt"<="
case NumericOperator.`/` => stmt"/"
case NumericOperator.`%` => stmt"%"
case SetOperator.`contains` => stmt"IN"
}
implicit def propertyTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Property] = {
def unnest(ast: Ast): (Ast, List[String]) =
ast match {
case Property.Opinionated(a, _, _, Hidden) =>
unnest(a) match {
case (a, nestedName) => (a, nestedName)
}
// Append the property name. This includes tuple indexes.
case Property(a, name) =>
unnest(a) match {
case (ast, nestedName) =>
(ast, nestedName :+ name)
}
case a => (a, Nil)
}
def tokenizePrefixedProperty(
name: String,
prefix: List[String],
strategy: NamingStrategy,
renameable: Renameable
) =
renameable.fixedOr(
(prefix.mkString + name).token
)(tokenizeColumn(strategy, prefix.mkString + name, renameable).token)
Tokenizer[Property] {
case Property.Opinionated(
ast,
name,
renameable,
_ /* Top level property cannot be invisible */
) =>
// When we have things like Embedded tables, properties inside of one another needs to be un-nested.
// E.g. in `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` the inner
// property needs to be unwrapped and the result of this should only be `realTable.realPropertyAlias`
// as opposed to `realTable.embeddedTableAlias.realPropertyAlias`.
unnest(ast) match {
// When using ExternalIdent such as .returning(eid => eid.idColumn) clauses drop the 'eid' since SQL
// returning clauses have no alias for the original table. I.e. INSERT [...] RETURNING idColumn there's no
// alias you can assign to the INSERT [...] clause that can be used as a prefix to 'idColumn'.
// In this case, `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)`
// should just be `realPropertyAlias` as opposed to `realTable.realPropertyAlias`.
// The exception to this is when a Query inside of a RETURNING clause is used. In that case, assume
// that there is an alias for the inserted table (i.e. `INSERT ... as theAlias values ... RETURNING`)
// and the instances of ExternalIdent use it.
case (ExternalIdent(_), prefix) =>
stmt"${actionAlias
.map(alias => stmt"${scopedTokenizer(alias)}.")
.getOrElse(stmt"")}${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
// In the rare case that the Ident is invisible, do not show it. See the Ident documentation for more info.
case (Ident.Opinionated(_, Hidden), prefix) =>
stmt"${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
// The normal case where `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)`
// becomes `realTable.realPropertyAlias`.
case (ast, prefix) =>
stmt"${scopedTokenizer(ast)}.${tokenizePrefixedProperty(name, prefix, strategy, renameable)}"
}
}
}
implicit def valueTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Value] = Tokenizer[Value] {
case Constant(v: String) => stmt"'${v.token}'"
case Constant(()) => stmt"1"
case Constant(v) => stmt"${v.toString.token}"
case NullValue => stmt"null"
case Tuple(values) => stmt"${values.token}"
case CaseClass(values) => stmt"${values.map(_._2).token}"
}
implicit def infixTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Infix] = Tokenizer[Infix] {
case Infix(parts, params, _, _) =>
val pt = parts.map(_.token)
val pr = params.map(_.token)
Statement(Interleave(pt, pr))
}
implicit def identTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Ident] =
Tokenizer[Ident](e => strategy.default(e.name).token)
implicit def externalIdentTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[ExternalIdent] =
Tokenizer[ExternalIdent](e => strategy.default(e.name).token)
implicit def assignmentTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Assignment] = Tokenizer[Assignment] {
case Assignment(alias, prop, value) =>
stmt"${prop.token} = ${scopedTokenizer(value)}"
}
implicit def defaultAstTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Action] = {
val insertEntityTokenizer = Tokenizer[Entity] {
case Entity.Opinionated(name, _, renameable) =>
stmt"INTO ${tokenizeTable(strategy, name, renameable).token}"
}
actionTokenizer(insertEntityTokenizer)(actionAstTokenizer, strategy)
}
protected def actionAstTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
) =
Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) {
case q: Query => astTokenizer.token(q)
case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") =>
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NULL"
case Property(
Property.Opinionated(_, name, renameable, _),
"isDefined"
) =>
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL"
case Property(Property.Opinionated(_, name, renameable, _), "nonEmpty") =>
stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL"
case Property.Opinionated(_, name, renameable, _) =>
renameable.fixedOr(name.token)(
tokenizeColumn(strategy, name, renameable).token
)
}
def returnListTokenizer(implicit
tokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[List[Ast]] = {
val customAstTokenizer =
Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) {
case sq: Query =>
stmt"(${tokenizer.token(sq)})"
}
Tokenizer[List[Ast]] {
case list =>
list.mkStmt(", ")(customAstTokenizer)
}
}
protected def actionTokenizer(
insertEntityTokenizer: Tokenizer[Entity]
)(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Action] =
Tokenizer[Action] {
case Insert(entity: Entity, assignments) =>
val table = insertEntityTokenizer.token(entity)
val columns = assignments.map(_.property.token)
val values = assignments.map(_.value)
stmt"INSERT $table${actionAlias.map(alias => stmt" AS ${alias.token}").getOrElse(stmt"")} (${columns
.mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})"
case Update(table: Entity, assignments) =>
stmt"UPDATE ${table.token}${actionAlias
.map(alias => stmt" AS ${alias.token}")
.getOrElse(stmt"")} SET ${assignments.token}"
case Update(Filter(table: Entity, x, where), assignments) =>
stmt"UPDATE ${table.token}${actionAlias
.map(alias => stmt" AS ${alias.token}")
.getOrElse(stmt"")} SET ${assignments.token} WHERE ${where.token}"
case Delete(Filter(table: Entity, x, where)) =>
stmt"DELETE FROM ${table.token} WHERE ${where.token}"
case Delete(table: Entity) =>
stmt"DELETE FROM ${table.token}"
case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) =>
idiomReturningCapability match {
// If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as
// a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5)
case ReturningClauseSupported
if (CollectAst.byType[Entity](prop).nonEmpty) =>
SqlIdiom.withActionAlias(this, r)
case ReturningClauseSupported =>
stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer
.token(ExpandReturning(r)(this, strategy).map(_._1))}"
case other =>
stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}"
}
case r @ ReturningAction(action, alias, prop) =>
idiomReturningCapability match {
// If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as
// a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5)
case ReturningClauseSupported
if (CollectAst.byType[Entity](prop).nonEmpty) =>
SqlIdiom.withActionAlias(this, r)
case ReturningClauseSupported =>
stmt"${action.token} RETURNING ${returnListTokenizer.token(
ExpandReturning(r)(this, strategy).map(_._1)
)}"
case other =>
stmt"${action.token}"
}
case other =>
fail(s"Action ast can't be translated to sql: '$other'")
}
implicit def entityTokenizer(implicit
astTokenizer: Tokenizer[Ast],
strategy: NamingStrategy
): Tokenizer[Entity] = Tokenizer[Entity] {
case Entity.Opinionated(name, _, renameable) =>
tokenizeTable(strategy, name, renameable).token
}
protected def scopedTokenizer(ast: Ast)(implicit tokenizer: Tokenizer[Ast]) =
ast match {
case _: Query => stmt"(${ast.token})"
case _: BinaryOperation => stmt"(${ast.token})"
case _: Tuple => stmt"(${ast.token})"
case _ => ast.token
}
}
object SqlIdiom {
private[minisql] def copyIdiom(
parent: SqlIdiom,
newActionAlias: Option[Ident]
) =
new SqlIdiom {
override protected def actionAlias: Option[Ident] = newActionAlias
override def prepareForProbing(string: String): String =
parent.prepareForProbing(string)
override def concatFunction: String = parent.concatFunction
override def liftingPlaceholder(index: Int): String =
parent.liftingPlaceholder(index)
override def idiomReturningCapability: ReturningCapability =
parent.idiomReturningCapability
}
/**
* Construct a new instance of the specified idiom with `newActionAlias`
* variable specified so that actions (i.e. insert, and update) will be
* rendered with the specified alias. This is needed for RETURNING clauses
* that have queries inside. See #1509 for details.
*/
private[minisql] def withActionAlias(
parentIdiom: SqlIdiom,
query: ReturningAction
)(implicit strategy: NamingStrategy) = {
val idiom = copyIdiom(parentIdiom, Some(query.alias))
import idiom._
implicit val stableTokenizer: Tokenizer[Ast] = idiom.astTokenizer(
new Tokenizer[Ast] { self =>
extension (v: Ast) {
def token = astTokenizer(self, strategy).token(v)
}
},
strategy
)
query match {
case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) =>
stmt"INSERT INTO ${table.token} AS ${alias.name.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer
.token(ExpandReturning(r)(idiom, strategy).map(_._1))}"
case r @ ReturningAction(action, alias, prop) =>
stmt"${action.token} RETURNING ${returnListTokenizer.token(
ExpandReturning(r)(idiom, strategy).map(_._1)
)}"
}
}
}

View file

@ -0,0 +1,326 @@
package minisql.context.sql
import minisql.ast._
import minisql.context.sql.norm.FlattenGroupByAggregation
import minisql.norm.BetaReduction
import minisql.util.Messages.fail
import minisql.{Literal, PseudoAst, NamingStrategy}
case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering)
sealed trait FromContext
case class TableContext(entity: Entity, alias: String) extends FromContext
case class QueryContext(query: SqlQuery, alias: String) extends FromContext
case class InfixContext(infix: Infix, alias: String) extends FromContext
case class JoinContext(t: JoinType, a: FromContext, b: FromContext, on: Ast)
extends FromContext
case class FlatJoinContext(t: JoinType, a: FromContext, on: Ast)
extends FromContext
sealed trait SqlQuery {
override def toString = {
import minisql.MirrorSqlDialect._
import minisql.idiom.StatementInterpolator.*
given Tokenizer[SqlQuery] = sqlQueryTokenizer(using
defaultTokenizer(using Literal),
Literal
)
summon[Tokenizer[SqlQuery]].token(this).toString()
}
}
sealed trait SetOperation
case object UnionOperation extends SetOperation
case object UnionAllOperation extends SetOperation
sealed trait DistinctKind { def isDistinct: Boolean }
case object DistinctKind {
case object Distinct extends DistinctKind { val isDistinct: Boolean = true }
case class DistinctOn(props: List[Ast]) extends DistinctKind {
val isDistinct: Boolean = true
}
case object None extends DistinctKind { val isDistinct: Boolean = false }
}
case class SetOperationSqlQuery(
a: SqlQuery,
op: SetOperation,
b: SqlQuery
) extends SqlQuery
case class UnaryOperationSqlQuery(
op: UnaryOperator,
q: SqlQuery
) extends SqlQuery
case class SelectValue(
ast: Ast,
alias: Option[String] = None,
concat: Boolean = false
) extends PseudoAst {
override def toString: String =
s"${ast.toString}${alias.map("->" + _).getOrElse("")}"
}
case class FlattenSqlQuery(
from: List[FromContext] = List(),
where: Option[Ast] = None,
groupBy: Option[Ast] = None,
orderBy: List[OrderByCriteria] = Nil,
limit: Option[Ast] = None,
offset: Option[Ast] = None,
select: List[SelectValue],
distinct: DistinctKind = DistinctKind.None
) extends SqlQuery
object TakeDropFlatten {
def unapply(q: Query): Option[(Query, Option[Ast], Option[Ast])] = q match {
case Take(q: FlatMap, n) => Some((q, Some(n), None))
case Drop(q: FlatMap, n) => Some((q, None, Some(n)))
case _ => None
}
}
object SqlQuery {
def apply(query: Ast): SqlQuery =
query match {
case Union(a, b) =>
SetOperationSqlQuery(apply(a), UnionOperation, apply(b))
case UnionAll(a, b) =>
SetOperationSqlQuery(apply(a), UnionAllOperation, apply(b))
case UnaryOperation(op, q: Query) => UnaryOperationSqlQuery(op, apply(q))
case _: Operation | _: Value =>
FlattenSqlQuery(select = List(SelectValue(query)))
case Map(q, a, b) if a == b => apply(q)
case TakeDropFlatten(q, limit, offset) =>
flatten(q, "x").copy(limit = limit, offset = offset)
case q: Query => flatten(q, "x")
case infix: Infix => flatten(infix, "x")
case other =>
fail(
s"Query not properly normalized. Please open a bug report. Ast: '$other'"
)
}
private def flatten(query: Ast, alias: String): FlattenSqlQuery = {
val (sources, finalFlatMapBody) = flattenContexts(query)
flatten(sources, finalFlatMapBody, alias)
}
private def flattenContexts(query: Ast): (List[FromContext], Ast) =
query match {
case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Query) =>
val source = this.source(q, alias)
val (nestedContexts, finalFlatMapBody) = flattenContexts(p)
(source +: nestedContexts, finalFlatMapBody)
case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Infix) =>
fail(s"Infix can't be use as a `flatMap` body. $query")
case other =>
(List.empty, other)
}
object NestedNest {
def unapply(q: Ast): Option[Ast] =
q match {
case _: Nested => recurse(q)
case _ => None
}
private def recurse(q: Ast): Option[Ast] =
q match {
case Nested(qn) => recurse(qn)
case other => Some(other)
}
}
private def flatten(
sources: List[FromContext],
finalFlatMapBody: Ast,
alias: String
): FlattenSqlQuery = {
def select(alias: String) = SelectValue(Ident(alias), None) :: Nil
def base(q: Ast, alias: String) = {
def nest(ctx: FromContext) =
FlattenSqlQuery(from = sources :+ ctx, select = select(alias))
q match {
case Map(_: GroupBy, _, _) => nest(source(q, alias))
case NestedNest(q) => nest(QueryContext(apply(q), alias))
case q: ConcatMap => nest(QueryContext(apply(q), alias))
case Join(tpe, a, b, iA, iB, on) =>
val ctx = source(q, alias)
def aliases(ctx: FromContext): List[String] =
ctx match {
case TableContext(_, alias) => alias :: Nil
case QueryContext(_, alias) => alias :: Nil
case InfixContext(_, alias) => alias :: Nil
case JoinContext(_, a, b, _) => aliases(a) ::: aliases(b)
case FlatJoinContext(_, a, _) => aliases(a)
}
FlattenSqlQuery(
from = ctx :: Nil,
select = aliases(ctx).map(a => SelectValue(Ident(a), None))
)
case q @ (_: Map | _: Filter | _: Entity) => flatten(sources, q, alias)
case q if (sources == Nil) => flatten(sources, q, alias)
case other => nest(source(q, alias))
}
}
finalFlatMapBody match {
case ConcatMap(q, Ident(alias), p) =>
FlattenSqlQuery(
from = source(q, alias) :: Nil,
select = selectValues(p).map(_.copy(concat = true))
)
case Map(GroupBy(q, x @ Ident(alias), g), a, p) =>
val b = base(q, alias)
val select = BetaReduction(p, a -> Tuple(List(g, x)))
val flattenSelect = FlattenGroupByAggregation(x)(select)
b.copy(groupBy = Some(g), select = this.selectValues(flattenSelect))
case GroupBy(q, Ident(alias), p) =>
fail("A `groupBy` clause must be followed by `map`.")
case Map(q, Ident(alias), p) =>
val b = base(q, alias)
val agg = b.select.collect {
case s @ SelectValue(_: Aggregation, _, _) => s
}
if (!b.distinct.isDistinct && agg.isEmpty)
b.copy(select = selectValues(p))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
select = selectValues(p)
)
case Filter(q, Ident(alias), p) =>
val b = base(q, alias)
if (b.where.isEmpty)
b.copy(where = Some(p))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
where = Some(p),
select = select(alias)
)
case SortBy(q, Ident(alias), p, o) =>
val b = base(q, alias)
val criterias = orderByCriterias(p, o)
if (b.orderBy.isEmpty)
b.copy(orderBy = criterias)
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
orderBy = criterias,
select = select(alias)
)
case Aggregation(op, q: Query) =>
val b = flatten(q, alias)
b.select match {
case head :: Nil if !b.distinct.isDistinct =>
b.copy(select = List(head.copy(ast = Aggregation(op, head.ast))))
case other =>
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
select = List(SelectValue(Aggregation(op, Ident("*"))))
)
}
case Take(q, n) =>
val b = base(q, alias)
if (b.limit.isEmpty)
b.copy(limit = Some(n))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
limit = Some(n),
select = select(alias)
)
case Drop(q, n) =>
val b = base(q, alias)
if (b.offset.isEmpty && b.limit.isEmpty)
b.copy(offset = Some(n))
else
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
offset = Some(n),
select = select(alias)
)
case Distinct(q: Query) =>
val b = base(q, alias)
b.copy(distinct = DistinctKind.Distinct)
case DistinctOn(q, Ident(alias), fields) =>
val distinctList =
fields match {
case Tuple(values) => values
case other => List(other)
}
q match {
// Ideally we don't need to make an extra sub-query for every single case of
// distinct-on but it only works when the parent AST is an entity. That's because DistinctOn
// selects from an alias of an outer clause. For example, query[Person].map(p => Name(p.firstName, p.lastName)).distinctOn(_.name)
// (Let's say Person(firstName, lastName, age), Name(first, last)) will turn into
// SELECT DISTINCT ON (p.name), p.firstName AS first, p.lastName AS last, p.age FROM Person
// This doesn't work beause `name` in `p.name` doesn't exist yet. Therefore we have to nest this in a subquery:
// SELECT DISTINCT ON (p.name) FROM (SELECT p.firstName AS first, p.lastName AS last, p.age FROM Person p) AS p
// The only exception to this is if we are directly selecting from an entity:
// query[Person].distinctOn(_.firstName) which should be fine: SELECT (x.firstName), x.firstName, x.lastName, a.age FROM Person x
// since all the fields inside the (...) of the DISTINCT ON must be contained in the entity.
case _: Entity =>
val b = base(q, alias)
b.copy(distinct = DistinctKind.DistinctOn(distinctList))
case _ =>
FlattenSqlQuery(
from = QueryContext(apply(q), alias) :: Nil,
select = select(alias),
distinct = DistinctKind.DistinctOn(distinctList)
)
}
case other =>
FlattenSqlQuery(
from = sources :+ source(other, alias),
select = select(alias)
)
}
}
private def selectValues(ast: Ast) =
ast match {
case Tuple(values) => values.map(SelectValue(_))
case other => SelectValue(ast) :: Nil
}
private def source(ast: Ast, alias: String): FromContext =
ast match {
case entity: Entity => TableContext(entity, alias)
case infix: Infix => InfixContext(infix, alias)
case Join(t, a, b, ia, ib, on) =>
JoinContext(t, source(a, ia.name), source(b, ib.name), on)
case FlatJoin(t, a, ia, on) => FlatJoinContext(t, source(a, ia.name), on)
case Nested(q) => QueryContext(apply(q), alias)
case other => QueryContext(apply(other), alias)
}
private def orderByCriterias(ast: Ast, ordering: Ast): List[OrderByCriteria] =
(ast, ordering) match {
case (Tuple(properties), ord: PropertyOrdering) =>
properties.flatMap(orderByCriterias(_, ord))
case (Tuple(properties), TupleOrdering(ord)) =>
properties.zip(ord).flatMap { case (a, o) => orderByCriterias(a, o) }
case (a, o: PropertyOrdering) => List(OrderByCriteria(a, o))
case other => fail(s"Invalid order by criteria $ast")
}
}

View file

@ -0,0 +1,122 @@
package minisql.context.sql.idiom
import minisql.ast._
import minisql.context.sql._
import minisql.norm.FreeVariables
case class Error(free: List[Ident], ast: Ast)
case class InvalidSqlQuery(errors: List[Error]) {
override def toString =
s"The monad composition can't be expressed using applicative joins. " +
errors
.map(error =>
s"Faulty expression: '${error.ast}'. Free variables: '${error.free}'."
)
.mkString(", ")
}
object VerifySqlQuery {
def apply(query: SqlQuery): Option[String] =
verify(query).map(_.toString)
private def verify(query: SqlQuery): Option[InvalidSqlQuery] =
query match {
case q: FlattenSqlQuery => verify(q)
case SetOperationSqlQuery(a, op, b) => verify(a).orElse(verify(b))
case UnaryOperationSqlQuery(op, q) => verify(q)
}
private def verifyFlatJoins(q: FlattenSqlQuery) = {
def loop(l: List[FromContext], available: Set[String]): Set[String] =
l.foldLeft(available) {
case (av, TableContext(_, alias)) => Set(alias)
case (av, InfixContext(_, alias)) => Set(alias)
case (av, QueryContext(_, alias)) => Set(alias)
case (av, JoinContext(_, a, b, on)) =>
av ++ loop(a :: Nil, av) ++ loop(b :: Nil, av)
case (av, FlatJoinContext(_, a, on)) =>
val nav = av ++ loop(a :: Nil, av)
val free = FreeVariables(on).map(_.name)
val invalid = free -- nav
require(
invalid.isEmpty,
s"Found an `ON` table reference of a table that is not available: $invalid. " +
"The `ON` condition can only use tables defined through explicit joins."
)
nav
}
loop(q.from, Set())
}
private def verify(query: FlattenSqlQuery): Option[InvalidSqlQuery] = {
verifyFlatJoins(query)
val aliases =
query.from.flatMap(this.aliases).map(Ident(_)) :+ Ident("*") :+ Ident("?")
def verifyAst(ast: Ast) = {
val freeVariables =
(FreeVariables(ast) -- aliases).toList
val freeIdents =
(CollectAst(ast) {
case ast: Property => None
case Aggregation(_, _: Ident) => None
case ast: Ident => Some(ast)
}).flatten
(freeVariables ++ freeIdents) match {
case Nil => None
case free => Some(Error(free, ast))
}
}
// Recursively expand children until values are fully flattened. Identities in all these should
// be skipped during verification.
def expandSelect(sv: SelectValue): List[SelectValue] =
sv.ast match {
case Tuple(values) =>
values.map(v => SelectValue(v)).flatMap(expandSelect(_))
case CaseClass(values) =>
values.map(v => SelectValue(v._2)).flatMap(expandSelect(_))
case _ => List(sv)
}
val freeVariableErrors: List[Error] =
query.where.flatMap(verifyAst).toList ++
query.orderBy.map(_.ast).flatMap(verifyAst) ++
query.limit.flatMap(verifyAst) ++
query.select
.flatMap(
expandSelect(_)
) // Expand tuple select clauses so their top-level identities are skipped
.map(_.ast)
.filterNot(_.isInstanceOf[Ident])
.flatMap(verifyAst) ++
query.from.flatMap {
case j: JoinContext => verifyAst(j.on)
case j: FlatJoinContext => verifyAst(j.on)
case _ => Nil
}
val nestedErrors =
query.from.collect {
case QueryContext(query, alias) => verify(query).map(_.errors)
}.flatten.flatten
(freeVariableErrors ++ nestedErrors) match {
case Nil => None
case errors => Some(InvalidSqlQuery(errors))
}
}
private def aliases(s: FromContext): List[String] =
s match {
case s: TableContext => List(s.alias)
case s: QueryContext => List(s.alias)
case s: InfixContext => List(s.alias)
case s: JoinContext => aliases(s.a) ++ aliases(s.b)
case s: FlatJoinContext => aliases(s.a)
}
}

View file

@ -0,0 +1,47 @@
package minisql.context.sql.norm
import minisql.ast.Constant
import minisql.context.sql.{FlattenSqlQuery, SqlQuery, _}
/**
* In SQL Server, `Order By` clauses are only allowed in sub-queries if the
* sub-query has a `TOP` or `OFFSET` modifier. Otherwise an exception will be
* thrown. This transformation adds a 'dummy' `OFFSET 0` in this scenario (if an
* `Offset` clause does not exist already).
*/
object AddDropToNestedOrderBy {
def applyInner(q: SqlQuery): SqlQuery =
q match {
case q: FlattenSqlQuery =>
q.copy(
offset =
if (q.orderBy.nonEmpty) q.offset.orElse(Some(Constant(0)))
else q.offset,
from = q.from.map(applyInner(_))
)
case SetOperationSqlQuery(a, op, b) =>
SetOperationSqlQuery(applyInner(a), op, applyInner(b))
case UnaryOperationSqlQuery(op, a) =>
UnaryOperationSqlQuery(op, applyInner(a))
}
private def applyInner(f: FromContext): FromContext =
f match {
case QueryContext(a, alias) => QueryContext(applyInner(a), alias)
case JoinContext(t, a, b, on) =>
JoinContext(t, applyInner(a), applyInner(b), on)
case FlatJoinContext(t, a, on) => FlatJoinContext(t, applyInner(a), on)
case other => other
}
def apply(q: SqlQuery): SqlQuery =
q match {
case q: FlattenSqlQuery => q.copy(from = q.from.map(applyInner(_)))
case SetOperationSqlQuery(a, op, b) =>
SetOperationSqlQuery(applyInner(a), op, applyInner(b))
case UnaryOperationSqlQuery(op, a) =>
UnaryOperationSqlQuery(op, applyInner(a))
}
}

View file

@ -0,0 +1,68 @@
package minisql.context.sql.norm
import minisql.ast.Visibility.Hidden
import minisql.ast._
object ExpandDistinct {
@annotation.tailrec
def hasJoin(q: Ast): Boolean = {
q match {
case _: Join => true
case Map(q, _, _) => hasJoin(q)
case Filter(q, _, _) => hasJoin(q)
case _ => false
}
}
def apply(q: Ast): Ast =
q match {
case Distinct(q) =>
Distinct(apply(q))
case q =>
Transform(q) {
case Aggregation(op, Distinct(q)) =>
Aggregation(op, Distinct(apply(q)))
case Distinct(Map(q, x, cc @ Tuple(values))) =>
Map(
Distinct(Map(q, x, cc)),
x,
Tuple(values.zipWithIndex.map {
case (_, i) => Property(x, s"_${i + 1}")
})
)
// Situations like this:
// case class AdHocCaseClass(id: Int, name: String)
// val q = quote {
// query[SomeTable].map(st => AdHocCaseClass(st.id, st.name)).distinct
// }
// ... need some special treatment. Otherwise their values will not be correctly expanded.
case Distinct(Map(q, x, cc @ CaseClass(values))) =>
Map(
Distinct(Map(q, x, cc)),
x,
CaseClass(values.map {
case (name, _) => (name, Property(x, name))
})
)
// Need some special handling to address issues with distinct returning a single embedded entity i.e:
// query[Parent].map(p => p.emb).distinct.map(e => (e.name, e.id))
// cannot treat such a case normally or "confused" queries will result e.g:
// SELECT p.embname, p.embid FROM (SELECT DISTINCT emb.name /* Where the heck is 'emb' coming from? */ AS embname, emb.id AS embid FROM Parent p) AS p
case d @ Distinct(
Map(q, x, p @ Property.Opinionated(_, _, _, Hidden))
) =>
d
// Problems with distinct were first discovered in #1032. Basically, unless
// the distinct is "expanded" adding an outer map, Ident's representing a Table will end up in invalid places
// such as "ORDER BY tableIdent" etc...
case Distinct(Map(q, x, p)) =>
val newMap = Map(q, x, Tuple(List(p)))
val newIdent = Ident(x.name)
Map(Distinct(newMap), newIdent, Property(newIdent, "_1"))
}
}
}

View file

@ -0,0 +1,49 @@
package minisql.context.sql.norm
import minisql.ast._
import minisql.norm.BetaReduction
import minisql.norm.Normalize
object ExpandJoin {
def apply(q: Ast) = expand(q, None)
def expand(q: Ast, id: Option[Ident]) =
Transform(q) {
case q @ Join(_, _, _, Ident(a), Ident(b), _) =>
val (qr, tuple) = expandedTuple(q)
Map(qr, id.getOrElse(Ident(s"$a$b")), tuple)
}
private def expandedTuple(q: Join): (Join, Tuple) =
q match {
case Join(t, a: Join, b: Join, tA, tB, o) =>
val (ar, at) = expandedTuple(a)
val (br, bt) = expandedTuple(b)
val or = BetaReduction(o, tA -> at, tB -> bt)
(Join(t, ar, br, tA, tB, or), Tuple(List(at, bt)))
case Join(t, a: Join, b, tA, tB, o) =>
val (ar, at) = expandedTuple(a)
val or = BetaReduction(o, tA -> at)
(Join(t, ar, b, tA, tB, or), Tuple(List(at, tB)))
case Join(t, a, b: Join, tA, tB, o) =>
val (br, bt) = expandedTuple(b)
val or = BetaReduction(o, tB -> bt)
(Join(t, a, br, tA, tB, or), Tuple(List(tA, bt)))
case q @ Join(t, a, b, tA, tB, on) =>
(
Join(t, nestedExpand(a, tA), nestedExpand(b, tB), tA, tB, on),
Tuple(List(tA, tB))
)
}
private def nestedExpand(q: Ast, id: Ident) =
Normalize(expand(q, Some(id))) match {
case Map(q, _, _) => q
case q => q
}
}

View file

@ -0,0 +1,12 @@
package minisql.context.sql.norm
import minisql.ast._
object ExpandMappedInfix {
def apply(q: Ast): Ast = {
Transform(q) {
case Map(Infix("" :: parts, (q: Query) :: params, pure, noParen), x, p) =>
Infix("" :: parts, Map(q, x, p) :: params, pure, noParen)
}
}
}

View file

@ -0,0 +1,147 @@
package minisql.context.sql.norm
import minisql.NamingStrategy
import minisql.ast.Ast
import minisql.ast.Ident
import minisql.ast._
import minisql.ast.StatefulTransformer
import minisql.ast.Visibility.Visible
import minisql.context.sql._
import scala.collection.mutable.LinkedHashSet
import minisql.util.Interpolator
import minisql.util.Messages.TraceType.NestedQueryExpansion
import minisql.context.sql.norm.nested.ExpandSelect
import minisql.norm.BetaReduction
import scala.collection.mutable
class ExpandNestedQueries(strategy: NamingStrategy) {
val interp = new Interpolator(3)
import interp._
def apply(q: SqlQuery, references: List[Property]): SqlQuery =
apply(q, LinkedHashSet.empty ++ references)
// Using LinkedHashSet despite the fact that it is mutable because it has better characteristics then ListSet.
// Also this collection is strictly internal to ExpandNestedQueries and exposed anywhere else.
private def apply(
q: SqlQuery,
references: LinkedHashSet[Property]
): SqlQuery =
q match {
case q: FlattenSqlQuery =>
val expand = expandNested(
q.copy(select = ExpandSelect(q.select, references, strategy))
)
trace"Expanded Nested Query $q into $expand".andLog()
expand
case SetOperationSqlQuery(a, op, b) =>
SetOperationSqlQuery(apply(a, references), op, apply(b, references))
case UnaryOperationSqlQuery(op, q) =>
UnaryOperationSqlQuery(op, apply(q, references))
}
private def expandNested(q: FlattenSqlQuery): SqlQuery =
q match {
case FlattenSqlQuery(
from,
where,
groupBy,
orderBy,
limit,
offset,
select,
distinct
) =>
val asts = Nil ++ select.map(_.ast) ++ where ++ groupBy ++ orderBy.map(
_.ast
) ++ limit ++ offset
val expansions = q.from.map(expandContext(_, asts))
val from = expansions.map(_._1)
val references = expansions.flatMap(_._2)
val replacedRefs = references.map(ref => (ref, unhideAst(ref)))
// Need to unhide properties that were used during the query
def replaceProps(ast: Ast) =
BetaReduction(ast, replacedRefs: _*)
def replacePropsOption(ast: Option[Ast]) =
ast.map(replaceProps(_))
val distinctKind =
q.distinct match {
case DistinctKind.DistinctOn(props) =>
DistinctKind.DistinctOn(props.map(p => replaceProps(p)))
case other => other
}
q.copy(
select = select.map(sv => sv.copy(ast = replaceProps(sv.ast))),
from = from,
where = replacePropsOption(where),
groupBy = replacePropsOption(groupBy),
orderBy = orderBy.map(ob => ob.copy(ast = replaceProps(ob.ast))),
limit = replacePropsOption(limit),
offset = replacePropsOption(offset),
distinct = distinctKind
)
}
def unhideAst(ast: Ast): Ast =
Transform(ast) {
case Property.Opinionated(a, n, r, v) =>
Property.Opinionated(unhideAst(a), n, r, Visible)
}
private def unhideProperties(sv: SelectValue) =
sv.copy(ast = unhideAst(sv.ast))
private def expandContext(
s: FromContext,
asts: List[Ast]
): (FromContext, LinkedHashSet[Property]) =
s match {
case QueryContext(q, alias) =>
val refs = references(alias, asts)
(QueryContext(apply(q, refs), alias), refs)
case JoinContext(t, a, b, on) =>
val (left, leftRefs) = expandContext(a, asts :+ on)
val (right, rightRefs) = expandContext(b, asts :+ on)
(JoinContext(t, left, right, on), leftRefs ++ rightRefs)
case FlatJoinContext(t, a, on) =>
val (next, refs) = expandContext(a, asts :+ on)
(FlatJoinContext(t, next, on), refs)
case _: TableContext | _: InfixContext =>
(s, new mutable.LinkedHashSet[Property]())
}
private def references(alias: String, asts: List[Ast]) =
LinkedHashSet.empty ++ (References(State(Ident(alias), Nil))(asts)(
_.apply
)._2.state.references)
}
case class State(ident: Ident, references: List[Property])
case class References(val state: State) extends StatefulTransformer[State] {
import state._
override def apply(a: Ast) =
a match {
case `reference`(p) => (p, References(State(ident, references :+ p)))
case other => super.apply(a)
}
object reference {
def unapply(p: Property): Option[Property] =
p match {
case Property(`ident`, name) => Some(p)
case Property(reference(_), name) => Some(p)
case other => None
}
}
}

View file

@ -0,0 +1,58 @@
package minisql.context.sql.norm
import minisql.ast.Aggregation
import minisql.ast.Ast
import minisql.ast.Drop
import minisql.ast.Filter
import minisql.ast.FlatMap
import minisql.ast.Ident
import minisql.ast.Join
import minisql.ast.Map
import minisql.ast.Query
import minisql.ast.SortBy
import minisql.ast.StatelessTransformer
import minisql.ast.Take
import minisql.ast.Union
import minisql.ast.UnionAll
import minisql.norm.BetaReduction
import minisql.util.Messages.fail
import minisql.ast.ConcatMap
case class FlattenGroupByAggregation(agg: Ident) extends StatelessTransformer {
override def apply(ast: Ast) =
ast match {
case q: Query if (isGroupByAggregation(q)) =>
q match {
case Aggregation(op, Map(`agg`, ident, body)) =>
Aggregation(op, BetaReduction(body, ident -> agg))
case Map(`agg`, ident, body) =>
BetaReduction(body, ident -> agg)
case q @ Aggregation(op, `agg`) =>
q
case other =>
fail(s"Invalid group by aggregation: '$other'")
}
case other =>
super.apply(other)
}
private[this] def isGroupByAggregation(ast: Ast): Boolean =
ast match {
case Aggregation(a, b) => isGroupByAggregation(b)
case Map(a, b, c) => isGroupByAggregation(a)
case FlatMap(a, b, c) => isGroupByAggregation(a)
case ConcatMap(a, b, c) => isGroupByAggregation(a)
case Filter(a, b, c) => isGroupByAggregation(a)
case SortBy(a, b, c, d) => isGroupByAggregation(a)
case Take(a, b) => isGroupByAggregation(a)
case Drop(a, b) => isGroupByAggregation(a)
case Union(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b)
case UnionAll(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b)
case Join(t, a, b, ta, tb, on) =>
isGroupByAggregation(a) || isGroupByAggregation(b)
case `agg` => true
case other => false
}
}

View file

@ -0,0 +1,53 @@
package minisql.context.sql.norm
import minisql.norm._
import minisql.ast.Ast
import minisql.norm.ConcatBehavior.AnsiConcat
import minisql.norm.EqualityBehavior.AnsiEquality
import minisql.norm.capture.DemarcateExternalAliases
import minisql.util.Messages.trace
object SqlNormalize {
def apply(
ast: Ast,
concatBehavior: ConcatBehavior = AnsiConcat,
equalityBehavior: EqualityBehavior = AnsiEquality
) =
new SqlNormalize(concatBehavior, equalityBehavior)(ast)
}
class SqlNormalize(
concatBehavior: ConcatBehavior,
equalityBehavior: EqualityBehavior
) {
private val normalize =
(identity[Ast] _)
.andThen(trace("original"))
.andThen(DemarcateExternalAliases.apply _)
.andThen(trace("DemarcateReturningAliases"))
.andThen(new FlattenOptionOperation(concatBehavior).apply _)
.andThen(trace("FlattenOptionOperation"))
.andThen(new SimplifyNullChecks(equalityBehavior).apply _)
.andThen(trace("SimplifyNullChecks"))
.andThen(Normalize.apply _)
.andThen(trace("Normalize"))
// Need to do RenameProperties before ExpandJoin which normalizes-out all the tuple indexes
// on which RenameProperties relies
.andThen(RenameProperties.apply _)
.andThen(trace("RenameProperties"))
.andThen(ExpandDistinct.apply _)
.andThen(trace("ExpandDistinct"))
.andThen(NestImpureMappedInfix.apply _)
.andThen(trace("NestMappedInfix"))
.andThen(Normalize.apply _)
.andThen(trace("Normalize"))
.andThen(ExpandJoin.apply _)
.andThen(trace("ExpandJoin"))
.andThen(ExpandMappedInfix.apply _)
.andThen(trace("ExpandMappedInfix"))
.andThen(Normalize.apply _)
.andThen(trace("Normalize"))
def apply(ast: Ast) = normalize(ast)
}

View file

@ -0,0 +1,29 @@
package minisql.context.sql.norm.nested
import minisql.PseudoAst
import minisql.context.sql.SelectValue
object Elements {
/**
* In order to be able to reconstruct the original ordering of elements inside
* of a select clause, we need to keep track of their order, not only within
* the top-level select but also it's order within any possible
* tuples/case-classes that in which it is embedded. For example, in the
* query: <pre><code> query[Person].map(p => (p.id, (p.name, p.age))).nested
* // SELECT p.id, p.name, p.age FROM (SELECT x.id, x.name, x.age FROM person
* x) AS p </code></pre> Since the `p.name` and `p.age` elements are selected
* inside of a sub-tuple, their "order" is `List(2,1)` and `List(2,2)`
* respectively as opposed to `p.id` whose "order" is just `List(1)`.
*
* This class keeps track of the values needed in order to perform do this.
*/
case class OrderedSelect(order: List[Int], selectValue: SelectValue)
extends PseudoAst {
override def toString: String = s"[${order.mkString(",")}]${selectValue}"
}
object OrderedSelect {
def apply(order: Int, selectValue: SelectValue) =
new OrderedSelect(List(order), selectValue)
}
}

View file

@ -0,0 +1,262 @@
package minisql.context.sql.norm.nested
import minisql.NamingStrategy
import minisql.ast.Property
import minisql.context.sql.SelectValue
import minisql.util.Interpolator
import minisql.util.Messages.TraceType.NestedQueryExpansion
import scala.collection.mutable.LinkedHashSet
import minisql.context.sql.norm.nested.Elements._
import minisql.ast._
import minisql.norm.BetaReduction
/**
* Takes the `SelectValue` elements inside of a sub-query (if a super/sub-query
* constrct exists) and flattens them from a nested-hiearchical structure (i.e.
* tuples inside case classes inside tuples etc..) into into a single series of
* top-level select elements where needed. In cases where a user wants to select
* an element that contains an entire tuple (i.e. a sub-tuple of the outer
* select clause) we pull out the entire tuple that is being selected and leave
* it to the tokenizer to flatten later.
*
* The part about this operation that is tricky is if there are situations where
* there are infix clauses in a sub-query representing an element that has not
* been selected by the query-query but in order to ensure the SQL operation has
* the same meaning, we need to keep track for it. For example: <pre><code> val
* q = quote { query[Person].map(p => (infix"DISTINCT ON (${p.other})".as[Int],
* p.name, p.id)).map(t => (t._2, t._3)) } run(q) // SELECT p._2, p._3 FROM
* (SELECT DISTINCT ON (p.other), p.name AS _2, p.id AS _3 FROM Person p) AS p
* </code></pre> Since `DISTINCT ON` significantly changes the behavior of the
* outer query, we need to keep track of it inside of the inner query. In order
* to do this, we need to keep track of the location of the infix in the inner
* query so that we can reconstruct it. This is why the `OrderedSelect` and
* `DoubleOrderedSelect` objects are used. See the notes on these classes for
* more detail.
*
* See issue #1597 for more details and another example.
*/
private class ExpandSelect(
selectValues: List[SelectValue],
references: LinkedHashSet[Property],
strategy: NamingStrategy
) {
val interp = new Interpolator(3)
import interp._
object TupleIndex {
def unapply(s: String): Option[Int] =
if (s.matches("_[0-9]*"))
Some(s.drop(1).toInt - 1)
else
None
}
object MultiTupleIndex {
def unapply(s: String): Boolean =
if (s.matches("(_[0-9]+)+"))
true
else
false
}
val select =
selectValues.zipWithIndex.map {
case (value, index) => OrderedSelect(index, value)
}
def expandColumn(name: String, renameable: Renameable): String =
renameable.fixedOr(name)(strategy.column(name))
def apply: List[SelectValue] =
trace"Expanding Select values: $selectValues into references: $references" andReturn {
def expandReference(ref: Property): OrderedSelect =
trace"Expanding: $ref from $select" andReturn {
def expressIfTupleIndex(str: String) =
str match {
case MultiTupleIndex() => Some(str)
case _ => None
}
def concat(alias: Option[String], idx: Int) =
Some(s"${alias.getOrElse("")}_${idx + 1}")
val orderedSelect = ref match {
case pp @ Property(ast: Property, TupleIndex(idx)) =>
trace"Reference is a sub-property of a tuple index: $idx. Walking inside." andReturn
expandReference(ast) match {
case OrderedSelect(o, SelectValue(Tuple(elems), alias, c)) =>
trace"Expressing Element $idx of $elems " andReturn
OrderedSelect(
o :+ idx,
SelectValue(elems(idx), concat(alias, idx), c)
)
case OrderedSelect(o, SelectValue(ast, alias, c)) =>
trace"Appending $idx to $alias " andReturn
OrderedSelect(o, SelectValue(ast, concat(alias, idx), c))
}
case pp @ Property.Opinionated(
ast: Property,
name,
renameable,
visible
) =>
trace"Reference is a sub-property. Walking inside." andReturn
expandReference(ast) match {
case OrderedSelect(o, SelectValue(ast, nested, c)) =>
// Alias is the name of the column after the naming strategy
// The clauses in `SqlIdiom` that use `Tokenizer[SelectValue]` select the
// alias field when it's value is Some(T).
// Technically the aliases of a column should not be using naming strategies
// but this is an issue to fix at a later date.
// In the current implementation, aliases we add nested tuple names to queries e.g.
// SELECT foo from
// SELECT x, y FROM (SELECT foo, bar, red, orange FROM baz JOIN colors)
// Typically becomes SELECT foo _1foo, _1bar, _2red, _2orange when
// this kind of query is the result of an applicative join that looks like this:
// query[baz].join(query[colors]).nested
// this may need to change based on how distinct appends table names instead of just tuple indexes
// into the property path.
trace"...inside walk completed, continuing to return: " andReturn
OrderedSelect(
o,
SelectValue(
// Note: Pass invisible properties to be tokenized by the idiom, they should be excluded there
Property.Opinionated(ast, name, renameable, visible),
// Skip concatonation of invisible properties into the alias e.g. so it will be
Some(
s"${nested.getOrElse("")}${expandColumn(name, renameable)}"
)
)
)
}
case pp @ Property(_, TupleIndex(idx)) =>
trace"Reference is a tuple index: $idx from $select." andReturn
select(idx) match {
case OrderedSelect(o, SelectValue(ast, alias, c)) =>
OrderedSelect(o, SelectValue(ast, concat(alias, idx), c))
}
case pp @ Property.Opinionated(_, name, renameable, visible) =>
select match {
case List(
OrderedSelect(o, SelectValue(cc: CaseClass, alias, c))
) =>
// Currently case class element name is not being appended. Need to change that in order to ensure
// path name uniqueness in future.
val ((_, ast), index) =
cc.values.zipWithIndex.find(_._1._1 == name) match {
case Some(v) => v
case None =>
throw new IllegalArgumentException(
s"Cannot find element $name in $cc"
)
}
trace"Reference is a case class member: " andReturn
OrderedSelect(
o :+ index,
SelectValue(ast, Some(expandColumn(name, renameable)), c)
)
case List(OrderedSelect(o, SelectValue(i: Ident, _, c))) =>
trace"Reference is an identifier: " andReturn
OrderedSelect(
o,
SelectValue(
Property.Opinionated(i, name, renameable, visible),
Some(name),
c
)
)
case other =>
trace"Reference is unidentified: $other returning:" andReturn
OrderedSelect(
Integer.MAX_VALUE,
SelectValue(
Ident.Opinionated(name, visible),
Some(expandColumn(name, renameable)),
false
)
)
}
}
// For certain very large queries where entities are unwrapped and then re-wrapped into CaseClass/Tuple constructs,
// the actual row-types can contain Tuple/CaseClass values. For this reason. They need to be beta-reduced again.
val normalizedOrderedSelect = orderedSelect.copy(selectValue =
orderedSelect.selectValue.copy(ast =
BetaReduction(orderedSelect.selectValue.ast)
)
)
trace"Expanded $ref into $orderedSelect then Normalized to $normalizedOrderedSelect" andReturn
normalizedOrderedSelect
}
def deAliasWhenUneeded(os: OrderedSelect) =
os match {
case OrderedSelect(
_,
sv @ SelectValue(Property(Ident(_), propName), Some(alias), _)
) if (propName == alias) =>
trace"Detected select value with un-needed alias: $os removing it:" andReturn
os.copy(selectValue = sv.copy(alias = None))
case _ => os
}
references.toList match {
case Nil => select.map(_.selectValue)
case refs => {
// elements first need to be sorted by their order in the select clause. Since some may map to multiple
// properties when expanded, we want to maintain this order of properties as a secondary value.
val mappedRefs =
refs
// Expand all the references to properties that we have selected in the super query
.map(expandReference)
// Once all the recursive calls of expandReference are done, remove the alias if it is not needed.
// We cannot do this because during recursive calls, the aliases of outer clauses are used for inner ones.
.map(deAliasWhenUneeded(_))
trace"Mapped Refs: $mappedRefs".andLog()
// are there any selects that have infix values which we have not already selected? We need to include
// them because they could be doing essential things e.g. RANK ... ORDER BY
val remainingSelectsWithInfixes =
trace"Searching Selects with Infix:" andReturn
new FindUnexpressedInfixes(select)(mappedRefs)
implicit val ordering: scala.math.Ordering[List[Int]] =
new scala.math.Ordering[List[Int]] {
override def compare(x: List[Int], y: List[Int]): Int =
(x, y) match {
case (head1 :: tail1, head2 :: tail2) =>
val diff = head1 - head2
if (diff != 0) diff
else compare(tail1, tail2)
case (Nil, Nil) => 0 // List(1,2,3) == List(1,2,3)
case (head1, Nil) => -1 // List(1,2,3) < List(1,2)
case (Nil, head2) => 1 // List(1,2) > List(1,2,3)
}
}
val sortedRefs =
(mappedRefs ++ remainingSelectsWithInfixes).sortBy(ref =>
ref.order
) // (ref.order, ref.secondaryOrder)
sortedRefs.map(_.selectValue)
}
}
}
}
object ExpandSelect {
def apply(
selectValues: List[SelectValue],
references: LinkedHashSet[Property],
strategy: NamingStrategy
): List[SelectValue] =
new ExpandSelect(selectValues, references, strategy).apply
}

View file

@ -0,0 +1,83 @@
package minisql.context.sql.norm.nested
import minisql.context.sql.norm.nested.Elements._
import minisql.util.Interpolator
import minisql.util.Messages.TraceType.NestedQueryExpansion
import minisql.ast._
import minisql.context.sql.SelectValue
/**
* The challenge with appeneding infixes (that have not been used but are still
* needed) back into the query, is that they could be inside of
* tuples/case-classes that have already been selected, or inside of sibling
* elements which have been selected. Take for instance a query that looks like
* this: <pre><code> query[Person].map(p => (p.name, (p.id,
* infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2._1)) </code></pre> In
* this situation, `p.id` which is the sibling of the non-selected infix has
* been selected via `p._2._1` (whose select-order is List(1,0) to represent 1st
* element in 2nd tuple. We need to add it's sibling infix.
*
* Or take the following situation: <pre><code> query[Person].map(p => (p.name,
* (p.id, infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2))
* </code></pre> In this case, we have selected the entire 2nd element including
* the infix. We need to know that `P._2._2` does not need to be selected since
* `p._2` was.
*
* In order to do these things, we use the `order` property from `OrderedSelect`
* in order to see which sub-sub-...-element has been selected. If `p._2` (that
* has order `List(1)`) has been selected, we know that any infixes inside of it
* e.g. `p._2._1` (ordering `List(1,0)`) does not need to be.
*/
class FindUnexpressedInfixes(select: List[OrderedSelect]) {
val interp = new Interpolator(3)
import interp._
def apply(refs: List[OrderedSelect]) = {
def pathExists(path: List[Int]) =
refs.map(_.order).contains(path)
def containsInfix(ast: Ast) =
CollectAst.byType[Infix](ast).length > 0
// build paths to every infix and see these paths were not selected already
def findMissingInfixes(
ast: Ast,
parentOrder: List[Int]
): List[(Ast, List[Int])] = {
trace"Searching for infix: $ast in the sub-path $parentOrder".andLog()
if (pathExists(parentOrder))
trace"No infixes found" andContinue
List()
else
ast match {
case Tuple(values) =>
values.zipWithIndex
.filter(v => containsInfix(v._1))
.flatMap {
case (ast, index) =>
findMissingInfixes(ast, parentOrder :+ index)
}
case CaseClass(values) =>
values.zipWithIndex
.filter(v => containsInfix(v._1._2))
.flatMap {
case ((_, ast), index) =>
findMissingInfixes(ast, parentOrder :+ index)
}
case other if (containsInfix(other)) =>
trace"Found unexpressed infix inside $other in $parentOrder"
.andLog()
List((other, parentOrder))
case _ =>
List()
}
}
select.flatMap {
case OrderedSelect(o, sv) => findMissingInfixes(sv.ast, o)
}.map {
case (ast, order) => OrderedSelect(order, SelectValue(ast))
}
}
}

View file

@ -0,0 +1,120 @@
package minisql.norm
import minisql.ast.*
import collection.immutable.Set
case class State(seen: Set[Ident], free: Set[Ident])
case class FreeVariables(state: State) extends StatefulTransformer[State] {
override def apply(ast: Ast): (Ast, StatefulTransformer[State]) =
ast match {
case ident: Ident if (!state.seen.contains(ident)) =>
(ident, FreeVariables(State(state.seen, state.free + ident)))
case f @ Function(params, body) =>
val (_, t) =
FreeVariables(State(state.seen ++ params, state.free))(body)
(f, FreeVariables(State(state.seen, state.free ++ t.state.free)))
case q @ Foreach(a, b, c) =>
(q, free(a, b, c))
case other =>
super.apply(other)
}
override def apply(
o: OptionOperation
): (OptionOperation, StatefulTransformer[State]) =
o match {
case q @ OptionTableFlatMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionTableMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionTableExists(a, b, c) =>
(q, free(a, b, c))
case q @ OptionTableForall(a, b, c) =>
(q, free(a, b, c))
case q @ OptionFlatMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionMap(a, b, c) =>
(q, free(a, b, c))
case q @ OptionForall(a, b, c) =>
(q, free(a, b, c))
case q @ OptionExists(a, b, c) =>
(q, free(a, b, c))
case other =>
super.apply(other)
}
override def apply(e: Assignment): (Assignment, StatefulTransformer[State]) =
e match {
case Assignment(a, b, c) =>
val t = FreeVariables(State(state.seen + a, state.free))
val (bt, btt) = t(b)
val (ct, ctt) = t(c)
(
Assignment(a, bt, ct),
FreeVariables(
State(state.seen, state.free ++ btt.state.free ++ ctt.state.free)
)
)
}
override def apply(action: Action): (Action, StatefulTransformer[State]) =
action match {
case q @ Returning(a, b, c) =>
(q, free(a, b, c))
case q @ ReturningGenerated(a, b, c) =>
(q, free(a, b, c))
case other =>
super.apply(other)
}
override def apply(
e: OnConflict.Target
): (OnConflict.Target, StatefulTransformer[State]) = (e, this)
override def apply(query: Query): (Query, StatefulTransformer[State]) =
query match {
case q @ Filter(a, b, c) => (q, free(a, b, c))
case q @ Map(a, b, c) => (q, free(a, b, c))
case q @ DistinctOn(a, b, c) => (q, free(a, b, c))
case q @ FlatMap(a, b, c) => (q, free(a, b, c))
case q @ ConcatMap(a, b, c) => (q, free(a, b, c))
case q @ SortBy(a, b, c, d) => (q, free(a, b, c))
case q @ GroupBy(a, b, c) => (q, free(a, b, c))
case q @ FlatJoin(t, a, b, c) => (q, free(a, b, c))
case q @ Join(t, a, b, iA, iB, on) =>
val (_, freeA) = apply(a)
val (_, freeB) = apply(b)
val (_, freeOn) =
FreeVariables(State(state.seen + iA + iB, Set.empty))(on)
(
q,
FreeVariables(
State(
state.seen,
state.free ++ freeA.state.free ++ freeB.state.free ++ freeOn.state.free
)
)
)
case _: Entity | _: Take | _: Drop | _: Union | _: UnionAll |
_: Aggregation | _: Distinct | _: Nested =>
super.apply(query)
}
private def free(a: Ast, ident: Ident, c: Ast) = {
val (_, ta) = apply(a)
val (_, tc) = FreeVariables(State(state.seen + ident, state.free))(c)
FreeVariables(
State(state.seen, state.free ++ ta.state.free ++ tc.state.free)
)
}
}
object FreeVariables {
def apply(ast: Ast): Set[Ident] =
new FreeVariables(State(Set.empty, Set.empty))(ast) match {
case (_, transformer) =>
transformer.state.free
}
}