diff --git a/src/main/scala/minisql/SqlInfix.scala b/src/main/scala/minisql/SqlInfix.scala new file mode 100644 index 0000000..ea5fe6a --- /dev/null +++ b/src/main/scala/minisql/SqlInfix.scala @@ -0,0 +1,12 @@ +package minisql + +import minisql.ast.Ast +import scala.quoted.* + +sealed trait InfixValue { + def as[T]: T +} + +extension (sc: StringContext) { + def infix(args: Any*): InfixValue = throw NonQuotedException() +} diff --git a/src/main/scala/minisql/parsing/InfixParsing.scala b/src/main/scala/minisql/parsing/InfixParsing.scala index 8b13789..1b19cad 100644 --- a/src/main/scala/minisql/parsing/InfixParsing.scala +++ b/src/main/scala/minisql/parsing/InfixParsing.scala @@ -1 +1,29 @@ +package minisql.parsing +import minisql.ast +import scala.quoted.* + +private[parsing] def infixParsing( + astParser: => Parser[ast.Ast] +)(using Quotes): Parser[ast.Infix] = { + import quotes.reflect.* + { + case '{ ($x: minisql.InfixValue).as[t] } => infixParsing(astParser)(x) + case '{ + minisql.infix(StringContext(${ Varargs(partsExprs) }*))(${ + Varargs(argsExprs) + }*) + } => + val parts = partsExprs.map { p => + p.value.getOrElse( + report.errorAndAbort( + s"Expected a string literal in StringContext parts, but got: ${p.show}" + ) + ) + }.toList + + val params = argsExprs.map(arg => astParser(arg)).toList + + '{ ast.Infix(${ Expr(parts) }, ${ Expr.ofList(params) }, true, false) } + } +} diff --git a/src/main/scala/minisql/parsing/Parsing.scala b/src/main/scala/minisql/parsing/Parsing.scala index f044292..8794e7f 100644 --- a/src/main/scala/minisql/parsing/Parsing.scala +++ b/src/main/scala/minisql/parsing/Parsing.scala @@ -41,8 +41,10 @@ private[minisql] object Parsing { f: Parser[ast.Ast] ): Parser[ast.Ast] = { case expr => - val t = expr.asTerm - f(extractTerm(t).asExpr) + val t = extractTerm(expr.asTerm) + if (t.isExpr) + f(t.asExpr) + else f(expr) } lazy val astParser: Parser[ast.Ast] = @@ -50,6 +52,7 @@ private[minisql] object Parsing { typedParser .orElse(propertyParser) .orElse(liftParser) + .orElse(infixParser) .orElse(identParser) .orElse(valueParser) .orElse(operationParser) @@ -108,6 +111,10 @@ private[minisql] object Parsing { lazy val traversableOperationParser: Parser[ast.IterableOperation] = traversableOperationParsing(astParser) + lazy val infixParser: Parser[ast.Infix] = infixParsing( + astParser + ) + astParser(expr) } diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index 4e6d8c9..8820c12 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -86,6 +86,15 @@ class FromExprsSuite extends FunSuite { ) } + testFor("Infix with different parameters") { + Infix( + List("?", " + ", "?"), + List(Constant(1), Constant(2)), + pure = true, + noParen = true + ) + } + testFor("OptionOperation - OptionMap") { OptionMap(Ident("opt"), Ident("x"), Ident("x")) } diff --git a/src/test/scala/minisql/context/sql/QuotedSuite.scala b/src/test/scala/minisql/context/sql/QuotedSuite.scala index 17ea26b..467a10c 100644 --- a/src/test/scala/minisql/context/sql/QuotedSuite.scala +++ b/src/test/scala/minisql/context/sql/QuotedSuite.scala @@ -43,4 +43,11 @@ class QuotedSuite extends munit.FunSuite { println(o) } + + test("Infix string interpolation") { + val o = testContext.io( + Foos.map(f => infix"CONCAT(${f.name}, ' ', ${f.id})".as[String]) + ) + assertEquals(o.sql, "SELECT CONCAT(f.name, ' ', f.id) FROM foo f") + } }