Skip to content

Commit

Permalink
add statement code generation
Browse files Browse the repository at this point in the history
  • Loading branch information
klahap committed Jan 7, 2025
1 parent 56e1282 commit 3f4baca
Show file tree
Hide file tree
Showing 22 changed files with 584 additions and 95 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ build/
!**/src/test/**/build/
.env
test-schema.sql
test-queries.sql

### IntelliJ IDEA ###
.idea/modules.xml
Expand Down
7 changes: 6 additions & 1 deletion src/main/kotlin/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import io.github.klahap.pgen.model.sql.Table
import io.github.klahap.pgen.model.Config
import io.github.klahap.pgen.model.sql.Enum
import io.github.klahap.pgen.model.sql.PgenSpec
import io.github.klahap.pgen.model.sql.Statement
import io.github.klahap.pgen.service.DbService
import io.github.klahap.pgen.service.DirectorySyncService.Companion.directorySync
import io.github.klahap.pgen.util.DefaultCodeFile
import io.github.klahap.pgen.service.EnvFileService
import io.github.klahap.pgen.util.codegen.CodeGenContext
import io.github.klahap.pgen.util.codegen.sync
import io.github.klahap.pgen.util.parseStatements
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.encodeToString
import org.gradle.api.Project
Expand All @@ -34,17 +36,19 @@ private fun generateSpec(config: Config) {
dbName = configDb.dbName,
connectionConfig = configDb.connectionConfig ?: error("no DB connection config defined")
).use { dbService ->
val statements = dbService.getStatements(parseStatements(configDb.statementScripts))
val tables = dbService.getTablesWithForeignTables(configDb.tableFilter)
val enumNames = tables.asSequence().flatMap { it.columns }.map { it.type }
.map { if (it is Table.Column.Type.NonPrimitive.Array) it.elementType else it }
.filterIsInstance<Table.Column.Type.NonPrimitive.Enum>().map { it.name }.toSet()
val enums = dbService.getEnums(enumNames)
PgenSpec(tables = tables, enums = enums)
PgenSpec(tables = tables, enums = enums, statements = statements)
}
}
val spec = PgenSpec(
tables = specData.flatMap(PgenSpec::tables).sortedBy(Table::name),
enums = specData.flatMap(PgenSpec::enums).sortedBy(Enum::name),
statements = specData.flatMap(PgenSpec::statements).sortedBy(Statement::name),
)
config.specFilePath.createParentDirectories().writeText(yaml.encodeToString(spec))
}
Expand All @@ -61,6 +65,7 @@ private fun generateCode(config: Config) {
DefaultCodeFile.all().forEach { sync(it) }
spec.enums.forEach { sync(it) }
spec.tables.forEach { sync(it) }
spec.statements.groupBy { it.name.dbName }.values.forEach { sync(it) }
cleanup()
}
}
Expand Down
49 changes: 48 additions & 1 deletion src/main/kotlin/dsl/Poet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ fun buildEnum(
block: TypeSpec.Builder.() -> Unit,
) = TypeSpec.enumBuilder(name).apply(block).build()

fun buildClass(
name: String,
block: TypeSpec.Builder.() -> Unit,
) = TypeSpec.classBuilder(name).apply(block).build()

fun TypeSpec.Builder.addClass(
name: String,
block: TypeSpec.Builder.() -> Unit,
) = addType(buildClass(name = name, block = block))

fun FileSpec.Builder.addClass(
name: String,
block: TypeSpec.Builder.() -> Unit,
) = addType(buildClass(name = name, block = block))

fun TypeSpec.Builder.addProperty(name: String, type: TypeName, block: PropertySpec.Builder.() -> Unit) =
addProperty(PropertySpec.builder(name = name, type = type).apply(block).build())

Expand All @@ -42,10 +57,42 @@ fun TypeSpec.Builder.addCompanionObject(
block: TypeSpec.Builder.() -> Unit,
) = addType(TypeSpec.companionObjectBuilder().apply(block).build())

fun buildFunction(
name: String,
block: FunSpec.Builder.() -> Unit,
) = FunSpec.builder(name).apply(block).build()

fun TypeSpec.Builder.addFunction(
name: String,
block: FunSpec.Builder.() -> Unit,
) = addFunction(FunSpec.builder(name).apply(block).build())
) = addFunction(buildFunction(name = name, block = block))

fun FileSpec.Builder.addFunction(
name: String,
block: FunSpec.Builder.() -> Unit,
) = addFunction(buildFunction(name = name, block = block))

fun TypeSpec.Builder.addInitializerBlock(block: CodeBlock.Builder.() -> Unit) =
addInitializerBlock(CodeBlock.builder().apply(block).build())

fun FunSpec.Builder.addParameter(
name: String,
type: TypeName,
block: ParameterSpec.Builder.() -> Unit,
) = addParameter(ParameterSpec.builder(name, type).apply(block).build())

fun FunSpec.Builder.addCode(
block: CodeBlock.Builder.() -> Unit
) = addCode(CodeBlock.builder().apply(block).build())

fun CodeBlock.Builder.addControlFlow(controlFlow: String, vararg args: Any, block: CodeBlock.Builder.() -> Unit) {
beginControlFlow(controlFlow, *args)
block()
endControlFlow()
}

fun CodeBlock.Builder.indent(block: CodeBlock.Builder.() -> Unit) {
indent()
block()
unindent()
}
4 changes: 4 additions & 0 deletions src/main/kotlin/dsl/Sql.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import java.sql.Connection
import java.sql.ResultSet


fun Connection.execute(@Language("sql") query: String) {
createStatement().use { statement -> statement.execute(query) }
}

