Skip to content

Commit

Permalink
Accepting Pool and PoolClient
Browse files Browse the repository at this point in the history
  • Loading branch information
golergka committed May 8, 2021
1 parent 35adc77 commit d74ef57
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 87 deletions.
284 changes: 202 additions & 82 deletions src/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,109 +1,229 @@
import { Pool } from 'pg'
import { Pool, PoolClient } from 'pg'
import tx from '.'

describe(`tx`, () => {
let pg: Pool
describe('with Pool client', () => {
let pg: Pool

beforeAll(async () => {
const { POSTGRES_URL } = process.env
beforeAll(async () => {
const { POSTGRES_URL } = process.env

if (!POSTGRES_URL) {
throw new Error('Must specify POSTGRES_URL')
}
if (!POSTGRES_URL) {
throw new Error('Must specify POSTGRES_URL')
}

pg = new Pool({ connectionString: POSTGRES_URL })
pg = new Pool({ connectionString: POSTGRES_URL })

await pg.query(`CREATE TABLE things (
id SERIAL PRIMARY KEY,
thing TEXT NOT NULL
)`)
})
afterAll(async () => {
await pg.query(`DROP TABLE THINGS`)
await pg.end()
})
await pg.query(`CREATE TABLE things (
id SERIAL PRIMARY KEY,
thing TEXT NOT NULL
)`)
})
afterAll(async () => {
await pg.query(`DROP TABLE THINGS`)
await pg.end()
})

it(`commits changes when there's no error`, async () => {
await tx(pg, async (db) => {
db.query(`INSERT INTO things (thing) VALUES ('comitted')`)
})

it(`commits changes when there's no error`, async () => {
await tx(pg, async (db) => {
db.query(`INSERT INTO things (thing) VALUES ('comitted')`)
const {
rows: [committed]
} = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'comitted'`
)
expect(committed).toBeDefined()
expect(committed.thing).toEqual('comitted')
})

const {
rows: [committed]
} = await pg.query(`SELECT id, thing FROM things WHERE thing = 'comitted'`)
expect(committed).toBeDefined()
expect(committed.thing).toEqual('comitted')
})
it(`doesn't commit changes with forcedRollback`, async () => {
await tx(pg, async (db) => {
db.query(`INSERT INTO things (thing) VALUES ('comitted')`)
})

it(`doesn't commit changes with forcedRollback`, async () => {
await tx(pg, async (db) => {
db.query(`INSERT INTO things (thing) VALUES ('comitted')`)
const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'node_error'`
)
expect(rowCount).toBe(0)
})

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'node_error'`
)
expect(rowCount).toBe(0)
})
it(`doesn't commit changes when there's a Node exception`, async () => {
await expect(
tx(pg, async (db) => {
await db.query(`INSERT INTO things (thing) VALUES ('node_error')`)
throw new Error(`node error`)
})
).rejects.toThrowError(`node error`)

it(`doesn't commit changes when there's a Node exception`, async () => {
await expect(
tx(pg, async (db) => {
await db.query(`INSERT INTO things (thing) VALUES ('node_error')`)
throw new Error(`node error`)
const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'node_error'`
)
expect(rowCount).toBe(0)
})

it(`doesn't commit changes when there's a query error`, async () => {
await expect(
tx(pg, async (db) => {
await db.query(`INSERT INTO things(thing) VALUES ('query_error')`)
await db.query(`this query has an error`)
})
).rejects.toThrowErrorMatchingInlineSnapshot(
`"syntax error at or near \\"this\\""`
)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'query_error'`
)
expect(rowCount).toBe(0)
})

it(`doesn't commit changes on next tick after query error`, async () => {
let laterPromise
const txPromise = tx(pg, async (db) => {
laterPromise = (async () => {
try {
await txPromise
} catch (e) {
// ignore error, we just needed to wait until it completed
}
await new Promise((resolve) => {
setImmediate(resolve)
})
await db.query(
`INSERT INTO things(thing) VALUES ('query_error_tick')`
)
})()
await Promise.all([db.query(`this query has an error`), laterPromise])
})
).rejects.toThrowError(`node error`)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'node_error'`
)
expect(rowCount).toBe(0)
await expect(txPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`"syntax error at or near \\"this\\""`
)

await expect(laterPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`"Transaction client already released. Did you forget to await on something?"`
)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'query_error_tick'`
)
expect(rowCount).toBe(0)
})
})

it(`doesn't commit changes when there's a query error`, async () => {
await expect(
tx(pg, async (db) => {
await db.query(`INSERT INTO things(thing) VALUES ('query_error')`)
await db.query(`this query has an error`)
describe('with PoolClient client', () => {
let pool: Pool
let pg: PoolClient

beforeAll(async () => {
const { POSTGRES_URL } = process.env

if (!POSTGRES_URL) {
throw new Error('Must specify POSTGRES_URL')
}

pool = new Pool({ connectionString: POSTGRES_URL })
pg = await pool.connect()

await pg.query(`CREATE TABLE things (
id SERIAL PRIMARY KEY,
thing TEXT NOT NULL
)`)
})
afterAll(async () => {
await pg.query(`DROP TABLE THINGS`)
await pg.release()
await pool.end()
})

it(`commits changes when there's no error`, async () => {
await tx(pg, async (db) => {
db.query(`INSERT INTO things (thing) VALUES ('comitted')`)
})

const {
rows: [committed]
} = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'comitted'`
)
expect(committed).toBeDefined()
expect(committed.thing).toEqual('comitted')
})

it(`doesn't commit changes with forcedRollback`, async () => {
await tx(pg, async (db) => {
db.query(`INSERT INTO things (thing) VALUES ('comitted')`)
})
).rejects.toThrowErrorMatchingInlineSnapshot(
`"syntax error at or near \\"this\\""`
)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'query_error'`
)
expect(rowCount).toBe(0)
})

it(`doesn't commit changes on next tick after query error`, async () => {
let laterPromise
const txPromise = tx(pg, async (db) => {
laterPromise = (async () => {
try {
await txPromise
} catch (e) {
// ignore error, we just needed to wait until it completed
}
await new Promise((resolve) => {
setImmediate(resolve)
const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'node_error'`
)
expect(rowCount).toBe(0)
})

it(`doesn't commit changes when there's a Node exception`, async () => {
await expect(
tx(pg, async (db) => {
await db.query(`INSERT INTO things (thing) VALUES ('node_error')`)
throw new Error(`node error`)
})
await db.query(`INSERT INTO things(thing) VALUES ('query_error_tick')`)
})()
await Promise.all([db.query(`this query has an error`), laterPromise])
).rejects.toThrowError(`node error`)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'node_error'`
)
expect(rowCount).toBe(0)
})

await expect(txPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`"syntax error at or near \\"this\\""`
)
it(`doesn't commit changes when there's a query error`, async () => {
await expect(
tx(pg, async (db) => {
await db.query(`INSERT INTO things(thing) VALUES ('query_error')`)
await db.query(`this query has an error`)
})
).rejects.toThrowErrorMatchingInlineSnapshot(
`"syntax error at or near \\"this\\""`
)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'query_error'`
)
expect(rowCount).toBe(0)
})

await expect(laterPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`"client already released"`
)
it(`doesn't commit changes on next tick after query error`, async () => {
let laterPromise
const txPromise = tx(pg, async (db) => {
laterPromise = (async () => {
try {
await txPromise
} catch (e) {
// ignore error, we just needed to wait until it completed
}
await new Promise((resolve) => {
setImmediate(resolve)
})
await db.query(
`INSERT INTO things(thing) VALUES ('query_error_tick')`
)
})()
await Promise.all([db.query(`this query has an error`), laterPromise])
})

await expect(txPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`"syntax error at or near \\"this\\""`
)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'query_error_tick'`
)
expect(rowCount).toBe(0)
await expect(laterPromise).rejects.toThrowErrorMatchingInlineSnapshot(
`"Transaction client already released. Did you forget to await on something?"`
)

const { rowCount } = await pg.query(
`SELECT id, thing FROM things WHERE thing = 'query_error_tick'`
)
expect(rowCount).toBe(0)
})
})
})
18 changes: 14 additions & 4 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@ import { ProxyClient } from './proxy_client'
/**
* @param pg node postgres pool
* @param callback callback that will use provided transaction client
* @param forceRollback force rollback even without errors - useful for integration tests
* @param forceRollback force rollback even without errors - useful for tests
* @returns
*/
export default async function tx<T>(
pg: Pool,
pg: Pool|PoolClient,
callback: (db: PoolClient) => Promise<T>,
forceRollback?: boolean
): Promise<T> {
const client = await pg.connect()
let connected
let client: PoolClient
if (pg instanceof Pool) {
client = await pg.connect()
connected = true
} else {
client = pg
connected = false
}
const proxyClient = new ProxyClient(client)
await proxyClient.query(`BEGIN`)

Expand All @@ -26,6 +34,8 @@ export default async function tx<T>(
await client.query(`ROLLBACK`)
throw e
} finally {
client.release()
if (connected) {
client.release()
}
}
}
2 changes: 1 addition & 1 deletion src/proxy_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export class ProxyClient implements PoolClient {

private releaseCheck() {
if (this.released) {
throw new Error(`client already released`)
throw new Error(`Transaction client already released. Did you forget to await on something?`)
}
}

Expand Down

0 comments on commit d74ef57

Please sign in to comment.