From 23d1ca480e64c93ae9a0ad61769df6b1579ca1be Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 9 Jul 2025 17:11:23 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/scala/minisql/Quoted.scala | 24 +++++++++++++++++++ src/main/scala/minisql/parsing/Parser.scala | 16 +++++++++++++ .../context/sql/MirrorSqlContextSuite.scala | 14 +++++++++++ 3 files changed, 54 insertions(+) diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 45376f1..a566339 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -28,6 +28,8 @@ opaque type Query[E] <: Quoted = Quoted opaque type Action[E] <: Quoted = Quoted +opaque type Update[E] <: Action[Long] = Quoted + opaque type Insert[E] <: Action[Long] = Quoted object Insert { @@ -165,6 +167,28 @@ object EntityQuery { inline def insert(v: E): Insert[E] = { ast.Insert(e, transformCaseClassToAssignments[E](v)) } + + inline def update(inline ass: (E => (Any, Any))*): Update[Long] = { + ast.Update(e, parseFuncAssign(ass)) + } + } +} + +private inline def parseFuncAssign[E]( + inline ass: Seq[(E => (Any, Any))] +): List[ast.Assignment] = ${ parseFuncAssignImpl[E]('ass) } + +private def parseFuncAssignImpl[E](x: Expr[Seq[E => (Any, Any)]])(using + Quotes, + Type[E] +): Expr[List[ast.Assignment]] = { + import quotes.reflect.* + x match { + case '{ ${ Varargs(ass) } } => + val assExprs = ass.map { a => + parseAssignment(a) + } + Expr.ofList(assExprs) } } diff --git a/src/main/scala/minisql/parsing/Parser.scala b/src/main/scala/minisql/parsing/Parser.scala index 4235398..576a5c2 100644 --- a/src/main/scala/minisql/parsing/Parser.scala +++ b/src/main/scala/minisql/parsing/Parser.scala @@ -46,3 +46,19 @@ private[minisql] def parseBody[X]( report.errorAndAbort(s"Can only parse function") } } + +private[minisql] def parseAssignment(x: Expr[?])(using + Quotes +): Expr[ast.Assignment] = { + import quotes.reflect.* + x.asTerm match { + case Lambda(List(ValDef(n, _, _)), IsTuple2(prop, value)) => + '{ + ast.Assignment( + ast.Ident(${ Expr(n) }), + ${ Parsing.parseExpr(prop) }, + ${ Parsing.parseExpr(value) } + ) + } + } +} diff --git a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala index 0b79e14..ceb7706 100644 --- a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala +++ b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala @@ -60,4 +60,18 @@ class MirrorSqlContextSuite extends munit.FunSuite { ) assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f") } + + test("Update") { + val name = "new name" + val id = 2L + val o = testContext.io( + Foos + .filter(_.id == 1) + .update(f => f.name -> lift(name), f => f.id -> lift(id)) + ) + assertEquals( + o.sql, + "UPDATE foo SET name = ?, id = ? WHERE id = 1" + ) + } }