From 9a7f66a2681805d051f1baa3f3f9370ab993f9b2 Mon Sep 17 00:00:00 2001 From: jilen Date: Thu, 17 Jul 2025 17:11:31 +0800 Subject: [PATCH] Add filterOpt --- .scalafmt.conf | 4 +-- src/main/scala/minisql/Quoted.scala | 26 ++++++++++++++++++- src/main/scala/minisql/context/Context.scala | 10 +++---- .../context/sql/MirrorSqlContextSuite.scala | 21 ++++++++++++--- 4 files changed, 49 insertions(+), 12 deletions(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index 25f62b6..b5fac54 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = "3.8.3" +version = "3.9.8" style = defaultWithAlign runner.dialect=scala3 maxColumn = 80 @@ -13,4 +13,4 @@ runner.optimizer.forceConfigStyleMinArgCount = 2 newlines.beforeCurlyLambdaParams = multilineWithCaseOnly indentOperator.topLevelOnly = true rewrite.imports.sort = original -docstrings.style = keep +docstrings.style = Asterisk diff --git a/src/main/scala/minisql/Quoted.scala b/src/main/scala/minisql/Quoted.scala index 42feb64..769c7ec 100644 --- a/src/main/scala/minisql/Quoted.scala +++ b/src/main/scala/minisql/Quoted.scala @@ -142,6 +142,24 @@ object Query { expandFields[E](e) } + /** + * Filter with pred `f` if x is Some(...), otherwise true + */ + inline def filterOpt[E1]( + x: Option[E1] + )( + inline f: (E, E1) => Boolean + )(using ParamEncoder[Option[E1]]): Query[E] = { + + transform2(e)(f) { (p, id1, id2, body) => + ast.Filter( + p, + id1, + ast.OptionForall(quotedLift(x), id2, body) + ) + } + } + inline def leftJoin[E1](inline e1: Query[E1]): Joined[E, Option[E1]] = Joined[E, Option[E1]](JoinType.LeftJoin, e, e1) @@ -258,7 +276,7 @@ private def transformCaseClassToAssignmentsImpl[E: Type]( ) } - val fields = TypeRepr.of[E].typeSymbol.caseFields + val fields = TypeRepr.of[E].typeSymbol.caseFields val assignments = fields.collect { case field if !excludeFields.contains(field.name) => val fieldName = field.name @@ -289,6 +307,12 @@ private inline def transform[A, B](inline q1: Quoted)( fast(q1, f.param0, f.body) } +private inline def transform2[A, A1, B](inline q1: Quoted)( + inline f: (A, A1) => B +)(inline fast: (Ast, Ident, Ident, Ast) => Ast): Quoted = { + fast(q1, f.param0, f.param1, f.body) +} + inline def alias(inline from: String, inline to: String): PropertyAlias = PropertyAlias(List(from), to) diff --git a/src/main/scala/minisql/context/Context.scala b/src/main/scala/minisql/context/Context.scala index 90bb8e7..c0e25ba 100644 --- a/src/main/scala/minisql/context/Context.scala +++ b/src/main/scala/minisql/context/Context.scala @@ -93,7 +93,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => @targetName("ioAction") inline def io[E](inline q: minisql.Action[E]): DBIO[E] = { val extractor = summonFrom { - case e: RowExtract[E, DBRow] => e + case e: RowExtract[E, DBRow] => e case e: ColumnDecoder.Aux[DBRow, E] => RowExtract.single(e) } @@ -120,7 +120,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => case _: (E <:< Option[?]) => (rows: Iterable[DBRow]) => rows.toVector match { - case Vector() => Success(None.asInstanceOf[E]) + case Vector() => Success(None.asInstanceOf[E]) case Vector(r) => RowExtract.single(e).extract(r).map(Some(_).asInstanceOf[E]) case o => @@ -134,7 +134,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => (rows) => rows.toVector match { case Vector(r) => RowExtract.single(e).extract(r) - case o => + case o => Failure( new IllegalStateException( s"Expect agg value, got ${o.size} rows" @@ -148,7 +148,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => val (sql, params) = stmt.expand(lifts) ( sql = sql, - params = params.map(_.value.get.asInstanceOf), + params = params.map(_.value.get.asInstanceOf[(Any, Encoder[?])]), mapper = mapper ) } @@ -169,7 +169,7 @@ trait Context[I <: Idiom, N <: NamingStrategy] { selft => val (sql, params) = stmt.expand(lifts) ( sql = sql, - params = params.map(_.value.get.asInstanceOf), + params = params.map(_.value.get.asInstanceOf[(Any, Encoder[?])]), mapper = (rows) => rows.traverse(extractor.extract) ) } diff --git a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala index a6d2d6a..409ae6a 100644 --- a/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala +++ b/src/test/scala/minisql/context/sql/MirrorSqlContextSuite.scala @@ -58,7 +58,7 @@ class MirrorSqlContextSuite extends munit.FunSuite { val o = testContext.io( Foos.insert( f => f.name -> lift(name), - f => f.age -> lift(age) + f => f.age -> lift(age) ) ) assertEquals( @@ -111,7 +111,7 @@ class MirrorSqlContextSuite extends munit.FunSuite { test("Update/assignments") { val name = "new name" val id = 2L - val o = testContext.io( + val o = testContext.io( Foos .filter(_.id == 1) .update(f => f.name -> lift(name), f => f.id -> lift(id)) @@ -124,7 +124,7 @@ class MirrorSqlContextSuite extends munit.FunSuite { test("Update/increment") { val delta = 1 - val o = testContext.io( + val o = testContext.io( Foos.filter(_.id == 1).update(f => f.id -> (f.id + lift(delta))) ) assertEquals( @@ -135,7 +135,7 @@ class MirrorSqlContextSuite extends munit.FunSuite { test("Update/excluding fields") { inline given UpdateMeta[Foo] = UpdateMeta[Foo](List("id")) - val o = testContext.io( + val o = testContext.io( Foos.filter(_.id == 1).update(foo0) ) assertEquals( @@ -185,4 +185,17 @@ class MirrorSqlContextSuite extends munit.FunSuite { "SELECT SUM(f.age) FROM foo f" ) } + + test("filterOpt - Some value") { + val maybeId: Option[Long] = Some(5L) + val o = testContext.io( + Foos.filterOpt(maybeId)((f, id) => f.id == id) + ) + assertEquals( + o.sql, + "SELECT f.id, f.name, f.age FROM foo f WHERE ? IS NULL OR f.id = ?" + ) + println(o.params) + } + }