From f2828ef49414960bdf84db1d856788a6ed8485cb Mon Sep 17 00:00:00 2001 From: jilen Date: Thu, 24 Jul 2025 12:59:38 +0800 Subject: [PATCH] add pg-async module --- build.sbt | 36 ++-- .../src/main/scala/minisql/ParamEncoder.scala | 2 +- .../main/scala/minisql/context/mirror.scala | 6 +- .../scala/minisql/context/AsyncCodecs.scala | 180 ++++++++++++++++++ .../minisql/context/PgAsyncContext.scala | 47 +++++ 5 files changed, 256 insertions(+), 15 deletions(-) create mode 100644 pg-async/src/main/scala/minisql/context/AsyncCodecs.scala create mode 100644 pg-async/src/main/scala/minisql/context/PgAsyncContext.scala diff --git a/build.sbt b/build.sbt index 6d1c303..c9ac7a1 100644 --- a/build.sbt +++ b/build.sbt @@ -1,23 +1,37 @@ -val prjScalaVersion = "3.7.1" +val pgAsyncVersion = "0.3.124" +val catsEffectVersion = "3.6.2" + +val commonSettings = Seq( + organization := "minisql", + scalaVersion := "3.7.1", + scalacOptions ++= Seq( + "-deprecation", + "-feature", + "-source:3.7-migration", + "-rewrite" + ) +) lazy val root = (project in file(".")) + .aggregate(core, pgAsync) + +lazy val pgAsync = (project in file("pg-async")) + .dependsOn(core) .aggregate(core) + .settings(commonSettings: _*) .settings( - name := "minisql", - scalaVersion := prjScalaVersion + name := "minisql-pg-async", + libraryDependencies ++= Seq( + "org.typelevel" %% "cats-effect" % catsEffectVersion, + "com.dripower" %% "postgresql-async" % pgAsyncVersion + ) ) lazy val core = (project in file("core")) + .settings(commonSettings: _*) .settings( - name := "minisql-core", - scalaVersion := prjScalaVersion, + name := "minisql-core", libraryDependencies ++= Seq( "org.scalameta" %% "munit" % "1.1.1" % Test - ), - scalacOptions ++= Seq( - "-deprecation", - "-feature", - "-source:3.7-migration", - "-rewrite" ) ) diff --git a/core/src/main/scala/minisql/ParamEncoder.scala b/core/src/main/scala/minisql/ParamEncoder.scala index 4d2abe4..1163a0c 100644 --- a/core/src/main/scala/minisql/ParamEncoder.scala +++ b/core/src/main/scala/minisql/ParamEncoder.scala @@ -5,7 +5,7 @@ import scala.util.Try trait ParamEncoder[E] { type Stmt - def setParam(s: Stmt, idx: Int, v: E): Stmt + def setParam(s: Stmt, idx: Int, v: Any): Stmt } trait ColumnDecoder[X] { diff --git a/core/src/main/scala/minisql/context/mirror.scala b/core/src/main/scala/minisql/context/mirror.scala index 81b8f98..d000457 100644 --- a/core/src/main/scala/minisql/context/mirror.scala +++ b/core/src/main/scala/minisql/context/mirror.scala @@ -14,7 +14,7 @@ trait MirrorCodecs { final protected def mirrorEncoder[V]: Encoder[V] = new ParamEncoder[V] { type Stmt = ctx.DBStatement - def setParam(s: Stmt, idx: Int, v: V): Stmt = { + def setParam(s: Stmt, idx: Int, v: Any): Stmt = { s + (idx -> v) } } @@ -53,11 +53,11 @@ trait MirrorCodecs { override def setParam( s: Stmt, idx: Int, - v: Option[T] + v: Any ): Stmt = v match { case Some(value) => e.setParam(s, idx, value) - case None => + case None => s + (idx -> null) } } diff --git a/pg-async/src/main/scala/minisql/context/AsyncCodecs.scala b/pg-async/src/main/scala/minisql/context/AsyncCodecs.scala new file mode 100644 index 0000000..8e2422d --- /dev/null +++ b/pg-async/src/main/scala/minisql/context/AsyncCodecs.scala @@ -0,0 +1,180 @@ +package minisql.context + +import com.github.mauricio.async.db.RowData +import java.util.{Date, UUID} +import java.time.LocalDate +import minisql.{ParamEncoder, ColumnDecoder} +import scala.util.* + +type AsyncStmt = (String, Array[Any]) +type AsyncEncoder[T] = ParamEncoder[T] { type Stmt = AsyncStmt } + +private def asyncEncoder[A]( + f: Any => Any +): AsyncEncoder[A] = new ParamEncoder[A] { + + type Stmt = AsyncStmt + + def setParam(stmt: Stmt, index: Int, value: Any): Stmt = { + val (sql, params) = stmt + params(index) = f(value) + stmt + } +} + +trait AsyncCodecs { + + given optionDecoder[T](using + d: ColumnDecoder.Aux[RowData, T] + ): ColumnDecoder.Aux[RowData, Option[T]] = + new ColumnDecoder[Option[T]] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Option[T]] = { + if (row(index) == null) scala.util.Success(None) + else d.decode(row, index).map(Some(_)) + } + } + + given optionEncoder[T](using e: AsyncEncoder[T]): AsyncEncoder[Option[T]] = + new ParamEncoder[Option[T]] { + type Stmt = AsyncStmt + def setParam(stmt: AsyncStmt, index: Int, value: Any) = { + value match { + case Some(v) => e.setParam(stmt, index, v) + case None => stmt + } + } + } + + given stringDecoder: ColumnDecoder.Aux[RowData, String] = + new ColumnDecoder[String] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[String] = + scala.util.Try(row(index).asInstanceOf[String]) + } + + given bigDecimalDecoder: ColumnDecoder.Aux[RowData, BigDecimal] = + new ColumnDecoder[BigDecimal] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[BigDecimal] = + scala.util.Try( + BigDecimal(row(index).asInstanceOf[java.math.BigDecimal]) + ) + } + + given booleanDecoder: ColumnDecoder.Aux[RowData, Boolean] = + new ColumnDecoder[Boolean] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Boolean] = + scala.util.Try(row(index).asInstanceOf[Boolean]) + } + + given byteDecoder: ColumnDecoder.Aux[RowData, Byte] = + new ColumnDecoder[Byte] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Byte] = + scala.util.Try(row(index).asInstanceOf[Byte]) + } + + given shortDecoder: ColumnDecoder.Aux[RowData, Short] = + new ColumnDecoder[Short] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Short] = + scala.util.Try(row(index).asInstanceOf[Short]) + } + + given intDecoder: ColumnDecoder.Aux[RowData, Int] = + new ColumnDecoder[Int] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Int] = + scala.util.Try(row(index).asInstanceOf[Int]) + } + + given longDecoder: ColumnDecoder.Aux[RowData, Long] = + new ColumnDecoder[Long] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Long] = + scala.util.Try(row(index).asInstanceOf[Long]) + } + + given floatDecoder: ColumnDecoder.Aux[RowData, Float] = + new ColumnDecoder[Float] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Float] = + scala.util.Try(row(index).asInstanceOf[Float]) + } + + given doubleDecoder: ColumnDecoder.Aux[RowData, Double] = + new ColumnDecoder[Double] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Double] = + scala.util.Try(row(index).asInstanceOf[Double]) + } + + given byteArrayDecoder: ColumnDecoder.Aux[RowData, Array[Byte]] = + new ColumnDecoder[Array[Byte]] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Array[Byte]] = + scala.util.Try(row(index).asInstanceOf[Array[Byte]]) + } + + given dateDecoder: ColumnDecoder.Aux[RowData, Date] = + new ColumnDecoder[Date] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[Date] = + scala.util.Try(row(index).asInstanceOf[Date]) + } + + given localDateDecoder: ColumnDecoder.Aux[RowData, LocalDate] = + new ColumnDecoder[LocalDate] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[LocalDate] = + scala.util.Try(row(index).asInstanceOf[LocalDate]) + } + + given uuidDecoder: ColumnDecoder.Aux[RowData, UUID] = + new ColumnDecoder[UUID] { + type DBRow = RowData + def decode(row: RowData, index: Int): scala.util.Try[UUID] = + scala.util.Try(row(index).asInstanceOf[UUID]) + } + + given stringEncoder: AsyncEncoder[String] = + asyncEncoder[String](identity) + + given bigDecimalEncoder: AsyncEncoder[BigDecimal] = + asyncEncoder[BigDecimal](_.asInstanceOf[java.math.BigDecimal]) + + given booleanEncoder: AsyncEncoder[Boolean] = + asyncEncoder[Boolean](identity) + + given byteEncoder: AsyncEncoder[Byte] = + asyncEncoder[Byte](identity) + + given shortEncoder: AsyncEncoder[Short] = + asyncEncoder[Short](identity) + + given intEncoder: AsyncEncoder[Int] = + asyncEncoder[Int](identity) + + given longEncoder: AsyncEncoder[Long] = + asyncEncoder[Long](identity) + + given floatEncoder: AsyncEncoder[Float] = + asyncEncoder[Float](identity) + + given doubleEncoder: AsyncEncoder[Double] = + asyncEncoder[Double](identity) + + given byteArrayEncoder: AsyncEncoder[Array[Byte]] = + asyncEncoder[Array[Byte]](identity) + + given dateEncoder: AsyncEncoder[Date] = + asyncEncoder[Date](identity) + + given localDateEncoder: AsyncEncoder[LocalDate] = + asyncEncoder[LocalDate](identity) + + given uuidEncoder: AsyncEncoder[UUID] = + asyncEncoder[UUID](identity) +} diff --git a/pg-async/src/main/scala/minisql/context/PgAsyncContext.scala b/pg-async/src/main/scala/minisql/context/PgAsyncContext.scala new file mode 100644 index 0000000..4dde266 --- /dev/null +++ b/pg-async/src/main/scala/minisql/context/PgAsyncContext.scala @@ -0,0 +1,47 @@ +package minisql.context + +import cats.syntax.all.* +import cats.effect.Async +import minisql.context.sql.* +import minisql.context.sql.idiom.PostgresDialect +import minisql.{NamingStrategy, ParamEncoder} +import com.github.mauricio.async.db.{RowData, QueryResult} +import com.github.mauricio.async.db.postgresql.PostgreSQLConnection +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Try, Success, Failure} + +class PgAsyncContext[F[_], I <: PostgresDialect, N <: NamingStrategy]( + val naming: N, + val idiom: I, + connection: PostgreSQLConnection +)(using Async[F]) + extends SqlContext[I, N] + with AsyncCodecs { + + type DBStatement = AsyncStmt + type DBRow = RowData + type DBResultSet = QueryResult + + private given ExecutionContext = ExecutionContext.parasitic + + def run[E](dbio: DBIO[E]): F[E] = { + + val (sql, params, mapper) = dbio + val initStmt = (sql, Array.ofDim[Any](params.size)) + val encodedParams = params.zipWithIndex.map { + case ((value, encoder), i) => + encoder.setParam(initStmt, i, value) + } + + Async[F].fromFuture { + Async[F].delay { + connection.sendPreparedStatement(sql, encodedParams).map { result => + mapper(result.rows.get).get + } + } + } + } + + def close(): F[Unit] = + Async[F].fromFuture(Async[F].delay(connection.disconnect)).void +}