This commit is contained in:
jilen 2025-07-09 19:22:17 +08:00
parent 0fcd4ac140
commit 13633e0b20
3 changed files with 77 additions and 30 deletions

View file

@ -1,3 +1,18 @@
package minisql
type QueryMeta
import minisql.util.*
import scala.deriving.*
import scala.quoted.*
opaque type UpdateMeta[E] <: List[String] = List[String]
object UpdateMeta {
inline def noExclude[E]: UpdateMeta[E] = unsafe(Nil)
inline def apply[E](using m: Mirror.ProductOf[E])(
inline exclude: List[Tuple.Union[m.MirroredElemLabels] & String]
): UpdateMeta[E] = unsafe(exclude)
inline def unsafe[E](inline ex: List[String]): UpdateMeta[E] = ex
}

View file

@ -21,6 +21,7 @@ import scala.deriving.*
import scala.compiletime.*
import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap}
import scala.util.NotGiven
opaque type Quoted <: Ast = Ast
@ -165,7 +166,13 @@ object EntityQuery {
}
inline def insert(v: E): Insert[E] = {
ast.Insert(e, transformCaseClassToAssignments[E](v))
ast.Insert(e, transformCaseClassToAssignments[E](v, Nil))
}
inline def update(
v: E
)(using inline m: UpdateMeta[E] = UpdateMeta.noExclude[E]): Update[Long] = {
ast.Update(e, transformCaseClassToAssignments[E](v, m))
}
inline def update(inline ass: (E => (Any, Any))*): Update[Long] = {
@ -193,35 +200,45 @@ private def parseFuncAssignImpl[E](x: Expr[Seq[E => (Any, Any)]])(using
}
private inline def transformCaseClassToAssignments[E](
v: E
v: E,
inline exclude: List[String]
): List[ast.Assignment] = ${
transformCaseClassToAssignmentsImpl[E]('v)
transformCaseClassToAssignmentsImpl[E]('v, 'exclude)
}
private def transformCaseClassToAssignmentsImpl[E: Type](
v: Expr[E]
v: Expr[E],
exclude: Expr[List[String]]
)(using Quotes): Expr[List[ast.Assignment]] = {
import quotes.reflect.*
val excludeFields = extractTerm(exclude.asTerm).asExpr match {
case '{ $xs: List[String] } =>
xs.value.getOrElse(
report.errorAndAbort(s"Cannot handle exclude ${xs.show}")
)
}
val fields = TypeRepr.of[E].typeSymbol.caseFields
val assignments = fields.map { field =>
val fieldName = field.name
val fieldType = field.tree match {
case v: ValDef => v.tpt.tpe
case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName")
}
fieldType.asType match {
case '[t] =>
'{
ast.Assignment(
ast.Ident("v"),
ast.Property(ast.Ident("v"), ${ Expr(fieldName) }),
quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using
summonInline[ParamEncoder[t]]
val assignments = fields.collect {
case field if !excludeFields.contains(field.name) =>
val fieldName = field.name
val fieldType = field.tree match {
case v: ValDef => v.tpt.tpe
case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName")
}
fieldType.asType match {
case '[t] =>
'{
ast.Assignment(
ast.Ident("v"),
ast.Property(ast.Ident("v"), ${ Expr(fieldName) }),
quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using
summonInline[ParamEncoder[t]]
)
)
)
}
}
}
}
}
Expr.ofList(assignments)

View file

@ -6,10 +6,12 @@ import minisql.idiom.*
import minisql.NamingStrategy
import minisql.MirrorContext
import minisql.context.mirror.{*, given}
import scala.quoted.* // Needed for inline/summonFrom
class MirrorSqlContextSuite extends munit.FunSuite {
case class Foo(id: Long, name: String)
case class Foo(id: Long, name: String, age: Int)
val foo0: Foo = Foo(0L, "foo", 10)
inline def Foos = query[Foo]("foo")
@ -22,26 +24,28 @@ class MirrorSqlContextSuite extends munit.FunSuite {
alias("id", "id1")
).filter(x => x.id > 0)
)
assertEquals(o.sql, "SELECT x.id1, x.name FROM foo x WHERE x.id1 > 0")
assertEquals(
o.sql,
"SELECT x.id1, x.name, x.age FROM foo x WHERE x.id1 > 0"
)
}
test("Insert") {
val v: Foo = Foo(0L, "foo")
val v: Foo = Foo(0L, "foo", 1)
val o = testContext.io(Foos.insert(v))
assertEquals(
o.sql,
"INSERT INTO foo (id,name) VALUES (?, ?)"
"INSERT INTO foo (id,name,age) VALUES (?, ?, ?)"
)
}
test("InsertReturningGenerated") {
val v: Foo = Foo(0L, "foo")
val o = testContext.io(Foos.insert(v).returningGenerated(_.id))
val o = testContext.io(Foos.insert(foo0).returningGenerated(_.id))
assertEquals(
o.sql,
"INSERT INTO foo (name) VALUES (?) RETURNING id"
"INSERT INTO foo (name,age) VALUES (?, ?) RETURNING id"
)
}
@ -64,7 +68,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f")
}
test("Update") {
test("Update with explicit assignments") {
val name = "new name"
val id = 2L
val o = testContext.io(
@ -88,4 +92,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
"UPDATE foo SET id = (id + ?) WHERE id = 1"
)
}
test("Update/excluding fields") {
inline given UpdateMeta[Foo] = UpdateMeta[Foo](List("id"))
val o = testContext.io(
Foos.filter(_.id == 1).update(foo0)
)
assertEquals(
o.sql,
"UPDATE foo SET name = ?, age = ? WHERE id = 1"
)
}
}