diff --git a/src/main/scala/minisql/ast/Ast.scala b/src/main/scala/minisql/ast/Ast.scala index 48e3fae..52446e3 100644 --- a/src/main/scala/minisql/ast/Ast.scala +++ b/src/main/scala/minisql/ast/Ast.scala @@ -385,8 +385,9 @@ case class ScalarValueLift( case class ScalarQueryLift( name: String, - liftId: String -) extends ScalarLift {} + liftId: String, + value: Option[(Seq[Any], ParamEncoder[?])] +) extends ScalarLift object ScalarLift { given ToExpr[ScalarLift] with { diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 8b3b96f..af33d80 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -4,9 +4,10 @@ import scala.deriving.* import scala.compiletime.* import scala.util.Try import minisql.util.* -import minisql.idiom.{Idiom, Statement} +import minisql.idiom.{Idiom, Statement, ReifyStatement} import minisql.{NamingStrategy, ParamEncoder} import minisql.ColumnDecoder +import minisql.ast.{Ast, ScalarValueLift, CollectAst} trait Context[I <: Idiom, N <: NamingStrategy] { selft => @@ -50,21 +51,40 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => type Decoder[X] = ColumnDecoder.Aux[DBRow, X] - type DBIO[X] = ( - statement: Statement, - params: (Any, Encoder[?]), - extract: RowExtract[X] + type DBIO[E] = ( + sql: String, + params: List[(Any, Encoder[?])], + mapper: Iterable[DBRow] => Try[E] ) extension (ast: Ast) { - extractParams + private def liftMap = { + val lifts = CollectAst.byType[ScalarValueLift](ast) + lifts.map(l => l.liftId -> l.value.get).toMap + } + } + + extension (stmt: Statement) { + def expand(liftMap: Map[String, (Any, ParamEncoder[?])]) = + ReifyStatement( + idiom.liftingPlaceholder, + idiom.emptySetContainsToken, + stmt, + liftMap + ) } inline def io[E]( inline q: minisql.Query[E] - )(using r: RowExtract[E]): DBIO[Seq[E]] = { - val statement = minisql.compile(q, idiom, naming) - ??? + )(using r: RowExtract[E]): DBIO[IArray[E]] = { + val lifts = q.liftMap + val stmt = minisql.compile(q, idiom, naming) + val (sql, params) = stmt.expand(lifts) + ( + sql = sql, + params = params.map(_.value.get.asInstanceOf), + mapper = (rows) => rows.traverse(r.extract) + ) } } diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala index 7c5b9a8..7a4a07a 100644 --- a/src/main/scala/minisql/idiom/ReifyStatement.scala +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -1,8 +1,9 @@ package minisql.idiom -import minisql.ast._ +import minisql.ParamEncoder +import minisql.ast.* import minisql.util.Interleave -import minisql.idiom.StatementInterpolator._ +import minisql.idiom.StatementInterpolator.* import scala.annotation.tailrec import scala.collection.immutable.{Map => SMap} @@ -12,7 +13,7 @@ object ReifyStatement { liftingPlaceholder: Int => String, emptySetContainsToken: Token => Token, statement: Statement, - liftMap: SMap[String, (Any, Any)] + liftMap: SMap[String, (Any, ParamEncoder[?])] ): (String, List[ScalarValueLift]) = { val expanded = expandLiftings(statement, emptySetContainsToken, liftMap) token2string(expanded, liftingPlaceholder) @@ -61,8 +62,39 @@ object ReifyStatement { private def expandLiftings( statement: Statement, emptySetContainsToken: Token => Token, - liftMap: SMap[String, (Any, Any)] - ): Token = { - ??? + liftMap: SMap[String, (Any, ParamEncoder[?])] + ): (Token) = { + Statement { + val lb = List.newBuilder[Token] + statement.tokens.foldLeft(lb) { + case ( + tokens, + SetContainsToken(a, op, ScalarLiftToken(lift: ScalarQueryLift)) + ) => + val (lv, le) = liftMap(lift.liftId) + lv.asInstanceOf[Iterable[Any]].toVector match { + case Vector() => tokens += emptySetContainsToken(a) + case values => + val liftings = values.zipWithIndex.map { + case (v, i) => + ScalarLiftToken( + ScalarValueLift( + s"${lift.name}[${i}]", + s"${lift.liftId}[${i}]", + Some(v -> le) + ) + ) + } + val separators = Vector.fill(liftings.size - 1)(StringToken(", ")) + (tokens += stmt"$a $op (") ++= Interleave( + liftings, + separators + ) += StringToken(")") + } + case (tokens, token) => + tokens += token + } + lb.result() + } } } diff --git a/src/main/scala/minisql/util/CollectTry.scala b/src/main/scala/minisql/util/CollectTry.scala index 74a6984..6b466be 100644 --- a/src/main/scala/minisql/util/CollectTry.scala +++ b/src/main/scala/minisql/util/CollectTry.scala @@ -2,7 +2,7 @@ package minisql.util import scala.util.* -extension [A](xs: IArray[A]) { +extension [A](xs: Iterable[A]) { private[minisql] def traverse[B](f: A => Try[B]): Try[IArray[B]] = { val out = IArray.newBuilder[Any] var left: Option[Throwable] = None