diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 21eaebe..28f0e90 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -16,6 +16,9 @@ okeydoke = "2.0.3" strikt = "0.34.1" version-catalog-update = "1.0.1" versions = "0.53.0" +testcontainers = "2.0.2" +postgres-driver = "42.7.7" +mariadb-driver = "3.5.7" [libraries] hamkrest = { module = "com.natpryce:hamkrest", version.ref = "hamkrest" } @@ -40,6 +43,12 @@ kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serializa okeydoke = { module = "com.oneeyedmen:okeydoke", version.ref = "okeydoke" } strikt-core = { module = "io.strikt:strikt-core", version.ref = "strikt" } strikt-jvm = { module = "io.strikt:strikt-jvm", version.ref = "strikt" } +testcontainers-bom = { module = "org.testcontainers:testcontainers-bom", version.ref = "testcontainers" } +testcontainers-junit5 = { module = "org.testcontainers:testcontainers-junit-jupiter", version.ref = "testcontainers" } +testcontainers-postgres = { module = "org.testcontainers:testcontainers-postgresql", version.ref = "testcontainers" } +testcontainers-mariadb = { module = "org.testcontainers:testcontainers-mariadb", version.ref = "testcontainers" } +postgres-driver = { module = "org.postgresql:postgresql", version.ref = "postgres-driver" } +mariadb-driver = { module = "org.mariadb.jdbc:mariadb-java-client", version.ref = "mariadb-driver" } [bundles] jmh = [ @@ -62,6 +71,21 @@ strikt = [ "strikt-jvm", ] +testcontainers = [ + "testcontainers-bom", + "testcontainers-junit5" +] + +testcontainers-postgres = [ + "testcontainers-postgres", + "postgres-driver" +] + +testcontainers-mariadb = [ + "testcontainers-mariadb", + "mariadb-driver" +] + [plugins] jmh = { id = "me.champeau.jmh", version.ref = "champeau-jmh" } jmhreport = { id = "io.morethan.jmhreport", version.ref = "jmhreport" } diff --git a/settings.gradle.kts b/settings.gradle.kts index 2cfc8e3..2a9902c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -36,4 +36,5 @@ include("ropes4k") include("state4k") include("time4k") include("tuples4k") +include("tx4k") include("values4k") diff --git a/tx4k/build.gradle.kts b/tx4k/build.gradle.kts new file mode 100644 index 0000000..9b3b3fb --- /dev/null +++ b/tx4k/build.gradle.kts @@ -0,0 +1,18 @@ +import org.jetbrains.kotlin.gradle.tasks.KotlinJvmCompile + +description = "ForkHandles Transactor library" + +dependencies { + testImplementation(kotlin("test-junit5")) + testImplementation(libs.bundles.junit) + testImplementation(libs.bundles.testcontainers) + testImplementation(libs.bundles.testcontainers.postgres) + testImplementation(libs.bundles.testcontainers.mariadb) +} + + +tasks.withType().configureEach { + compilerOptions { + freeCompilerArgs.set(freeCompilerArgs.get() + "-Xinline-classes") + } +} diff --git a/tx4k/src/main/kotlin/dev/forkhandles/tx/Transactor.kt b/tx4k/src/main/kotlin/dev/forkhandles/tx/Transactor.kt new file mode 100644 index 0000000..e541ac3 --- /dev/null +++ b/tx4k/src/main/kotlin/dev/forkhandles/tx/Transactor.kt @@ -0,0 +1,61 @@ +package dev.forkhandles.tx + +import java.time.Duration +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +abstract class Transactor { + abstract fun createResource(): Resource + abstract fun configureResource(resource: Resource) + abstract fun destroyResource(resource: Resource) + + abstract fun createApi(resource: Resource): API + + abstract fun startTransaction(resource: Resource) + abstract fun commitTransaction(resource: Resource) + abstract fun rollbackTransaction(resource: Resource) + + abstract fun canRetry(e: Exception): Boolean + abstract fun retryBackoff(attempt: Int): Duration? + + // Inline so that the `work` lambda can do an early return + @OptIn(ExperimentalContracts::class) + inline fun perform(work: (API) -> Result): Result { + contract { + callsInPlace(work, InvocationKind.AT_LEAST_ONCE) + } + + val resource = createResource() + try { + configureResource(resource) + val api = createApi(resource) + + var attempts = 0 + while (true) try { + startTransaction(resource) + val result = work(api) + commitTransaction(resource) + return result + } + catch (e: Exception) { + rollbackTransaction(resource) + if (canRetry(e)) { + attempts++ + when (val backoff = retryBackoff(attempts)) { + null -> throw SerialisabilityFailure(e) + else -> Thread.sleep(backoff) + } + } else { + throw e + } + } + } finally { + destroyResource(resource) + } + } +} + +class SerialisabilityFailure(cause: Exception) : Exception(cause) + +typealias Transactional = Transactor<*, API> diff --git a/tx4k/src/main/kotlin/dev/forkhandles/tx/jdbc/JdbcTransactor.kt b/tx4k/src/main/kotlin/dev/forkhandles/tx/jdbc/JdbcTransactor.kt new file mode 100644 index 0000000..004953d --- /dev/null +++ b/tx4k/src/main/kotlin/dev/forkhandles/tx/jdbc/JdbcTransactor.kt @@ -0,0 +1,49 @@ +package dev.forkhandles.tx.jdbc + +import dev.forkhandles.tx.RetryPolicy +import dev.forkhandles.tx.Transactor +import dev.forkhandles.tx.increasingBackoff +import dev.forkhandles.tx.withAdditiveJitter +import dev.forkhandles.tx.maxAttempts +import java.sql.Connection +import java.sql.SQLException +import java.time.Duration + + +class JdbcTransactor( + private val createConnection: () -> Connection, + private val createWrapper: (Connection) -> API, + private val retryPolicy: RetryPolicy = + increasingBackoff(Duration.ofMillis(50)).withAdditiveJitter().maxAttempts(5), + private val retryableFailurePolicy: (Exception) -> Boolean = + ::jdbcStandardRetryability +) : Transactor() { + override fun createResource(): Connection = createConnection() + + override fun configureResource(resource: Connection) { + resource.autoCommit = false + resource.transactionIsolation = Connection.TRANSACTION_SERIALIZABLE + } + + override fun destroyResource(resource: Connection) = resource.close() + + override fun createApi(resource: Connection): API = + createWrapper(resource) + + override fun startTransaction(resource: Connection) { + // Nothing required for JDBC + } + + override fun rollbackTransaction(resource: Connection) = resource.rollback() + override fun commitTransaction(resource: Connection) = resource.commit() + + override fun canRetry(e: Exception): Boolean = + retryableFailurePolicy(e) + + override fun retryBackoff(attempt: Int): Duration? = + retryPolicy(attempt) +} + + +fun jdbcStandardRetryability(e: Exception): Boolean = + e is SQLException && (e.sqlState == "40001" || e.sqlState == "40P01") diff --git a/tx4k/src/main/kotlin/dev/forkhandles/tx/mem/InMemoryTransactor.kt b/tx4k/src/main/kotlin/dev/forkhandles/tx/mem/InMemoryTransactor.kt new file mode 100644 index 0000000..dbfdc56 --- /dev/null +++ b/tx4k/src/main/kotlin/dev/forkhandles/tx/mem/InMemoryTransactor.kt @@ -0,0 +1,46 @@ +@file:OptIn(ExperimentalAtomicApi::class) + +package dev.forkhandles.tx.mem + +import dev.forkhandles.tx.RetryPolicy +import dev.forkhandles.tx.Transactor +import dev.forkhandles.tx.increasingBackoff +import dev.forkhandles.tx.withAdditiveJitter +import dev.forkhandles.tx.maxAttempts +import java.time.Duration +import kotlin.concurrent.atomics.AtomicReference +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.Unit as noop + +class InMemoryTransaction(val initialState: State) { + var state: State = initialState +} + +class InMemoryTransactor( + initialState: State, + private val createRepository: (InMemoryTransaction) -> API, + private val retryPolicy: RetryPolicy = + increasingBackoff(Duration.ofMillis(1)).withAdditiveJitter().maxAttempts(5) +) : Transactor, API>() { + val state = AtomicReference(initialState) + + override fun createResource() = InMemoryTransaction(state.load()) + override fun configureResource(resource: InMemoryTransaction) = noop + override fun destroyResource(resource: InMemoryTransaction) = noop + + override fun createApi(resource: InMemoryTransaction) = + createRepository(resource) + + override fun startTransaction(resource: InMemoryTransaction) = noop + override fun rollbackTransaction(resource: InMemoryTransaction) = noop + override fun commitTransaction(resource: InMemoryTransaction) { + if (!state.compareAndSet(expectedValue = resource.initialState, newValue = resource.state)) { + throw RetryException() + } + } + + override fun canRetry(e: Exception) = e is RetryException + override fun retryBackoff(attempt: Int) = retryPolicy(attempt) + + internal class RetryException : Exception() +} diff --git a/tx4k/src/main/kotlin/dev/forkhandles/tx/retry.kt b/tx4k/src/main/kotlin/dev/forkhandles/tx/retry.kt new file mode 100644 index 0000000..0ad883f --- /dev/null +++ b/tx4k/src/main/kotlin/dev/forkhandles/tx/retry.kt @@ -0,0 +1,41 @@ +package dev.forkhandles.tx + +import java.time.Duration +import kotlin.random.Random + +typealias RetryPolicy = (attempt: Int) -> Duration? + +/** + * A very simple backoff strategy that always backs off the same amount of time. + * This is not suitable for production workloads. + */ +fun linearBackoff(backoff: Duration): RetryPolicy = + fun(attempt: Int) = + backoff + +/** + * A backoff strategy that increases the retry delay with each attempt. + */ +fun increasingBackoff(step: Duration): RetryPolicy = + fun(attempt: Int) = + step.multipliedBy(attempt.toLong()) + +/** + * Limit the number of retries. + */ +fun RetryPolicy.maxAttempts(max: Int): RetryPolicy = + fun(attempt: Int): Duration? { + require(attempt > 0) { "attempt must be > 0" } + return if (attempt <= max) this(attempt) else null + } + +/** + * Applies jitter to a retry policy. + */ +fun RetryPolicy.withAdditiveJitter(maxFraction: Double = 0.1, random: Random = Random.Default): RetryPolicy = + fun(attempt: Int): Duration? = + this(attempt)?.let { baseDelay -> + val jitterScale = maxFraction * random.nextDouble(-1.0, 1.0) + val jitter = baseDelay.toMillis() * jitterScale + baseDelay + Duration.ofMillis(jitter.toLong()) + } diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/BackoffJitterTest.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/BackoffJitterTest.kt new file mode 100644 index 0000000..e2e1de4 --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/BackoffJitterTest.kt @@ -0,0 +1,27 @@ +package dev.forkhandles.tx + +import java.time.Duration +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class BackoffJitterTest { + @Test + fun `jittered policy adds jitter to base delay`() { + val baseDelay = Duration.ofSeconds(1) + val jitteredPolicy = linearBackoff(baseDelay).withAdditiveJitter(0.125) + + repeat(1000) { + val jitteredDelay = assertNotNull(jitteredPolicy(it)) + assertTrue(jitteredDelay in (Duration.ofMillis(875) .. Duration.ofMillis(1125))) + } + } + + @Test + fun `jittered policy signals end of retry`() { + val jitteredPolicy = { _ : Int -> null }.withAdditiveJitter(0.125) + + assertNull(jitteredPolicy(1)) + } +} diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/Counter.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/Counter.kt new file mode 100644 index 0000000..cecd863 --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/Counter.kt @@ -0,0 +1,6 @@ +package dev.forkhandles.tx + +interface Counter { + fun incrementBy(n: Int) + fun count(): Int +} diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/IncreasingBackoffTest.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/IncreasingBackoffTest.kt new file mode 100644 index 0000000..84100c1 --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/IncreasingBackoffTest.kt @@ -0,0 +1,23 @@ +package dev.forkhandles.tx + +import java.time.Duration +import kotlin.test.Test +import kotlin.test.assertEquals + + +class IncreasingBackoffTest { + @Test + fun `total backoff`() { + val policy = increasingBackoff(Duration.ofSeconds(1)) + + assertEquals(Duration.ofSeconds(1), totalBackoff(policy, 1), "attempt 1") + assertEquals(Duration.ofSeconds(3), totalBackoff(policy, 2), "attempt 2") + assertEquals(Duration.ofSeconds(6), totalBackoff(policy, 3), "attempt 3") + assertEquals(Duration.ofSeconds(10), totalBackoff(policy, 4), "attempt 4") + assertEquals(Duration.ofSeconds(15), totalBackoff(policy, 5), "attempt 5") + } + + private fun totalBackoff(policy: RetryPolicy, maxAttempt: Int): Duration = + (1..maxAttempt) + .fold(Duration.ZERO) { acc, attempt -> acc + policy(attempt) } +} diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/TransactorContract.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/TransactorContract.kt new file mode 100644 index 0000000..8ee63cd --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/TransactorContract.kt @@ -0,0 +1,72 @@ +@file:OptIn(kotlin.concurrent.atomics.ExperimentalAtomicApi::class) + +package dev.forkhandles.tx + +import java.util.concurrent.Executor +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit.SECONDS +import kotlin.concurrent.atomics.AtomicInt +import kotlin.concurrent.atomics.incrementAndFetch +import kotlin.test.Test +import kotlin.test.assertEquals + +abstract class TransactorContract { + abstract val transactor: Transactional + + @Test + fun `with one thread`() { + testWith(Executors.newSingleThreadExecutor(), 180) + } + + @Test + fun `with multiple threads`() { + testWith(Executors.newFixedThreadPool(5), 400) + } + + @Test + fun `with multiple virtual threads`() { + testWith(Executors.newFixedThreadPool(5, Thread.ofVirtual().factory()), 600) + } + + private fun testWith(executorService: ExecutorService, intendedWriteCount: Int) { + val failureCount = AtomicInt(0) + + executorService.use { + useCounter(it, intendedWriteCount, failureCount) + } + executorService.awaitTermination(10, SECONDS) + + transactor.perform { counter -> + val actualWriteCount = counter.count() + val actualFailureCount = failureCount.load() + + if (actualFailureCount > 0) { + println("serialisation failures: $actualFailureCount") + } + + assertEquals(intendedWriteCount, actualWriteCount + actualFailureCount) + } + } + + private fun useCounter(executor: Executor, count: Int, failureCount: AtomicInt) { + repeat(count * 2) { n -> + executor.execute { + try { + transactor.perform { counter -> + when (n % 2) { + 1 -> throw ForceARollback() + else -> counter.incrementBy(1) + } + } + } catch (_: SerialisabilityFailure) { + failureCount.incrementAndFetch() + } catch (_: ForceARollback) { + // Expected, but catch here to prevent the executor writing a lot of noise to stderr + } + } + } + } + + private class ForceARollback() : Exception() +} diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/JdbcCounter.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/JdbcCounter.kt new file mode 100644 index 0000000..4f835da --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/JdbcCounter.kt @@ -0,0 +1,70 @@ +package dev.forkhandles.tx.jdbc + +import dev.forkhandles.tx.Counter +import dev.forkhandles.tx.Transactional +import org.testcontainers.containers.JdbcDatabaseContainer +import java.sql.Connection + +class JdbcCounter( + val connection: Connection, + val name: String +) : Counter { + fun init() { + connection.prepareStatement( + "INSERT INTO COUNTER (id,count) VALUES (?,0)" + ).use { s -> + s.setString(1, name) + s.executeUpdate() + } + } + + override fun incrementBy(n: Int) { + val newCount = count() + n + + connection.prepareStatement( + "UPDATE COUNTER SET count = ? WHERE id = ?" + ).use { s -> + s.setInt(1, newCount) + s.setString(2, name) + + s.executeUpdate() + } + } + + override fun count(): Int { + return connection.prepareStatement( + "SELECT count FROM COUNTER WHERE id = ?" + ).use { s -> + s.setString(1, name) + s.executeQuery().use { rs -> + require(rs.next()) { "no counter with id $name" } + rs.getInt("count") + } + } + } +} + +fun createSchema(c: Connection): Boolean = c.createStatement().use { s -> + s.execute( + """ + create table COUNTER ( + id VARCHAR(64) PRIMARY KEY, + count NUMERIC(8) NOT NULL DEFAULT 0 + ) + """ + ) +} + +fun createCounterTransactor( + database: JdbcDatabaseContainer<*>, + testName: String +): Transactional { + val transactor = JdbcTransactor( + createConnection = { database.createConnection("") }, + createWrapper = { JdbcCounter(it, testName) } + ) + + transactor.perform { it.init() } + + return transactor +} diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/MariaDBTransactorTest.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/MariaDBTransactorTest.kt new file mode 100644 index 0000000..c3f7fd0 --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/MariaDBTransactorTest.kt @@ -0,0 +1,40 @@ +@file:OptIn(ExperimentalUuidApi::class) + +package dev.forkhandles.tx.jdbc + +import dev.forkhandles.tx.Counter +import dev.forkhandles.tx.Transactional +import dev.forkhandles.tx.TransactorContract +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.TestInfo +import org.testcontainers.containers.JdbcDatabaseContainer +import org.testcontainers.junit.jupiter.Container +import org.testcontainers.junit.jupiter.Testcontainers +import org.testcontainers.mariadb.MariaDBContainer +import kotlin.uuid.ExperimentalUuidApi + +// language = mariadb +@Testcontainers +class MariaDBTransactorTest : TransactorContract() { + override lateinit var transactor: Transactional + + @BeforeEach + fun createCounter(testInfo: TestInfo) { + transactor = createCounterTransactor(database, testInfo.displayName) + } + + companion object { + @Container + @JvmStatic + private val database: JdbcDatabaseContainer<*> = MariaDBContainer("mariadb:10.5.5") + + @BeforeAll + @JvmStatic + fun createSchema() { + database.createConnection("").use(::createSchema) + } + } +} + + diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/PostgreSQLTransactorTest.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/PostgreSQLTransactorTest.kt new file mode 100644 index 0000000..200a2b9 --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/jdbc/PostgreSQLTransactorTest.kt @@ -0,0 +1,39 @@ +@file:OptIn(ExperimentalUuidApi::class) + +package dev.forkhandles.tx.jdbc + +import dev.forkhandles.tx.Counter +import dev.forkhandles.tx.Transactional +import dev.forkhandles.tx.TransactorContract +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.TestInfo +import org.testcontainers.containers.JdbcDatabaseContainer +import org.testcontainers.junit.jupiter.Container +import org.testcontainers.junit.jupiter.Testcontainers +import org.testcontainers.postgresql.PostgreSQLContainer +import kotlin.uuid.ExperimentalUuidApi + +// language = postgresql +@Testcontainers +class PostgreSQLTransactorTest : TransactorContract() { + override lateinit var transactor: Transactional + + @BeforeEach + fun createCounter(testInfo: TestInfo) { + transactor = createCounterTransactor(database, testInfo.displayName) + } + + companion object { + @Container + @JvmStatic + private val database: JdbcDatabaseContainer<*> = PostgreSQLContainer("postgres:18.2") + + @BeforeAll + @JvmStatic + fun createSchema() { + database.createConnection("").use(::createSchema) + } + } +} + diff --git a/tx4k/src/test/kotlin/dev/forkhandles/tx/mem/InMemoryTransactorTest.kt b/tx4k/src/test/kotlin/dev/forkhandles/tx/mem/InMemoryTransactorTest.kt new file mode 100644 index 0000000..cd61b6d --- /dev/null +++ b/tx4k/src/test/kotlin/dev/forkhandles/tx/mem/InMemoryTransactorTest.kt @@ -0,0 +1,19 @@ +@file:OptIn(kotlin.concurrent.atomics.ExperimentalAtomicApi::class) + +package dev.forkhandles.tx.mem + +import dev.forkhandles.tx.Counter +import dev.forkhandles.tx.TransactorContract + + +class InMemoryTransactorTest : TransactorContract() { + class InMemoryCounter(val tx: InMemoryTransaction) : Counter { + override fun incrementBy(n: Int) { + tx.state += n + } + + override fun count() = tx.state + } + + override val transactor = InMemoryTransactor(0, ::InMemoryCounter) +}