fun <T> Connection.executeQuery(@Language("sql") query: String, mapper: (ResultSet) -> T): List<T> {
return createStatement().use { statement ->
statement.executeQuery(query).use { rs ->
Expand Down
15 changes: 14 additions & 1 deletion src/main/kotlin/model/Config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ data class Config(
val dbName: DbName,
val connectionConfig: DbConnectionConfig?,
val tableFilter: SqlObjectFilter,
val statementScripts: Set<Path>,
) {
data class DbConnectionConfig(
val url: String,
Expand All @@ -43,9 +44,16 @@ data class Config(
private val dbName = DbName(name.also {
if (it.isBlank()) error("empty DB name")
})

private var connectionConfig: DbConnectionConfig? = null
private var tableFilter: SqlObjectFilter? = null
private var statementScripts: Set<Path>? = null

class StatementCollectionBuilder {
private val scripts = linkedSetOf<Path>()
fun addScript(file: Path) = apply { scripts.add(file) }
fun addScript(file: String) = apply { scripts.add(Path(file)) }
fun build() = scripts.toSet()
}

fun connectionConfig(ignoreErrors: Boolean = true, block: DbConnectionConfig.Builder.() -> Unit) = apply {
this.connectionConfig = runCatching {
Expand All @@ -61,10 +69,15 @@ data class Config(
tableFilter = SqlObjectFilter.Builder(dbName = dbName).apply(block).build()
}

fun statements(block: StatementCollectionBuilder.() -> Unit) {
statementScripts = StatementCollectionBuilder().apply(block).build()
}

fun build() = Db(
dbName = dbName,
connectionConfig = connectionConfig,
tableFilter = tableFilter ?: error("no table filter defined for DB config '$dbName'"),
statementScripts = statementScripts ?: emptySet(),
)
}
}
Expand Down
23 changes: 23 additions & 0 deletions src/main/kotlin/model/sql/Common.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.squareup.kotlinpoet.ClassName
import io.github.klahap.pgen.util.codegen.CodeGenContext
import io.github.klahap.pgen.dsl.PackageName
import io.github.klahap.pgen.util.SqlObjectNameSerializer
import io.github.klahap.pgen.util.SqlStatementNameSerializer
import io.github.klahap.pgen.util.kotlinKeywords
import io.github.klahap.pgen.util.makeDifferent
import io.github.klahap.pgen.util.toCamelCase
Expand Down Expand Up @@ -32,6 +33,28 @@ sealed interface SqlObject {
val name: SqlObjectName
}

@Serializable(with = SqlStatementNameSerializer::class)
data class SqlStatementName(
val dbName: DbName,
val name: String,
) : Comparable<SqlStatementName> {

val prettyName get() = name.toCamelCase(capitalized = false)
val prettyResultClassName get() = name.toCamelCase(capitalized = true) + "Result"

context(CodeGenContext)
val packageName
get() = PackageName("$rootPackageName.db.${dbName}")

context(CodeGenContext)
val typeName
get() = ClassName(packageName.name, prettyName)

override fun compareTo(other: SqlStatementName): Int =
dbName.compareTo(other.dbName).takeIf { it != 0 }
?: name.compareTo(other.name)
}

@Serializable(with = SqlObjectNameSerializer::class)
data class SqlObjectName(
val schema: SchemaName,
Expand Down
1 change: 1 addition & 0 deletions src/main/kotlin/model/sql/PgenSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ import kotlinx.serialization.Serializable
data class PgenSpec(
val tables: List<Table>,
val enums: List<Enum>,
val statements: List<Statement>,
)
10 changes: 10 additions & 0 deletions src/main/kotlin/model/sql/SqlObjectFilter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ sealed interface SqlObjectFilter {
override fun exactSizeOrNull(): Int = objectNames.size
}

data class TempTable(val names: Set<String>) : SqlObjectFilter {
override fun toFilterString(schemaField: String, tableField: String): String {
val schemasStr = names.toSet().joinToString(",") { "'$it'" }
return "$tableField IN ($schemasStr)"
}

override fun isEmpty(): Boolean = names.isEmpty()
override fun exactSizeOrNull(): Int = names.size
}

data class Multi(
val filters: List<SqlObjectFilter>,
) : SqlObjectFilter {
Expand Down
42 changes: 42 additions & 0 deletions src/main/kotlin/model/sql/Statement.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package io.github.klahap.pgen.model.sql

import io.github.klahap.pgen.util.kotlinKeywords
import io.github.klahap.pgen.util.makeDifferent
import io.github.klahap.pgen.util.toCamelCase
import kotlinx.serialization.Serializable


@Serializable
data class Statement(
val name: SqlStatementName,
val cardinality: Cardinality,
val variables: List<VariableName>,
val variableTypes: Map<VariableName, Table.Column.Type>,
val columns: List<Table.Column>,
val sql: String,
) {
@JvmInline
@Serializable
value class VariableName(val name: String) : Comparable<VariableName> {
val pretty get() = name.toCamelCase(capitalized = false)
.makeDifferent(kotlinKeywords + setOf("coroutineContext", "db"))
override fun compareTo(other: VariableName): Int = name.compareTo(other.name)
}

data class Raw(
val name: String,
val cardinality: Cardinality,
val allVariables: List<VariableName>,
val uniqueSortedVariables: List<VariableName>,
val nonNullColumns: Set<String>,
val sql: String,
val preparedPsql: String,
val preparedSql: String,
)

@Serializable
enum class Cardinality {
ONE,
MANY,
}
}
Loading

0 comments on commit 3f4baca

Please sign in to comment.