diff --git a/src/main/scala/minisql/ast/FromExprs.scala b/src/main/scala/minisql/ast/FromExprs.scala index c6d48cc..aca73db 100644 --- a/src/main/scala/minisql/ast/FromExprs.scala +++ b/src/main/scala/minisql/ast/FromExprs.scala @@ -70,7 +70,7 @@ private given FromExpr[Property] with { } => Some(Property(a, n, r, v)) case o => - println(s"Cannot extrat ${o.show}") + println(s"Cannot extract ${o.show}") None } } @@ -82,6 +82,8 @@ private given FromExpr[Ordering] with { case '{ Desc } => Some(Desc) case '{ AscNullsFirst } => Some(AscNullsFirst) case '{ AscNullsLast } => Some(AscNullsLast) + case '{ DescNullsFirst } => Some(DescNullsFirst) + case '{ DescNullsLast } => Some(DescNullsLast) case '{ TupleOrdering($xs) } => xs.value.map(TupleOrdering(_)) } } @@ -135,6 +137,35 @@ private given FromExpr[Query] with { Some(Take(b, n)) case '{ SortBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(s) }, ${ Expr(o) }) } => Some(SortBy(b, p, s, o)) + case '{ GroupBy(${ Expr(b) }, ${ Expr(p) }, ${ Expr(body) }) } => + Some(GroupBy(b, p, body)) + case '{ Distinct(${ Expr(a) }) } => + Some(Distinct(a)) + case '{ DistinctOn(${ Expr(q) }, ${ Expr(a) }, ${ Expr(body) }) } => + Some(DistinctOn(q, a, body)) + case '{ Aggregation(${ Expr(op) }, ${ Expr(a) }) } => + Some(Aggregation(op, a)) + case '{ Union(${ Expr(a) }, ${ Expr(b) }) } => + Some(Union(a, b)) + case '{ UnionAll(${ Expr(a) }, ${ Expr(b) }) } => + Some(UnionAll(a, b)) + case '{ + Join( + ${ Expr(t) }, + ${ Expr(a) }, + ${ Expr(b) }, + ${ Expr(ia) }, + ${ Expr(ib) }, + ${ Expr(on) } + ) + } => + Some(Join(t, a, b, ia, ib, on)) + case '{ + FlatJoin(${ Expr(t) }, ${ Expr(a) }, ${ Expr(ia) }, ${ Expr(on) }) + } => + Some(FlatJoin(t, a, ia, on)) + case '{ Nested(${ Expr(a) }) } => + Some(Nested(a)) case o => println(s"Cannot extract ${o.show}") None @@ -153,17 +184,21 @@ private given FromExpr[BinaryOperator] with { case '{ NumericOperator.* } => Some(NumericOperator.*) case '{ NumericOperator./ } => Some(NumericOperator./) case '{ NumericOperator.> } => Some(NumericOperator.>) + case '{ NumericOperator.>= } => Some(NumericOperator.>=) + case '{ NumericOperator.< } => Some(NumericOperator.<) + case '{ NumericOperator.<= } => Some(NumericOperator.<=) + case '{ NumericOperator.% } => Some(NumericOperator.%) case '{ StringOperator.split } => Some(StringOperator.split) case '{ StringOperator.startsWith } => Some(StringOperator.startsWith) case '{ StringOperator.concat } => Some(StringOperator.concat) case '{ BooleanOperator.&& } => Some(BooleanOperator.&&) case '{ BooleanOperator.|| } => Some(BooleanOperator.||) + case '{ SetOperator.contains } => Some(SetOperator.contains) } } } private given FromExpr[UnaryOperator] with { - def unapply(x: Expr[UnaryOperator])(using Quotes): Option[UnaryOperator] = { x match { case '{ BooleanOperator.! } => Some(BooleanOperator.!) @@ -171,6 +206,33 @@ private given FromExpr[UnaryOperator] with { case '{ StringOperator.toLowerCase } => Some(StringOperator.toLowerCase) case '{ StringOperator.toLong } => Some(StringOperator.toLong) case '{ StringOperator.toInt } => Some(StringOperator.toInt) + case '{ NumericOperator.- } => Some(NumericOperator.-) + case '{ SetOperator.nonEmpty } => Some(SetOperator.nonEmpty) + case '{ SetOperator.isEmpty } => Some(SetOperator.isEmpty) + } + } +} + +private given FromExpr[AggregationOperator] with { + def unapply( + x: Expr[AggregationOperator] + )(using Quotes): Option[AggregationOperator] = { + x match { + case '{ AggregationOperator.min } => Some(AggregationOperator.min) + case '{ AggregationOperator.max } => Some(AggregationOperator.max) + case '{ AggregationOperator.avg } => Some(AggregationOperator.avg) + case '{ AggregationOperator.sum } => Some(AggregationOperator.sum) + case '{ AggregationOperator.size } => Some(AggregationOperator.size) + } + } +} + +private given FromExpr[Operator] with { + def unapply(x: Expr[Operator])(using Quotes): Option[Operator] = { + x match { + case '{ $x: BinaryOperator } => x.value + case '{ $x: UnaryOperator } => x.value + case '{ $x: AggregationOperator } => x.value } } } @@ -225,18 +287,21 @@ private given FromExpr[Action] with { extension [A](xs: Seq[Expr[A]]) { private def sequence(using FromExpr[A], Quotes): Option[List[A]] = { - val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) => - for { - _r <- r - _x <- x.value - } yield _r += _x + if (xs.isEmpty) Some(Nil) + else { + val acc = xs.foldLeft(Option(List.newBuilder[A])) { (r, x) => + for { + _r <- r + _x <- x.value + } yield _r += _x + } + acc.map(_.result()) } - acc.map(b => b.result()) } } -private given FromExpr[Constant] with { - def unapply(x: Expr[Constant])(using Quotes): Option[Constant] = { +private given FromExpr[Value] with { + def unapply(x: Expr[Value])(using Quotes): Option[Value] = { import quotes.reflect.{Constant => *, *} x match { case '{ Constant($ce) } => @@ -244,8 +309,92 @@ private given FromExpr[Constant] with { case Literal(v) => Some(Constant(v.value)) } + case '{ NullValue } => + Some(NullValue) + case '{ $x: CaseClass } => x.value } + } +} +private given FromExpr[OptionOperation] with { + def unapply( + x: Expr[OptionOperation] + )(using Quotes): Option[OptionOperation] = { + x match { + case '{ OptionFlatten(${ Expr(ast) }) } => + Some(OptionFlatten(ast)) + case '{ OptionGetOrElse(${ Expr(ast) }, ${ Expr(body) }) } => + Some(OptionGetOrElse(ast, body)) + case '{ + OptionFlatMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionFlatMap(ast, alias, body)) + case '{ OptionMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) } => + Some(OptionMap(ast, alias, body)) + case '{ + OptionForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionForall(ast, alias, body)) + case '{ + OptionExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionExists(ast, alias, body)) + case '{ OptionContains(${ Expr(ast) }, ${ Expr(body) }) } => + Some(OptionContains(ast, body)) + case '{ OptionIsEmpty(${ Expr(ast) }) } => + Some(OptionIsEmpty(ast)) + case '{ OptionNonEmpty(${ Expr(ast) }) } => + Some(OptionNonEmpty(ast)) + case '{ OptionIsDefined(${ Expr(ast) }) } => + Some(OptionIsDefined(ast)) + case '{ + OptionTableFlatMap( + ${ Expr(ast) }, + ${ Expr(alias) }, + ${ Expr(body) } + ) + } => + Some(OptionTableFlatMap(ast, alias, body)) + case '{ + OptionTableMap(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionTableMap(ast, alias, body)) + case '{ + OptionTableExists(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionTableExists(ast, alias, body)) + case '{ + OptionTableForall(${ Expr(ast) }, ${ Expr(alias) }, ${ Expr(body) }) + } => + Some(OptionTableForall(ast, alias, body)) + case '{ OptionNone } => Some(OptionNone) + case '{ OptionSome(${ Expr(ast) }) } => Some(OptionSome(ast)) + case '{ OptionApply(${ Expr(ast) }) } => Some(OptionApply(ast)) + case '{ OptionOrNull(${ Expr(ast) }) } => Some(OptionOrNull(ast)) + case '{ OptionGetOrNull(${ Expr(ast) }) } => Some(OptionGetOrNull(ast)) + case _ => None + } + } +} + +private given FromExpr[CaseClass] with { + def unapply(x: Expr[CaseClass])(using Quotes): Option[CaseClass] = { + import quotes.reflect.* + x match { + case '{ CaseClass(${ Expr(values) }) } => + // Verify the values are properly structured as List[(String, Ast)] + try { + Some(CaseClass(values)) + } catch { + case e: Exception => + report.warning( + s"Failed to extract CaseClass values: ${e.getMessage}", + x.asTerm.pos + ) + None + } + case _ => None + } } } @@ -316,12 +465,14 @@ given astFromExpr: FromExpr[Ast] = new FromExpr[Ast] { case '{ $x: Property } => x.value case '{ $x: Ident } => x.value case '{ $x: Tuple } => x.value - case '{ $x: Constant } => x.value + case '{ $x: Value } => x.value case '{ $x: Operation } => x.value case '{ $x: Ordering } => x.value case '{ $x: Action } => x.value case '{ $x: If } => x.value case '{ $x: Infix } => x.value + case '{ $x: CaseClass } => x.value + case '{ $x: OptionOperation } => x.value case o => import quotes.reflect.* report.warning(s"Cannot get value from ${o.show}", o.asTerm.pos) diff --git a/src/main/scala/minisql/ast/JoinType.scala b/src/main/scala/minisql/ast/JoinType.scala index bcb623b..911b4f2 100644 --- a/src/main/scala/minisql/ast/JoinType.scala +++ b/src/main/scala/minisql/ast/JoinType.scala @@ -1,8 +1,22 @@ package minisql.ast -sealed trait JoinType +import scala.quoted.* -case object InnerJoin extends JoinType -case object LeftJoin extends JoinType -case object RightJoin extends JoinType -case object FullJoin extends JoinType +enum JoinType { + case InnerJoin + case LeftJoin + case RightJoin + case FullJoin +} + +object JoinType { + given FromExpr[JoinType] with { + + def unapply(x: Expr[JoinType])(using Quotes): Option[JoinType] = x match { + case '{ JoinType.InnerJoin } => Some(JoinType.InnerJoin) + case '{ JoinType.LeftJoin } => Some(JoinType.LeftJoin) + case '{ JoinType.RightJoin } => Some(JoinType.RightJoin) + case '{ JoinType.FullJoin } => Some(JoinType.FullJoin) + } + } +} diff --git a/src/main/scala/minisql/context/sql/SqlIdiom.scala b/src/main/scala/minisql/context/sql/SqlIdiom.scala index dffd56b..daeee76 100644 --- a/src/main/scala/minisql/context/sql/SqlIdiom.scala +++ b/src/main/scala/minisql/context/sql/SqlIdiom.scala @@ -346,10 +346,10 @@ trait SqlIdiom extends Idiom { } implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { - case InnerJoin => stmt"INNER JOIN" - case LeftJoin => stmt"LEFT JOIN" - case RightJoin => stmt"RIGHT JOIN" - case FullJoin => stmt"FULL JOIN" + case JoinType.InnerJoin => stmt"INNER JOIN" + case JoinType.LeftJoin => stmt"LEFT JOIN" + case JoinType.RightJoin => stmt"RIGHT JOIN" + case JoinType.FullJoin => stmt"FULL JOIN" } implicit def orderByCriteriaTokenizer(implicit diff --git a/src/main/scala/minisql/idiom/MirrorIdiom.scala b/src/main/scala/minisql/idiom/MirrorIdiom.scala index 1507919..2630288 100644 --- a/src/main/scala/minisql/idiom/MirrorIdiom.scala +++ b/src/main/scala/minisql/idiom/MirrorIdiom.scala @@ -192,10 +192,10 @@ trait MirrorIdiomBase extends Idiom { } implicit val joinTypeTokenizer: Tokenizer[JoinType] = Tokenizer[JoinType] { - case InnerJoin => stmt"join" - case LeftJoin => stmt"leftJoin" - case RightJoin => stmt"rightJoin" - case FullJoin => stmt"fullJoin" + case JoinType.InnerJoin => stmt"join" + case JoinType.LeftJoin => stmt"leftJoin" + case JoinType.RightJoin => stmt"rightJoin" + case JoinType.FullJoin => stmt"fullJoin" } implicit def functionTokenizer(implicit diff --git a/src/test/scala/minisql/ast/FromExprsSuite.scala b/src/test/scala/minisql/ast/FromExprsSuite.scala index ea6d14b..051da57 100644 --- a/src/test/scala/minisql/ast/FromExprsSuite.scala +++ b/src/test/scala/minisql/ast/FromExprsSuite.scala @@ -100,7 +100,7 @@ class FromExprsSuite extends FunSuite { testFor("Join") { Join( - InnerJoin, + JoinType.InnerJoin, Ident("a"), Ident("b"), Ident("a1"),