diff --git a/build.sbt b/build.sbt index 86502a5..d869492 100644 --- a/build.sbt +++ b/build.sbt @@ -6,6 +6,8 @@ libraryDependencies ++= Seq( "org.scalameta" %% "munit" % "1.0.3" % Test ) +javaOptions ++= Seq("-Xss16m") + scalacOptions ++= Seq( "-deprecation", "-feature", diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index af3b4c6..3256d00 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -39,8 +39,8 @@ private def quotedLiftImpl[X: Type]( e: Expr[ParamEncoder[X]] )(using Quotes): Expr[ast.ScalarValueLift] = { import quotes.reflect.* - val name = x.asTerm.symbol.fullName - val liftId = x.asTerm.symbol.owner.fullName + "@" + name + val name = x.asTerm.show + val liftId = liftIdOfExpr(x) '{ ast.ScalarValueLift( ${ Expr(name) }, @@ -84,23 +84,19 @@ object EntityQuery { } inline def insert(v: E)(using m: Mirror.ProductOf[E]): Insert = { - val entity = e.asInstanceOf[ast.Entity] - val assignments = transformCaseClassToAssignments[E](v, entity.name) - ast.Insert(entity, assignments) + ast.Insert(e, transformCaseClassToAssignments[E](v)) } } } private inline def transformCaseClassToAssignments[E]( - v: E, - entityName: String + v: E )(using m: Mirror.ProductOf[E]): List[ast.Assignment] = ${ - transformCaseClassToAssignmentsImpl[E]('v, 'entityName) + transformCaseClassToAssignmentsImpl[E]('v) } private def transformCaseClassToAssignmentsImpl[E: Type]( - v: Expr[E], - entityName: Expr[String] + v: Expr[E] )(using Quotes): Expr[List[ast.Assignment]] = { import quotes.reflect.* @@ -115,10 +111,10 @@ private def transformCaseClassToAssignmentsImpl[E: Type]( case '[t] => '{ ast.Assignment( - ast.Ident($entityName), - ast.Property(ast.Ident($entityName), ${ Expr(fieldName) }), - quotedLift[t](${ Select.unique(v.asTerm, fieldName).asExprOf[t] })( - using summonInline[ParamEncoder[t]] + ast.Ident("v"), + ast.Property(ast.Ident("v"), ${ Expr(fieldName) }), + quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using + summonInline[ParamEncoder[t]] ) ) } @@ -186,8 +182,10 @@ private def compileImpl[I <: Idiom, N <: NamingStrategy]( n: Expr[N] )(using Quotes, Type[I], Type[N]): Expr[Statement] = { import quotes.reflect.* + println(s"Start q.value") q.value match { case Some(ast) => + println(s"Finish q.value: ${ast}") val idiom = LoadObject[I].getOrElse( report.errorAndAbort(s"Idiom not known at compile") ) diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 9f70b0d..e527a6f 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -46,8 +46,7 @@ private given FromExpr[ScalarValueLift] with { def unapply(x: Expr[ScalarValueLift])(using Quotes): Option[ScalarValueLift] = x match { case '{ ScalarValueLift(${ Expr(n) }, ${ Expr(id) }, $y) } => - // don't cared about value here, a little tricky - Some(ScalarValueLift(n, id, null)) + Some(ScalarValueLift(n, id, None)) } } diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 6f6bea5..c469064 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -1,13 +1,14 @@ package minisql.context -import scala.deriving.* -import scala.compiletime.* -import scala.util.Try import minisql.util.* import minisql.idiom.{Idiom, Statement, ReifyStatement} import minisql.{NamingStrategy, ParamEncoder} import minisql.ColumnDecoder import minisql.ast.{Ast, ScalarValueLift, CollectAst} +import scala.deriving.* +import scala.compiletime.* +import scala.util.Try +import scala.annotation.targetName trait RowExtract[A, Row] { def extract(row: Row): Try[A] @@ -89,6 +90,30 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => ) } + @targetName("ioAction") + inline def io[E](inline q: minisql.Action[E]): DBIO[E] = { + val extractor = summonFrom { + case e: RowExtract[E, DBRow] => e + case e: ColumnDecoder.Aux[DBRow, E] => + RowExtract.single(e) + } + + val lifts = q.liftMap + val stmt = minisql.compile(q, idiom, naming) + val (sql, params) = stmt.expand(lifts) + ( + sql = sql, + params = params.map(_.value.get.asInstanceOf[(Any, Encoder[?])]), + mapper = (rows) => + rows + .traverse(extractor.extract) + .flatMap( + _.headOption.toRight(new Exception(s"No value return")).toTry + ) + ) + } + + @targetName("ioQuery") inline def io[E]( inline q: minisql.Query[E] ): DBIO[IArray[E]] = { diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala index ba099b0..dffd56b 100644 --- a/src/main/scala/minisql/context/sql/SqlIdiom.scala +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -31,13 +31,13 @@ trait SqlIdiom extends Idiom { def querifyAst(ast: Ast) = SqlQuery(ast) - private def doTranslate(ast: Ast, cached: Boolean)(implicit + private def doTranslate(ast: Ast, cached: Boolean)(using naming: NamingStrategy ): (Ast, Statement) = { val normalizedAst = SqlNormalize(ast, concatBehavior, equalityBehavior) - implicit val tokernizer: Tokenizer[Ast] = defaultTokenizer + given Tokenizer[Ast] = defaultTokenizer val token = normalizedAst match { @@ -63,7 +63,7 @@ trait SqlIdiom extends Idiom { doTranslate(ast, false) } - def defaultTokenizer(implicit naming: NamingStrategy): Tokenizer[Ast] = + def defaultTokenizer(using naming: NamingStrategy): Tokenizer[Ast] = new Tokenizer[Ast] { private val stableTokenizer = astTokenizer(using this, naming) @@ -73,7 +73,7 @@ trait SqlIdiom extends Idiom { } - def astTokenizer(implicit + def astTokenizer(using astTokenizer: Tokenizer[Ast], strategy: NamingStrategy ): Tokenizer[Ast] = diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala index fd18549..1507919 100644 --- a/src/main/scala/minisql/idiom/MirrorIdiom.scala +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -305,7 +305,7 @@ trait MirrorIdiomBase extends Idiom { Tokenizer[OnConflict.Target] { case OnConflict.NoTarget => stmt"" case OnConflict.Properties(props) => - val listTokens = listTokenizer(using astTokenizer).token(props) + val listTokens = props.token stmt"(${listTokens})" } diff --git a/src/main/scala/minisql/idiom/ReifyStatement.scala b/src/main/scala/minisql/idiom/ReifyStatement.scala index 7a4a07a..4206238 100644 --- a/src/main/scala/minisql/idiom/ReifyStatement.scala +++ b/src/main/scala/minisql/idiom/ReifyStatement.scala @@ -16,11 +16,12 @@ object ReifyStatement { liftMap: SMap[String, (Any, ParamEncoder[?])] ): (String, List[ScalarValueLift]) = { val expanded = expandLiftings(statement, emptySetContainsToken, liftMap) - token2string(expanded, liftingPlaceholder) + token2string(expanded, liftMap, liftingPlaceholder) } private def token2string( token: Token, + liftMap: SMap[String, (Any, ParamEncoder[?])], liftingPlaceholder: Int => String ): (String, List[ScalarValueLift]) = { @@ -44,7 +45,7 @@ object ReifyStatement { ) case ScalarLiftToken(lift: ScalarValueLift) => sqlBuilder ++= liftingPlaceholder(liftingSize) - liftBuilder += lift + liftBuilder += lift.copy(value = liftMap.get(lift.liftId)) loop(tail, liftingSize + 1) case ScalarLiftToken(o) => throw new Exception(s"Cannot tokenize ScalarQueryLift: ${o}") diff --git a/src/main/scala/minisql/idiom/StatementInterpolator.scala b/src/main/scala/minisql/idiom/StatementInterpolator.scala index 056d58f..2893c8d 100644 --- a/src/main/scala/minisql/idiom/StatementInterpolator.scala +++ b/src/main/scala/minisql/idiom/StatementInterpolator.scala @@ -8,12 +8,20 @@ import scala.collection.mutable.ListBuffer object StatementInterpolator { + extension [T](list: List[T]) { + private[minisql] def mkStmt( + sep: String = ", " + )(using tokenize: Tokenizer[T]) = { + val l1 = list.map(_.token) + val l2 = List.fill(l1.size - 1)(StringToken(sep)) + Statement(Interleave(l1, l2)) + } + } trait Tokenizer[T] { extension (v: T) { def token: Token } } - object Tokenizer { def apply[T](f: T => Token): Tokenizer[T] = new Tokenizer[T] { extension (v: T) { @@ -31,37 +39,29 @@ object StatementInterpolator { } } - implicit class TokenImplicit[T](v: T)(implicit tokenizer: Tokenizer[T]) { + extension [T](v: T)(using tokenizer: Tokenizer[T]) { def token = tokenizer.token(v) } - implicit def stringTokenizer: Tokenizer[String] = + given stringTokenizer: Tokenizer[String] = Tokenizer[String] { case string => StringToken(string) } - implicit def liftTokenizer: Tokenizer[Lift] = + given liftTokenizer: Tokenizer[Lift] = Tokenizer[Lift] { case lift: ScalarLift => ScalarLiftToken(lift) } - implicit def tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity) - implicit def statementTokenizer: Tokenizer[Statement] = + given tokenTokenizer: Tokenizer[Token] = Tokenizer[Token](identity) + given statementTokenizer: Tokenizer[Statement] = Tokenizer[Statement](identity) - implicit def stringTokenTokenizer: Tokenizer[StringToken] = + given stringTokenTokenizer: Tokenizer[StringToken] = Tokenizer[StringToken](identity) - implicit def liftingTokenTokenizer: Tokenizer[ScalarLiftToken] = + given liftingTokenTokenizer: Tokenizer[ScalarLiftToken] = Tokenizer[ScalarLiftToken](identity) - extension [T](list: List[T]) { - def mkStmt(sep: String = ", ")(implicit tokenize: Tokenizer[T]) = { - val l1 = list.map(_.token) - val l2 = List.fill(l1.size - 1)(StringToken(sep)) - Statement(Interleave(l1, l2)) - } - } - - implicit def listTokenizer[T](implicit + given listTokenizer[T](using tokenize: Tokenizer[T] ): Tokenizer[List[T]] = Tokenizer[List[T]] { diff --git a/src/main/scala/minisql/parsing/LiftParsing.scala b/src/main/scala/minisql/parsing/LiftParsing.scala index 9a0f32b..f0aba8a 100644 --- a/src/main/scala/minisql/parsing/LiftParsing.scala +++ b/src/main/scala/minisql/parsing/LiftParsing.scala @@ -4,13 +4,14 @@ import scala.quoted.* import minisql.ParamEncoder import minisql.ast import minisql.* +import minisql.util.* private[parsing] def liftParsing( astParser: => Parser[ast.Ast] )(using Quotes): Parser[ast.Lift] = { case '{ lift[t](${ x })(using $e: ParamEncoder[t]) } => import quotes.reflect.* - val name = x.asTerm.symbol.fullName - val liftId = x.asTerm.symbol.owner.fullName + "@" + name + val name = x.show + val liftId = liftIdOfExpr(x) '{ ast.ScalarValueLift(${ Expr(name) }, ${ Expr(liftId) }, Some($x -> $e)) } } diff --git a/src/main/scala/minisql/util/CollectTry.scala b/src/main/scala/minisql/util/CollectTry.scala index 6b466be..f4572b8 100644 --- a/src/main/scala/minisql/util/CollectTry.scala +++ b/src/main/scala/minisql/util/CollectTry.scala @@ -20,7 +20,7 @@ extension [A](xs: Iterable[A]) { } } -object CollectTry { +private[minisql] object CollectTry { def apply[T](list: List[Try[T]]): Try[List[T]] = list.foldLeft(Try(List.empty[T])) { case (list, t) => diff --git a/src/main/scala/minisql/util/QuotesHelper.scala b/src/main/scala/minisql/util/QuotesHelper.scala new file mode 100644 index 0000000..6ecbc76 --- /dev/null +++ b/src/main/scala/minisql/util/QuotesHelper.scala @@ -0,0 +1,24 @@ +package minisql.util + +import scala.quoted.* + +private[minisql] def splicePkgPath(using Quotes) = { + import quotes.reflect.* + def recurse(sym: Symbol): String = + sym match { + case s if s.isPackageDef => s.fullName + case s if s.isNoSymbol => "" + case _ => + recurse(sym.maybeOwner) + } + recurse(Symbol.spliceOwner) +} + +private[minisql] def liftIdOfExpr(x: Expr[?])(using Quotes) = { + import quotes.reflect.* + val name = x.asTerm.show + val packageName = splicePkgPath + val pos = x.asTerm.pos + val fileName = pos.sourceFile.name + s"${name}@${packageName}.${fileName}:${pos.startLine}:${pos.startColumn}" +} diff --git a/src/test/scala/minisql/mirror/QuotedSuite.scala b/src/test/scala/minisql/context/sql/QuotedSuite.scala similarity index 57% rename from src/test/scala/minisql/mirror/QuotedSuite.scala rename to src/test/scala/minisql/context/sql/QuotedSuite.scala index e8f4b39..c5fda24 100644 --- a/src/test/scala/minisql/mirror/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/QuotedSuite.scala @@ -1,4 +1,4 @@ -package minisql.context.mirror +package minisql.context.sql import minisql.* import minisql.ast.* @@ -9,23 +9,27 @@ import minisql.context.mirror.{*, given} class QuotedSuite extends munit.FunSuite { - case class Foo(id: Long) + case class Foo(id: Long, name: String) - import mirrorContext.given + inline def Foos = query[Foo]("foo") + + import testContext.given test("SimpleQuery") { - val o = mirrorContext.io( + val o = testContext.io( query[Foo]( "foo", alias("id", "id1") ).filter(_.id > 0) ) - println("============" + o) - o + println(o) } test("Insert") { + val v: Foo = Foo(0L, "foo") + + val o = testContext.io(Foos.insert(v)) + println(o) ??? } - } diff --git a/src/test/scala/minisql/context/sql/context.scala b/src/test/scala/minisql/context/sql/context.scala new file mode 100644 index 0000000..d3d36fa --- /dev/null +++ b/src/test/scala/minisql/context/sql/context.scala @@ -0,0 +1,5 @@ +package minisql.context.sql + +import minisql.* + +val testContext = new MirrorSqlContext(Literal) diff --git a/src/test/scala/minisql/mirror/context.scala b/src/test/scala/minisql/mirror/context.scala deleted file mode 100644 index 240a475..0000000 --- a/src/test/scala/minisql/mirror/context.scala +++ /dev/null @@ -1,6 +0,0 @@ -package minisql.context.mirror - -import minisql.* -import minisql.idiom.MirrorIdiom - -val mirrorContext = new MirrorContext(MirrorIdiom, Literal)