diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 1efc593..45376f1 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -28,7 +28,23 @@ opaque type Query[E] <: Quoted = Quoted opaque type Action[E] <: Quoted = Quoted -opaque type Insert <: Action[Long] = Quoted +opaque type Insert[E] <: Action[Long] = Quoted + +object Insert { + extension [E](inline insert: Insert[E]) { + inline def returning[E1](inline f: E => E1): InsertReturning[E1] = { + transform(insert)(f)(ast.Returning.apply) + } + + inline def returningGenerated[E1]( + inline f: E => E1 + ): InsertReturning[E1] = { + transform(insert)(f)(ast.ReturningGenerated.apply) + } + } +} + +opaque type InsertReturning[E] <: Action[E] = Quoted sealed trait Joined[E1, E2] @@ -66,9 +82,7 @@ private def joinQueryOf[E1, E2]( )(using Quotes, Type[E1], Type[E2]): Expr[Join] = { import quotes.reflect.* extractTerm(x.asTerm).asExpr match { - case '{ - Joined[E1, E2]($jt, $a, $b) - } => + case '{ Joined[E1, E2]($jt, $a, $b) } => '{ Join($jt, $a, $b, $aliasA, $aliasB, $on) } @@ -148,7 +162,7 @@ object EntityQuery { transform(e)(f)(Filter.apply) } - inline def insert(v: E)(using m: Mirror.ProductOf[E]): Insert = { + inline def insert(v: E): Insert[E] = { ast.Insert(e, transformCaseClassToAssignments[E](v)) } } @@ -156,7 +170,7 @@ object EntityQuery { private inline def transformCaseClassToAssignments[E]( v: E -)(using m: Mirror.ProductOf[E]): List[ast.Assignment] = ${ +): List[ast.Assignment] = ${ transformCaseClassToAssignmentsImpl[E]('v) } diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index 072cf21..08686b3 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -270,12 +270,29 @@ private given FromExpr[Action] with { ass.sequence.map { ass1 => Update(a, ass1) } - case '{ Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) } => + case '{ + Returning(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) + } => + Some(Returning(act, id, body)) + case '{ + val x: Ast = ${ Expr(act) } + val y: Ident = ${ Expr(id) } + val z: Ast = ${ Expr(body) } + Returning(x, y, z) + } => Some(Returning(act, id, body)) case '{ ReturningGenerated(${ Expr(act) }, ${ Expr(id) }, ${ Expr(body) }) } => Some(ReturningGenerated(act, id, body)) + case '{ + val x: Ast = ${ Expr(act) } + val y: Ident = ${ Expr(id) } + val z: Ast = ${ Expr(body) } + ReturningGenerated(x, y, z) + } => + Some(ReturningGenerated(act, id, body)) + } } } diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index 8820c12..9cab3b1 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -73,6 +73,22 @@ class FromExprsSuite extends FunSuite { ) } + testFor("Action - Returning") { + Returning( + Insert(Ident("table"), List(Assignment(Ident("x"), Ident("col"), Ident("val")))), + Ident("x"), + Property(Ident("x"), "id") + ) + } + + testFor("Action - ReturningGenerated") { + ReturningGenerated( + Insert(Ident("table"), List(Assignment(Ident("x"), Ident("col"), Ident("val")))), + Ident("x"), + Property(Ident("x"), "generatedId") + ) + } + testFor("If expression") { If(Ident("cond"), Ident("then"), Ident("else")) } diff --git a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala index ff26942..0b79e14 100644 --- a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala +++ b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala @@ -35,6 +35,16 @@ class MirrorSqlContextSuite extends munit.FunSuite { ) } + test("InsertReturningGenerated") { + val v: Foo = Foo(0L, "foo") + + val o = testContext.io(Foos.insert(v).returningGenerated(_.id)) + assertEquals( + o.sql, + "INSERT INTO foo (name) VALUES (?) RETURNING id" + ) + } + test("LeftJoin") { val o = testContext .io(Foos.join(Foos).on((f1, f2) => f1.id == f2.id).map {