diff --git a/src/main/scala/minisql/Meta.scala b/src/main/scala/minisql/Meta.scala index 38d8fb9..d7c5464 100644 --- a/src/main/scala/minisql/Meta.scala +++ b/src/main/scala/minisql/Meta.scala @@ -1,3 +1,18 @@ package minisql -type QueryMeta +import minisql.util.* +import scala.deriving.* +import scala.quoted.* + +opaque type UpdateMeta[E] <: List[String] = List[String] + +object UpdateMeta { + + inline def noExclude[E]: UpdateMeta[E] = unsafe(Nil) + + inline def apply[E](using m: Mirror.ProductOf[E])( + inline exclude: List[Tuple.Union[m.MirroredElemLabels] & String] + ): UpdateMeta[E] = unsafe(exclude) + + inline def unsafe[E](inline ex: List[String]): UpdateMeta[E] = ex +} diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index a566339..0f16d93 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -21,6 +21,7 @@ import scala.deriving.* import scala.compiletime.* import scala.compiletime.ops.string.* import scala.collection.immutable.{Map => IMap} +import scala.util.NotGiven opaque type Quoted <: Ast = Ast @@ -165,7 +166,13 @@ object EntityQuery { } inline def insert(v: E): Insert[E] = { - ast.Insert(e, transformCaseClassToAssignments[E](v)) + ast.Insert(e, transformCaseClassToAssignments[E](v, Nil)) + } + + inline def update( + v: E + )(using inline m: UpdateMeta[E] = UpdateMeta.noExclude[E]): Update[Long] = { + ast.Update(e, transformCaseClassToAssignments[E](v, m)) } inline def update(inline ass: (E => (Any, Any))*): Update[Long] = { @@ -193,35 +200,45 @@ private def parseFuncAssignImpl[E](x: Expr[Seq[E => (Any, Any)]])(using } private inline def transformCaseClassToAssignments[E]( - v: E + v: E, + inline exclude: List[String] ): List[ast.Assignment] = ${ - transformCaseClassToAssignmentsImpl[E]('v) + transformCaseClassToAssignmentsImpl[E]('v, 'exclude) } private def transformCaseClassToAssignmentsImpl[E: Type]( - v: Expr[E] + v: Expr[E], + exclude: Expr[List[String]] )(using Quotes): Expr[List[ast.Assignment]] = { import quotes.reflect.* + val excludeFields = extractTerm(exclude.asTerm).asExpr match { + case '{ $xs: List[String] } => + xs.value.getOrElse( + report.errorAndAbort(s"Cannot handle exclude ${xs.show}") + ) + } + val fields = TypeRepr.of[E].typeSymbol.caseFields - val assignments = fields.map { field => - val fieldName = field.name - val fieldType = field.tree match { - case v: ValDef => v.tpt.tpe - case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName") - } - fieldType.asType match { - case '[t] => - '{ - ast.Assignment( - ast.Ident("v"), - ast.Property(ast.Ident("v"), ${ Expr(fieldName) }), - quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using - summonInline[ParamEncoder[t]] + val assignments = fields.collect { + case field if !excludeFields.contains(field.name) => + val fieldName = field.name + val fieldType = field.tree match { + case v: ValDef => v.tpt.tpe + case _ => report.errorAndAbort(s"Expected ValDef for field $fieldName") + } + fieldType.asType match { + case '[t] => + '{ + ast.Assignment( + ast.Ident("v"), + ast.Property(ast.Ident("v"), ${ Expr(fieldName) }), + quotedLift[t](${ Select(v.asTerm, field).asExprOf[t] })(using + summonInline[ParamEncoder[t]] + ) ) - ) - } - } + } + } } Expr.ofList(assignments) diff --git a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala index 210a80f..ee320c3 100644 --- a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala +++ b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala @@ -6,10 +6,12 @@ import minisql.idiom.* import minisql.NamingStrategy import minisql.MirrorContext import minisql.context.mirror.{*, given} +import scala.quoted.* // Needed for inline/summonFrom class MirrorSqlContextSuite extends munit.FunSuite { - case class Foo(id: Long, name: String) + case class Foo(id: Long, name: String, age: Int) + val foo0: Foo = Foo(0L, "foo", 10) inline def Foos = query[Foo]("foo") @@ -22,26 +24,28 @@ class MirrorSqlContextSuite extends munit.FunSuite { alias("id", "id1") ).filter(x => x.id > 0) ) - assertEquals(o.sql, "SELECT x.id1, x.name FROM foo x WHERE x.id1 > 0") + assertEquals( + o.sql, + "SELECT x.id1, x.name, x.age FROM foo x WHERE x.id1 > 0" + ) } test("Insert") { - val v: Foo = Foo(0L, "foo") + val v: Foo = Foo(0L, "foo", 1) val o = testContext.io(Foos.insert(v)) assertEquals( o.sql, - "INSERT INTO foo (id,name) VALUES (?, ?)" + "INSERT INTO foo (id,name,age) VALUES (?, ?, ?)" ) } test("InsertReturningGenerated") { - val v: Foo = Foo(0L, "foo") - val o = testContext.io(Foos.insert(v).returningGenerated(_.id)) + val o = testContext.io(Foos.insert(foo0).returningGenerated(_.id)) assertEquals( o.sql, - "INSERT INTO foo (name) VALUES (?) RETURNING id" + "INSERT INTO foo (name,age) VALUES (?, ?) RETURNING id" ) } @@ -64,7 +68,7 @@ class MirrorSqlContextSuite extends munit.FunSuite { assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f") } - test("Update") { + test("Update with explicit assignments") { val name = "new name" val id = 2L val o = testContext.io( @@ -88,4 +92,15 @@ class MirrorSqlContextSuite extends munit.FunSuite { "UPDATE foo SET id = (id + ?) WHERE id = 1" ) } + + test("Update/excluding fields") { + inline given UpdateMeta[Foo] = UpdateMeta[Foo](List("id")) + val o = testContext.io( + Foos.filter(_.id == 1).update(foo0) + ) + assertEquals( + o.sql, + "UPDATE foo SET name = ?, age = ? WHERE id = 1" + ) + } }