diff --git a/src/main/scala/minisql/MirrorSqlDialect.scala b/src/main/scala/minisql/MirrorSqlDialect.scala new file mode 100644 index 0000000..563f770 --- /dev/null +++ b/src/main/scala/minisql/MirrorSqlDialect.scala @@ -0,0 +1,52 @@ +package minisql + +import minisql.context.{ + CanReturnClause, + CanReturnField, + CanReturnMultiField, + CannotReturn +} +import minisql.context.sql.idiom.SqlIdiom +import minisql.context.sql.idiom.QuestionMarkBindVariables +import minisql.context.sql.idiom.ConcatSupport + +trait MirrorSqlDialect + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnField + +trait MirrorSqlDialectWithReturnMulti + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnMultiField + +trait MirrorSqlDialectWithReturnClause + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CanReturnClause + +trait MirrorSqlDialectWithNoReturn + extends SqlIdiom + with QuestionMarkBindVariables + with ConcatSupport + with CannotReturn + +object MirrorSqlDialect extends MirrorSqlDialect { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithReturnMulti extends MirrorSqlDialectWithReturnMulti { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithReturnClause + extends MirrorSqlDialectWithReturnClause { + override def prepareForProbing(string: String) = string +} + +object MirrorSqlDialectWithNoReturn extends MirrorSqlDialectWithNoReturn { + override def prepareForProbing(string: String) = string +} diff --git a/src/main/scala/minisql/context/sql/ConcatSupport.scala b/src/main/scala/minisql/context/sql/ConcatSupport.scala new file mode 100644 index 0000000..39a2e64 --- /dev/null +++ b/src/main/scala/minisql/context/sql/ConcatSupport.scala @@ -0,0 +1,17 @@ +package minisql.context.sql.idiom + +import minisql.util.Messages + +trait ConcatSupport { + this: SqlIdiom => + + override def concatFunction = "UNNEST" +} + +trait NoConcatSupport { + this: SqlIdiom => + + override def concatFunction = Messages.fail( + s"`concatMap` not supported by ${this.getClass.getSimpleName}" + ) +} diff --git a/src/main/scala/minisql/context/sql/OnConflictSupport.scala b/src/main/scala/minisql/context/sql/OnConflictSupport.scala new file mode 100644 index 0000000..940d5bf --- /dev/null +++ b/src/main/scala/minisql/context/sql/OnConflictSupport.scala @@ -0,0 +1,70 @@ +package minisql.context.sql.idiom + +import minisql.ast._ +import minisql.idiom.StatementInterpolator._ +import minisql.idiom.Token +import minisql.NamingStrategy +import minisql.util.Messages.fail + +trait OnConflictSupport { + self: SqlIdiom => + + implicit def conflictTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[OnConflict] = { + + val customEntityTokenizer = Tokenizer[Entity] { + case Entity.Opinionated(name, _, renameable) => + stmt"INTO ${renameable.fixedOr(name.token)(strategy.table(name).token)} AS t" + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](self.astTokenizer(_, strategy)) { + case _: OnConflict.Excluded => stmt"EXCLUDED" + case OnConflict.Existing(a) => stmt"${a.token}" + case a: Action => + self + .actionTokenizer(customEntityTokenizer)( + actionAstTokenizer, + strategy + ) + .token(a) + } + + import OnConflict._ + + def doUpdateStmt(i: Token, t: Token, u: Update) = { + val assignments = u.assignments + .map(a => + stmt"${actionAstTokenizer.token(a.property)} = ${scopedTokenizer(a.value)(customAstTokenizer)}" + ) + .mkStmt() + + stmt"$i ON CONFLICT $t DO UPDATE SET $assignments" + } + + def doNothingStmt(i: Ast, t: Token) = + stmt"${i.token} ON CONFLICT $t DO NOTHING" + + implicit val conflictTargetPropsTokenizer: Tokenizer[Properties] = + Tokenizer[Properties] { + case OnConflict.Properties(props) => + stmt"(${props.map(n => n.renameable.fixedOr(n.name)(strategy.column(n.name))).mkStmt(",")})" + } + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[OnConflict] { + case OnConflict(_, NoTarget, _: Update) => + fail("'DO UPDATE' statement requires explicit conflict target") + case OnConflict(i, p: Properties, u: Update) => + doUpdateStmt(i.token, p.token, u) + + case OnConflict(i, NoTarget, Ignore) => + stmt"${astTokenizer.token(i)} ON CONFLICT DO NOTHING" + case OnConflict(i, p: Properties, Ignore) => doNothingStmt(i, p.token) + } + + tokenizer(customAstTokenizer) + } +} diff --git a/src/main/scala/minisql/context/sql/PositionalBindVariables.scala b/src/main/scala/minisql/context/sql/PositionalBindVariables.scala new file mode 100644 index 0000000..fcd42ea --- /dev/null +++ b/src/main/scala/minisql/context/sql/PositionalBindVariables.scala @@ -0,0 +1,6 @@ +package minisql.context.sql.idiom + +trait PositionalBindVariables { self: SqlIdiom => + + override def liftingPlaceholder(index: Int): String = s"$$${index + 1}" +} diff --git a/src/main/scala/minisql/context/sql/QuestionMarkBindVariables.scala b/src/main/scala/minisql/context/sql/QuestionMarkBindVariables.scala new file mode 100644 index 0000000..f7ccf27 --- /dev/null +++ b/src/main/scala/minisql/context/sql/QuestionMarkBindVariables.scala @@ -0,0 +1,6 @@ +package minisql.context.sql.idiom + +trait QuestionMarkBindVariables { self: SqlIdiom => + + override def liftingPlaceholder(index: Int): String = s"?" +} diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala new file mode 100644 index 0000000..c593153 --- /dev/null +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -0,0 +1,700 @@ +package minisql.context.sql.idiom + +import minisql.ast._ +import minisql.ast.BooleanOperator._ +import minisql.ast.Lift +import minisql.context.sql._ +import minisql.context.sql.norm._ +import minisql.idiom._ +import minisql.idiom.StatementInterpolator._ +import minisql.NamingStrategy +import minisql.ast.Renameable.Fixed +import minisql.ast.Visibility.Hidden +import minisql.context.{ReturningCapability, ReturningClauseSupported} +import minisql.util.Interleave +import minisql.util.Messages.{fail, trace} +import minisql.idiom.Token +import minisql.norm.EqualityBehavior +import minisql.norm.ConcatBehavior +import minisql.norm.ConcatBehavior.AnsiConcat +import minisql.norm.EqualityBehavior.AnsiEquality +import minisql.norm.ExpandReturning + +trait SqlIdiom extends Idiom { + + override def prepareForProbing(string: String): String + + protected def concatBehavior: ConcatBehavior = AnsiConcat + protected def equalityBehavior: EqualityBehavior = AnsiEquality + + protected def actionAlias: Option[Ident] = None + + override def format(queryString: String): String = queryString + + def querifyAst(ast: Ast) = SqlQuery(ast) + + private def doTranslate(ast: Ast, cached: Boolean)(implicit + naming: NamingStrategy + ): (Ast, Statement) = { + val normalizedAst = + SqlNormalize(ast, concatBehavior, equalityBehavior) + + implicit val tokernizer: Tokenizer[Ast] = defaultTokenizer + + val token = + normalizedAst match { + case q: Query => + val sql = querifyAst(q) + trace("sql")(sql) + VerifySqlQuery(sql).map(fail) + val expanded = new ExpandNestedQueries(naming)(sql, List()) + trace("expanded sql")(expanded) + val tokenized = expanded.token + trace("tokenized sql")(tokenized) + tokenized + case other => + other.token + } + + (normalizedAst, stmt"$token") + } + + override def translate( + ast: Ast + )(implicit naming: NamingStrategy): (Ast, Statement) = { + doTranslate(ast, false) + } + + def defaultTokenizer(implicit naming: NamingStrategy): Tokenizer[Ast] = + new Tokenizer[Ast] { + private val stableTokenizer = astTokenizer(this, naming) + + extension (v: Ast) { + def token = stableTokenizer.token(v) + } + + } + + def astTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Ast] = + Tokenizer[Ast] { + case a: Query => SqlQuery(a).token + case a: Operation => a.token + case a: Infix => a.token + case a: Action => a.token + case a: Ident => a.token + case a: ExternalIdent => a.token + case a: Property => a.token + case a: Value => a.token + case a: If => a.token + case a: Lift => a.token + case a: Assignment => a.token + case a: OptionOperation => a.token + case a @ ( + _: Function | _: FunctionApply | _: Dynamic | _: OptionOperation | + _: Block | _: Val | _: Ordering | _: IterableOperation | + _: OnConflict.Excluded | _: OnConflict.Existing + ) => + fail(s"Malformed or unsupported construct: $a.") + } + + implicit def ifTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[If] = Tokenizer[If] { + case ast: If => + def flatten(ast: Ast): (List[(Ast, Ast)], Ast) = + ast match { + case If(cond, a, b) => + val (l, e) = flatten(b) + ((cond, a) +: l, e) + case other => + (List(), other) + } + + val (l, e) = flatten(ast) + val conditions = + for ((cond, body) <- l) yield { + stmt"WHEN ${cond.token} THEN ${body.token}" + } + stmt"CASE ${conditions.mkStmt(" ")} ELSE ${e.token} END" + } + + def concatFunction: String + + protected def tokenizeGroupBy( + values: Ast + )(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy): Token = + values.token + + protected class FlattenSqlQueryTokenizerHelper(q: FlattenSqlQuery)(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ) { + import q._ + + def selectTokenizer = + select match { + case Nil => stmt"*" + case _ => select.token + } + + def distinctTokenizer = ( + distinct match { + case DistinctKind.Distinct => stmt"DISTINCT " + case DistinctKind.DistinctOn(props) => + stmt"DISTINCT ON (${props.token}) " + case DistinctKind.None => stmt"" + } + ) + + def withDistinct = stmt"$distinctTokenizer${selectTokenizer}" + + def withFrom = + from match { + case Nil => withDistinct + case head :: tail => + val t = tail.foldLeft(stmt"${head.token}") { + case (a, b: FlatJoinContext) => + stmt"$a ${(b: FromContext).token}" + case (a, b) => + stmt"$a, ${b.token}" + } + + stmt"$withDistinct FROM $t" + } + + def withWhere = + where match { + case None => withFrom + case Some(where) => stmt"$withFrom WHERE ${where.token}" + } + def withGroupBy = + groupBy match { + case None => withWhere + case Some(groupBy) => + stmt"$withWhere GROUP BY ${tokenizeGroupBy(groupBy)}" + } + def withOrderBy = + orderBy match { + case Nil => withGroupBy + case orderBy => stmt"$withGroupBy ${tokenOrderBy(orderBy)}" + } + def withLimitOffset = limitOffsetToken(withOrderBy).token((limit, offset)) + + def apply = stmt"SELECT $withLimitOffset" + } + + implicit def sqlQueryTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[SqlQuery] = Tokenizer[SqlQuery] { + case q: FlattenSqlQuery => + new FlattenSqlQueryTokenizerHelper(q).apply + case SetOperationSqlQuery(a, op, b) => + stmt"(${a.token}) ${op.token} (${b.token})" + case UnaryOperationSqlQuery(op, q) => + stmt"SELECT ${op.token} (${q.token})" + } + + protected def tokenizeColumn( + strategy: NamingStrategy, + column: String, + renameable: Renameable + ) = + renameable match { + case Fixed => column + case _ => strategy.column(column) + } + + protected def tokenizeTable( + strategy: NamingStrategy, + table: String, + renameable: Renameable + ) = + renameable match { + case Fixed => table + case _ => strategy.table(table) + } + + protected def tokenizeAlias(strategy: NamingStrategy, table: String) = + strategy.default(table) + + implicit def selectValueTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[SelectValue] = { + + def tokenizer(implicit astTokenizer: Tokenizer[Ast]) = + Tokenizer[SelectValue] { + case SelectValue(ast, Some(alias), false) => { + stmt"${ast.token} AS ${alias.token}" + } + case SelectValue(ast, Some(alias), true) => + stmt"${concatFunction.token}(${ast.token}) AS ${alias.token}" + case selectValue => + val value = + selectValue match { + case SelectValue(Ident("?"), _, _) => "?".token + case SelectValue(Ident(name), _, _) => + stmt"${strategy.default(name).token}.*" + case SelectValue(ast, _, _) => ast.token + } + selectValue.concat match { + case true => stmt"${concatFunction.token}(${value.token})" + case false => value + } + } + + val customAstTokenizer = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case Aggregation(op, Ident(_) | Tuple(_)) => stmt"${op.token}(*)" + case Aggregation(op, Distinct(ast)) => + stmt"${op.token}(DISTINCT ${ast.token})" + case ast @ Aggregation(op, _: Query) => scopedTokenizer(ast) + case Aggregation(op, ast) => stmt"${op.token}(${ast.token})" + } + + tokenizer(customAstTokenizer) + } + + implicit def operationTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Operation] = Tokenizer[Operation] { + case UnaryOperation(op, ast) => stmt"${op.token} (${ast.token})" + case BinaryOperation(a, EqualityOperator.`==`, NullValue) => + stmt"${scopedTokenizer(a)} IS NULL" + case BinaryOperation(NullValue, EqualityOperator.`==`, b) => + stmt"${scopedTokenizer(b)} IS NULL" + case BinaryOperation(a, EqualityOperator.`!=`, NullValue) => + stmt"${scopedTokenizer(a)} IS NOT NULL" + case BinaryOperation(NullValue, EqualityOperator.`!=`, b) => + stmt"${scopedTokenizer(b)} IS NOT NULL" + case BinaryOperation(a, StringOperator.`startsWith`, b) => + stmt"${scopedTokenizer(a)} LIKE (${(BinaryOperation(b, StringOperator.`concat`, Constant("%")): Ast).token})" + case BinaryOperation(a, op @ StringOperator.`split`, b) => + stmt"${op.token}(${scopedTokenizer(a)}, ${scopedTokenizer(b)})" + case BinaryOperation(a, op @ SetOperator.`contains`, b) => + SetContainsToken(scopedTokenizer(b), op.token, a.token) + case BinaryOperation(a, op @ `&&`, b) => + (a, b) match { + case (BinaryOperation(_, `||`, _), BinaryOperation(_, `||`, _)) => + stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}" + case (BinaryOperation(_, `||`, _), _) => + stmt"${scopedTokenizer(a)} ${op.token} ${b.token}" + case (_, BinaryOperation(_, `||`, _)) => + stmt"${a.token} ${op.token} ${scopedTokenizer(b)}" + case _ => stmt"${a.token} ${op.token} ${b.token}" + } + case BinaryOperation(a, op @ `||`, b) => + stmt"${a.token} ${op.token} ${b.token}" + case BinaryOperation(a, op, b) => + stmt"${scopedTokenizer(a)} ${op.token} ${scopedTokenizer(b)}" + case e: FunctionApply => fail(s"Can't translate the ast to sql: '$e'") + } + + implicit def optionOperationTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[OptionOperation] = Tokenizer[OptionOperation] { + case OptionIsEmpty(ast) => stmt"${ast.token} IS NULL" + case OptionNonEmpty(ast) => stmt"${ast.token} IS NOT NULL" + case OptionIsDefined(ast) => stmt"${ast.token} IS NOT NULL" + case other => fail(s"Malformed or unsupported construct: $other.") + } + + implicit val setOperationTokenizer: Tokenizer[SetOperation] = + Tokenizer[SetOperation] { + case UnionOperation => stmt"UNION" + case UnionAllOperation => stmt"UNION ALL" + } + + protected def limitOffsetToken( + query: Statement + )(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = + Tokenizer[(Option[Ast], Option[Ast])] { + case (None, None) => query + case (Some(limit), None) => stmt"$query LIMIT ${limit.token}" + case (Some(limit), Some(offset)) => + stmt"$query LIMIT ${limit.token} OFFSET ${offset.token}" + case (None, Some(offset)) => stmt"$query OFFSET ${offset.token}" + } + + protected def tokenOrderBy( + criterias: List[OrderByCriteria] + )(implicit astTokenizer: Tokenizer[Ast], strategy: NamingStrategy) = + stmt"ORDER BY ${criterias.token}" + + implicit def sourceTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[FromContext] = Tokenizer[FromContext] { + case TableContext(name, alias) => + stmt"${name.token} ${tokenizeAlias(strategy, alias).token}" + case QueryContext(query, alias) => + stmt"(${query.token}) AS ${tokenizeAlias(strategy, alias).token}" + case InfixContext(infix, alias) if infix.noParen => + stmt"${(infix: Ast).token} AS ${strategy.default(alias).token}" + case InfixContext(infix, alias) => + stmt"(${(infix: Ast).token}) AS ${strategy.default(alias).token}" + case JoinContext(t, a, b, on) => + stmt"${a.token} ${t.token} ${b.token} ON ${on.token}" + case FlatJoinContext(t, a, on) => stmt"${t.token} ${a.token} ON ${on.token}" + } + + implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { + case InnerJoin => stmt"INNER JOIN" + case LeftJoin => stmt"LEFT JOIN" + case RightJoin => stmt"RIGHT JOIN" + case FullJoin => stmt"FULL JOIN" + } + + implicit def orderByCriteriaTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[OrderByCriteria] = Tokenizer[OrderByCriteria] { + case OrderByCriteria(ast, Asc) => stmt"${scopedTokenizer(ast)} ASC" + case OrderByCriteria(ast, Desc) => stmt"${scopedTokenizer(ast)} DESC" + case OrderByCriteria(ast, AscNullsFirst) => + stmt"${scopedTokenizer(ast)} ASC NULLS FIRST" + case OrderByCriteria(ast, DescNullsFirst) => + stmt"${scopedTokenizer(ast)} DESC NULLS FIRST" + case OrderByCriteria(ast, AscNullsLast) => + stmt"${scopedTokenizer(ast)} ASC NULLS LAST" + case OrderByCriteria(ast, DescNullsLast) => + stmt"${scopedTokenizer(ast)} DESC NULLS LAST" + } + + implicit val unaryOperatorTokenizer: Tokenizer[UnaryOperator] = + Tokenizer[UnaryOperator] { + case NumericOperator.`-` => stmt"-" + case BooleanOperator.`!` => stmt"NOT" + case StringOperator.`toUpperCase` => stmt"UPPER" + case StringOperator.`toLowerCase` => stmt"LOWER" + case StringOperator.`toLong` => stmt"" // cast is implicit + case StringOperator.`toInt` => stmt"" // cast is implicit + case SetOperator.`isEmpty` => stmt"NOT EXISTS" + case SetOperator.`nonEmpty` => stmt"EXISTS" + } + + implicit val aggregationOperatorTokenizer: Tokenizer[AggregationOperator] = + Tokenizer[AggregationOperator] { + case AggregationOperator.`min` => stmt"MIN" + case AggregationOperator.`max` => stmt"MAX" + case AggregationOperator.`avg` => stmt"AVG" + case AggregationOperator.`sum` => stmt"SUM" + case AggregationOperator.`size` => stmt"COUNT" + } + + implicit val binaryOperatorTokenizer: Tokenizer[BinaryOperator] = + Tokenizer[BinaryOperator] { + case EqualityOperator.`==` => stmt"=" + case EqualityOperator.`!=` => stmt"<>" + case BooleanOperator.`&&` => stmt"AND" + case BooleanOperator.`||` => stmt"OR" + case StringOperator.`concat` => stmt"||" + case StringOperator.`startsWith` => + fail("bug: this code should be unreachable") + case StringOperator.`split` => stmt"SPLIT" + case NumericOperator.`-` => stmt"-" + case NumericOperator.`+` => stmt"+" + case NumericOperator.`*` => stmt"*" + case NumericOperator.`>` => stmt">" + case NumericOperator.`>=` => stmt">=" + case NumericOperator.`<` => stmt"<" + case NumericOperator.`<=` => stmt"<=" + case NumericOperator.`/` => stmt"/" + case NumericOperator.`%` => stmt"%" + case SetOperator.`contains` => stmt"IN" + } + + implicit def propertyTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Property] = { + + def unnest(ast: Ast): (Ast, List[String]) = + ast match { + case Property.Opinionated(a, _, _, Hidden) => + unnest(a) match { + case (a, nestedName) => (a, nestedName) + } + // Append the property name. This includes tuple indexes. + case Property(a, name) => + unnest(a) match { + case (ast, nestedName) => + (ast, nestedName :+ name) + } + case a => (a, Nil) + } + + def tokenizePrefixedProperty( + name: String, + prefix: List[String], + strategy: NamingStrategy, + renameable: Renameable + ) = + renameable.fixedOr( + (prefix.mkString + name).token + )(tokenizeColumn(strategy, prefix.mkString + name, renameable).token) + + Tokenizer[Property] { + case Property.Opinionated( + ast, + name, + renameable, + _ /* Top level property cannot be invisible */ + ) => + // When we have things like Embedded tables, properties inside of one another needs to be un-nested. + // E.g. in `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` the inner + // property needs to be unwrapped and the result of this should only be `realTable.realPropertyAlias` + // as opposed to `realTable.embeddedTableAlias.realPropertyAlias`. + unnest(ast) match { + // When using ExternalIdent such as .returning(eid => eid.idColumn) clauses drop the 'eid' since SQL + // returning clauses have no alias for the original table. I.e. INSERT [...] RETURNING idColumn there's no + // alias you can assign to the INSERT [...] clause that can be used as a prefix to 'idColumn'. + // In this case, `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` + // should just be `realPropertyAlias` as opposed to `realTable.realPropertyAlias`. + // The exception to this is when a Query inside of a RETURNING clause is used. In that case, assume + // that there is an alias for the inserted table (i.e. `INSERT ... as theAlias values ... RETURNING`) + // and the instances of ExternalIdent use it. + case (ExternalIdent(_), prefix) => + stmt"${actionAlias + .map(alias => stmt"${scopedTokenizer(alias)}.") + .getOrElse(stmt"")}${tokenizePrefixedProperty(name, prefix, strategy, renameable)}" + + // In the rare case that the Ident is invisible, do not show it. See the Ident documentation for more info. + case (Ident.Opinionated(_, Hidden), prefix) => + stmt"${tokenizePrefixedProperty(name, prefix, strategy, renameable)}" + + // The normal case where `Property(Property(Ident("realTable"), embeddedTableAlias), realPropertyAlias)` + // becomes `realTable.realPropertyAlias`. + case (ast, prefix) => + stmt"${scopedTokenizer(ast)}.${tokenizePrefixedProperty(name, prefix, strategy, renameable)}" + } + } + } + + implicit def valueTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Value] = Tokenizer[Value] { + case Constant(v: String) => stmt"'${v.token}'" + case Constant(()) => stmt"1" + case Constant(v) => stmt"${v.toString.token}" + case NullValue => stmt"null" + case Tuple(values) => stmt"${values.token}" + case CaseClass(values) => stmt"${values.map(_._2).token}" + } + + implicit def infixTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Infix] = Tokenizer[Infix] { + case Infix(parts, params, _, _) => + val pt = parts.map(_.token) + val pr = params.map(_.token) + Statement(Interleave(pt, pr)) + } + + implicit def identTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Ident] = + Tokenizer[Ident](e => strategy.default(e.name).token) + + implicit def externalIdentTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[ExternalIdent] = + Tokenizer[ExternalIdent](e => strategy.default(e.name).token) + + implicit def assignmentTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Assignment] = Tokenizer[Assignment] { + case Assignment(alias, prop, value) => + stmt"${prop.token} = ${scopedTokenizer(value)}" + } + + implicit def defaultAstTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Action] = { + val insertEntityTokenizer = Tokenizer[Entity] { + case Entity.Opinionated(name, _, renameable) => + stmt"INTO ${tokenizeTable(strategy, name, renameable).token}" + } + actionTokenizer(insertEntityTokenizer)(actionAstTokenizer, strategy) + } + + protected def actionAstTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ) = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case q: Query => astTokenizer.token(q) + case Property(Property.Opinionated(_, name, renameable, _), "isEmpty") => + stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NULL" + case Property( + Property.Opinionated(_, name, renameable, _), + "isDefined" + ) => + stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL" + case Property(Property.Opinionated(_, name, renameable, _), "nonEmpty") => + stmt"${renameable.fixedOr(name)(tokenizeColumn(strategy, name, renameable)).token} IS NOT NULL" + case Property.Opinionated(_, name, renameable, _) => + renameable.fixedOr(name.token)( + tokenizeColumn(strategy, name, renameable).token + ) + } + + def returnListTokenizer(implicit + tokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[List[Ast]] = { + val customAstTokenizer = + Tokenizer.withFallback[Ast](SqlIdiom.this.astTokenizer(_, strategy)) { + case sq: Query => + stmt"(${tokenizer.token(sq)})" + } + + Tokenizer[List[Ast]] { + case list => + list.mkStmt(", ")(customAstTokenizer) + } + } + + protected def actionTokenizer( + insertEntityTokenizer: Tokenizer[Entity] + )(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Action] = + Tokenizer[Action] { + + case Insert(entity: Entity, assignments) => + val table = insertEntityTokenizer.token(entity) + val columns = assignments.map(_.property.token) + val values = assignments.map(_.value) + stmt"INSERT $table${actionAlias.map(alias => stmt" AS ${alias.token}").getOrElse(stmt"")} (${columns + .mkStmt(",")}) VALUES (${values.map(scopedTokenizer(_)).mkStmt(", ")})" + + case Update(table: Entity, assignments) => + stmt"UPDATE ${table.token}${actionAlias + .map(alias => stmt" AS ${alias.token}") + .getOrElse(stmt"")} SET ${assignments.token}" + + case Update(Filter(table: Entity, x, where), assignments) => + stmt"UPDATE ${table.token}${actionAlias + .map(alias => stmt" AS ${alias.token}") + .getOrElse(stmt"")} SET ${assignments.token} WHERE ${where.token}" + + case Delete(Filter(table: Entity, x, where)) => + stmt"DELETE FROM ${table.token} WHERE ${where.token}" + + case Delete(table: Entity) => + stmt"DELETE FROM ${table.token}" + + case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) => + idiomReturningCapability match { + // If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as + // a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5) + case ReturningClauseSupported + if (CollectAst.byType[Entity](prop).nonEmpty) => + SqlIdiom.withActionAlias(this, r) + case ReturningClauseSupported => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer + .token(ExpandReturning(r)(this, strategy).map(_._1))}" + case other => + stmt"INSERT INTO ${table.token} ${defaultAutoGeneratedToken(prop.token)}" + } + + case r @ ReturningAction(action, alias, prop) => + idiomReturningCapability match { + // If there are queries inside of the returning clause we are forced to alias the inserted table (see #1509). Only do this as + // a last resort since it is not even supported in all Postgres versions (i.e. only after 9.5) + case ReturningClauseSupported + if (CollectAst.byType[Entity](prop).nonEmpty) => + SqlIdiom.withActionAlias(this, r) + case ReturningClauseSupported => + stmt"${action.token} RETURNING ${returnListTokenizer.token( + ExpandReturning(r)(this, strategy).map(_._1) + )}" + case other => + stmt"${action.token}" + } + + case other => + fail(s"Action ast can't be translated to sql: '$other'") + } + + implicit def entityTokenizer(implicit + astTokenizer: Tokenizer[Ast], + strategy: NamingStrategy + ): Tokenizer[Entity] = Tokenizer[Entity] { + case Entity.Opinionated(name, _, renameable) => + tokenizeTable(strategy, name, renameable).token + } + + protected def scopedTokenizer(ast: Ast)(implicit tokenizer: Tokenizer[Ast]) = + ast match { + case _: Query => stmt"(${ast.token})" + case _: BinaryOperation => stmt"(${ast.token})" + case _: Tuple => stmt"(${ast.token})" + case _ => ast.token + } +} + +object SqlIdiom { + private[minisql] def copyIdiom( + parent: SqlIdiom, + newActionAlias: Option[Ident] + ) = + new SqlIdiom { + override protected def actionAlias: Option[Ident] = newActionAlias + override def prepareForProbing(string: String): String = + parent.prepareForProbing(string) + override def concatFunction: String = parent.concatFunction + override def liftingPlaceholder(index: Int): String = + parent.liftingPlaceholder(index) + override def idiomReturningCapability: ReturningCapability = + parent.idiomReturningCapability + } + + /** + * Construct a new instance of the specified idiom with `newActionAlias` + * variable specified so that actions (i.e. insert, and update) will be + * rendered with the specified alias. This is needed for RETURNING clauses + * that have queries inside. See #1509 for details. + */ + private[minisql] def withActionAlias( + parentIdiom: SqlIdiom, + query: ReturningAction + )(implicit strategy: NamingStrategy) = { + val idiom = copyIdiom(parentIdiom, Some(query.alias)) + import idiom._ + + implicit val stableTokenizer: Tokenizer[Ast] = idiom.astTokenizer( + new Tokenizer[Ast] { self => + extension (v: Ast) { + def token = astTokenizer(self, strategy).token(v) + } + }, + strategy + ) + + query match { + case r @ ReturningAction(Insert(table: Entity, Nil), alias, prop) => + stmt"INSERT INTO ${table.token} AS ${alias.name.token} ${defaultAutoGeneratedToken(prop.token)} RETURNING ${returnListTokenizer + .token(ExpandReturning(r)(idiom, strategy).map(_._1))}" + case r @ ReturningAction(action, alias, prop) => + stmt"${action.token} RETURNING ${returnListTokenizer.token( + ExpandReturning(r)(idiom, strategy).map(_._1) + )}" + } + } +} diff --git a/src/main/scala/minisql/context/sql/SqlQuery.scala b/src/main/scala/minisql/context/sql/SqlQuery.scala new file mode 100644 index 0000000..06ec412 --- /dev/null +++ b/src/main/scala/minisql/context/sql/SqlQuery.scala @@ -0,0 +1,326 @@ +package minisql.context.sql + +import minisql.ast._ +import minisql.context.sql.norm.FlattenGroupByAggregation +import minisql.norm.BetaReduction +import minisql.util.Messages.fail +import minisql.{Literal, PseudoAst, NamingStrategy} + +case class OrderByCriteria(ast: Ast, ordering: PropertyOrdering) + +sealed trait FromContext +case class TableContext(entity: Entity, alias: String) extends FromContext +case class QueryContext(query: SqlQuery, alias: String) extends FromContext +case class InfixContext(infix: Infix, alias: String) extends FromContext +case class JoinContext(t: JoinType, a: FromContext, b: FromContext, on: Ast) + extends FromContext +case class FlatJoinContext(t: JoinType, a: FromContext, on: Ast) + extends FromContext + +sealed trait SqlQuery { + override def toString = { + import minisql.MirrorSqlDialect._ + import minisql.idiom.StatementInterpolator.* + given Tokenizer[SqlQuery] = sqlQueryTokenizer(using + defaultTokenizer(using Literal), + Literal + ) + summon[Tokenizer[SqlQuery]].token(this).toString() + } +} + +sealed trait SetOperation +case object UnionOperation extends SetOperation +case object UnionAllOperation extends SetOperation + +sealed trait DistinctKind { def isDistinct: Boolean } +case object DistinctKind { + case object Distinct extends DistinctKind { val isDistinct: Boolean = true } + case class DistinctOn(props: List[Ast]) extends DistinctKind { + val isDistinct: Boolean = true + } + case object None extends DistinctKind { val isDistinct: Boolean = false } +} + +case class SetOperationSqlQuery( + a: SqlQuery, + op: SetOperation, + b: SqlQuery +) extends SqlQuery + +case class UnaryOperationSqlQuery( + op: UnaryOperator, + q: SqlQuery +) extends SqlQuery + +case class SelectValue( + ast: Ast, + alias: Option[String] = None, + concat: Boolean = false +) extends PseudoAst { + override def toString: String = + s"${ast.toString}${alias.map("->" + _).getOrElse("")}" +} + +case class FlattenSqlQuery( + from: List[FromContext] = List(), + where: Option[Ast] = None, + groupBy: Option[Ast] = None, + orderBy: List[OrderByCriteria] = Nil, + limit: Option[Ast] = None, + offset: Option[Ast] = None, + select: List[SelectValue], + distinct: DistinctKind = DistinctKind.None +) extends SqlQuery + +object TakeDropFlatten { + def unapply(q: Query): Option[(Query, Option[Ast], Option[Ast])] = q match { + case Take(q: FlatMap, n) => Some((q, Some(n), None)) + case Drop(q: FlatMap, n) => Some((q, None, Some(n))) + case _ => None + } +} + +object SqlQuery { + + def apply(query: Ast): SqlQuery = + query match { + case Union(a, b) => + SetOperationSqlQuery(apply(a), UnionOperation, apply(b)) + case UnionAll(a, b) => + SetOperationSqlQuery(apply(a), UnionAllOperation, apply(b)) + case UnaryOperation(op, q: Query) => UnaryOperationSqlQuery(op, apply(q)) + case _: Operation | _: Value => + FlattenSqlQuery(select = List(SelectValue(query))) + case Map(q, a, b) if a == b => apply(q) + case TakeDropFlatten(q, limit, offset) => + flatten(q, "x").copy(limit = limit, offset = offset) + case q: Query => flatten(q, "x") + case infix: Infix => flatten(infix, "x") + case other => + fail( + s"Query not properly normalized. Please open a bug report. Ast: '$other'" + ) + } + + private def flatten(query: Ast, alias: String): FlattenSqlQuery = { + val (sources, finalFlatMapBody) = flattenContexts(query) + flatten(sources, finalFlatMapBody, alias) + } + + private def flattenContexts(query: Ast): (List[FromContext], Ast) = + query match { + case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Query) => + val source = this.source(q, alias) + val (nestedContexts, finalFlatMapBody) = flattenContexts(p) + (source +: nestedContexts, finalFlatMapBody) + case FlatMap(q @ (_: Query | _: Infix), Ident(alias), p: Infix) => + fail(s"Infix can't be use as a `flatMap` body. $query") + case other => + (List.empty, other) + } + + object NestedNest { + def unapply(q: Ast): Option[Ast] = + q match { + case _: Nested => recurse(q) + case _ => None + } + + private def recurse(q: Ast): Option[Ast] = + q match { + case Nested(qn) => recurse(qn) + case other => Some(other) + } + } + + private def flatten( + sources: List[FromContext], + finalFlatMapBody: Ast, + alias: String + ): FlattenSqlQuery = { + + def select(alias: String) = SelectValue(Ident(alias), None) :: Nil + + def base(q: Ast, alias: String) = { + def nest(ctx: FromContext) = + FlattenSqlQuery(from = sources :+ ctx, select = select(alias)) + q match { + case Map(_: GroupBy, _, _) => nest(source(q, alias)) + case NestedNest(q) => nest(QueryContext(apply(q), alias)) + case q: ConcatMap => nest(QueryContext(apply(q), alias)) + case Join(tpe, a, b, iA, iB, on) => + val ctx = source(q, alias) + def aliases(ctx: FromContext): List[String] = + ctx match { + case TableContext(_, alias) => alias :: Nil + case QueryContext(_, alias) => alias :: Nil + case InfixContext(_, alias) => alias :: Nil + case JoinContext(_, a, b, _) => aliases(a) ::: aliases(b) + case FlatJoinContext(_, a, _) => aliases(a) + } + FlattenSqlQuery( + from = ctx :: Nil, + select = aliases(ctx).map(a => SelectValue(Ident(a), None)) + ) + case q @ (_: Map | _: Filter | _: Entity) => flatten(sources, q, alias) + case q if (sources == Nil) => flatten(sources, q, alias) + case other => nest(source(q, alias)) + } + } + + finalFlatMapBody match { + + case ConcatMap(q, Ident(alias), p) => + FlattenSqlQuery( + from = source(q, alias) :: Nil, + select = selectValues(p).map(_.copy(concat = true)) + ) + + case Map(GroupBy(q, x @ Ident(alias), g), a, p) => + val b = base(q, alias) + val select = BetaReduction(p, a -> Tuple(List(g, x))) + val flattenSelect = FlattenGroupByAggregation(x)(select) + b.copy(groupBy = Some(g), select = this.selectValues(flattenSelect)) + + case GroupBy(q, Ident(alias), p) => + fail("A `groupBy` clause must be followed by `map`.") + + case Map(q, Ident(alias), p) => + val b = base(q, alias) + val agg = b.select.collect { + case s @ SelectValue(_: Aggregation, _, _) => s + } + if (!b.distinct.isDistinct && agg.isEmpty) + b.copy(select = selectValues(p)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + select = selectValues(p) + ) + + case Filter(q, Ident(alias), p) => + val b = base(q, alias) + if (b.where.isEmpty) + b.copy(where = Some(p)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + where = Some(p), + select = select(alias) + ) + + case SortBy(q, Ident(alias), p, o) => + val b = base(q, alias) + val criterias = orderByCriterias(p, o) + if (b.orderBy.isEmpty) + b.copy(orderBy = criterias) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + orderBy = criterias, + select = select(alias) + ) + + case Aggregation(op, q: Query) => + val b = flatten(q, alias) + b.select match { + case head :: Nil if !b.distinct.isDistinct => + b.copy(select = List(head.copy(ast = Aggregation(op, head.ast)))) + case other => + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + select = List(SelectValue(Aggregation(op, Ident("*")))) + ) + } + + case Take(q, n) => + val b = base(q, alias) + if (b.limit.isEmpty) + b.copy(limit = Some(n)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + limit = Some(n), + select = select(alias) + ) + + case Drop(q, n) => + val b = base(q, alias) + if (b.offset.isEmpty && b.limit.isEmpty) + b.copy(offset = Some(n)) + else + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + offset = Some(n), + select = select(alias) + ) + + case Distinct(q: Query) => + val b = base(q, alias) + b.copy(distinct = DistinctKind.Distinct) + + case DistinctOn(q, Ident(alias), fields) => + val distinctList = + fields match { + case Tuple(values) => values + case other => List(other) + } + + q match { + // Ideally we don't need to make an extra sub-query for every single case of + // distinct-on but it only works when the parent AST is an entity. That's because DistinctOn + // selects from an alias of an outer clause. For example, query[Person].map(p => Name(p.firstName, p.lastName)).distinctOn(_.name) + // (Let's say Person(firstName, lastName, age), Name(first, last)) will turn into + // SELECT DISTINCT ON (p.name), p.firstName AS first, p.lastName AS last, p.age FROM Person + // This doesn't work beause `name` in `p.name` doesn't exist yet. Therefore we have to nest this in a subquery: + // SELECT DISTINCT ON (p.name) FROM (SELECT p.firstName AS first, p.lastName AS last, p.age FROM Person p) AS p + // The only exception to this is if we are directly selecting from an entity: + // query[Person].distinctOn(_.firstName) which should be fine: SELECT (x.firstName), x.firstName, x.lastName, a.age FROM Person x + // since all the fields inside the (...) of the DISTINCT ON must be contained in the entity. + case _: Entity => + val b = base(q, alias) + b.copy(distinct = DistinctKind.DistinctOn(distinctList)) + case _ => + FlattenSqlQuery( + from = QueryContext(apply(q), alias) :: Nil, + select = select(alias), + distinct = DistinctKind.DistinctOn(distinctList) + ) + } + + case other => + FlattenSqlQuery( + from = sources :+ source(other, alias), + select = select(alias) + ) + } + } + + private def selectValues(ast: Ast) = + ast match { + case Tuple(values) => values.map(SelectValue(_)) + case other => SelectValue(ast) :: Nil + } + + private def source(ast: Ast, alias: String): FromContext = + ast match { + case entity: Entity => TableContext(entity, alias) + case infix: Infix => InfixContext(infix, alias) + case Join(t, a, b, ia, ib, on) => + JoinContext(t, source(a, ia.name), source(b, ib.name), on) + case FlatJoin(t, a, ia, on) => FlatJoinContext(t, source(a, ia.name), on) + case Nested(q) => QueryContext(apply(q), alias) + case other => QueryContext(apply(other), alias) + } + + private def orderByCriterias(ast: Ast, ordering: Ast): List[OrderByCriteria] = + (ast, ordering) match { + case (Tuple(properties), ord: PropertyOrdering) => + properties.flatMap(orderByCriterias(_, ord)) + case (Tuple(properties), TupleOrdering(ord)) => + properties.zip(ord).flatMap { case (a, o) => orderByCriterias(a, o) } + case (a, o: PropertyOrdering) => List(OrderByCriteria(a, o)) + case other => fail(s"Invalid order by criteria $ast") + } +} diff --git a/src/main/scala/minisql/context/sql/VerifySqlQuery.scala b/src/main/scala/minisql/context/sql/VerifySqlQuery.scala new file mode 100644 index 0000000..82a3d59 --- /dev/null +++ b/src/main/scala/minisql/context/sql/VerifySqlQuery.scala @@ -0,0 +1,122 @@ +package minisql.context.sql.idiom + +import minisql.ast._ +import minisql.context.sql._ +import minisql.norm.FreeVariables + +case class Error(free: List[Ident], ast: Ast) +case class InvalidSqlQuery(errors: List[Error]) { + override def toString = + s"The monad composition can't be expressed using applicative joins. " + + errors + .map(error => + s"Faulty expression: '${error.ast}'. Free variables: '${error.free}'." + ) + .mkString(", ") +} + +object VerifySqlQuery { + + def apply(query: SqlQuery): Option[String] = + verify(query).map(_.toString) + + private def verify(query: SqlQuery): Option[InvalidSqlQuery] = + query match { + case q: FlattenSqlQuery => verify(q) + case SetOperationSqlQuery(a, op, b) => verify(a).orElse(verify(b)) + case UnaryOperationSqlQuery(op, q) => verify(q) + } + + private def verifyFlatJoins(q: FlattenSqlQuery) = { + + def loop(l: List[FromContext], available: Set[String]): Set[String] = + l.foldLeft(available) { + case (av, TableContext(_, alias)) => Set(alias) + case (av, InfixContext(_, alias)) => Set(alias) + case (av, QueryContext(_, alias)) => Set(alias) + case (av, JoinContext(_, a, b, on)) => + av ++ loop(a :: Nil, av) ++ loop(b :: Nil, av) + case (av, FlatJoinContext(_, a, on)) => + val nav = av ++ loop(a :: Nil, av) + val free = FreeVariables(on).map(_.name) + val invalid = free -- nav + require( + invalid.isEmpty, + s"Found an `ON` table reference of a table that is not available: $invalid. " + + "The `ON` condition can only use tables defined through explicit joins." + ) + nav + } + loop(q.from, Set()) + } + + private def verify(query: FlattenSqlQuery): Option[InvalidSqlQuery] = { + + verifyFlatJoins(query) + + val aliases = + query.from.flatMap(this.aliases).map(Ident(_)) :+ Ident("*") :+ Ident("?") + + def verifyAst(ast: Ast) = { + val freeVariables = + (FreeVariables(ast) -- aliases).toList + val freeIdents = + (CollectAst(ast) { + case ast: Property => None + case Aggregation(_, _: Ident) => None + case ast: Ident => Some(ast) + }).flatten + (freeVariables ++ freeIdents) match { + case Nil => None + case free => Some(Error(free, ast)) + } + } + + // Recursively expand children until values are fully flattened. Identities in all these should + // be skipped during verification. + def expandSelect(sv: SelectValue): List[SelectValue] = + sv.ast match { + case Tuple(values) => + values.map(v => SelectValue(v)).flatMap(expandSelect(_)) + case CaseClass(values) => + values.map(v => SelectValue(v._2)).flatMap(expandSelect(_)) + case _ => List(sv) + } + + val freeVariableErrors: List[Error] = + query.where.flatMap(verifyAst).toList ++ + query.orderBy.map(_.ast).flatMap(verifyAst) ++ + query.limit.flatMap(verifyAst) ++ + query.select + .flatMap( + expandSelect(_) + ) // Expand tuple select clauses so their top-level identities are skipped + .map(_.ast) + .filterNot(_.isInstanceOf[Ident]) + .flatMap(verifyAst) ++ + query.from.flatMap { + case j: JoinContext => verifyAst(j.on) + case j: FlatJoinContext => verifyAst(j.on) + case _ => Nil + } + + val nestedErrors = + query.from.collect { + case QueryContext(query, alias) => verify(query).map(_.errors) + }.flatten.flatten + + (freeVariableErrors ++ nestedErrors) match { + case Nil => None + case errors => Some(InvalidSqlQuery(errors)) + } + } + + private def aliases(s: FromContext): List[String] = + s match { + case s: TableContext => List(s.alias) + case s: QueryContext => List(s.alias) + case s: InfixContext => List(s.alias) + case s: JoinContext => aliases(s.a) ++ aliases(s.b) + case s: FlatJoinContext => aliases(s.a) + } +} diff --git a/src/main/scala/minisql/context/sql/norm/AddDropToNestedOrderBy.scala b/src/main/scala/minisql/context/sql/norm/AddDropToNestedOrderBy.scala new file mode 100644 index 0000000..8fb0d20 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/AddDropToNestedOrderBy.scala @@ -0,0 +1,47 @@ +package minisql.context.sql.norm + +import minisql.ast.Constant +import minisql.context.sql.{FlattenSqlQuery, SqlQuery, _} + +/** + * In SQL Server, `Order By` clauses are only allowed in sub-queries if the + * sub-query has a `TOP` or `OFFSET` modifier. Otherwise an exception will be + * thrown. This transformation adds a 'dummy' `OFFSET 0` in this scenario (if an + * `Offset` clause does not exist already). + */ +object AddDropToNestedOrderBy { + + def applyInner(q: SqlQuery): SqlQuery = + q match { + case q: FlattenSqlQuery => + q.copy( + offset = + if (q.orderBy.nonEmpty) q.offset.orElse(Some(Constant(0))) + else q.offset, + from = q.from.map(applyInner(_)) + ) + + case SetOperationSqlQuery(a, op, b) => + SetOperationSqlQuery(applyInner(a), op, applyInner(b)) + case UnaryOperationSqlQuery(op, a) => + UnaryOperationSqlQuery(op, applyInner(a)) + } + + private def applyInner(f: FromContext): FromContext = + f match { + case QueryContext(a, alias) => QueryContext(applyInner(a), alias) + case JoinContext(t, a, b, on) => + JoinContext(t, applyInner(a), applyInner(b), on) + case FlatJoinContext(t, a, on) => FlatJoinContext(t, applyInner(a), on) + case other => other + } + + def apply(q: SqlQuery): SqlQuery = + q match { + case q: FlattenSqlQuery => q.copy(from = q.from.map(applyInner(_))) + case SetOperationSqlQuery(a, op, b) => + SetOperationSqlQuery(applyInner(a), op, applyInner(b)) + case UnaryOperationSqlQuery(op, a) => + UnaryOperationSqlQuery(op, applyInner(a)) + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala b/src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala new file mode 100644 index 0000000..9d03c1f --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandDistinct.scala @@ -0,0 +1,68 @@ +package minisql.context.sql.norm + +import minisql.ast.Visibility.Hidden +import minisql.ast._ + +object ExpandDistinct { + + @annotation.tailrec + def hasJoin(q: Ast): Boolean = { + q match { + case _: Join => true + case Map(q, _, _) => hasJoin(q) + case Filter(q, _, _) => hasJoin(q) + case _ => false + } + } + + def apply(q: Ast): Ast = + q match { + case Distinct(q) => + Distinct(apply(q)) + case q => + Transform(q) { + case Aggregation(op, Distinct(q)) => + Aggregation(op, Distinct(apply(q))) + case Distinct(Map(q, x, cc @ Tuple(values))) => + Map( + Distinct(Map(q, x, cc)), + x, + Tuple(values.zipWithIndex.map { + case (_, i) => Property(x, s"_${i + 1}") + }) + ) + + // Situations like this: + // case class AdHocCaseClass(id: Int, name: String) + // val q = quote { + // query[SomeTable].map(st => AdHocCaseClass(st.id, st.name)).distinct + // } + // ... need some special treatment. Otherwise their values will not be correctly expanded. + case Distinct(Map(q, x, cc @ CaseClass(values))) => + Map( + Distinct(Map(q, x, cc)), + x, + CaseClass(values.map { + case (name, _) => (name, Property(x, name)) + }) + ) + + // Need some special handling to address issues with distinct returning a single embedded entity i.e: + // query[Parent].map(p => p.emb).distinct.map(e => (e.name, e.id)) + // cannot treat such a case normally or "confused" queries will result e.g: + // SELECT p.embname, p.embid FROM (SELECT DISTINCT emb.name /* Where the heck is 'emb' coming from? */ AS embname, emb.id AS embid FROM Parent p) AS p + case d @ Distinct( + Map(q, x, p @ Property.Opinionated(_, _, _, Hidden)) + ) => + d + + // Problems with distinct were first discovered in #1032. Basically, unless + // the distinct is "expanded" adding an outer map, Ident's representing a Table will end up in invalid places + // such as "ORDER BY tableIdent" etc... + case Distinct(Map(q, x, p)) => + val newMap = Map(q, x, Tuple(List(p))) + val newIdent = Ident(x.name) + Map(Distinct(newMap), newIdent, Property(newIdent, "_1")) + } + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandJoin.scala b/src/main/scala/minisql/context/sql/norm/ExpandJoin.scala new file mode 100644 index 0000000..1677cf7 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandJoin.scala @@ -0,0 +1,49 @@ +package minisql.context.sql.norm + +import minisql.ast._ +import minisql.norm.BetaReduction +import minisql.norm.Normalize + +object ExpandJoin { + + def apply(q: Ast) = expand(q, None) + + def expand(q: Ast, id: Option[Ident]) = + Transform(q) { + case q @ Join(_, _, _, Ident(a), Ident(b), _) => + val (qr, tuple) = expandedTuple(q) + Map(qr, id.getOrElse(Ident(s"$a$b")), tuple) + } + + private def expandedTuple(q: Join): (Join, Tuple) = + q match { + + case Join(t, a: Join, b: Join, tA, tB, o) => + val (ar, at) = expandedTuple(a) + val (br, bt) = expandedTuple(b) + val or = BetaReduction(o, tA -> at, tB -> bt) + (Join(t, ar, br, tA, tB, or), Tuple(List(at, bt))) + + case Join(t, a: Join, b, tA, tB, o) => + val (ar, at) = expandedTuple(a) + val or = BetaReduction(o, tA -> at) + (Join(t, ar, b, tA, tB, or), Tuple(List(at, tB))) + + case Join(t, a, b: Join, tA, tB, o) => + val (br, bt) = expandedTuple(b) + val or = BetaReduction(o, tB -> bt) + (Join(t, a, br, tA, tB, or), Tuple(List(tA, bt))) + + case q @ Join(t, a, b, tA, tB, on) => + ( + Join(t, nestedExpand(a, tA), nestedExpand(b, tB), tA, tB, on), + Tuple(List(tA, tB)) + ) + } + + private def nestedExpand(q: Ast, id: Ident) = + Normalize(expand(q, Some(id))) match { + case Map(q, _, _) => q + case q => q + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandMappedInfix.scala b/src/main/scala/minisql/context/sql/norm/ExpandMappedInfix.scala new file mode 100644 index 0000000..b1cc186 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandMappedInfix.scala @@ -0,0 +1,12 @@ +package minisql.context.sql.norm + +import minisql.ast._ + +object ExpandMappedInfix { + def apply(q: Ast): Ast = { + Transform(q) { + case Map(Infix("" :: parts, (q: Query) :: params, pure, noParen), x, p) => + Infix("" :: parts, Map(q, x, p) :: params, pure, noParen) + } + } +} diff --git a/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala b/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala new file mode 100644 index 0000000..56095e9 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/ExpandNestedQueries.scala @@ -0,0 +1,147 @@ +package minisql.context.sql.norm + +import minisql.NamingStrategy +import minisql.ast.Ast +import minisql.ast.Ident +import minisql.ast._ +import minisql.ast.StatefulTransformer +import minisql.ast.Visibility.Visible +import minisql.context.sql._ + +import scala.collection.mutable.LinkedHashSet +import minisql.util.Interpolator +import minisql.util.Messages.TraceType.NestedQueryExpansion +import minisql.context.sql.norm.nested.ExpandSelect +import minisql.norm.BetaReduction + +import scala.collection.mutable + +class ExpandNestedQueries(strategy: NamingStrategy) { + + val interp = new Interpolator(3) + import interp._ + + def apply(q: SqlQuery, references: List[Property]): SqlQuery = + apply(q, LinkedHashSet.empty ++ references) + + // Using LinkedHashSet despite the fact that it is mutable because it has better characteristics then ListSet. + // Also this collection is strictly internal to ExpandNestedQueries and exposed anywhere else. + private def apply( + q: SqlQuery, + references: LinkedHashSet[Property] + ): SqlQuery = + q match { + case q: FlattenSqlQuery => + val expand = expandNested( + q.copy(select = ExpandSelect(q.select, references, strategy)) + ) + trace"Expanded Nested Query $q into $expand".andLog() + expand + case SetOperationSqlQuery(a, op, b) => + SetOperationSqlQuery(apply(a, references), op, apply(b, references)) + case UnaryOperationSqlQuery(op, q) => + UnaryOperationSqlQuery(op, apply(q, references)) + } + + private def expandNested(q: FlattenSqlQuery): SqlQuery = + q match { + case FlattenSqlQuery( + from, + where, + groupBy, + orderBy, + limit, + offset, + select, + distinct + ) => + val asts = Nil ++ select.map(_.ast) ++ where ++ groupBy ++ orderBy.map( + _.ast + ) ++ limit ++ offset + val expansions = q.from.map(expandContext(_, asts)) + val from = expansions.map(_._1) + val references = expansions.flatMap(_._2) + + val replacedRefs = references.map(ref => (ref, unhideAst(ref))) + + // Need to unhide properties that were used during the query + def replaceProps(ast: Ast) = + BetaReduction(ast, replacedRefs: _*) + def replacePropsOption(ast: Option[Ast]) = + ast.map(replaceProps(_)) + + val distinctKind = + q.distinct match { + case DistinctKind.DistinctOn(props) => + DistinctKind.DistinctOn(props.map(p => replaceProps(p))) + case other => other + } + + q.copy( + select = select.map(sv => sv.copy(ast = replaceProps(sv.ast))), + from = from, + where = replacePropsOption(where), + groupBy = replacePropsOption(groupBy), + orderBy = orderBy.map(ob => ob.copy(ast = replaceProps(ob.ast))), + limit = replacePropsOption(limit), + offset = replacePropsOption(offset), + distinct = distinctKind + ) + + } + + def unhideAst(ast: Ast): Ast = + Transform(ast) { + case Property.Opinionated(a, n, r, v) => + Property.Opinionated(unhideAst(a), n, r, Visible) + } + + private def unhideProperties(sv: SelectValue) = + sv.copy(ast = unhideAst(sv.ast)) + + private def expandContext( + s: FromContext, + asts: List[Ast] + ): (FromContext, LinkedHashSet[Property]) = + s match { + case QueryContext(q, alias) => + val refs = references(alias, asts) + (QueryContext(apply(q, refs), alias), refs) + case JoinContext(t, a, b, on) => + val (left, leftRefs) = expandContext(a, asts :+ on) + val (right, rightRefs) = expandContext(b, asts :+ on) + (JoinContext(t, left, right, on), leftRefs ++ rightRefs) + case FlatJoinContext(t, a, on) => + val (next, refs) = expandContext(a, asts :+ on) + (FlatJoinContext(t, next, on), refs) + case _: TableContext | _: InfixContext => + (s, new mutable.LinkedHashSet[Property]()) + } + + private def references(alias: String, asts: List[Ast]) = + LinkedHashSet.empty ++ (References(State(Ident(alias), Nil))(asts)( + _.apply + )._2.state.references) +} + +case class State(ident: Ident, references: List[Property]) + +case class References(val state: State) extends StatefulTransformer[State] { + + import state._ + + override def apply(a: Ast) = + a match { + case `reference`(p) => (p, References(State(ident, references :+ p))) + case other => super.apply(a) + } + + object reference { + def unapply(p: Property): Option[Property] = + p match { + case Property(`ident`, name) => Some(p) + case Property(reference(_), name) => Some(p) + case other => None + } + } +} diff --git a/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala b/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala new file mode 100644 index 0000000..30abb53 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/FlattenGroupByAggregation.scala @@ -0,0 +1,58 @@ +package minisql.context.sql.norm + +import minisql.ast.Aggregation +import minisql.ast.Ast +import minisql.ast.Drop +import minisql.ast.Filter +import minisql.ast.FlatMap +import minisql.ast.Ident +import minisql.ast.Join +import minisql.ast.Map +import minisql.ast.Query +import minisql.ast.SortBy +import minisql.ast.StatelessTransformer +import minisql.ast.Take +import minisql.ast.Union +import minisql.ast.UnionAll +import minisql.norm.BetaReduction +import minisql.util.Messages.fail +import minisql.ast.ConcatMap + +case class FlattenGroupByAggregation(agg: Ident) extends StatelessTransformer { + + override def apply(ast: Ast) = + ast match { + case q: Query if (isGroupByAggregation(q)) => + q match { + case Aggregation(op, Map(`agg`, ident, body)) => + Aggregation(op, BetaReduction(body, ident -> agg)) + case Map(`agg`, ident, body) => + BetaReduction(body, ident -> agg) + case q @ Aggregation(op, `agg`) => + q + case other => + fail(s"Invalid group by aggregation: '$other'") + } + case other => + super.apply(other) + } + + private[this] def isGroupByAggregation(ast: Ast): Boolean = + ast match { + case Aggregation(a, b) => isGroupByAggregation(b) + case Map(a, b, c) => isGroupByAggregation(a) + case FlatMap(a, b, c) => isGroupByAggregation(a) + case ConcatMap(a, b, c) => isGroupByAggregation(a) + case Filter(a, b, c) => isGroupByAggregation(a) + case SortBy(a, b, c, d) => isGroupByAggregation(a) + case Take(a, b) => isGroupByAggregation(a) + case Drop(a, b) => isGroupByAggregation(a) + case Union(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b) + case UnionAll(a, b) => isGroupByAggregation(a) || isGroupByAggregation(b) + case Join(t, a, b, ta, tb, on) => + isGroupByAggregation(a) || isGroupByAggregation(b) + case `agg` => true + case other => false + } + +} diff --git a/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala b/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala new file mode 100644 index 0000000..c239b63 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/SqlNormalize.scala @@ -0,0 +1,53 @@ +package minisql.context.sql.norm + +import minisql.norm._ +import minisql.ast.Ast +import minisql.norm.ConcatBehavior.AnsiConcat +import minisql.norm.EqualityBehavior.AnsiEquality +import minisql.norm.capture.DemarcateExternalAliases +import minisql.util.Messages.trace + +object SqlNormalize { + def apply( + ast: Ast, + concatBehavior: ConcatBehavior = AnsiConcat, + equalityBehavior: EqualityBehavior = AnsiEquality + ) = + new SqlNormalize(concatBehavior, equalityBehavior)(ast) +} + +class SqlNormalize( + concatBehavior: ConcatBehavior, + equalityBehavior: EqualityBehavior +) { + + private val normalize = + (identity[Ast] _) + .andThen(trace("original")) + .andThen(DemarcateExternalAliases.apply _) + .andThen(trace("DemarcateReturningAliases")) + .andThen(new FlattenOptionOperation(concatBehavior).apply _) + .andThen(trace("FlattenOptionOperation")) + .andThen(new SimplifyNullChecks(equalityBehavior).apply _) + .andThen(trace("SimplifyNullChecks")) + .andThen(Normalize.apply _) + .andThen(trace("Normalize")) + // Need to do RenameProperties before ExpandJoin which normalizes-out all the tuple indexes + // on which RenameProperties relies + .andThen(RenameProperties.apply _) + .andThen(trace("RenameProperties")) + .andThen(ExpandDistinct.apply _) + .andThen(trace("ExpandDistinct")) + .andThen(NestImpureMappedInfix.apply _) + .andThen(trace("NestMappedInfix")) + .andThen(Normalize.apply _) + .andThen(trace("Normalize")) + .andThen(ExpandJoin.apply _) + .andThen(trace("ExpandJoin")) + .andThen(ExpandMappedInfix.apply _) + .andThen(trace("ExpandMappedInfix")) + .andThen(Normalize.apply _) + .andThen(trace("Normalize")) + + def apply(ast: Ast) = normalize(ast) +} diff --git a/src/main/scala/minisql/context/sql/norm/nested/Elements.scala b/src/main/scala/minisql/context/sql/norm/nested/Elements.scala new file mode 100644 index 0000000..1cdd629 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/nested/Elements.scala @@ -0,0 +1,29 @@ +package minisql.context.sql.norm.nested + +import minisql.PseudoAst +import minisql.context.sql.SelectValue + +object Elements { + + /** + * In order to be able to reconstruct the original ordering of elements inside + * of a select clause, we need to keep track of their order, not only within + * the top-level select but also it's order within any possible + * tuples/case-classes that in which it is embedded. For example, in the + * query:
 query[Person].map(p => (p.id, (p.name, p.age))).nested
+   * // SELECT p.id, p.name, p.age FROM (SELECT x.id, x.name, x.age FROM person
+   * x) AS p 
Since the `p.name` and `p.age` elements are selected + * inside of a sub-tuple, their "order" is `List(2,1)` and `List(2,2)` + * respectively as opposed to `p.id` whose "order" is just `List(1)`. + * + * This class keeps track of the values needed in order to perform do this. + */ + case class OrderedSelect(order: List[Int], selectValue: SelectValue) + extends PseudoAst { + override def toString: String = s"[${order.mkString(",")}]${selectValue}" + } + object OrderedSelect { + def apply(order: Int, selectValue: SelectValue) = + new OrderedSelect(List(order), selectValue) + } +} diff --git a/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala b/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala new file mode 100644 index 0000000..a8fd5d6 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/nested/ExpandSelect.scala @@ -0,0 +1,262 @@ +package minisql.context.sql.norm.nested + +import minisql.NamingStrategy +import minisql.ast.Property +import minisql.context.sql.SelectValue +import minisql.util.Interpolator +import minisql.util.Messages.TraceType.NestedQueryExpansion + +import scala.collection.mutable.LinkedHashSet +import minisql.context.sql.norm.nested.Elements._ +import minisql.ast._ +import minisql.norm.BetaReduction + +/** + * Takes the `SelectValue` elements inside of a sub-query (if a super/sub-query + * constrct exists) and flattens them from a nested-hiearchical structure (i.e. + * tuples inside case classes inside tuples etc..) into into a single series of + * top-level select elements where needed. In cases where a user wants to select + * an element that contains an entire tuple (i.e. a sub-tuple of the outer + * select clause) we pull out the entire tuple that is being selected and leave + * it to the tokenizer to flatten later. + * + * The part about this operation that is tricky is if there are situations where + * there are infix clauses in a sub-query representing an element that has not + * been selected by the query-query but in order to ensure the SQL operation has + * the same meaning, we need to keep track for it. For example:
 val
+ * q = quote { query[Person].map(p => (infix"DISTINCT ON (${p.other})".as[Int],
+ * p.name, p.id)).map(t => (t._2, t._3)) } run(q) // SELECT p._2, p._3 FROM
+ * (SELECT DISTINCT ON (p.other), p.name AS _2, p.id AS _3 FROM Person p) AS p
+ * 
Since `DISTINCT ON` significantly changes the behavior of the + * outer query, we need to keep track of it inside of the inner query. In order + * to do this, we need to keep track of the location of the infix in the inner + * query so that we can reconstruct it. This is why the `OrderedSelect` and + * `DoubleOrderedSelect` objects are used. See the notes on these classes for + * more detail. + * + * See issue #1597 for more details and another example. + */ +private class ExpandSelect( + selectValues: List[SelectValue], + references: LinkedHashSet[Property], + strategy: NamingStrategy +) { + val interp = new Interpolator(3) + import interp._ + + object TupleIndex { + def unapply(s: String): Option[Int] = + if (s.matches("_[0-9]*")) + Some(s.drop(1).toInt - 1) + else + None + } + + object MultiTupleIndex { + def unapply(s: String): Boolean = + if (s.matches("(_[0-9]+)+")) + true + else + false + } + + val select = + selectValues.zipWithIndex.map { + case (value, index) => OrderedSelect(index, value) + } + + def expandColumn(name: String, renameable: Renameable): String = + renameable.fixedOr(name)(strategy.column(name)) + + def apply: List[SelectValue] = + trace"Expanding Select values: $selectValues into references: $references" andReturn { + + def expandReference(ref: Property): OrderedSelect = + trace"Expanding: $ref from $select" andReturn { + + def expressIfTupleIndex(str: String) = + str match { + case MultiTupleIndex() => Some(str) + case _ => None + } + + def concat(alias: Option[String], idx: Int) = + Some(s"${alias.getOrElse("")}_${idx + 1}") + + val orderedSelect = ref match { + case pp @ Property(ast: Property, TupleIndex(idx)) => + trace"Reference is a sub-property of a tuple index: $idx. Walking inside." andReturn + expandReference(ast) match { + case OrderedSelect(o, SelectValue(Tuple(elems), alias, c)) => + trace"Expressing Element $idx of $elems " andReturn + OrderedSelect( + o :+ idx, + SelectValue(elems(idx), concat(alias, idx), c) + ) + case OrderedSelect(o, SelectValue(ast, alias, c)) => + trace"Appending $idx to $alias " andReturn + OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) + } + case pp @ Property.Opinionated( + ast: Property, + name, + renameable, + visible + ) => + trace"Reference is a sub-property. Walking inside." andReturn + expandReference(ast) match { + case OrderedSelect(o, SelectValue(ast, nested, c)) => + // Alias is the name of the column after the naming strategy + // The clauses in `SqlIdiom` that use `Tokenizer[SelectValue]` select the + // alias field when it's value is Some(T). + // Technically the aliases of a column should not be using naming strategies + // but this is an issue to fix at a later date. + + // In the current implementation, aliases we add nested tuple names to queries e.g. + // SELECT foo from + // SELECT x, y FROM (SELECT foo, bar, red, orange FROM baz JOIN colors) + // Typically becomes SELECT foo _1foo, _1bar, _2red, _2orange when + // this kind of query is the result of an applicative join that looks like this: + // query[baz].join(query[colors]).nested + // this may need to change based on how distinct appends table names instead of just tuple indexes + // into the property path. + + trace"...inside walk completed, continuing to return: " andReturn + OrderedSelect( + o, + SelectValue( + // Note: Pass invisible properties to be tokenized by the idiom, they should be excluded there + Property.Opinionated(ast, name, renameable, visible), + // Skip concatonation of invisible properties into the alias e.g. so it will be + Some( + s"${nested.getOrElse("")}${expandColumn(name, renameable)}" + ) + ) + ) + } + case pp @ Property(_, TupleIndex(idx)) => + trace"Reference is a tuple index: $idx from $select." andReturn + select(idx) match { + case OrderedSelect(o, SelectValue(ast, alias, c)) => + OrderedSelect(o, SelectValue(ast, concat(alias, idx), c)) + } + case pp @ Property.Opinionated(_, name, renameable, visible) => + select match { + case List( + OrderedSelect(o, SelectValue(cc: CaseClass, alias, c)) + ) => + // Currently case class element name is not being appended. Need to change that in order to ensure + // path name uniqueness in future. + val ((_, ast), index) = + cc.values.zipWithIndex.find(_._1._1 == name) match { + case Some(v) => v + case None => + throw new IllegalArgumentException( + s"Cannot find element $name in $cc" + ) + } + trace"Reference is a case class member: " andReturn + OrderedSelect( + o :+ index, + SelectValue(ast, Some(expandColumn(name, renameable)), c) + ) + case List(OrderedSelect(o, SelectValue(i: Ident, _, c))) => + trace"Reference is an identifier: " andReturn + OrderedSelect( + o, + SelectValue( + Property.Opinionated(i, name, renameable, visible), + Some(name), + c + ) + ) + case other => + trace"Reference is unidentified: $other returning:" andReturn + OrderedSelect( + Integer.MAX_VALUE, + SelectValue( + Ident.Opinionated(name, visible), + Some(expandColumn(name, renameable)), + false + ) + ) + } + } + + // For certain very large queries where entities are unwrapped and then re-wrapped into CaseClass/Tuple constructs, + // the actual row-types can contain Tuple/CaseClass values. For this reason. They need to be beta-reduced again. + val normalizedOrderedSelect = orderedSelect.copy(selectValue = + orderedSelect.selectValue.copy(ast = + BetaReduction(orderedSelect.selectValue.ast) + ) + ) + + trace"Expanded $ref into $orderedSelect then Normalized to $normalizedOrderedSelect" andReturn + normalizedOrderedSelect + } + + def deAliasWhenUneeded(os: OrderedSelect) = + os match { + case OrderedSelect( + _, + sv @ SelectValue(Property(Ident(_), propName), Some(alias), _) + ) if (propName == alias) => + trace"Detected select value with un-needed alias: $os removing it:" andReturn + os.copy(selectValue = sv.copy(alias = None)) + case _ => os + } + + references.toList match { + case Nil => select.map(_.selectValue) + case refs => { + // elements first need to be sorted by their order in the select clause. Since some may map to multiple + // properties when expanded, we want to maintain this order of properties as a secondary value. + val mappedRefs = + refs + // Expand all the references to properties that we have selected in the super query + .map(expandReference) + // Once all the recursive calls of expandReference are done, remove the alias if it is not needed. + // We cannot do this because during recursive calls, the aliases of outer clauses are used for inner ones. + .map(deAliasWhenUneeded(_)) + + trace"Mapped Refs: $mappedRefs".andLog() + + // are there any selects that have infix values which we have not already selected? We need to include + // them because they could be doing essential things e.g. RANK ... ORDER BY + val remainingSelectsWithInfixes = + trace"Searching Selects with Infix:" andReturn + new FindUnexpressedInfixes(select)(mappedRefs) + + implicit val ordering: scala.math.Ordering[List[Int]] = + new scala.math.Ordering[List[Int]] { + override def compare(x: List[Int], y: List[Int]): Int = + (x, y) match { + case (head1 :: tail1, head2 :: tail2) => + val diff = head1 - head2 + if (diff != 0) diff + else compare(tail1, tail2) + case (Nil, Nil) => 0 // List(1,2,3) == List(1,2,3) + case (head1, Nil) => -1 // List(1,2,3) < List(1,2) + case (Nil, head2) => 1 // List(1,2) > List(1,2,3) + } + } + + val sortedRefs = + (mappedRefs ++ remainingSelectsWithInfixes).sortBy(ref => + ref.order + ) // (ref.order, ref.secondaryOrder) + + sortedRefs.map(_.selectValue) + } + } + } +} + +object ExpandSelect { + def apply( + selectValues: List[SelectValue], + references: LinkedHashSet[Property], + strategy: NamingStrategy + ): List[SelectValue] = + new ExpandSelect(selectValues, references, strategy).apply +} diff --git a/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala b/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala new file mode 100644 index 0000000..2ea1320 --- /dev/null +++ b/src/main/scala/minisql/context/sql/norm/nested/FindUnexpressedInfixes.scala @@ -0,0 +1,83 @@ +package minisql.context.sql.norm.nested + +import minisql.context.sql.norm.nested.Elements._ +import minisql.util.Interpolator +import minisql.util.Messages.TraceType.NestedQueryExpansion +import minisql.ast._ +import minisql.context.sql.SelectValue + +/** + * The challenge with appeneding infixes (that have not been used but are still + * needed) back into the query, is that they could be inside of + * tuples/case-classes that have already been selected, or inside of sibling + * elements which have been selected. Take for instance a query that looks like + * this:
 query[Person].map(p => (p.name, (p.id,
+ * infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2._1)) 
In + * this situation, `p.id` which is the sibling of the non-selected infix has + * been selected via `p._2._1` (whose select-order is List(1,0) to represent 1st + * element in 2nd tuple. We need to add it's sibling infix. + * + * Or take the following situation:
 query[Person].map(p => (p.name,
+ * (p.id, infix"foo(\${p.other})".as[Int]))).map(p => (p._1, p._2))
+ * 
In this case, we have selected the entire 2nd element including + * the infix. We need to know that `P._2._2` does not need to be selected since + * `p._2` was. + * + * In order to do these things, we use the `order` property from `OrderedSelect` + * in order to see which sub-sub-...-element has been selected. If `p._2` (that + * has order `List(1)`) has been selected, we know that any infixes inside of it + * e.g. `p._2._1` (ordering `List(1,0)`) does not need to be. + */ +class FindUnexpressedInfixes(select: List[OrderedSelect]) { + val interp = new Interpolator(3) + import interp._ + + def apply(refs: List[OrderedSelect]) = { + + def pathExists(path: List[Int]) = + refs.map(_.order).contains(path) + + def containsInfix(ast: Ast) = + CollectAst.byType[Infix](ast).length > 0 + + // build paths to every infix and see these paths were not selected already + def findMissingInfixes( + ast: Ast, + parentOrder: List[Int] + ): List[(Ast, List[Int])] = { + trace"Searching for infix: $ast in the sub-path $parentOrder".andLog() + if (pathExists(parentOrder)) + trace"No infixes found" andContinue + List() + else + ast match { + case Tuple(values) => + values.zipWithIndex + .filter(v => containsInfix(v._1)) + .flatMap { + case (ast, index) => + findMissingInfixes(ast, parentOrder :+ index) + } + case CaseClass(values) => + values.zipWithIndex + .filter(v => containsInfix(v._1._2)) + .flatMap { + case ((_, ast), index) => + findMissingInfixes(ast, parentOrder :+ index) + } + case other if (containsInfix(other)) => + trace"Found unexpressed infix inside $other in $parentOrder" + .andLog() + List((other, parentOrder)) + case _ => + List() + } + } + + select.flatMap { + case OrderedSelect(o, sv) => findMissingInfixes(sv.ast, o) + }.map { + case (ast, order) => OrderedSelect(order, SelectValue(ast)) + } + } +} diff --git a/src/main/scala/minisql/norm/FreeVariables.scala b/src/main/scala/minisql/norm/FreeVariables.scala new file mode 100644 index 0000000..9c63437 --- /dev/null +++ b/src/main/scala/minisql/norm/FreeVariables.scala @@ -0,0 +1,120 @@ +package minisql.norm + +import minisql.ast.* +import collection.immutable.Set + +case class State(seen: Set[Ident], free: Set[Ident]) + +case class FreeVariables(state: State) extends StatefulTransformer[State] { + + override def apply(ast: Ast): (Ast, StatefulTransformer[State]) = + ast match { + case ident: Ident if (!state.seen.contains(ident)) => + (ident, FreeVariables(State(state.seen, state.free + ident))) + case f @ Function(params, body) => + val (_, t) = + FreeVariables(State(state.seen ++ params, state.free))(body) + (f, FreeVariables(State(state.seen, state.free ++ t.state.free))) + case q @ Foreach(a, b, c) => + (q, free(a, b, c)) + case other => + super.apply(other) + } + + override def apply( + o: OptionOperation + ): (OptionOperation, StatefulTransformer[State]) = + o match { + case q @ OptionTableFlatMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionTableMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionTableExists(a, b, c) => + (q, free(a, b, c)) + case q @ OptionTableForall(a, b, c) => + (q, free(a, b, c)) + case q @ OptionFlatMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionMap(a, b, c) => + (q, free(a, b, c)) + case q @ OptionForall(a, b, c) => + (q, free(a, b, c)) + case q @ OptionExists(a, b, c) => + (q, free(a, b, c)) + case other => + super.apply(other) + } + + override def apply(e: Assignment): (Assignment, StatefulTransformer[State]) = + e match { + case Assignment(a, b, c) => + val t = FreeVariables(State(state.seen + a, state.free)) + val (bt, btt) = t(b) + val (ct, ctt) = t(c) + ( + Assignment(a, bt, ct), + FreeVariables( + State(state.seen, state.free ++ btt.state.free ++ ctt.state.free) + ) + ) + } + + override def apply(action: Action): (Action, StatefulTransformer[State]) = + action match { + case q @ Returning(a, b, c) => + (q, free(a, b, c)) + case q @ ReturningGenerated(a, b, c) => + (q, free(a, b, c)) + case other => + super.apply(other) + } + + override def apply( + e: OnConflict.Target + ): (OnConflict.Target, StatefulTransformer[State]) = (e, this) + + override def apply(query: Query): (Query, StatefulTransformer[State]) = + query match { + case q @ Filter(a, b, c) => (q, free(a, b, c)) + case q @ Map(a, b, c) => (q, free(a, b, c)) + case q @ DistinctOn(a, b, c) => (q, free(a, b, c)) + case q @ FlatMap(a, b, c) => (q, free(a, b, c)) + case q @ ConcatMap(a, b, c) => (q, free(a, b, c)) + case q @ SortBy(a, b, c, d) => (q, free(a, b, c)) + case q @ GroupBy(a, b, c) => (q, free(a, b, c)) + case q @ FlatJoin(t, a, b, c) => (q, free(a, b, c)) + case q @ Join(t, a, b, iA, iB, on) => + val (_, freeA) = apply(a) + val (_, freeB) = apply(b) + val (_, freeOn) = + FreeVariables(State(state.seen + iA + iB, Set.empty))(on) + ( + q, + FreeVariables( + State( + state.seen, + state.free ++ freeA.state.free ++ freeB.state.free ++ freeOn.state.free + ) + ) + ) + case _: Entity | _: Take | _: Drop | _: Union | _: UnionAll | + _: Aggregation | _: Distinct | _: Nested => + super.apply(query) + } + + private def free(a: Ast, ident: Ident, c: Ast) = { + val (_, ta) = apply(a) + val (_, tc) = FreeVariables(State(state.seen + ident, state.free))(c) + FreeVariables( + State(state.seen, state.free ++ ta.state.free ++ tc.state.free) + ) + } +} + +object FreeVariables { + def apply(ast: Ast): Set[Ident] = + new FreeVariables(State(Set.empty, Set.empty))(ast) match { + case (_, transformer) => + transformer.state.free + } +}