This commit is contained in:
jilen 2025-07-09 19:22:17 +08:00
parent 0fcd4ac140
commit 13633e0b20
3 changed files with 77 additions and 30 deletions

View file

@ -1,3 +1,18 @@
package minisql package minisql
type QueryMeta import minisql.util.*
import scala.deriving.*
import scala.quoted.*
opaque type UpdateMeta[E] <: List[String] = List[String]
object UpdateMeta {
inline def noExclude[E]: UpdateMeta[E] = unsafe(Nil)
inline def apply[E](using m: Mirror.ProductOf[E])(
inline exclude: List[Tuple.Union[m.MirroredElemLabels] & String]
): UpdateMeta[E] = unsafe(exclude)
inline def unsafe[E](inline ex: List[String]): UpdateMeta[E] = ex
}

View file

@ -21,6 +21,7 @@ import scala.deriving.*
import scala.compiletime.* import scala.compiletime.*
import scala.compiletime.ops.string.* import scala.compiletime.ops.string.*
import scala.collection.immutable.{Map => IMap} import scala.collection.immutable.{Map => IMap}
import scala.util.NotGiven
opaque type Quoted <: Ast = Ast opaque type Quoted <: Ast = Ast
@ -165,7 +166,13 @@ object EntityQuery {
} }
inline def insert(v: E): Insert[E] = { inline def insert(v: E): Insert[E] = {
ast.Insert(e, transformCaseClassToAssignments[E](v)) ast.Insert(e, transformCaseClassToAssignments[E](v, Nil))
}
inline def update(
v: E
)(using inline m: UpdateMeta[E] = UpdateMeta.noExclude[E]): Update[Long] = {
ast.Update(e, transformCaseClassToAssignments[E](v, m))
} }
inline def update(inline ass: (E => (Any, Any))*): Update[Long] = { inline def update(inline ass: (E => (Any, Any))*): Update[Long] = {
@ -193,35 +200,45 @@ private def parseFuncAssignImpl[E](x: Expr[Seq[E => (Any, Any)]])(using
} }
private inline def transformCaseClassToAssignments[E]( private inline def transformCaseClassToAssignments[E](
v: E v: E,
inline exclude: List[String]
): List[ast.Assignment] = ${ ): List[ast.Assignment] = ${
transformCaseClassToAssignmentsImpl[E]('v) transformCaseClassToAssignmentsImpl[E]('v, 'exclude)
} }
private def transformCaseClassToAssignmentsImpl[E: Type]( private def transformCaseClassToAssignmentsImpl[E: Type](
v: Expr[E] v: Expr[E],
exclude: Expr[List[String]]
)(using Quotes): Expr[List[ast.Assignment]] = { )(using Quotes): Expr[List[ast.Assignment]] = {
import quotes.reflect.* import quotes.reflect.*
val excludeFields = extractTerm(exclude.asTerm).asExpr match {
case '{ $xs: List[String] } =>
xs.value.getOrElse(
report.errorAndAbort(s"Cannot handle exclude ${xs.show}")
)
}
val fields = TypeRepr.of[E].typeSymbol.caseFields val fields = TypeRepr.of[E].typeSymbol.caseFields
val assignments = fields.map { field => val assignments = fields.collect {
val fieldName = field.name case field if !excludeFields.contains(field.name) =>
val fieldType = field.tree match { val fieldName = field.name
case v: ValDef => v.tpt.tpe val fieldType = field.tree match {
case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName") case v: ValDef => v.tpt.tpe
} case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName")
fieldType.asType match { }
case '[t] => fieldType.asType match {
'{ case '[t] =>
ast.Assignment( '{
ast.Ident("v"), ast.Assignment(
ast.Property(ast.Ident("v"), ${ Expr(fieldName) }), ast.Ident("v"),
quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using ast.Property(ast.Ident("v"), ${ Expr(fieldName) }),
summonInline[ParamEncoder[t]] quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using
summonInline[ParamEncoder[t]]
)
) )
) }
} }
}
} }
Expr.ofList(assignments) Expr.ofList(assignments)

View file

@ -6,10 +6,12 @@ import minisql.idiom.*
import minisql.NamingStrategy import minisql.NamingStrategy
import minisql.MirrorContext import minisql.MirrorContext
import minisql.context.mirror.{*, given} import minisql.context.mirror.{*, given}
import scala.quoted.* // Needed for inline/summonFrom
class MirrorSqlContextSuite extends munit.FunSuite { class MirrorSqlContextSuite extends munit.FunSuite {
case class Foo(id: Long, name: String) case class Foo(id: Long, name: String, age: Int)
val foo0: Foo = Foo(0L, "foo", 10)
inline def Foos = query[Foo]("foo") inline def Foos = query[Foo]("foo")
@ -22,26 +24,28 @@ class MirrorSqlContextSuite extends munit.FunSuite {
alias("id", "id1") alias("id", "id1")
).filter(x => x.id > 0) ).filter(x => x.id > 0)
) )
assertEquals(o.sql, "SELECT x.id1, x.name FROM foo x WHERE x.id1 > 0") assertEquals(
o.sql,
"SELECT x.id1, x.name, x.age FROM foo x WHERE x.id1 > 0"
)
} }
test("Insert") { test("Insert") {
val v: Foo = Foo(0L, "foo") val v: Foo = Foo(0L, "foo", 1)
val o = testContext.io(Foos.insert(v)) val o = testContext.io(Foos.insert(v))
assertEquals( assertEquals(
o.sql, o.sql,
"INSERT INTO foo (id,name) VALUES (?, ?)" "INSERT INTO foo (id,name,age) VALUES (?, ?, ?)"
) )
} }
test("InsertReturningGenerated") { test("InsertReturningGenerated") {
val v: Foo = Foo(0L, "foo")
val o = testContext.io(Foos.insert(v).returningGenerated(_.id)) val o = testContext.io(Foos.insert(foo0).returningGenerated(_.id))
assertEquals( assertEquals(
o.sql, o.sql,
"INSERT INTO foo (name) VALUES (?) RETURNING id" "INSERT INTO foo (name,age) VALUES (?, ?) RETURNING id"
) )
} }
@ -64,7 +68,7 @@ class MirrorSqlContextSuite extends munit.FunSuite {
assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f") assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f")
} }
test("Update") { test("Update with explicit assignments") {
val name = "new name" val name = "new name"
val id = 2L val id = 2L
val o = testContext.io( val o = testContext.io(
@ -88,4 +92,15 @@ class MirrorSqlContextSuite extends munit.FunSuite {
"UPDATE foo SET id = (id + ?) WHERE id = 1" "UPDATE foo SET id = (id + ?) WHERE id = 1"
) )
} }
test("Update/excluding fields") {
inline given UpdateMeta[Foo] = UpdateMeta[Foo](List("id"))
val o = testContext.io(
Foos.filter(_.id == 1).update(foo0)
)
assertEquals(
o.sql,
"UPDATE foo SET name = ?, age = ? WHERE id = 1"
)
}
} }