diff --git a/generator/src/ba/sake/squery/generator/SqueryGenerator.scala b/generator/src/ba/sake/squery/generator/SqueryGenerator.scala index 3f29937..ed3edde 100644 --- a/generator/src/ba/sake/squery/generator/SqueryGenerator.scala +++ b/generator/src/ba/sake/squery/generator/SqueryGenerator.scala @@ -83,11 +83,13 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene basePackage: String, fileGen: Boolean ): (Seq[GeneratedFileSource], Seq[GeneratedFileSource]) = { - val enumDefs = schemaDef.tables.flatMap { - _.columnDefs.map(_.scalaType).collect { case e: ColumnType.Enumeration => - e + val enumDefs = schemaDef.tables + .flatMap { + _.columnDefs.map(_.scalaType).collect { case e: ColumnType.Enumeration => + e + } } - }.distinctBy(_.name) + .distinctBy(_.name) val enumFiles = enumDefs.map { enumDef => val enumCaseDefs = Defn.RepeatedEnumCase( List.empty, @@ -547,8 +549,8 @@ class SqueryGenerator(ds: DataSource, config: SqueryGeneratorConfig = SqueryGene q"import java.util.UUID", q"import ba.sake.squery.{*, given}", q"import ..${List(dbSpecificImporter)}", - q"import ba.sake.squery.read.SqlRead", - q"import ba.sake.squery.write.SqlWrite" + q"import ba.sake.squery.read.{*, given}", + q"import ba.sake.squery.write.{*, given}" ) } private def generateDaoImports(dbType: DbType, basePackage: String) = { diff --git a/squery/src/ba/sake/squery/SqlNonScalarType.scala b/squery/src/ba/sake/squery/SqlNonScalarType.scala new file mode 100644 index 0000000..aae78f3 --- /dev/null +++ b/squery/src/ba/sake/squery/SqlNonScalarType.scala @@ -0,0 +1,12 @@ +package ba.sake.squery + +// - a marker typeclass for non-scalar types +// i.e. array-type and similar + +// - this is to prevent infinite recursion of Array[Array[Array...T +// the SqlWrite, SqlRead typeclasses would break + +trait SqlNonScalarType[T] + +given [T]: SqlNonScalarType[Array[T]] = new {} +given [T]: SqlNonScalarType[Seq[T]] = new {} diff --git a/squery/src/ba/sake/squery/postgres/reads.scala b/squery/src/ba/sake/squery/postgres/reads.scala index 9fdf417..3601a1a 100644 --- a/squery/src/ba/sake/squery/postgres/reads.scala +++ b/squery/src/ba/sake/squery/postgres/reads.scala @@ -11,11 +11,3 @@ given SqlRead[UUID] with { def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[UUID] = Option(jRes.getObject(colIdx, classOf[UUID])) } - -given [T: SqlRead]: SqlRead[Array[T]] with { - def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[T]] = - Option(jRes.getArray(colName)).map(_.getArray().asInstanceOf[Array[T]]) - - def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[T]] = - Option(jRes.getArray(colIdx)).map(_.getArray().asInstanceOf[Array[T]]) -} diff --git a/squery/src/ba/sake/squery/postgres/typeNames.scala b/squery/src/ba/sake/squery/postgres/typeNames.scala new file mode 100644 index 0000000..5c92ad2 --- /dev/null +++ b/squery/src/ba/sake/squery/postgres/typeNames.scala @@ -0,0 +1,8 @@ +package ba.sake.squery.postgres + +import ba.sake.squery.write.SqlTypeName +import java.util.UUID + +given SqlTypeName[Array[UUID]] with { + def value: String = "UUID" +} diff --git a/squery/src/ba/sake/squery/read/SqlRead.scala b/squery/src/ba/sake/squery/read/SqlRead.scala index 63c6f9a..521ddf3 100644 --- a/squery/src/ba/sake/squery/read/SqlRead.scala +++ b/squery/src/ba/sake/squery/read/SqlRead.scala @@ -6,6 +6,8 @@ import java.time.* import java.util.UUID import scala.deriving.* import scala.quoted.* +import scala.util.NotGiven +import scala.reflect.ClassTag // reads a value from a column trait SqlRead[T]: @@ -44,7 +46,6 @@ object SqlRead { Option(jRes.getShort(colIdx)).filterNot(_ => jRes.wasNull()) } - given SqlRead[Int] with { def readByName(jRes: jsql.ResultSet, colName: String): Option[Int] = Option(jRes.getInt(colName)).filterNot(_ => jRes.wasNull()) @@ -98,6 +99,35 @@ object SqlRead { Option(jRes.getTimestamp(colIdx)).map(_.toLocalDateTime()) } + /* Arrays */ + // - general first, then specific ones, for implicits ordering + // - _.map(_.asInstanceOf[T]) because of boxing/unboxing... + given sqlReadArray1[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Array[T]] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[T]] = + Option(jRes.getArray(colName)).map(_.getArray().asInstanceOf[Array[T]].map(_.asInstanceOf[T])) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[T]] = + Option(jRes.getArray(colIdx)).map(_.getArray().asInstanceOf[Array[T]].map(_.asInstanceOf[T])) + } + + given sqlReadArray2[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Array[Array[T]]] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[Array[T]]] = + Option(jRes.getArray(colName)).map(_.getArray().asInstanceOf[Array[Array[T]]].map(_.map(_.asInstanceOf[T]))) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[Array[T]]] = + Option(jRes.getArray(colIdx)).map(_.getArray().asInstanceOf[Array[Array[T]]].map(_.map(_.asInstanceOf[T]))) + } + + given sqlReadArray3[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Array[Array[Array[T]]]] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[Array[Array[T]]]] = + Option(jRes.getArray(colName)) + .map(_.getArray().asInstanceOf[Array[Array[Array[T]]]].map(_.map(_.map(_.asInstanceOf[T])))) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Array[Array[Array[T]]]] = + Option(jRes.getArray(colIdx)) + .map(_.getArray().asInstanceOf[Array[Array[Array[T]]]].map(_.map(_.map(_.asInstanceOf[T])))) + } + given SqlRead[Array[Byte]] with { def readByName(jRes: jsql.ResultSet, colName: String): Option[Array[Byte]] = Option(jRes.getBytes(colName)) @@ -106,6 +136,40 @@ object SqlRead { Option(jRes.getBytes(colIdx)) } + // vector utils, nicer to deal with + given sqlReadVector1[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Vector[T]] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[T]] = + SqlRead[Array[T]].readByName(jRes, colName).map(_.toVector) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[T]] = + SqlRead[Array[T]].readByIdx(jRes, colIdx).map(_.toVector) + } + + given sqlReadVector2[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Vector[Vector[T]]] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[Vector[T]]] = + SqlRead[Array[Array[T]]].readByName(jRes, colName).map(_.toVector.map(_.toVector)) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[Vector[T]]] = + SqlRead[Array[Array[T]]].readByIdx(jRes, colIdx).map(_.toVector.map(_.toVector)) + } + + given sqlReadVector3[T: SqlRead: ClassTag](using NotGiven[SqlNonScalarType[T]]): SqlRead[Vector[Vector[Vector[T]]]] + with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[Vector[Vector[T]]]] = + SqlRead[Array[Array[Array[T]]]].readByName(jRes, colName).map(_.toVector.map(_.toVector.map(_.toVector))) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[Vector[Vector[T]]]] = + SqlRead[Array[Array[Array[T]]]].readByIdx(jRes, colIdx).map(_.toVector.map(_.toVector.map(_.toVector))) + } + + given SqlRead[Vector[Byte]] with { + def readByName(jRes: jsql.ResultSet, colName: String): Option[Vector[Byte]] = + SqlRead[Array[Byte]].readByName(jRes, colName).map(_.toVector) + + def readByIdx(jRes: jsql.ResultSet, colIdx: Int): Option[Vector[Byte]] = + SqlRead[Array[Byte]].readByIdx(jRes, colIdx).map(_.toVector) + } + // this "cannot fail" given [T](using sr: SqlRead[T]): SqlRead[Option[T]] with { def readByName(jRes: jsql.ResultSet, colName: String): Option[Option[T]] = diff --git a/squery/src/ba/sake/squery/write/SqlTypeName.scala b/squery/src/ba/sake/squery/write/SqlTypeName.scala new file mode 100644 index 0000000..1faa568 --- /dev/null +++ b/squery/src/ba/sake/squery/write/SqlTypeName.scala @@ -0,0 +1,53 @@ +package ba.sake.squery.write + +import java.time.* + +// used for createArrayOf(sqlTypeName, myArray) +trait SqlTypeName[T]: + def value: String + +// for Array[Array[... T]] the value is the inner-most T name +given [T](using stn: SqlTypeName[T]): SqlTypeName[Array[T]] with { + def value: String = stn.value +} +// for Seq[Seq[... T]] the value is the inner-most T name +given [T](using stn: SqlTypeName[T]): SqlTypeName[Seq[T]] with { + def value: String = stn.value +} + +given SqlTypeName[String] with { + def value: String = "VARCHAR" +} + +given SqlTypeName[Boolean] with { + def value: String = "BOOLEAN" +} +given SqlTypeName[Byte] with { + def value: String = "TINYINT" +} +given SqlTypeName[Short] with { + def value: String = "SMALLINT" +} +given SqlTypeName[Int] with { + def value: String = "INTEGER" +} +given SqlTypeName[Long] with { + def value: String = "BIGINT" +} +given SqlTypeName[Double] with { + def value: String = "REAL" +} + +given SqlTypeName[LocalDate] with { + def value: String = "DATE" +} +given SqlTypeName[LocalDateTime] with { + def value: String = "TIMESTAMPT" +} +given SqlTypeName[Instant] with { + def value: String = "TIMESTAMPTZ" +} + +given SqlTypeName[Array[Byte]] with { + def value: String = "BINARY" +} diff --git a/squery/src/ba/sake/squery/write/SqlWrite.scala b/squery/src/ba/sake/squery/write/SqlWrite.scala index 8a96d9c..be1fc89 100644 --- a/squery/src/ba/sake/squery/write/SqlWrite.scala +++ b/squery/src/ba/sake/squery/write/SqlWrite.scala @@ -1,4 +1,5 @@ -package ba.sake.squery.write +package ba.sake.squery +package write import java.{sql => jsql} import java.time.Instant @@ -11,6 +12,8 @@ import java.time.ZoneId import java.time.OffsetDateTime import scala.deriving.* import scala.quoted.* +import scala.reflect.ClassTag +import scala.util.NotGiven trait SqlWrite[T]: def write(ps: jsql.PreparedStatement, idx: Int, valueOpt: Option[T]): Unit @@ -129,6 +132,51 @@ object SqlWrite { case None => ps.setNull(idx, jsql.Types.TIMESTAMP) } + /* Arrays */ + given sqlWriteArray1[T: SqlWrite](using stn: SqlTypeName[T], ng: NotGiven[SqlNonScalarType[T]]): SqlWrite[Array[T]] + with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[Array[T]] + ): Unit = valueOpt match + case Some(value) => + val valuesAsAnyRef = value.map(_.asInstanceOf[AnyRef]) // box primitives like Array[Int] + val sqlArray = ps.getConnection().createArrayOf(stn.value, valuesAsAnyRef) + ps.setArray(idx, sqlArray) + case None => ps.setArray(idx, null) + } + given sqlWriteArray2[T: SqlWrite](using + stn: SqlTypeName[T], + ng: NotGiven[SqlNonScalarType[T]] + ): SqlWrite[Array[Array[T]]] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[Array[Array[T]]] + ): Unit = valueOpt match + case Some(value) => + val valuesAsAnyRef = value.map(_.map(_.asInstanceOf[AnyRef])) // box primitives like Array[Array[Int]] + val sqlArray = ps.getConnection().createArrayOf(stn.value, valuesAsAnyRef.asInstanceOf[Array[AnyRef]]) + ps.setArray(idx, sqlArray) + case None => ps.setArray(idx, null) + } + given sqlWriteArray3[T: SqlWrite](using + stn: SqlTypeName[T], + ng: NotGiven[SqlNonScalarType[T]] + ): SqlWrite[Array[Array[Array[T]]]] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[Array[Array[Array[T]]]] + ): Unit = valueOpt match + case Some(value) => + val valuesAsAnyRef = + value.map(_.map(_.map(_.asInstanceOf[AnyRef]))) // box primitives like Array[Array[Array[Int]]] + val sqlArray = ps.getConnection().createArrayOf(stn.value, valuesAsAnyRef.asInstanceOf[Array[AnyRef]]) + ps.setArray(idx, sqlArray) + case None => ps.setArray(idx, null) + } given SqlWrite[Array[Byte]] with { def write( ps: jsql.PreparedStatement, @@ -139,6 +187,47 @@ object SqlWrite { case None => ps.setNull(idx, jsql.Types.BINARY) } + given sqlWriteVector1[T: SqlWrite: ClassTag](using + stn: SqlTypeName[T], + ng: NotGiven[SqlNonScalarType[T]] + ): SqlWrite[Vector[T]] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[Vector[T]] + ): Unit = SqlWrite[Array[T]].write(ps, idx, valueOpt.map(_.toArray)) + } + + given sqlWriteVector2[T: SqlWrite: ClassTag](using + stn: SqlTypeName[T], + ng: NotGiven[SqlNonScalarType[T]] + ): SqlWrite[Vector[Vector[T]]] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[Vector[Vector[T]]] + ): Unit = SqlWrite[Array[Array[T]]].write(ps, idx, valueOpt.map(_.toArray.map(_.toArray))) + } + + given sqlWriteVector3[T: SqlWrite: ClassTag](using + stn: SqlTypeName[T], + ng: NotGiven[SqlNonScalarType[T]] + ): SqlWrite[Vector[Vector[Vector[T]]]] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[Vector[Vector[Vector[T]]]] + ): Unit = SqlWrite[Array[Array[Array[T]]]].write(ps, idx, valueOpt.map(_.toArray.map(_.toArray.map(_.toArray)))) + } + + given SqlWrite[Vector[Byte]] with { + def write( + ps: jsql.PreparedStatement, + idx: Int, + valueOpt: Option[Vector[Byte]] + ): Unit = SqlWrite[Array[Byte]].write(ps, idx, valueOpt.map(_.toArray)) + } + given [T](using sw: SqlWrite[T]): SqlWrite[Option[T]] with { def write( ps: jsql.PreparedStatement, diff --git a/squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala b/squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala index caea73c..40589d4 100644 --- a/squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala +++ b/squery/test/src/ba/sake/squery/mariadb/MariaDbSuite.scala @@ -97,9 +97,6 @@ class MariaDbSuite extends munit.FunSuite { (${customer2.name}, ${customer2.street}) """.insertReturningGenKeys[Int]() - - println(customerIds) - customer1 = customer1.copy(id = customerIds(0)) customer2 = customer2.copy(id = customerIds(1)) diff --git a/squery/test/src/ba/sake/squery/postgres/PostgresSuite.scala b/squery/test/src/ba/sake/squery/postgres/PostgresSuite.scala index 5aec672..1fc638e 100644 --- a/squery/test/src/ba/sake/squery/postgres/PostgresSuite.scala +++ b/squery/test/src/ba/sake/squery/postgres/PostgresSuite.scala @@ -6,6 +6,8 @@ import java.time.Instant import java.time.temporal.ChronoUnit import scala.collection.decorators._ import org.testcontainers.containers.PostgreSQLContainer +import ba.sake.squery.read.{*, given} +import ba.sake.squery.write.{*, given} // UUID, enum.. Postgres specific case class Datatypes( @@ -16,12 +18,22 @@ case class Datatypes( d_string: Option[String], d_uuid: Option[UUID], d_tstz: Option[Instant], - d_clr: Option[Color] + d_clr: Option[Color], + d_array_bytes: Option[Vector[Byte]], + d_array_int: Option[Vector[Int]], + d_array_array_int: Option[Vector[Vector[Int]]], + d_array_str: Option[Vector[String]], + d_array_array_str: Option[Vector[Vector[String]]] ) derives SqlReadRow: - def insertTuple = sql"(${d_int}, ${d_long}, ${d_double}, ${d_boolean}, ${d_string}, ${d_uuid}, ${d_tstz}, ${d_clr})" + + def insertTuple = + sql"""(${d_int}, ${d_long}, ${d_double}, ${d_boolean}, ${d_string}, ${d_uuid}, ${d_tstz}, ${d_clr}, + ${d_array_bytes}, ${d_array_int}, ${d_array_array_int}, ${d_array_str}, ${d_array_array_str} + )""" object Datatypes: - inline val allCols = "d_int, d_long, d_double, d_boolean, d_string, d_uuid, d_tstz, d_clr" + inline val allCols = + "d_int, d_long, d_double, d_boolean, d_string, d_uuid, d_tstz, d_clr, d_array_bytes, d_array_int, d_array_array_int, d_array_str, d_array_array_str" enum Color derives SqlRead, SqlWrite: case red, green, blue @@ -273,7 +285,12 @@ class PostgresSuite extends munit.FunSuite { d_string VARCHAR(255), d_uuid UUID, d_tstz TIMESTAMPTZ, - d_clr color + d_clr color, + d_array_bytes bytea, + d_array_int INTEGER[], + d_array_array_int INTEGER[][], + d_array_str VARCHAR[], + d_array_array_str VARCHAR[][] ) """.update() val dt1 = Datatypes( @@ -284,9 +301,14 @@ class PostgresSuite extends munit.FunSuite { Some("abc"), Some(UUID.randomUUID), Some(Instant.now.truncatedTo(ChronoUnit.MICROS)), - Some(Color.red) + Some(Color.red), + Some("array".getBytes("utf8").toVector), + Some(Vector(1, 2, 3)), + Some(Vector(Vector(1, 1, 1), Vector(2, 2, 2), Vector(3, 3, 3))), + Some(Vector("abc")), + Some(Vector(Vector("aaa"), Vector("bbb"), Vector("ccc"))) ) - val dt2 = Datatypes(None, None, None, None, None, None, None, None) + val dt2 = Datatypes(None, None, None, None, None, None, None, None, None, None, None, None, None) val values = Seq(dt1, dt2) .map(_.insertTuple) @@ -301,10 +323,22 @@ class PostgresSuite extends munit.FunSuite { SELECT ${Datatypes.allCols} FROM datatypes """.readRows[Datatypes]() - assertEquals( - storedRows, - Seq(dt1) - ) + val firstRow = storedRows.head + + assertEquals(firstRow.d_int, dt1.d_int) + assertEquals(firstRow.d_long, dt1.d_long) + assertEquals(firstRow.d_double, dt1.d_double) + assertEquals(firstRow.d_boolean, dt1.d_boolean) + assertEquals(firstRow.d_string, dt1.d_string) + assertEquals(firstRow.d_uuid, dt1.d_uuid) + assertEquals(firstRow.d_tstz, dt1.d_tstz) + assertEquals(firstRow.d_clr, dt1.d_clr) + assertEquals(firstRow.d_int, dt1.d_int) + assertEquals(Some("array"), dt1.d_array_bytes.map(b => new String(b.toArray, "utf8"))) + assertEquals(firstRow.d_array_int, dt1.d_array_int) + assertEquals(firstRow.d_array_array_int, dt1.d_array_array_int) + assertEquals(firstRow.d_array_str, dt1.d_array_str) + assertEquals(firstRow.d_array_array_str, dt1.d_array_array_str) } }