Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 49 additions & 50 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ export class SQLocal {
this.reinitChannel.postMessage(message);
};

protected exec = async (
exec = async (
sql: string,
params: unknown[],
method: Sqlite3Method = 'all',
Expand All @@ -247,8 +247,10 @@ export class SQLocal {
};

if (message.type === 'data') {
data.rows = message.data[0]?.rows ?? [];
data.columns = message.data[0]?.columns ?? [];
const results = message.data[0];
data.rows = results?.rows ?? [];
data.columns = results?.columns ?? [];
data.numAffectedRows = results?.numAffectedRows;
}

return data;
Expand Down Expand Up @@ -310,55 +312,52 @@ export class SQLocal {
action: 'begin',
});

const query = async <Result extends Record<string, any>>(
passStatement: StatementInput<Result>
): Promise<Result[]> => {
const statement = normalizeStatement(passStatement);
if (statement.exec) {
this.transactionQueryKeyQueue.push(transactionKey);
return statement.exec();
}
const { rows, columns } = await this.exec(
statement.sql,
statement.params,
'all',
transactionKey
);
const resultRecords = convertRowsToObjects(rows, columns) as Result[];
return resultRecords;
};

const sql = async <Result extends Record<string, any>>(
queryTemplate: TemplateStringsArray | string,
...params: unknown[]
): Promise<Result[]> => {
const statement = normalizeSql(queryTemplate, params);
const resultRecords = await query<Result>(statement);
return resultRecords;
};

const commit = async (): Promise<void> => {
await this.createQuery({
type: 'transaction',
transactionKey,
action: 'commit',
});
};

const rollback = async (): Promise<void> => {
await this.createQuery({
type: 'transaction',
transactionKey,
action: 'rollback',
});
const transaction: Transaction = {
transactionKey,
lastAffectedRows: undefined,
query: async <Result extends Record<string, any>>(
passStatement: StatementInput<Result>
): Promise<Result[]> => {
const statement = normalizeStatement(passStatement);
if (statement.exec) {
this.transactionQueryKeyQueue.push(transactionKey);
return statement.exec();
}
const { rows, columns, numAffectedRows } = await this.exec(
statement.sql,
statement.params,
'all',
transactionKey
);
transaction.lastAffectedRows = numAffectedRows;
const resultRecords = convertRowsToObjects(rows, columns) as Result[];
return resultRecords;
},
sql: async <Result extends Record<string, any>>(
queryTemplate: TemplateStringsArray | string,
...params: unknown[]
): Promise<Result[]> => {
const statement = normalizeSql(queryTemplate, params);
const resultRecords = await transaction.query<Result>(statement);
return resultRecords;
},
commit: async (): Promise<void> => {
await this.createQuery({
type: 'transaction',
transactionKey,
action: 'commit',
});
},
rollback: async (): Promise<void> => {
await this.createQuery({
type: 'transaction',
transactionKey,
action: 'rollback',
});
},
};

return {
query,
sql,
commit,
rollback,
};
return transaction;
};

transaction = async <Result>(
Expand Down
3 changes: 3 additions & 0 deletions src/drivers/sqlite-memory-driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ export class SQLiteMemoryDriver implements SQLocalDriver {
}

protected execOnDb(db: Sqlite3Db, statement: DriverStatement): RawResultData {
const changesBefore = db.changes(true, true);
const statementData: RawResultData = {
rows: [],
columns: [],
Expand All @@ -228,6 +229,8 @@ export class SQLiteMemoryDriver implements SQLocalDriver {
break;
}

// @ts-expect-error https://github.com/sqlite/sqlite-wasm/pull/122
statementData.numAffectedRows = db.changes(true, true) - changesBefore;
return statementData;
}

Expand Down
14 changes: 13 additions & 1 deletion src/kysely/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import {
import type { DatabaseConnection, Dialect, Driver, QueryResult } from 'kysely';
import { SQLocal } from '../index.js';
import type { Transaction } from '../types.js';
import { convertRowsToObjects } from '../lib/convert-rows-to-objects.js';
import { normalizeSql } from '../lib/normalize-sql.js';

export class SQLocalKysely extends SQLocal {
dialect: Dialect = {
Expand Down Expand Up @@ -58,15 +60,25 @@ class SQLocalKyselyConnection implements DatabaseConnection {
query: CompiledQuery
): Promise<QueryResult<Result>> {
let rows;
let affectedRows: bigint | undefined;

if (this.transaction === null) {
rows = await this.client.sql(query.sql, ...query.parameters);
const statement = normalizeSql(query.sql, [...query.parameters]);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to make a copy of the array here?

const result = await this.client.exec(
statement.sql,
statement.params,
'all'
);
rows = convertRowsToObjects(result.rows, result.columns);
affectedRows = result.numAffectedRows;
} else {
rows = await this.transaction.query(query);
affectedRows = this.transaction.lastAffectedRows;
}

return {
rows: rows as Result[],
numAffectedRows: affectedRows,
};
}

Expand Down
1 change: 1 addition & 0 deletions src/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ export type DataMessage = {
data: {
columns: string[];
rows: unknown[] | unknown[][];
numAffectedRows?: bigint;
}[];
};
export type BufferMessage = {
Expand Down
3 changes: 3 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ export type Transaction = {
) => Promise<Result[]>;
commit: () => Promise<void>;
rollback: () => Promise<void>;
lastAffectedRows?: bigint;
transactionKey: QueryKey;
};

export type ReactiveQuery<Result = unknown> = {
Expand All @@ -68,6 +70,7 @@ export type ReactiveQueryStatus = 'pending' | 'ok' | 'error';
export type RawResultData = {
rows: unknown[] | unknown[][];
columns: string[];
numAffectedRows?: bigint;
};

// Driver
Expand Down
101 changes: 101 additions & 0 deletions test/kysely/dialect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,72 @@ describe.each(testVariation('kysely-dialect'))(
expect(select2).toEqual([{ name: 'white rice' }, { name: 'bread' }]);
});

it('should return affected rows for regular queries', async () => {
// Test INSERT
const insertResult = await db
.insertInto('groceries')
.values([{ name: 'bread' }, { name: 'milk' }, { name: 'rice' }])
.executeTakeFirst();
expect(insertResult.numInsertedOrUpdatedRows).toBe(3n);

const [insertResult2] = await db
.insertInto('prices')
.values([{ groceryId: 1, price: 1.99 }])
.execute();
expect(insertResult2.numInsertedOrUpdatedRows).toBe(1n);

// Test SELECT
const selectResult = await db
.selectFrom('groceries')
.selectAll()
.execute();
expect(selectResult).toEqual([
{ id: 1, name: 'bread' },
{ id: 2, name: 'milk' },
{ id: 3, name: 'rice' },
]);

// Test UPDATE
const updateResult = await db
.updateTable('groceries')
.set({ name: 'brown rice' })
.where('name', '=', 'rice')
.executeTakeFirst();
expect(updateResult.numUpdatedRows).toBe(1n);

// Test UPDATE multiple rows
const updateAllResult = await db
.updateTable('groceries')
.set({ name: 'updated' })
.executeTakeFirst();
expect(updateAllResult.numUpdatedRows).toBe(3n);

// Test DELETE
const deleteResult = await db
.deleteFrom('groceries')
.where('name', '=', 'updated')
.executeTakeFirst();
expect(deleteResult.numDeletedRows).toBe(3n);

// Verify all deleted
const remaining = await db.selectFrom('groceries').selectAll().execute();
expect(remaining.length).toBe(0);
});

it('should have 0 affected rows for non-data-modifying queries', async () => {
const insert1 = db
.insertInto('groceries')
.values({ name: 'bread' })
.compile();

const result = await db.executeQuery(insert1);
expect(result.numAffectedRows).toBe(1n);

const select1 = db.selectFrom('groceries').selectAll().compile();
const result2 = await db.executeQuery(select1);
expect(result2.numAffectedRows).toBe(0n);
});

it('should execute queries with relations', async () => {
await db
.insertInto('groceries')
Expand Down Expand Up @@ -224,6 +290,41 @@ describe.each(testVariation('kysely-dialect'))(
expect(data.length).toBe(2);
});

it('should return affected rows in transactions using kysely way', async () => {
await db.transaction().execute(async (tx) => {
// Insert multiple rows
const insertResult = await tx
.insertInto('groceries')
.values([{ name: 'bread' }, { name: 'milk' }, { name: 'rice' }])
.executeTakeFirst();
expect(insertResult.numInsertedOrUpdatedRows).toBe(3n);

// Update one row
const updateOneResult = await tx
.updateTable('groceries')
.set({ name: 'brown rice' })
.where('name', '=', 'rice')
.executeTakeFirst();
expect(updateOneResult.numUpdatedRows).toBe(1n);

// Update all rows
const updateAllResult = await tx
.updateTable('groceries')
.set({ name: 'updated' })
.executeTakeFirst();
expect(updateAllResult.numUpdatedRows).toBe(3n);

// Delete all rows
const deleteResult = await tx
.deleteFrom('groceries')
.executeTakeFirst();
expect(deleteResult.numDeletedRows).toBe(3n);
});

const remaining = await db.selectFrom('groceries').selectAll().execute();
expect(remaining.length).toBe(0);
});

it('should rollback failed transaction using kysely way', async () => {
await db
.transaction()
Expand Down