diff --git a/src/client.ts b/src/client.ts index 70b15f8..0b484b5 100644 --- a/src/client.ts +++ b/src/client.ts @@ -227,7 +227,7 @@ export class SQLocal { this.reinitChannel.postMessage(message); }; - protected exec = async ( + exec = async ( sql: string, params: unknown[], method: Sqlite3Method = 'all', @@ -247,8 +247,10 @@ export class SQLocal { }; if (message.type === 'data') { - data.rows = message.data[0]?.rows ?? []; - data.columns = message.data[0]?.columns ?? []; + const results = message.data[0]; + data.rows = results?.rows ?? []; + data.columns = results?.columns ?? []; + data.numAffectedRows = results?.numAffectedRows; } return data; @@ -310,55 +312,52 @@ export class SQLocal { action: 'begin', }); - const query = async >( - passStatement: StatementInput - ): Promise => { - const statement = normalizeStatement(passStatement); - if (statement.exec) { - this.transactionQueryKeyQueue.push(transactionKey); - return statement.exec(); - } - const { rows, columns } = await this.exec( - statement.sql, - statement.params, - 'all', - transactionKey - ); - const resultRecords = convertRowsToObjects(rows, columns) as Result[]; - return resultRecords; - }; - - const sql = async >( - queryTemplate: TemplateStringsArray | string, - ...params: unknown[] - ): Promise => { - const statement = normalizeSql(queryTemplate, params); - const resultRecords = await query(statement); - return resultRecords; - }; - - const commit = async (): Promise => { - await this.createQuery({ - type: 'transaction', - transactionKey, - action: 'commit', - }); - }; - - const rollback = async (): Promise => { - await this.createQuery({ - type: 'transaction', - transactionKey, - action: 'rollback', - }); + const transaction: Transaction = { + transactionKey, + lastAffectedRows: undefined, + query: async >( + passStatement: StatementInput + ): Promise => { + const statement = normalizeStatement(passStatement); + if (statement.exec) { + this.transactionQueryKeyQueue.push(transactionKey); + return statement.exec(); + } + const { rows, columns, numAffectedRows } = await this.exec( + statement.sql, + statement.params, + 'all', + transactionKey + ); + transaction.lastAffectedRows = numAffectedRows; + const resultRecords = convertRowsToObjects(rows, columns) as Result[]; + return resultRecords; + }, + sql: async >( + queryTemplate: TemplateStringsArray | string, + ...params: unknown[] + ): Promise => { + const statement = normalizeSql(queryTemplate, params); + const resultRecords = await transaction.query(statement); + return resultRecords; + }, + commit: async (): Promise => { + await this.createQuery({ + type: 'transaction', + transactionKey, + action: 'commit', + }); + }, + rollback: async (): Promise => { + await this.createQuery({ + type: 'transaction', + transactionKey, + action: 'rollback', + }); + }, }; - return { - query, - sql, - commit, - rollback, - }; + return transaction; }; transaction = async ( diff --git a/src/drivers/sqlite-memory-driver.ts b/src/drivers/sqlite-memory-driver.ts index 3a4c8a0..bef5349 100644 --- a/src/drivers/sqlite-memory-driver.ts +++ b/src/drivers/sqlite-memory-driver.ts @@ -203,6 +203,7 @@ export class SQLiteMemoryDriver implements SQLocalDriver { } protected execOnDb(db: Sqlite3Db, statement: DriverStatement): RawResultData { + const changesBefore = db.changes(true, true); const statementData: RawResultData = { rows: [], columns: [], @@ -228,6 +229,8 @@ export class SQLiteMemoryDriver implements SQLocalDriver { break; } + // @ts-expect-error https://github.com/sqlite/sqlite-wasm/pull/122 + statementData.numAffectedRows = db.changes(true, true) - changesBefore; return statementData; } diff --git a/src/kysely/client.ts b/src/kysely/client.ts index 59154f1..ca5be45 100644 --- a/src/kysely/client.ts +++ b/src/kysely/client.ts @@ -7,6 +7,8 @@ import { import type { DatabaseConnection, Dialect, Driver, QueryResult } from 'kysely'; import { SQLocal } from '../index.js'; import type { Transaction } from '../types.js'; +import { convertRowsToObjects } from '../lib/convert-rows-to-objects.js'; +import { normalizeSql } from '../lib/normalize-sql.js'; export class SQLocalKysely extends SQLocal { dialect: Dialect = { @@ -58,15 +60,25 @@ class SQLocalKyselyConnection implements DatabaseConnection { query: CompiledQuery ): Promise> { let rows; + let affectedRows: bigint | undefined; if (this.transaction === null) { - rows = await this.client.sql(query.sql, ...query.parameters); + const statement = normalizeSql(query.sql, [...query.parameters]); + const result = await this.client.exec( + statement.sql, + statement.params, + 'all' + ); + rows = convertRowsToObjects(result.rows, result.columns); + affectedRows = result.numAffectedRows; } else { rows = await this.transaction.query(query); + affectedRows = this.transaction.lastAffectedRows; } return { rows: rows as Result[], + numAffectedRows: affectedRows, }; } diff --git a/src/messages.ts b/src/messages.ts index 6a6b9a8..3448f8d 100644 --- a/src/messages.ts +++ b/src/messages.ts @@ -108,6 +108,7 @@ export type DataMessage = { data: { columns: string[]; rows: unknown[] | unknown[][]; + numAffectedRows?: bigint; }[]; }; export type BufferMessage = { diff --git a/src/types.ts b/src/types.ts index 478fea6..f6cdf80 100644 --- a/src/types.ts +++ b/src/types.ts @@ -52,6 +52,8 @@ export type Transaction = { ) => Promise; commit: () => Promise; rollback: () => Promise; + lastAffectedRows?: bigint; + transactionKey: QueryKey; }; export type ReactiveQuery = { @@ -68,6 +70,7 @@ export type ReactiveQueryStatus = 'pending' | 'ok' | 'error'; export type RawResultData = { rows: unknown[] | unknown[][]; columns: string[]; + numAffectedRows?: bigint; }; // Driver diff --git a/test/kysely/dialect.test.ts b/test/kysely/dialect.test.ts index 5ae7de9..a4d9624 100644 --- a/test/kysely/dialect.test.ts +++ b/test/kysely/dialect.test.ts @@ -90,6 +90,72 @@ describe.each(testVariation('kysely-dialect'))( expect(select2).toEqual([{ name: 'white rice' }, { name: 'bread' }]); }); + it('should return affected rows for regular queries', async () => { + // Test INSERT + const insertResult = await db + .insertInto('groceries') + .values([{ name: 'bread' }, { name: 'milk' }, { name: 'rice' }]) + .executeTakeFirst(); + expect(insertResult.numInsertedOrUpdatedRows).toBe(3n); + + const [insertResult2] = await db + .insertInto('prices') + .values([{ groceryId: 1, price: 1.99 }]) + .execute(); + expect(insertResult2.numInsertedOrUpdatedRows).toBe(1n); + + // Test SELECT + const selectResult = await db + .selectFrom('groceries') + .selectAll() + .execute(); + expect(selectResult).toEqual([ + { id: 1, name: 'bread' }, + { id: 2, name: 'milk' }, + { id: 3, name: 'rice' }, + ]); + + // Test UPDATE + const updateResult = await db + .updateTable('groceries') + .set({ name: 'brown rice' }) + .where('name', '=', 'rice') + .executeTakeFirst(); + expect(updateResult.numUpdatedRows).toBe(1n); + + // Test UPDATE multiple rows + const updateAllResult = await db + .updateTable('groceries') + .set({ name: 'updated' }) + .executeTakeFirst(); + expect(updateAllResult.numUpdatedRows).toBe(3n); + + // Test DELETE + const deleteResult = await db + .deleteFrom('groceries') + .where('name', '=', 'updated') + .executeTakeFirst(); + expect(deleteResult.numDeletedRows).toBe(3n); + + // Verify all deleted + const remaining = await db.selectFrom('groceries').selectAll().execute(); + expect(remaining.length).toBe(0); + }); + + it('should have 0 affected rows for non-data-modifying queries', async () => { + const insert1 = db + .insertInto('groceries') + .values({ name: 'bread' }) + .compile(); + + const result = await db.executeQuery(insert1); + expect(result.numAffectedRows).toBe(1n); + + const select1 = db.selectFrom('groceries').selectAll().compile(); + const result2 = await db.executeQuery(select1); + expect(result2.numAffectedRows).toBe(0n); + }); + it('should execute queries with relations', async () => { await db .insertInto('groceries') @@ -224,6 +290,41 @@ describe.each(testVariation('kysely-dialect'))( expect(data.length).toBe(2); }); + it('should return affected rows in transactions using kysely way', async () => { + await db.transaction().execute(async (tx) => { + // Insert multiple rows + const insertResult = await tx + .insertInto('groceries') + .values([{ name: 'bread' }, { name: 'milk' }, { name: 'rice' }]) + .executeTakeFirst(); + expect(insertResult.numInsertedOrUpdatedRows).toBe(3n); + + // Update one row + const updateOneResult = await tx + .updateTable('groceries') + .set({ name: 'brown rice' }) + .where('name', '=', 'rice') + .executeTakeFirst(); + expect(updateOneResult.numUpdatedRows).toBe(1n); + + // Update all rows + const updateAllResult = await tx + .updateTable('groceries') + .set({ name: 'updated' }) + .executeTakeFirst(); + expect(updateAllResult.numUpdatedRows).toBe(3n); + + // Delete all rows + const deleteResult = await tx + .deleteFrom('groceries') + .executeTakeFirst(); + expect(deleteResult.numDeletedRows).toBe(3n); + }); + + const remaining = await db.selectFrom('groceries').selectAll().execute(); + expect(remaining.length).toBe(0); + }); + it('should rollback failed transaction using kysely way', async () => { await db .transaction()