From 34a6cf37ca81541fdae2f8b0d7ee2d11ae8d2571 Mon Sep 17 00:00:00 2001 From: Forbes Lindesay Date: Mon, 10 Apr 2023 15:33:18 +0100 Subject: [PATCH 1/5] refactor: redesign pg-bulk API to be more flexibile & efficient --- .../pg-bulk/src/__tests__/index.test.pg.ts | 140 ++++--- packages/pg-bulk/src/index.ts | 372 +++++++----------- .../src/__tests__/bulk-insert.test.pg.ts | 52 ++- packages/pg-typed/src/index.ts | 371 +++++++---------- 4 files changed, 388 insertions(+), 547 deletions(-) diff --git a/packages/pg-bulk/src/__tests__/index.test.pg.ts b/packages/pg-bulk/src/__tests__/index.test.pg.ts index 3667ed23..f9d9ff02 100644 --- a/packages/pg-bulk/src/__tests__/index.test.pg.ts +++ b/packages/pg-bulk/src/__tests__/index.test.pg.ts @@ -1,15 +1,20 @@ import connect, {sql} from '@databases/pg'; import { - bulkInsert, - bulkSelect, - bulkUpdate, - bulkDelete, - BulkOperationOptions, + bulkInsertStatement, + bulkWhereCondition, + bulkUpdateStatement, + bulkDeleteStatement, } from '..'; const SCHEMA_NAME = `bulk_utils_test`; const TABLE_NAME = `users`; const table = sql.ident(SCHEMA_NAME, TABLE_NAME); +const COLUMN_TYPES = { + id: sql`BIGINT`, + screen_name: sql`TEXT`, + bio: sql`TEXT`, + age: sql`INT`, +}; let queries: {readonly text: string; readonly values: readonly any[]}[] = []; const db = connect({ @@ -38,18 +43,6 @@ function expectQueries(fn: () => Promise) { ).resolves; } -const options: BulkOperationOptions<'id' | 'screen_name' | 'bio' | 'age'> = { - database: db, - schemaName: SCHEMA_NAME, - tableName: TABLE_NAME, - columnTypes: { - id: sql`BIGINT`, - screen_name: sql`TEXT`, - bio: sql`TEXT`, - age: sql`INT`, - }, -}; - afterAll(async () => { await db.dispose(); }); @@ -94,19 +87,21 @@ test('create users in bulk', async () => { names.push(`bulk_insert_name_${i}`); } await expectQueries(async () => { - await bulkInsert({ - ...options, - columnsToInsert: [`screen_name`, `age`, `bio`], - records: names.map((n) => ({ - screen_name: n, - age: 42, - bio: `My name is ${n}`, - })), - }); + await db.query( + bulkInsertStatement({ + table, + operations: names, + columns: { + screen_name: {getValue: (n) => n, type: COLUMN_TYPES.screen_name}, + age: {value: 42}, + bio: {getValue: (n) => `My name is ${n}`, type: COLUMN_TYPES.bio}, + }, + }), + ); }).toEqual([ { - text: `INSERT INTO "users" ("screen_name","age","bio") SELECT * FROM UNNEST($1::TEXT[],$2::INT[],$3::TEXT[])`, - values: ['Array', 'Array', 'Array'], + text: `INSERT INTO "users" ("age","screen_name","bio") SELECT $1,* FROM UNNEST($2::TEXT[],$3::TEXT[])`, + values: [42, 'Array', 'Array'], }, ]); const records = await db.query( @@ -119,17 +114,24 @@ test('create users in bulk', async () => { test('query users in bulk', async () => { await expectQueries(async () => { expect( - await bulkSelect({ - ...options, - whereColumnNames: [`screen_name`, `age`], - whereConditions: [ - {screen_name: `bulk_insert_name_5`, age: 42}, - {screen_name: `bulk_insert_name_6`, age: 42}, - {screen_name: `bulk_insert_name_7`, age: 32}, - ], - selectColumnNames: [`screen_name`, `age`, `bio`], - orderBy: [{columnName: `screen_name`, direction: `ASC`}], - }), + await db.query( + sql`SELECT screen_name, age, bio FROM ${table} WHERE ${bulkWhereCondition( + { + operations: [ + {screen_name: `bulk_insert_name_5`, age: 42}, + {screen_name: `bulk_insert_name_6`, age: 42}, + {screen_name: `bulk_insert_name_7`, age: 32}, + ], + whereColumns: { + screen_name: { + getValue: (o) => o.screen_name, + type: COLUMN_TYPES.screen_name, + }, + age: {getValue: (o) => o.age, type: COLUMN_TYPES.age}, + }, + }, + )} ORDER BY screen_name ASC`, + ), ).toEqual([ { screen_name: `bulk_insert_name_5`, @@ -144,7 +146,7 @@ test('query users in bulk', async () => { ]); }).toEqual([ { - text: `SELECT "screen_name","age","bio" FROM "users" WHERE ("screen_name","age") IN (SELECT * FROM UNNEST($1::TEXT[],$2::INT[])) ORDER BY "screen_name" ASC`, + text: `SELECT screen_name, age, bio FROM "users" WHERE ("screen_name","age") IN (SELECT * FROM UNNEST($1::TEXT[],$2::INT[])) ORDER BY screen_name ASC`, values: ['Array', 'Array'], }, ]); @@ -152,20 +154,28 @@ test('query users in bulk', async () => { test('update users in bulk', async () => { await expectQueries(async () => { - await bulkUpdate({ - ...options, - whereColumnNames: [`screen_name`, `age`], - setColumnNames: [`age`], - updates: [ - {where: {screen_name: `bulk_insert_name_10`, age: 42}, set: {age: 1}}, - {where: {screen_name: `bulk_insert_name_11`, age: 42}, set: {age: 2}}, - {where: {screen_name: `bulk_insert_name_12`, age: 32}, set: {age: 3}}, - ], - }); + await db.query( + bulkUpdateStatement({ + table, + operations: [ + {name: `bulk_insert_name_10`, age: 42, new_age: 1}, + {name: `bulk_insert_name_11`, age: 42, new_age: 2}, + {name: `bulk_insert_name_12`, age: 32, new_age: 3}, + ], + whereColumns: { + screen_name: { + getValue: (o) => o.name, + type: COLUMN_TYPES.screen_name, + }, + age: {getValue: (o) => o.age, type: COLUMN_TYPES.age}, + }, + setColumns: {age: {getValue: (o) => o.new_age, type: COLUMN_TYPES.age}}, + }), + ); }).toEqual([ { - text: `UPDATE "users" SET "age" = "bulk_query"."updated_value_of_age" FROM (SELECT * FROM UNNEST($1::TEXT[],$2::INT[],$3::INT[]) AS bulk_query("screen_name","age","updated_value_of_age")) AS bulk_query WHERE "users"."screen_name" = "bulk_query"."screen_name" AND "users"."age" = "bulk_query"."age"`, - values: ['Array', 'Array', 'Array'], + text: `UPDATE "users" SET "age" = bulk_query."set_0" FROM UNNEST($1::INT[],$2::TEXT[],$3::INT[]) AS bulk_query("set_0","where_0","where_1") WHERE "screen_name" = bulk_query."where_0" AND "age" = bulk_query."where_1"`, + values: ['Array', 'Array', 'Array'], }, ]); expect( @@ -185,15 +195,23 @@ test('update users in bulk', async () => { test('delete users in bulk', async () => { await expectQueries(async () => { - await bulkDelete({ - ...options, - whereColumnNames: [`screen_name`, `age`], - whereConditions: [ - {screen_name: `bulk_insert_name_15`, age: 42}, - {screen_name: `bulk_insert_name_16`, age: 42}, - {screen_name: `bulk_insert_name_17`, age: 32}, - ], - }); + await db.query( + bulkDeleteStatement({ + table, + operations: [ + {name: `bulk_insert_name_15`, age: 42}, + {name: `bulk_insert_name_16`, age: 42}, + {name: `bulk_insert_name_17`, age: 32}, + ], + whereColumns: { + screen_name: { + getValue: (o) => o.name, + type: COLUMN_TYPES.screen_name, + }, + age: {getValue: (o) => o.age, type: COLUMN_TYPES.age}, + }, + }), + ); }).toEqual([ { text: `DELETE FROM "users" WHERE ("screen_name","age") IN (SELECT * FROM UNNEST($1::TEXT[],$2::INT[]))`, diff --git a/packages/pg-bulk/src/index.ts b/packages/pg-bulk/src/index.ts index 6248a5b0..1abaeaa2 100644 --- a/packages/pg-bulk/src/index.ts +++ b/packages/pg-bulk/src/index.ts @@ -1,248 +1,178 @@ -import {SQLQuery, Queryable} from '@databases/pg'; +import {sql, SQLQuery} from '@databases/pg'; -type ColumnName = string | number | symbol; -export interface BulkOperationOptions { - readonly database: Queryable; - readonly tableName: string; - readonly columnTypes: {readonly [K in TColumnName]: SQLQuery}; - readonly schemaName?: string; - readonly serializeValue?: (columnName: string, value: unknown) => unknown; +export interface ConstantOperationValue { + readonly getValue?: undefined; + readonly value: unknown; + readonly type?: undefined; } - -export interface BulkInsertOptions - extends BulkOperationOptions { - readonly columnsToInsert: readonly TColumnToInsert[]; - readonly records: readonly any[]; +export interface DynamicOperationValue { + readonly getValue: ( + operation: TOperation, + index: number, + operations: readonly TOperation[], + ) => unknown; + readonly value?: undefined; + readonly type: SQLQuery; } +export type BulkOperationValue = + | ConstantOperationValue + | DynamicOperationValue; -export interface BulkConditionOptions - extends BulkOperationOptions { - readonly whereColumnNames: readonly TWhereColumn[]; - readonly whereConditions: readonly any[]; -} -export interface BulkSelectOptions - extends BulkConditionOptions { - readonly distinctColumnNames?: readonly string[]; - readonly selectColumnNames?: readonly string[]; - readonly orderBy?: readonly { - readonly columnName: string; - readonly direction: 'ASC' | 'DESC'; - }[]; - readonly limit?: number; +function prepareColumns( + columns: Record>, +): readonly [ + readonly [SQLQuery, unknown][], + readonly [SQLQuery, DynamicOperationValue][], +] { + const constantColumns: [SQLQuery, unknown][] = []; + const dynamicColumns: [SQLQuery, DynamicOperationValue][] = []; + for (const [columnName, value] of Object.entries(columns)) { + if (value.getValue) { + dynamicColumns.push([sql.ident(columnName), value]); + } else { + constantColumns.push([sql.ident(columnName), value.value]); + } + } + return [constantColumns, dynamicColumns] as const; } -export interface BulkUpdateOptions< - TWhereColumn extends ColumnName, - TSetColumn extends ColumnName, -> extends BulkOperationOptions { - readonly whereColumnNames: readonly TWhereColumn[]; - readonly setColumnNames: readonly TSetColumn[]; - readonly updates: readonly {readonly where: any; readonly set: any}[]; +export interface BulkInsertOptions { + readonly table: SQLQuery; + readonly columns: Record>; + readonly operations: readonly TOperation[]; } -export interface BulkDeleteOptions - extends BulkConditionOptions {} - -function tableId( - options: BulkOperationOptions, +export function bulkInsertStatement( + options: BulkInsertOptions, ) { - const {sql} = options.database; - return options.schemaName - ? sql.ident(options.schemaName, options.tableName) - : sql.ident(options.tableName); -} + const {table, columns, operations} = options; -function select( - columns: readonly { - readonly name: TColumnName; - readonly getValue?: (record: any) => unknown; - }[], - records: readonly any[], - options: BulkOperationOptions, -) { - const {database, columnTypes, serializeValue} = options; - const {sql} = database; - return sql`SELECT * FROM UNNEST(${sql.join( - columns.map(({name, getValue}) => { - const typeName = columnTypes[name]; - if (!typeName) { - throw new Error(`Missing type name for ${name as string}`); - } - return sql`${records.map((r) => { - const value = getValue ? getValue(r) : r[name]; - return serializeValue - ? serializeValue(`${name as string}`, value) - : value; - })}::${typeName}[]`; - }), + const [constantColumns, dynamicColumns] = prepareColumns(columns); + + // TODO: handle cases where all columns are constant + + return sql`INSERT INTO ${table} (${sql.join( + [...constantColumns, ...dynamicColumns].map(([c]) => c), + `,`, + )}) SELECT ${sql.join( + constantColumns.map(([, value]) => sql`${value}`), + `,`, + )},* FROM UNNEST(${sql.join( + dynamicColumns.map( + ([, {getValue, type}]) => sql`${operations.map(getValue)}::${type}[]`, + ), `,`, )})`; } -export function bulkInsertStatement( - options: BulkInsertOptions, -): SQLQuery { - const {database, columnsToInsert, records} = options; - const {sql} = database; - return sql`INSERT INTO ${tableId(options)} (${sql.join( - columnsToInsert.map((columnName) => sql.ident(columnName)), - `,`, - )}) ${select( - columnsToInsert.map((name) => ({name})), - records, - options, - )}`; +export interface BulkUpdateOptions { + readonly table: SQLQuery; + readonly setColumns: Record>; + readonly whereColumns: Record>; + readonly operations: readonly TOperation[]; } +export function bulkUpdateStatement( + options: BulkUpdateOptions, +) { + const {table, whereColumns, setColumns, operations} = options; + const [constantSetColumns, dynamicSetColumns] = prepareColumns(setColumns); + const [constantWhereColumns, dynamicWhereColumns] = + prepareColumns(whereColumns); -export async function bulkInsert( - options: BulkInsertOptions & {returning: SQLQuery}, -): Promise; -export async function bulkInsert( - options: BulkInsertOptions, -): Promise; -export async function bulkInsert( - options: BulkInsertOptions & {returning?: SQLQuery}, -): Promise { - const {database, returning} = options; - const {sql} = database; - return await database.query( - returning - ? sql`${bulkInsertStatement(options)} RETURNING ${returning}` - : bulkInsertStatement(options), - ); -} + if (dynamicSetColumns.length === 0 && dynamicWhereColumns.length === 0) { + if (constantWhereColumns.length === 0) { + return sql`UPDATE ${table} SET ${sql.join( + constantSetColumns.map(([c, v]) => sql`${c} = ${v}`), + `,`, + )}`; + } + return sql`UPDATE ${table} SET ${sql.join( + constantSetColumns.map(([c, v]) => sql`${c} = ${v}`), + `,`, + )} WHERE ${sql.join( + constantWhereColumns.map(([c, v]) => sql`${c} = ${v}`), + ` AND `, + )}`; + } + if (dynamicWhereColumns.length === 0) { + throw new Error( + `You cannot have dynamic set columns but no dynamic where columns in a bulk update.`, + ); + } -export function bulkCondition( - options: BulkConditionOptions, -): SQLQuery { - const {database, whereColumnNames, whereConditions} = options; - const {sql} = database; - return sql`(${sql.join( - whereColumnNames.map((columnName) => sql.ident(columnName)), + return sql`UPDATE ${table} SET ${sql.join( + [ + ...constantSetColumns.map(([c, v]) => sql`${c} = ${v}`), + ...dynamicSetColumns.map( + ([c], i) => sql`${c} = bulk_query.${sql.ident(`set_${i}`)}`, + ), + ], `,`, - )}) IN (${select( - whereColumnNames.map((columnName) => ({name: columnName})), - whereConditions, - options, - )})`; -} - -export async function bulkSelect( - options: BulkSelectOptions, -): Promise { - const {database, distinctColumnNames, selectColumnNames, orderBy, limit} = - options; - const {sql} = database; - return await database.query( - sql.join( - [ - sql`SELECT`, - distinctColumnNames?.length - ? sql`DISTINCT ON (${sql.join( - distinctColumnNames.map((columnName) => sql.ident(columnName)), - `,`, - )})` - : null, - selectColumnNames - ? sql.join( - selectColumnNames.map((columnName) => sql.ident(columnName)), - ',', - ) - : sql`*`, - sql`FROM ${tableId(options)} WHERE`, - bulkCondition(options), - orderBy?.length - ? sql`ORDER BY ${sql.join( - orderBy.map((q) => - q.direction === 'ASC' - ? sql`${sql.ident(q.columnName)} ASC` - : sql`${sql.ident(q.columnName)} DESC`, - ), - sql`, `, - )}` - : null, - limit ? sql`LIMIT ${limit}` : null, - ].filter((v: T): v is Exclude => v !== null), - sql` `, + )} FROM UNNEST(${sql.join( + [...dynamicSetColumns, ...dynamicWhereColumns].map( + ([, {getValue, type}]) => sql`${operations.map(getValue)}::${type}[]`, ), - ); + `,`, + )}) AS bulk_query(${sql.join( + [ + ...dynamicSetColumns.map((_, i) => sql.ident(`set_${i}`)), + ...dynamicWhereColumns.map((_, i) => sql.ident(`where_${i}`)), + ], + `,`, + )}) WHERE ${sql.join( + [ + ...constantWhereColumns.map(([c, v]) => sql`${c} = ${v}`), + ...dynamicWhereColumns.map( + ([c], i) => sql`${c} = bulk_query.${sql.ident(`where_${i}`)}`, + ), + ], + ` AND `, + )}`; } -export async function bulkUpdate< - TWhereColumn extends ColumnName, - TSetColumn extends ColumnName, ->( - options: BulkUpdateOptions & {returning: SQLQuery}, -): Promise; -export async function bulkUpdate< - TWhereColumn extends ColumnName, - TSetColumn extends ColumnName, ->(options: BulkUpdateOptions): Promise; -export async function bulkUpdate< - TWhereColumn extends ColumnName, - TSetColumn extends ColumnName, ->( - options: BulkUpdateOptions & {returning?: SQLQuery}, -): Promise { - const { - database, - tableName, - whereColumnNames, - setColumnNames, - updates, - returning, - } = options; - const {sql} = database; - return await database.query( - sql`UPDATE ${tableId(options)} SET ${sql.join( - setColumnNames.map( - (columnName) => - sql`${sql.ident(columnName)} = ${sql.ident( - `bulk_query`, - `updated_value_of_${columnName as string}`, - )}`, +export interface BulkWhereConditionOptions { + readonly table?: SQLQuery; + readonly whereColumns: Record>; + readonly operations: readonly TOperation[]; +} +export function bulkWhereCondition( + options: BulkWhereConditionOptions, +) { + const {table, whereColumns: columns, operations} = options; + const [constantColumns, dynamicColumns] = prepareColumns(columns); + const conditions: SQLQuery[] = constantColumns.map(([c, v]) => + table ? sql`${table}.${c} = ${v}` : sql`${c} = ${v}`, + ); + if (dynamicColumns.length !== 0) { + const columns = sql.join( + dynamicColumns.map(([columnName]) => + table ? sql`${table}.${columnName}` : columnName, ), `,`, - )} FROM (${select( - [ - ...whereColumnNames.map((columnName) => ({ - name: columnName, - getValue: (u: any) => u.where[columnName], - })), - ...setColumnNames.map((columnName) => ({ - name: columnName, - getValue: (u: any) => u.set[columnName], - })), - ], - updates, - options, - )} AS bulk_query(${sql.join( - [ - ...whereColumnNames.map((columnName) => sql.ident(columnName)), - ...setColumnNames.map((columnName) => - sql.ident(`updated_value_of_${columnName as string}`), - ), - ], - `,`, - )})) AS bulk_query WHERE ${sql.join( - whereColumnNames.map( - (columnName) => - sql`${sql.ident(tableName, columnName)} = ${sql.ident( - `bulk_query`, - columnName, - )}`, + ); + const unnestExpression = sql`SELECT * FROM UNNEST(${sql.join( + dynamicColumns.map( + ([, {getValue, type}]) => sql`${operations.map(getValue)}::${type}[]`, ), - ` AND `, - )}${returning ? sql` RETURNING ${returning}` : sql``}`, - ); + `,`, + )})`; + const condition = sql`(${columns}) IN (${unnestExpression})`; + conditions.push(condition); + } + return sql.join(conditions, ` AND `); } -export async function bulkDelete( - options: BulkDeleteOptions, -): Promise { - const {database} = options; - const {sql} = database; - await database.query( - sql`DELETE FROM ${tableId(options)} WHERE ${bulkCondition(options)}`, - ); +export interface BulkDeleteOptions + extends BulkWhereConditionOptions { + readonly table: SQLQuery; +} +export function bulkDeleteStatement( + options: BulkDeleteOptions, +) { + const {table, whereColumns, operations} = options; + const condition = bulkWhereCondition({ + whereColumns, + operations, + }); + return sql`DELETE FROM ${table} WHERE ${condition}`; } diff --git a/packages/pg-typed/src/__tests__/bulk-insert.test.pg.ts b/packages/pg-typed/src/__tests__/bulk-insert.test.pg.ts index 6ecf980f..59541b5c 100644 --- a/packages/pg-typed/src/__tests__/bulk-insert.test.pg.ts +++ b/packages/pg-typed/src/__tests__/bulk-insert.test.pg.ts @@ -135,8 +135,8 @@ test('create users in bulk', async () => { names.push(`bulk_insert_name_${i}`); } const inserted = await users(db).bulkInsert({ - columnsToInsert: [`age`], - records: names.map((n) => ({screen_name: n, age: 42})), + columnsToInsert: {screen_name: (n) => n, age: 42}, + records: names, }); expect(inserted.map((i) => i.screen_name)).toEqual(names); @@ -154,8 +154,8 @@ test('query users in bulk', async () => { expect( await users(db) .bulkFind({ - whereColumnNames: [`screen_name`, `age`], - whereConditions: [ + whereColumns: {screen_name: (c) => c.screen_name, age: (c) => c.age}, + records: [ {screen_name: `bulk_insert_name_5`, age: 42}, {screen_name: `bulk_insert_name_6`, age: 42}, {screen_name: `bulk_insert_name_7`, age: 32}, @@ -172,8 +172,8 @@ test('query users in bulk', async () => { expect( await users(db) .bulkFind({ - whereColumnNames: [`screen_name`, `age`], - whereConditions: [ + whereColumns: {screen_name: (c) => c.screen_name, age: (c) => c.age}, + records: [ {screen_name: `bulk_insert_name_3`, age: 42}, {screen_name: `bulk_insert_name_4`, age: 42}, {screen_name: `bulk_insert_name_5`, age: 42}, @@ -195,12 +195,12 @@ test('query users in bulk', async () => { test('update users in bulk', async () => { const updated = await users(db).bulkUpdate({ - whereColumnNames: [`screen_name`, `age`], - setColumnNames: [`age`], - updates: [ - {where: {screen_name: `bulk_insert_name_10`, age: 42}, set: {age: 1}}, - {where: {screen_name: `bulk_insert_name_11`, age: 42}, set: {age: 2}}, - {where: {screen_name: `bulk_insert_name_12`, age: 32}, set: {age: 3}}, + whereColumns: {screen_name: (r) => r.screen_name, age: (r) => r.age}, + setColumns: {age: (r) => r.new_age}, + records: [ + {screen_name: `bulk_insert_name_10`, age: 42, new_age: 1}, + {screen_name: `bulk_insert_name_11`, age: 42, new_age: 2}, + {screen_name: `bulk_insert_name_12`, age: 32, new_age: 3}, ], }); expect(updated.map(({screen_name, age}) => ({screen_name, age}))).toEqual([ @@ -228,8 +228,8 @@ test('update users in bulk', async () => { test('delete users in bulk', async () => { await users(db).bulkDelete({ - whereColumnNames: [`screen_name`, `age`], - whereConditions: [ + whereColumns: {screen_name: (c) => c.screen_name, age: (c) => c.age}, + records: [ {screen_name: `bulk_insert_name_15`, age: 42}, {screen_name: `bulk_insert_name_16`, age: 42}, {screen_name: `bulk_insert_name_17`, age: 32}, @@ -252,12 +252,12 @@ test('delete users in bulk', async () => { test('insertOrIgnore users in bulk', async () => { await users(db).bulkInsertOrIgnore({ - columnsToInsert: [`age`], + columnsToInsert: {screen_name: (n) => n, age: 56}, records: [ - {screen_name: `bulk_insert_name_18`, age: 56}, - {screen_name: `bulk_insert_name_19`, age: 56}, - {screen_name: `bulk_insert_name_20`, age: 56}, - {screen_name: `bulk_insert_or_ignore_name_1`, age: 56}, + `bulk_insert_name_18`, + `bulk_insert_name_19`, + `bulk_insert_name_20`, + `bulk_insert_or_ignore_name_1`, ], }); expect( @@ -283,18 +283,14 @@ test('insertOrIgnore users in bulk', async () => { test('insertOrUpdate users in bulk', async () => { await users(db).bulkInsertOrUpdate({ - columnsToInsert: [`screen_name`, `age`, `bio`], + columnsToInsert: {screen_name: (n) => n, age: 56, bio: 'Updated in bulk'}, columnsThatConflict: [`screen_name`], columnsToUpdate: [`bio`], records: [ - {screen_name: `bulk_insert_name_21`, age: 56, bio: `Updated in bulk`}, - {screen_name: `bulk_insert_name_22`, age: 56, bio: `Updated in bulk`}, - {screen_name: `bulk_insert_name_23`, age: 56, bio: `Updated in bulk`}, - { - screen_name: `bulk_insert_or_update_name_1`, - age: 56, - bio: `Updated in bulk`, - }, + `bulk_insert_name_21`, + `bulk_insert_name_22`, + `bulk_insert_name_23`, + `bulk_insert_or_update_name_1`, ], }); expect( diff --git a/packages/pg-typed/src/index.ts b/packages/pg-typed/src/index.ts index 089821b5..7acaf5ef 100644 --- a/packages/pg-typed/src/index.ts +++ b/packages/pg-typed/src/index.ts @@ -1,11 +1,11 @@ import assertNever from 'assert-never'; -import {SQLQuery, Queryable} from '@databases/pg'; +import {sql, SQLQuery, Queryable} from '@databases/pg'; import { - bulkUpdate, - bulkDelete, - BulkOperationOptions, - bulkCondition, + bulkUpdateStatement, + bulkDeleteStatement, + bulkWhereCondition, bulkInsertStatement, + BulkOperationValue, } from '@databases/pg-bulk'; const NO_RESULT_FOUND = `NO_RESULT_FOUND`; @@ -698,93 +698,28 @@ class SelectQueryImplementation } } -type BulkRecord = { - readonly [key in TKey]-?: Exclude; -} & { - readonly [key in Exclude]?: undefined; +export type BulkColumns = { + [TColumn in keyof TRecord]: + | TRecord[TColumn] + | (( + operation: TOperation, + index: number, + operations: readonly TOperation[], + ) => TRecord[TColumn]); }; -type BulkInsertFields< - TInsertParameters, - TKey extends keyof TInsertParameters, -> = - | TKey - | { - readonly [K in keyof TInsertParameters]: undefined extends TInsertParameters[K] - ? never - : K; - }[keyof TInsertParameters]; - -type BulkInsertRecord< - TInsertParameters, - TKey extends keyof TInsertParameters, -> = BulkRecord>; - -type BulkOperationOptionsBase< - TColumnName extends string | number | symbol, - TInsertColumnName extends string | number | symbol, -> = Omit, 'database'> & { - requiredInsertColumnNames: readonly TInsertColumnName[]; -}; -function getBulkOperationOptionsBase< - TColumnName extends string | number | symbol, - TInsertColumnName extends string | number | symbol, ->( - table: DatabaseSchemaTable, - { - sql, - schemaName, - serializeValue, - }: { - sql: Queryable['sql']; - schemaName?: string; - serializeValue: (column: string, value: unknown) => unknown; - }, -): BulkOperationOptionsBase { - return { - tableName: table.name, - columnTypes: Object.fromEntries( - table.columns.map((c) => [ - c.name, - sql.__dangerous__rawValue(`${c.typeName}`), - ]), - ) as any, - schemaName, - serializeValue, - requiredInsertColumnNames: table.columns - .filter((c) => !c.isNullable && !c.hasDefault) - .map((c) => c.name as TInsertColumnName), - }; -} class Table { private readonly _value: (columnName: string, value: any) => unknown; - private readonly _bulkOperationOptions: - | (BulkOperationOptions & { - requiredInsertColumnNames: readonly (keyof TInsertParameters)[]; - }) - | undefined; + private readonly _columnTypes: {[columnName: string]: SQLQuery} | undefined; constructor( private readonly _underlyingDb: Queryable, public readonly tableId: SQLQuery, public readonly tableName: string, serializeValue: (columnName: string, value: unknown) => unknown, - bulkOperationOptions: - | (BulkOperationOptions & { - requiredInsertColumnNames: readonly (keyof TInsertParameters)[]; - }) - | undefined, + columnTypes: {[columnName: string]: SQLQuery} | undefined, ) { this._value = (c, v) => serializeValue(c, v); - this._bulkOperationOptions = bulkOperationOptions; - } - - private _getBulkOperationOptions() { - if (!this._bulkOperationOptions) { - throw new Error( - `You must provide a "databaseSchema" when constructing pg-typed to use bulk operations.`, - ); - } - return this._bulkOperationOptions; + this._columnTypes = columnTypes; } conditionToSql( @@ -804,108 +739,94 @@ class Table { : query; } - async bulkInsert< - TColumnsToInsert extends readonly [ - ...(readonly (keyof TInsertParameters)[]), - ], - >({ + private _getColumnTypes() { + if (!this._columnTypes) { + throw new Error( + `You must provide a "databaseSchema" when constructing pg-typed to use bulk operations.`, + ); + } + return this._columnTypes; + } + private _prepareBulkColumns( + columns: Partial>, + ): Record> { + const columnTypes = this._getColumnTypes(); + return Object.fromEntries( + Object.entries(columns).map( + ([columnName, value]): [string, BulkOperationValue] => { + if (typeof value === 'function') { + return [ + columnName, + { + // @ts-expect-error + getValue: value, + type: columnTypes[columnName], + }, + ]; + } + return [columnName, {value}]; + }, + ), + ); + } + async bulkInsert({ columnsToInsert, records, }: { - readonly columnsToInsert: TColumnsToInsert; - readonly records: readonly BulkInsertRecord< - TInsertParameters, - TColumnsToInsert[number] - >[]; + readonly columnsToInsert: BulkColumns; + readonly records: readonly TOperation[]; }): Promise { - if (records.length === 0) { - return []; - } + const columns = this._prepareBulkColumns(columnsToInsert); + if (records.length === 0) return []; const {sql} = this._underlyingDb; return await this._underlyingDb.query( - sql`${bulkInsertStatement({ - ...this._getBulkOperationOptions(), - columnsToInsert: [ - ...new Set([ - ...columnsToInsert, - ...this._getBulkOperationOptions().requiredInsertColumnNames, - ]), - ].sort(), - records, + sql`${bulkInsertStatement({ + table: this.tableId, + columns, + operations: records, })} RETURNING ${this.tableId}.*`, ); } - async bulkInsertOrIgnore< - TColumnsToInsert extends readonly [ - ...(readonly (keyof TInsertParameters)[]), - ], - >({ + async bulkInsertOrIgnore({ columnsToInsert, records, }: { - readonly columnsToInsert: TColumnsToInsert; - readonly records: readonly BulkInsertRecord< - TInsertParameters, - TColumnsToInsert[number] - >[]; + readonly columnsToInsert: BulkColumns; + readonly records: readonly TOperation[]; }): Promise { - if (records.length === 0) { - return []; - } + const columns = this._prepareBulkColumns(columnsToInsert); + if (records.length === 0) return []; const {sql} = this._underlyingDb; return await this._underlyingDb.query( - sql`${bulkInsertStatement({ - ...this._getBulkOperationOptions(), - columnsToInsert: [ - ...new Set([ - ...columnsToInsert, - ...this._getBulkOperationOptions().requiredInsertColumnNames, - ]), - ].sort(), - records, + sql`${bulkInsertStatement({ + table: this.tableId, + columns, + operations: records, })} ON CONFLICT DO NOTHING RETURNING ${this.tableId}.*`, ); } - async bulkInsertOrUpdate< - TColumnsToInsert extends readonly [ - ...(readonly (keyof TInsertParameters)[]), - ], - >({ + async bulkInsertOrUpdate({ columnsToInsert, columnsThatConflict, columnsToUpdate, records, }: { - readonly columnsToInsert: TColumnsToInsert; - readonly columnsThatConflict: readonly [ - TColumnsToInsert[number], - ...TColumnsToInsert[number][], - ]; - readonly columnsToUpdate: readonly [ - TColumnsToInsert[number], - ...TColumnsToInsert[number][], - ]; - readonly records: readonly BulkInsertRecord< - TInsertParameters, - TColumnsToInsert[number] - >[]; + readonly columnsToInsert: BulkColumns; + readonly columnsThatConflict: readonly (keyof TRecord)[]; + // TODO: allow more complex update expressions + readonly columnsToUpdate: readonly (keyof TRecord)[]; + readonly records: readonly TOperation[]; }): Promise { - if (records.length === 0) { - return []; - } + const columns = this._prepareBulkColumns(columnsToInsert); + if (records.length === 0) return []; const {sql} = this._underlyingDb; return await this._underlyingDb.query( - sql`${bulkInsertStatement({ - ...this._getBulkOperationOptions(), - columnsToInsert: [ - ...new Set([ - ...columnsToInsert, - ...this._getBulkOperationOptions().requiredInsertColumnNames, - ]), - ].sort(), - records, + sql`${bulkInsertStatement({ + table: this.tableId, + columns, + operations: records, })} ON CONFLICT (${sql.join( columnsThatConflict.map((k) => sql.ident(k)), sql`, `, @@ -918,76 +839,60 @@ class Table { ); } - bulkFind({ - whereColumnNames, - whereConditions, - }: { - readonly whereColumnNames: TWhereColumns; - readonly whereConditions: readonly BulkRecord< - TRecord, - TWhereColumns[number] - >[]; + bulkFind(options: { + readonly whereColumns: Partial>; + readonly records: TOperation[]; }): UnorderedSelectQuery { - const bulkOperationOptions = this._getBulkOperationOptions(); + const whereColumns = this._prepareBulkColumns(options.whereColumns); + const records = options.records; + return this._findUntyped( - whereConditions.length - ? bulkCondition({ - ...bulkOperationOptions, - whereColumnNames, - whereConditions, + records.length + ? bulkWhereCondition({ + whereColumns, + operations: records, }) : 'FALSE', ); } - async bulkUpdate< - TWhereColumns extends readonly [...(readonly (keyof TRecord)[])], - TSetColumns extends readonly [...(readonly (keyof TRecord)[])], - >({ - whereColumnNames, - setColumnNames, - updates, - }: { - readonly whereColumnNames: TWhereColumns; - readonly setColumnNames: TSetColumns; - readonly updates: readonly { - readonly where: BulkRecord; - readonly set: BulkRecord; - }[]; + async bulkUpdate(options: { + readonly whereColumns: Partial>; + // TODO: allow more complex update expressions + readonly setColumns: Partial>; + readonly records: TOperation[]; }): Promise { - if (updates.length === 0) { - return []; - } + const whereColumns = this._prepareBulkColumns(options.whereColumns); + const setColumns = this._prepareBulkColumns(options.setColumns); + const records = options.records; + if (records.length === 0) return []; const {sql} = this._underlyingDb; - return await bulkUpdate({ - ...this._getBulkOperationOptions(), - whereColumnNames, - setColumnNames, - updates, - returning: sql`${this.tableId}.*`, - }); + return await this._underlyingDb.query( + sql`${bulkUpdateStatement({ + table: this.tableId, + whereColumns, + setColumns, + operations: records, + })} RETURNING ${this.tableId}.*`, + ); } - async bulkDelete< - TWhereColumns extends readonly [...(readonly (keyof TRecord)[])], - >({ - whereColumnNames, - whereConditions, - }: { - readonly whereColumnNames: TWhereColumns; - readonly whereConditions: readonly BulkRecord< - TRecord, - TWhereColumns[number] - >[]; + async bulkDelete(options: { + readonly whereColumns: Partial>; + readonly records: TOperation[]; }) { - if (whereConditions.length === 0) { - return; - } - await bulkDelete({ - ...this._getBulkOperationOptions(), - whereColumnNames, - whereConditions, - }); + const whereColumns = this._prepareBulkColumns(options.whereColumns); + const records = options.records; + if (records.length === 0) return; + + const {sql} = this._underlyingDb; + await this._underlyingDb.query( + sql`${bulkDeleteStatement({ + table: this.tableId, + whereColumns, + operations: records, + })}`, + ); } private async _insert( @@ -1263,13 +1168,21 @@ function getTable( tableSchema?: DatabaseSchemaTable, ): TableHelper { const cache = new WeakMap>(); - const bulkOperationOptionsCache = new Map< - Queryable['sql'], - BulkOperationOptionsBase< - keyof TRecord | keyof TInsertParameters, - keyof TInsertParameters - > - >(); + + let columnTypes: {[columnName: string]: SQLQuery} | undefined; + if (tableSchema) { + columnTypes = Object.fromEntries( + tableSchema.columns.map((c) => { + if (!c.typeName) { + throw new Error( + `Missing typeName for column ${c.name} in table ${tableName}`, + ); + } + return [c.name, sql.__dangerous__rawValue(c.typeName)]; + }), + ); + } + return Object.assign( ( queryable: Queryable | undefined = defaultConnection, @@ -1282,20 +1195,6 @@ function getTable( const cached = cache.get(queryable); if (cached) return cached; - let bulkOperationsBase = bulkOperationOptionsCache.get(queryable.sql); - if (tableSchema && !bulkOperationsBase) { - bulkOperationsBase = - tableSchema && - getBulkOperationOptionsBase< - keyof TRecord | keyof TInsertParameters, - keyof TInsertParameters - >(tableSchema, { - sql: queryable.sql, - schemaName, - serializeValue, - }); - bulkOperationOptionsCache.set(queryable.sql, bulkOperationsBase); - } const fresh = new Table( queryable, schemaName @@ -1303,9 +1202,7 @@ function getTable( : queryable.sql.ident(tableName), tableName, serializeValue, - bulkOperationsBase - ? {...bulkOperationsBase, database: queryable} - : undefined, + columnTypes, ); cache.set(queryable, fresh); From 5cf34c5454ac16213bf22f09958578815dd3de5a Mon Sep 17 00:00:00 2001 From: Forbes Lindesay Date: Wed, 17 May 2023 20:02:15 +0100 Subject: [PATCH 2/5] feat: work in progress v2 for pg-typed --- packages/pg-typed/src/temp.ts | 233 +++++ packages/pg-typed/src/v2/AliasedQuery.ts | 9 + packages/pg-typed/src/v2/GroupByQuery.ts | 8 + packages/pg-typed/src/v2/InsertQuery.ts | 1 + packages/pg-typed/src/v2/PostgresOperators.ts | 160 ++++ .../pg-typed/src/v2/ProjectedLimitQuery.ts | 22 + .../pg-typed/src/v2/QueryImplementation.ts | 655 ++++++++++++++ packages/pg-typed/src/v2/SelectQuery.ts | 29 + packages/pg-typed/src/v2/Table.ts | 61 ++ packages/pg-typed/src/v2/TableSchema.ts | 9 + packages/pg-typed/src/v2/WhereCondition.ts | 27 + .../pg-typed/src/v2/__tests__/index.test.ts | 74 ++ .../pg-typed/src/v2/implementation/Columns.ts | 91 ++ .../src/v2/implementation/Operators.ts | 797 ++++++++++++++++++ packages/pg-typed/src/v2/index.ts | 10 + packages/pg-typed/src/v2/types/Columns.ts | 42 + packages/pg-typed/src/v2/types/Join.ts | 23 + .../pg-typed/src/v2/types/JoinableQuery.ts | 51 ++ packages/pg-typed/src/v2/types/Operators.ts | 42 + packages/pg-typed/src/v2/types/Queries.ts | 70 ++ .../pg-typed/src/v2/types/SelectionSet.ts | 9 + .../pg-typed/src/v2/types/SpecialValues.ts | 94 +++ .../src/v2/types/TypedDatabaseQuery.ts | 9 + packages/sql/src/web.ts | 12 +- 24 files changed, 2532 insertions(+), 6 deletions(-) create mode 100644 packages/pg-typed/src/temp.ts create mode 100644 packages/pg-typed/src/v2/AliasedQuery.ts create mode 100644 packages/pg-typed/src/v2/GroupByQuery.ts create mode 100644 packages/pg-typed/src/v2/InsertQuery.ts create mode 100644 packages/pg-typed/src/v2/PostgresOperators.ts create mode 100644 packages/pg-typed/src/v2/ProjectedLimitQuery.ts create mode 100644 packages/pg-typed/src/v2/QueryImplementation.ts create mode 100644 packages/pg-typed/src/v2/SelectQuery.ts create mode 100644 packages/pg-typed/src/v2/Table.ts create mode 100644 packages/pg-typed/src/v2/TableSchema.ts create mode 100644 packages/pg-typed/src/v2/WhereCondition.ts create mode 100644 packages/pg-typed/src/v2/__tests__/index.test.ts create mode 100644 packages/pg-typed/src/v2/implementation/Columns.ts create mode 100644 packages/pg-typed/src/v2/implementation/Operators.ts create mode 100644 packages/pg-typed/src/v2/index.ts create mode 100644 packages/pg-typed/src/v2/types/Columns.ts create mode 100644 packages/pg-typed/src/v2/types/Join.ts create mode 100644 packages/pg-typed/src/v2/types/JoinableQuery.ts create mode 100644 packages/pg-typed/src/v2/types/Operators.ts create mode 100644 packages/pg-typed/src/v2/types/Queries.ts create mode 100644 packages/pg-typed/src/v2/types/SelectionSet.ts create mode 100644 packages/pg-typed/src/v2/types/SpecialValues.ts create mode 100644 packages/pg-typed/src/v2/types/TypedDatabaseQuery.ts diff --git a/packages/pg-typed/src/temp.ts b/packages/pg-typed/src/temp.ts new file mode 100644 index 00000000..61da99e4 --- /dev/null +++ b/packages/pg-typed/src/temp.ts @@ -0,0 +1,233 @@ +import {q, Table} from './v2'; + +export {}; + +interface User { + id: number; + username: string; +} +interface Post { + author_id: number; + title: string; + created_at: Date; +} + +declare const users: Table; +declare const posts: Table; + +// users.select(`id`, `username`).all(); +// const join = users +// .as(`u`) +// .innerJoin(posts.as(`p`)) +// .on((c) => op.eq(c(`u.id`), c(`p.author_id`))) +// .select((c) => ({ +// id: c(`u.id`), +// username: c(`u.username`), +// title: c(`p.title`), +// })); +const join = users + .as(`u`) + .where((u) => q.eq(u.id, 10)) + .innerJoin(posts.as(`p`)) + .on(({u, p}) => q.eq(u.id, p.author_id)) + .select(({u, p}) => ({ + id: u.id, + username: u.username, + title: p.title, + })); + +// const result = users +// .as(`u`) +// .innerJoin(posts.as(`p`)) +// .on((c) => op.eq(c(`u.id`), c(`p.author_id`))) +// .selectGroupBy( +// ({u}) => ({ +// id: u.id, +// username: u.use, +// }), +// (c) => ({ +// last_posted_at: op.max(c(`p.created_at`)), +// }), +// ) +// .orderByDesc(`last_posted_at`) +// .all(); + +const result2 = users + .as(`u`) + .innerJoin(posts.as(`p`)) + .on(({u, p}) => q.eq(u.id, p.author_id)) + .groupBy(({u}) => ({ + id: u.id, + username: u.username, + })) + .selectAggregate(({p}) => ({ + last_posted_at: q.max(p.created_at), + })) + .orderByDesc(`last_posted_at`) + .all(); + +const lastPostedAt = posts.as(`p`).selectAggregate(({p}) => ({ + last_posted_at: q.max(p.created_at), +})); + +// SELECT u.id, u.username, MAX(p.created_at) AS last_posted_at +// FROM users AS u +// INNER JOIN posts AS p ON u.id=p.author_id +// GROUP BY u.id, u.username; + +interface BaseQuery { + select>( + columns: TColumns, + ): FinishedQuery>; + where(query: Partial): BaseQuery; +} + +// interface SelectQuery { +// // as( +// // alias: TAliasTableName, +// // ): AliasedQuery<{ +// // [TKey in `${TAliasTableName}.${string & +// // keyof TRecord}`]: TKey extends `${TAliasTableName}.${infer TColumn}` +// // ? TColumn extends keyof TRecord +// // ? TRecord[TColumn] +// // : never +// // : never; +// // }>; +// as( +// alias: TAliasTableName, +// ): AliasedQuery<{[TKey in TAliasTableName]: TRecord}>; +// select( +// ...columnNames: TColumnNames +// ): FinishedQuery>; +// select>( +// columns: TColumns, +// ): FinishedQuery>; +// } + +// interface Table extends SelectQuery { +// find(where?: Partial): SelectQuery; +// } + +interface FinishedQuery { + all(): Promise; +} + +interface SelectionSet { + [k: string]: + | keyof TRecord + | `${`MAX` | `MIN` | `SUM`}(${string & keyof TRecord})` + | `COUNT(*)`; +} +type SelectionSetResult< + TRecord, + TSelectionSet extends SelectionSet, +> = { + [TAliasName in keyof TSelectionSet]: TSelectionSet[TAliasName] extends `COUNT(*)` + ? number + : TSelectionSet[TAliasName] extends `${ + | `MAX` + | `MIN` + | `SUM`}(${infer TColumnName})` + ? TColumnName extends keyof TRecord + ? TRecord[TColumnName] + : never + : TSelectionSet[TAliasName] extends keyof TRecord + ? TRecord[TSelectionSet[TAliasName]] + : never; +}; + +// interface JoinedQuery {} + +// declare function innerJoin< +// TLeftName extends string, +// TLeftRecord, +// TRightName extends string, +// TRightRecord, +// >( +// left: SelectQuery, +// right: SelectQuery, +// ): { +// on(where: { +// [key in `${TLeftName}.${string & +// keyof TLeftRecord}`]?: `${TRightName}.${string & keyof TRightRecord}`; +// }): { +// select< +// TColumns extends { +// [k: string]: +// | `${TLeftName}.${string & keyof TLeftRecord}` +// | `${TRightName}.${string & keyof TRightRecord}`; +// }, +// >( +// columns: TColumns, +// ): { +// [TAliasName in keyof TColumns]: TColumns[TAliasName] extends `${infer TTableName}.${infer TColumnName}` +// ? TTableName extends TLeftName +// ? TColumnName extends keyof TLeftRecord +// ? TLeftRecord[TColumnName] +// : unknown +// : TTableName extends TRightName +// ? TColumnName extends keyof TRightRecord +// ? TRightRecord[TColumnName] +// : unknown +// : unknown +// : unknown; +// }; +// }; +// }; + +// innerJoin(users.find().as(`u`), posts.find().as(`p`)) +// .on({ +// 'u.id': 'p.author_id', +// }) +// .select({ +// user_id: `u.id`, +// post_title: `p.title`, +// }); +// const res = users.find().select({name: `username`}).as(`u`); +// //(posts.find().as('p')).select('u.id', 'p.title'); + +// declare function innerJoinMany< +// TParts extends {[alias: string]: SelectQuery}, +// >( +// parts: TParts, +// ): { +// on(): { +// select< +// TColumns extends { +// [k: string]: { +// [TTableName in keyof TParts]: `${string & +// TTableName}.${TParts[TTableName] extends SelectQuery< +// any, +// infer TRecord +// > +// ? keyof TRecord +// : never}`; +// }[keyof TParts]; +// }, +// >( +// columns: TColumns, +// ): { +// [TAliasName in keyof TColumns]: TColumns[TAliasName] extends `${infer TTableName}.${infer TColumnName}` +// ? TTableName extends TLeftName +// ? TColumnName extends keyof TLeftRecord +// ? TLeftRecord[TColumnName] +// : unknown +// : TTableName extends TRightName +// ? TColumnName extends keyof TRightRecord +// ? TRightRecord[TColumnName] +// : unknown +// : unknown +// : unknown; +// }; +// }; +// }; + +// innerJoinMany({ +// u: users.find(), +// p: posts.find(), +// }) +// .on() +// .select({ +// username: `u.username`, +// post_title: `p.title`, +// }); diff --git a/packages/pg-typed/src/v2/AliasedQuery.ts b/packages/pg-typed/src/v2/AliasedQuery.ts new file mode 100644 index 00000000..232add8a --- /dev/null +++ b/packages/pg-typed/src/v2/AliasedQuery.ts @@ -0,0 +1,9 @@ +import {Columns} from './types/Columns'; + +import SelectQuery from './SelectQuery'; +import {JoinableQueryLeft, JoinableQueryRight} from './types/JoinableQuery'; + +export default interface AliasedQuery + extends SelectQuery, + JoinableQueryRight, + JoinableQueryLeft<{[TKey in TAlias]: Columns}> {} diff --git a/packages/pg-typed/src/v2/GroupByQuery.ts b/packages/pg-typed/src/v2/GroupByQuery.ts new file mode 100644 index 00000000..6145456c --- /dev/null +++ b/packages/pg-typed/src/v2/GroupByQuery.ts @@ -0,0 +1,8 @@ +import {ProjectedSortedQuery} from './types/Queries'; +import {AggregatedSelectionSet} from './types/SelectionSet'; + +export default interface GroupByQuery { + selectAggregate( + aggregation: (column: TColumns) => AggregatedSelectionSet, + ): ProjectedSortedQuery; +} diff --git a/packages/pg-typed/src/v2/InsertQuery.ts b/packages/pg-typed/src/v2/InsertQuery.ts new file mode 100644 index 00000000..d933b680 --- /dev/null +++ b/packages/pg-typed/src/v2/InsertQuery.ts @@ -0,0 +1 @@ +export default interface InsertQuery {} diff --git a/packages/pg-typed/src/v2/PostgresOperators.ts b/packages/pg-typed/src/v2/PostgresOperators.ts new file mode 100644 index 00000000..7ec6725a --- /dev/null +++ b/packages/pg-typed/src/v2/PostgresOperators.ts @@ -0,0 +1,160 @@ +import {SQLQuery, sql} from '@databases/pg'; + +// https://www.postgresql.org/docs/15/sql-syntax-lexical.html#SQL-PRECEDENCE + +export interface BinaryInput { + left: SQLQuery; + right: SQLQuery; +} + +export interface OperatorDefinition { + readonly toSql: ( + input: TInput, + ctx: {parentOperatorPrecedence: number | null}, + ) => SQLQuery; + readonly precedence: number; + // readonly staticValue?: (input: TStaticValueInput) => boolean | null; +} + +function operatorDefinition( + toSql: (input: TInput) => SQLQuery, + precedence: number, + options: Omit, 'toSql' | 'precedence'> = {}, +): OperatorDefinition { + return { + toSql: (input, ctx) => { + const expression = toSql(input); + if ( + ctx.parentOperatorPrecedence !== null && + ctx.parentOperatorPrecedence <= precedence + ) { + return sql`(${expression})`; + } else { + return expression; + } + }, + precedence, + ...options, + }; +} + +export const OperatorDefinitions = { + // table/column name separator + TABLE_COLUMN: operatorDefinition( + (p: {tableAlias?: string; columnName: string}) => + p.tableAlias === undefined + ? sql.ident(p.columnName) + : sql.ident(p.tableAlias, p.columnName), + 1, + ), + + // PostgreSQL-style typecast + TYPECAST: operatorDefinition( + (p: {expression: SQLQuery; type: SQLQuery}) => + sql`${p.expression}::${p.type}`, + 2, + ), + + // array element selection + ARRAY_ELEMENT_SELECTION: operatorDefinition( + (p: {expression: SQLQuery; index: SQLQuery}) => + sql`${p.expression}[${p.index}]`, + 3, + ), + + // unary plus, unary minus + UNARY_PLUS: operatorDefinition((exp: SQLQuery) => sql`+${exp}`, 4), + UNARY_MINUS: operatorDefinition((exp: SQLQuery) => sql`-${exp}`, 4), + + // exponentiation + EXPONENTIATION: operatorDefinition( + (p: BinaryInput) => sql`${p.left}^${p.right}`, + 5, + ), + + // multiplication, division, modulo + MULTIPLICATION: operatorDefinition((p: SQLQuery[]) => sql.join(p, sql`*`), 6), + DIVISION: operatorDefinition((p: SQLQuery[]) => sql.join(p, sql`/`), 6), + MODULUS: operatorDefinition((p: SQLQuery[]) => sql.join(p, sql`%`), 6), + + // addition, subtraction + ADDITION: operatorDefinition((p: SQLQuery[]) => sql.join(p, sql`+`), 7), + SUBTRACTION: operatorDefinition((p: SQLQuery[]) => sql.join(p, sql`-`), 7), + + // ... any other operator ... + CUSTOM: operatorDefinition((expression: SQLQuery) => expression, 8), + + // range containment, set membership, string matching + BETWEEN: operatorDefinition( + (p: {expression: SQLQuery; lower: SQLQuery; upper: SQLQuery}) => + sql`${p.expression} BETWEEN ${p.lower} AND ${p.upper}`, + 9, + ), + IN: operatorDefinition( + (p: {expression: SQLQuery; set: SQLQuery}) => + sql`${p.expression} IN ${p.set}`, + 9, + ), + LIKE: operatorDefinition( + (p: BinaryInput) => sql`${p.left} LIKE ${p.right}`, + 9, + ), + ILIKE: operatorDefinition( + (p: BinaryInput) => sql`${p.left} ILIKE ${p.right}`, + 9, + ), + SIMILAR: operatorDefinition( + (p: { + expression: SQLQuery; + pattern: SQLQuery; + escapeCharacter?: SQLQuery; + }) => + p.escapeCharacter + ? sql`${p.expression} SIMILAR TO ${p.pattern} ESCAPE ${p.escapeCharacter}` + : sql`${p.expression} SIMILAR TO ${p.pattern}`, + 9, + ), + + // comparison operators + LT: operatorDefinition((p: BinaryInput) => sql`${p.left}<${p.right}`, 10), + GT: operatorDefinition((p: BinaryInput) => sql`${p.left}>${p.right}`, 10), + LTE: operatorDefinition((p: BinaryInput) => sql`${p.left}<=${p.right}`, 10), + GTE: operatorDefinition((p: BinaryInput) => sql`${p.left}>=${p.right}`, 10), + EQ: operatorDefinition((p: BinaryInput) => sql`${p.left}=${p.right}`, 10), + NEQ: operatorDefinition((p: BinaryInput) => sql`${p.left}<>${p.right}`, 10), + + // IS + IS_NULL: operatorDefinition((exp: SQLQuery) => sql`${exp} IS NULL`, 11), + IS_NOT_NULL: operatorDefinition( + (exp: SQLQuery) => sql`${exp} IS NOT NULL`, + 11, + ), + + // NOT + NOT: operatorDefinition((exp: SQLQuery) => sql`NOT ${exp}`, 12, { + // staticValue: (input: boolean | null) => { + // if (input === null) return null; + // return !input; + // }, + }), + + // AND + AND: operatorDefinition((parts: SQLQuery[]) => sql.join(parts, ` AND `), 13, { + // staticValue: (parts: (boolean | null)[]) => { + // if (parts.every((p) => p === true)) return true; + // if (parts.some((p) => p === false)) return false; + // return null; + // }, + }), + + // OR + OR: operatorDefinition((parts: SQLQuery[]) => sql.join(parts, ` OR `), 14, { + // staticValue: (parts: (boolean | null)[]) => { + // if (parts.some((p) => p === true)) return true; + // if (parts.every((p) => p === false)) return false; + // return null; + // }, + }), +} as const; + +export type Operator = keyof typeof OperatorDefinitions; diff --git a/packages/pg-typed/src/v2/ProjectedLimitQuery.ts b/packages/pg-typed/src/v2/ProjectedLimitQuery.ts new file mode 100644 index 00000000..43040c34 --- /dev/null +++ b/packages/pg-typed/src/v2/ProjectedLimitQuery.ts @@ -0,0 +1,22 @@ +import {SQLQuery} from '@databases/pg'; +import AliasedQuery from './AliasedQuery'; +import {TypedDatabaseQuery} from './types/TypedDatabaseQuery'; + +export default interface ProjectedLimitQuery + extends TypedDatabaseQuery { + /** + * Get the SQL query that would be executed. This is useful if you want to use this query as a sub-query in a query that is not type safe. + */ + toSql(): SQLQuery; + + /** + * If this is a complex query: + * Wrap the entire query in parentheses, and give it an alias. This lets you use joins, group by, etc. as sub-queries. + * + * If this is a simple query: + * Give the table an alias. This lets you use it in a join. + */ + as( + alias: TAliasTableName, + ): AliasedQuery; +} diff --git a/packages/pg-typed/src/v2/QueryImplementation.ts b/packages/pg-typed/src/v2/QueryImplementation.ts new file mode 100644 index 00000000..78e51d5d --- /dev/null +++ b/packages/pg-typed/src/v2/QueryImplementation.ts @@ -0,0 +1,655 @@ +import {Queryable, SQLQuery, sql} from '@databases/pg'; +import {aliasColumns, columns} from './implementation/Columns'; +import WhereCondition from './WhereCondition'; +import Value from './types/SpecialValues'; +import {Columns} from './types/Columns'; +import Operators, { + aliasTableInValue, + fieldConditionToPredicateValue, + valueToSelect, + valueToSql, +} from './implementation/Operators'; +import {AggregatedSelectionSet, SelectionSet} from './types/SelectionSet'; +import AliasedQuery from './AliasedQuery'; +import {TypedDatabaseQuery} from './types/TypedDatabaseQuery'; +import {escapePostgresIdentifier} from '@databases/escape-identifier'; + +import SelectQuery from './SelectQuery'; +import GroupByQuery from './GroupByQuery'; +import { + ProjectedDistinctColumnsQuery, + ProjectedDistinctQuery, + ProjectedLimitQuery, + ProjectedSortedQuery, + ProjectedQuery, +} from './types/Queries'; +import {JoinQueryBuilder, JoinQuery} from './types/Join'; +import { + InnerJoinedColumns, + JoinableQueryLeft, + JoinableQueryRight, + LeftOuterJoinedColumns, +} from './types/JoinableQuery'; + +const NO_RESULT_FOUND = `NO_RESULT_FOUND`; +const MULTIPLE_RESULTS_FOUND = `MULTIPLE_RESULTS_FOUND`; + +export default function createQuery( + tableName: string, + tableId: SQLQuery, + columns: Columns, +): SelectQuery { + return new SelectQueryImplementation({ + columns, + distinct: false, + distinctColumns: [], + groupBy: 0, + isAliased: false, + isJoin: false, + limit: null, + orderBy: [], + projection: null, + tableId, + tableName, + where: [], + }); +} + +interface CompleteQuery< + TAlias extends string, + TRecord, + TColumns, + TAliasedColumns, +> extends ProjectedQuery, + JoinableQueryLeft, + JoinableQueryRight { + where(condition: WhereCondition): this; + + select( + ...columnNames: TColumnNames + ): ProjectedQuery>; + select( + selection: (column: TColumns) => SelectionSet, + ): ProjectedQuery; + + groupBy( + ...columnNames: TColumnNames + ): GroupByQuery, TColumns>; + groupBy( + selection: (column: TColumns) => SelectionSet, + ): GroupByQuery; + + selectAggregate( + aggregation: (column: TColumns) => AggregatedSelectionSet, + ): ProjectedQuery; +} + +interface QueryConfig { + columns: TColumns; + distinct: boolean; + distinctColumns: readonly SQLQuery[]; + groupBy: number; + isAliased: boolean; + isJoin: boolean; + limit: number | null; + orderBy: readonly SQLQuery[]; + projection: Projection | null; + tableId: SQLQuery; + tableName: TAlias; + where: readonly Value[]; +} + +interface FinalQueryConfig { + query: SQLQuery; + isEmpty: boolean; +} + +abstract class FinalQuery implements TypedDatabaseQuery { + private readonly _q: FinalQueryConfig; + constructor(q: FinalQueryConfig) { + this._q = q; + } + async executeQuery(database: Queryable): Promise { + return this._prepareResults( + this._q.isEmpty ? [] : await database.query(this._q.query), + ); + } + protected _queryForError(): string { + return this._q.query.format({ + escapeIdentifier: escapePostgresIdentifier, + formatValue: () => ({placeholder: `?`, value: undefined}), + }).text; + } + protected abstract _prepareResults(results: TRecord[]): T; +} + +class FirstQuery extends FinalQuery { + protected _prepareResults(results: TRecord[]): TRecord | undefined { + if (!results.length) return undefined; + return results[0]; + } +} + +class OneQuery extends FinalQuery { + protected _prepareResults(results: TRecord[]): TRecord | undefined { + if (!results.length) return undefined; + if (results.length > 1) { + throw Object.assign( + new Error( + `More than one row matched this query but we only expected one: ${this._queryForError()}`, + ), + {code: MULTIPLE_RESULTS_FOUND}, + ); + } + return results[0]; + } +} + +class OneRequiredQuery extends FinalQuery { + protected _prepareResults(results: TRecord[]): TRecord { + if (!results.length) { + throw Object.assign( + new Error(`No results matched this query: ${this._queryForError()}.`), + {code: NO_RESULT_FOUND}, + ); + } + if (results.length > 1) { + throw Object.assign( + new Error( + `More than one row matched this query but we only expected one: ${this._queryForError()}`, + ), + {code: MULTIPLE_RESULTS_FOUND}, + ); + } + return results[0]; + } +} + +class SelectQueryImplementation< + TAlias extends string, + TRecord, + TColumns, + TAliasedColumns, +> implements CompleteQuery +{ + public readonly alias: TAlias; + private readonly _config: QueryConfig; + constructor(config: QueryConfig) { + this.alias = config.tableName; + this._config = config; + } + + private _query(): FinalQueryConfig { + const parts = [sql`SELECT`]; + + if (this._config.distinct) { + parts.push(sql`DISTINCT`); + } + if (this._config.distinctColumns.length) { + parts.push( + sql`DISTINCT ON (${sql.join(this._config.distinctColumns, `,`)})`, + ); + } + + parts.push( + sql`${this._config.projection?.query ?? sql`*`} FROM ${ + this._config.tableId + }`, + ); + + const whereCondition = Operators.and(...this._config.where); + if (whereCondition !== true) { + parts.push( + sql`WHERE ${valueToSql(whereCondition, { + toValue: (v) => v, + tableAlias: () => null, + parentOperatorPrecedence: null, + })}`, + ); + } + + if (this._config.groupBy !== 0) { + const groupByColumns: number[] = []; + for (let i = 0; i < this._config.groupBy; i++) { + groupByColumns.push(i + 1); + } + parts.push( + sql`GROUP BY ${sql.__dangerous__rawValue(groupByColumns.join(`,`))}`, + ); + } + + if (this._config.orderBy.length) { + parts.push(sql`ORDER BY ${sql.join(this._config.orderBy, `,`)}`); + } + + if (this._config.limit !== null) { + parts.push(sql`LIMIT ${this._config.limit}`); + } + + return { + query: sql.join(parts, sql` `), + isEmpty: whereCondition === false, + }; + } + + private _projectedQuery( + projection: Projection, + ): ProjectedQuery { + // Projecting the query (i.e. choosing a selection set) may change the column names and types, so + // we can no longer use the columns object from the schema, we must create a new one. + const preparedColumns = columns( + this._config.tableName, + projection.columnNames.map((n) => ({columnName: n})), + this._config.isAliased, + ); + return new SelectQueryImplementation< + TAlias, + TRecord, + Columns, + any + >({ + columns: preparedColumns, + distinct: this._config.distinct, + distinctColumns: this._config.distinctColumns, + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, + limit: this._config.limit, + orderBy: this._config.orderBy, + projection, + tableId: this._config.tableId, + tableName: this._config.tableName, + where: this._config.where, + }); + } + + toSql(): SQLQuery { + const {query} = this._query(); + return query; + } + + as(alias: TAlias): AliasedQuery { + if (!/^[a-z][a-z0-9_]*$/.test(alias)) { + throw new Error( + `Table aliases must start with a lower case letter and only contain letters, numbers and underscores`, + ); + } + if (this._config.isAliased) { + throw new Error(`Cannot alias a query that has already been aliased`); + } + + const { + columns, + distinct, + distinctColumns, + groupBy, + limit, + projection, + tableId, + tableName, + where, + orderBy, + isJoin, + } = this._config; + + const aliasedColumns = aliasColumns(alias, columns as Columns); + + if ( + distinct || + distinctColumns.length || + groupBy || + isJoin || + limit || + orderBy.length || + projection + ) { + return new SelectQueryImplementation({ + columns: aliasedColumns, + distinct: false, + distinctColumns: [], + groupBy: 0, + isAliased: true, + isJoin: false, + limit: null, + orderBy: [], + projection: null, + tableId: sql`(${this.toSql()}) AS ${sql.ident(alias)}`, + tableName: alias, + where: [], + }); + } + + return new SelectQueryImplementation({ + columns: aliasedColumns, + distinct: false, + distinctColumns: [], + groupBy: 0, + isAliased: true, + isJoin: false, + limit: null, + orderBy: [], + projection: null, + tableId: sql`${tableId} AS ${sql.ident(alias)}`, + tableName: alias, + where: where.map((c) => aliasTableInValue(tableName, alias, c)), + }); + } + + where(condition: WhereCondition): any { + return new SelectQueryImplementation({ + ...this._config, + where: [ + ...this._config.where, + ...(sql.isSqlQuery(condition) + ? [condition] + : typeof condition === 'function' + ? [condition(this._config.columns)] + : Object.entries(condition).map(([columnName, value]) => + fieldConditionToPredicateValue( + (this._config.columns as Columns)[ + columnName as keyof Columns + ], + value, + ), + )), + ], + }); + } + + select(...selection: any[]): ProjectedQuery { + return this._projectedQuery( + selection.length === 1 && typeof selection[0] === 'function' + ? selectionSetToProjection( + selection[0](this._config.columns) as SelectionSet, + ) + : columnNamesToProjection(selection), + ); + } + + selectAggregate( + aggregation: (column: TColumns) => AggregatedSelectionSet, + ): ProjectedQuery { + return this._projectedQuery( + selectionSetToProjection(aggregation(this._config.columns)), + ); + } + + groupBy(...selection: any[]): GroupByQuery { + return new GroupByQueryImplementation( + selection.length === 1 && typeof selection[0] === 'function' + ? selectionSetToProjection( + selection[0](this._config.columns) as SelectionSet, + ) + : columnNamesToProjection(selection), + this._config, + ); + } + + private _orderByColumn(columnName: keyof TRecord): SQLQuery { + if (this._config.projection) { + const index = this._config.projection.columnNames.indexOf( + columnName as string, + ); + if (index === -1) { + throw new Error(`Cannot find column: "${columnName as string}"`); + } + return sql.__dangerous__rawValue((index + 1).toString(10)); + } else { + return sql.ident(columnName); + } + } + private _orderByInternal( + columnName: keyof TRecord, + distinct: boolean, + direction: SQLQuery, + ): ProjectedDistinctColumnsQuery { + return new SelectQueryImplementation({ + ...this._config, + distinctColumns: distinct + ? [...this._config.distinctColumns, sql.ident(columnName)] + : this._config.distinctColumns, + orderBy: [ + ...this._config.orderBy, + sql`${this._orderByColumn(columnName)} ${direction}`, + ], + }); + } + + orderByAscDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery { + return this._orderByInternal(columnName, true, sql`ASC`); + } + orderByDescDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery { + return this._orderByInternal(columnName, true, sql`DESC`); + } + orderByAsc(columnName: keyof TRecord): ProjectedSortedQuery { + return this._orderByInternal(columnName, false, sql`ASC`); + } + orderByDesc(columnName: keyof TRecord): ProjectedSortedQuery { + return this._orderByInternal(columnName, false, sql`DESC`); + } + + distinct(): ProjectedDistinctQuery { + if (this._config.distinct || this._config.distinctColumns.length) { + throw new Error( + `Cannot call distinct() after a query has already been marked as distinct.`, + ); + } + return new SelectQueryImplementation({ + ...this._config, + distinct: true, + distinctColumns: [], + }); + } + limit(n: number): ProjectedLimitQuery { + return new SelectQueryImplementation({ + ...this._config, + limit: n, + }); + } + + innerJoin( + otherQuery: JoinableQueryRight, + ): JoinQueryBuilder< + InnerJoinedColumns + > { + if ( + !(otherQuery instanceof SelectQueryImplementation) || + !otherQuery._config.isAliased || + otherQuery._config.isJoin + ) { + throw new Error(`Right hand side of join is not valid.`); + } + return new JoinImplementation< + InnerJoinedColumns + >({ + ...this._config, + tableId: sql`${this._config.tableId} INNER JOIN ${otherQuery._config.tableId}`, + where: [...this._config.where, ...otherQuery._config.where], + columns: Object.assign( + {[otherQuery._config.tableName]: otherQuery._config.columns}, + this._config.isJoin + ? this._config.columns + : {[this._config.tableName]: this._config.columns}, + ) as any, + }); + } + + leftOuterJoin( + otherQuery: JoinableQueryRight, + ): JoinQueryBuilder< + LeftOuterJoinedColumns + > { + if ( + !(otherQuery instanceof SelectQueryImplementation) || + !otherQuery._config.isAliased || + otherQuery._config.isJoin + ) { + throw new Error(`Right hand side of join is not valid.`); + } + if (otherQuery._config.where.length) { + throw new Error( + `Right hand side of a LEFT OUTER JOIN cannot have a WHERE clause.`, + ); + } + return new JoinImplementation< + InnerJoinedColumns + >({ + ...this._config, + tableId: sql`${this._config.tableId} LEFT OUTER JOIN ${otherQuery._config.tableId}`, + columns: Object.assign( + {[otherQuery._config.tableName]: otherQuery._config.columns}, + this._config.isJoin + ? this._config.columns + : {[this._config.tableName]: this._config.columns}, + ) as any, + }); + } + + one(): TypedDatabaseQuery { + return new OneQuery(this._query()); + } + + oneRequired(): TypedDatabaseQuery { + return new OneRequiredQuery(this._query()); + } + + first(): TypedDatabaseQuery { + return new FirstQuery(this._query()); + } + + async executeQuery(database: Queryable): Promise { + const {query, isEmpty} = this._query(); + if (isEmpty) return []; + return await database.query(query); + } +} + +class GroupByQueryImplementation + implements GroupByQuery +{ + private readonly _groupByProjection: Projection; + private readonly _config: QueryConfig; + constructor( + groupByProjection: Projection, + config: QueryConfig, + ) { + this._groupByProjection = groupByProjection; + this._config = config; + } + selectAggregate( + aggregation: (column: TColumns) => AggregatedSelectionSet, + ): ProjectedSortedQuery { + const groupByProjection = this._groupByProjection; + const aggregatedProjection = selectionSetToProjection( + aggregation(this._config.columns), + ); + const columnNames = [ + ...groupByProjection.columnNames, + ...aggregatedProjection.columnNames, + ]; + + // Projecting the query (i.e. choosing a selection set) may change the column names and types, so + // we can no longer use the columns object from the schema, we must create a new one. + const rawColumns = columns( + this._config.tableName, + columnNames.map((n) => ({columnName: n})), + ); + const projectionParts: SQLQuery[] = []; + if (groupByProjection.columnNames.length) { + projectionParts.push(groupByProjection.query); + } + if (aggregatedProjection.columnNames.length) { + projectionParts.push(aggregatedProjection.query); + } + const projection = { + query: sql.join(projectionParts, `,`), + columnNames, + }; + return new SelectQueryImplementation< + TAlias, + TSelection & TAggregation, + Columns, + never + >({ + columns: rawColumns, + distinct: false, + distinctColumns: [], + groupBy: groupByProjection.columnNames.length, + isAliased: false, + isJoin: false, + limit: this._config.limit, + orderBy: [], + projection, + tableId: this._config.tableId, + tableName: this._config.tableName, + where: this._config.where, + }); + } +} + +class JoinImplementation implements JoinQueryBuilder { + private readonly _config: QueryConfig; + constructor(config: QueryConfig) { + this._config = config; + } + on(predicate: (column: TColumns) => Value): JoinQuery { + return new SelectQueryImplementation({ + ...this._config, + tableId: sql`${this._config.tableId} ON (${valueToSql( + predicate(this._config.columns), + { + toValue: (v) => v, + tableAlias: () => null, + parentOperatorPrecedence: null, + }, + )})`, + }); + } +} + +export interface Projection { + /** + * The SQL for the selection set. + * + * e.g. u.name AS user_name, u.email AS user_email, COUNT(*) AS post_count + */ + readonly query: SQLQuery; + /** + * The column names (in order) returned by this projection. + * + * e.g. ['user_name', 'user_email', 'post_count'] + */ + readonly columnNames: readonly string[]; +} + +function selectionSetToProjection( + ...selections: (SelectionSet | AggregatedSelectionSet)[] +): Projection { + const entries = selections.flatMap((selection) => Object.entries(selection)); + return { + query: sql.join( + entries.map(([alias, value]) => valueToSelect(alias, value)), + `,`, + ), + columnNames: entries.map(([alias]) => alias), + }; +} + +function columnNamesToProjection(columnNames: readonly string[]): Projection { + return { + query: sql.join( + columnNames.map((column) => { + if (typeof column !== 'string') { + throw new Error(`Expected column names to be strings.`); + } + return sql.ident(column); + }), + `,`, + ), + columnNames, + }; +} diff --git a/packages/pg-typed/src/v2/SelectQuery.ts b/packages/pg-typed/src/v2/SelectQuery.ts new file mode 100644 index 00000000..c7cc1e4b --- /dev/null +++ b/packages/pg-typed/src/v2/SelectQuery.ts @@ -0,0 +1,29 @@ +import {Columns} from './types/Columns'; +import {ProjectedQuery} from './types/Queries'; +import GroupByQuery from './GroupByQuery'; +import WhereCondition from './WhereCondition'; +import {AggregatedSelectionSet, SelectionSet} from './types/SelectionSet'; + +export default interface SelectQuery extends ProjectedQuery { + where(condition: WhereCondition>): this; + + select( + ...columnNames: TColumnNames + ): ProjectedQuery>; + select( + selection: (column: Columns) => SelectionSet, + ): ProjectedQuery; + + groupBy( + ...columnNames: TColumnNames + ): GroupByQuery, Columns>; + groupBy( + selection: (column: Columns) => SelectionSet, + ): GroupByQuery>; + + selectAggregate( + aggregation: ( + column: Columns, + ) => AggregatedSelectionSet, + ): ProjectedQuery; +} diff --git a/packages/pg-typed/src/v2/Table.ts b/packages/pg-typed/src/v2/Table.ts new file mode 100644 index 00000000..924caae0 --- /dev/null +++ b/packages/pg-typed/src/v2/Table.ts @@ -0,0 +1,61 @@ +import AliasedQuery from './AliasedQuery'; +import {Columns} from './types/Columns'; +import {ProjectedQuery} from './types/Queries'; +import GroupByQuery from './GroupByQuery'; +import InsertQuery from './InsertQuery'; +import SelectQuery, {WhereCondition, selectQuery} from './SelectQuery'; +import {AggregatedSelectionSet, SelectionSet} from './types/SelectionSet'; +import TableSchema from './TableSchema'; +import {TypedDatabaseQuery} from './types/TypedDatabaseQuery'; + +export interface Table + extends SelectQuery { + insert(...records: TInsertParameters[]): InsertQuery; +} + +class TableImplementation + implements Table +{ + private _table: TableSchema; + constructor(table: TableSchema) { + this._table = table; + } + + as( + alias: TAliasTableName, + ): AliasedQuery<{[TKey in TAliasTableName]: TRecord}> { + return selectQuery(this._table).as(alias); + } + + where(condition: WhereCondition): SelectQuery { + return selectQuery(this._table).where(condition); + } + + select( + ...columnNames: TColumnNames + ): ProjectedQuery>; + select( + selection: (column: Columns) => SelectionSet, + ): ProjectedQuery; + select(...args: any): any { + return selectQuery(this._table).select(...args); + } + + groupBy( + ...columnNames: TColumnNames + ): GroupByQuery, TRecord>; + groupBy( + selection: (column: Columns) => SelectionSet, + ): GroupByQuery; + groupBy(...args: any): any { + return selectQuery(this._table).groupBy(...args); + } + + selectAggregate( + aggregation: ( + column: Columns, + ) => AggregatedSelectionSet, + ): TypedDatabaseQuery { + return selectQuery(this._table).selectAggregate(aggregation); + } +} diff --git a/packages/pg-typed/src/v2/TableSchema.ts b/packages/pg-typed/src/v2/TableSchema.ts new file mode 100644 index 00000000..9598c44a --- /dev/null +++ b/packages/pg-typed/src/v2/TableSchema.ts @@ -0,0 +1,9 @@ +import {SQLQuery} from '@databases/pg'; +import {Columns} from './types/Columns'; + +export default interface TableSchema { + __getType(): TRecord; + tableName: string; + tableId: SQLQuery; + columns: Columns; +} diff --git a/packages/pg-typed/src/v2/WhereCondition.ts b/packages/pg-typed/src/v2/WhereCondition.ts new file mode 100644 index 00000000..d8bcf674 --- /dev/null +++ b/packages/pg-typed/src/v2/WhereCondition.ts @@ -0,0 +1,27 @@ +import {SQLQuery} from '@databases/pg'; +import Value, {FieldCondition} from './types/SpecialValues'; +import {Columns} from './types/Columns'; + +// TODO: this was for doing AND/OR with the simplified API - not sure if we should/can still support this +export interface WhereCombinedCondition { + readonly __isSpecialValue: true; + readonly __isWhereCombinedCondition: true; + readonly conditions: readonly WhereConditionObject[]; + readonly combiner: 'AND' | 'OR'; +} + +export type WhereConditionObject = + | Partial<{ + readonly [key in keyof TRecord]: + | TRecord[key] + | FieldCondition; + }> + | WhereCombinedCondition + | SQLQuery; + +export type WhereConditionFunction = (c: TColumns) => Value; + +type WhereCondition> = + | WhereConditionObject + | WhereConditionFunction; +export default WhereCondition; diff --git a/packages/pg-typed/src/v2/__tests__/index.test.ts b/packages/pg-typed/src/v2/__tests__/index.test.ts new file mode 100644 index 00000000..6f816867 --- /dev/null +++ b/packages/pg-typed/src/v2/__tests__/index.test.ts @@ -0,0 +1,74 @@ +import {SQLQuery, sql} from '@databases/pg'; +import {columns} from '../implementation/Columns'; +import createQuery from '../QueryImplementation'; +import {q} from '..'; +import {escapePostgresIdentifier} from '@databases/escape-identifier'; + +interface DbUser { + id: number; + username: string; +} +interface DbPost { + author_id: number; + title: string; + created_at: Date; +} + +const users = createQuery('users', sql`users`, columns(`users`)); +const posts = createQuery('posts', sql`posts`, columns(`posts`)); + +test(`q`, () => { + const join = users + .as(`u`) + .where((u) => q.eq(u.id, 10)) + .innerJoin(posts.as(`p`)) + .on(({u, p}) => q.eq(u.id, p.author_id)) + // .where(({u}) => q.eq(u.id, 10)) + .select(({u, p}) => ({ + id: u.id, + username: u.username, + title: p.title, + })); + expect(printQueryForTest(join)).toEqual( + `SELECT "u"."id","u"."username","p"."title" FROM users AS "u" INNER JOIN posts AS "p" ON ("u"."id"="p"."author_id") WHERE "u"."id"=\${ 10 }`, + ); + + const groupBy = users + .as(`u`) + .innerJoin(posts.as(`p`)) + .on(({u, p}) => q.eq(u.id, p.author_id)) + .groupBy(({u}) => ({ + id: u.id, + username: u.username, + })) + .selectAggregate(({p}) => ({ + last_posted_at: q.max(p.created_at), + total_count: q.count(), + })) + .orderByDesc(`last_posted_at`); + + expect(printQueryForTest(groupBy)).toEqual( + `SELECT "u"."id","u"."username",MAX("p"."created_at") AS "last_posted_at",COUNT(*) AS "total_count" FROM users AS "u" INNER JOIN posts AS "p" ON ("u"."id"="p"."author_id") GROUP BY 1,2 ORDER BY 3 DESC`, + ); + + const conditions = users + .where({id: 10}) + .where(sql`username='x' OR username='y'`) + .where((u) => + q.and(q.gt(u.id, 1), q.lt(u.id, 20), sql`id % 2 = 0 OR id % 3 = 0`), + ); + + expect(printQueryForTest(conditions)).toEqual( + `SELECT * FROM users WHERE "id"=\${ 10 } AND (username='x' OR username='y') AND "id">\${ 1 } AND "id"<\${ 20 } AND (id % 2 = 0 OR id % 3 = 0)`, + ); +}); + +function printQueryForTest(query: {toSql(): SQLQuery}) { + return query.toSql().format({ + escapeIdentifier: escapePostgresIdentifier, + formatValue: (value: unknown) => ({ + placeholder: '${ ' + JSON.stringify(value) + ' }', + value: undefined, + }), + }).text; +} diff --git a/packages/pg-typed/src/v2/implementation/Columns.ts b/packages/pg-typed/src/v2/implementation/Columns.ts new file mode 100644 index 00000000..971303dd --- /dev/null +++ b/packages/pg-typed/src/v2/implementation/Columns.ts @@ -0,0 +1,91 @@ +import {SQLQuery} from '@databases/pg'; +import {columnReference} from './Operators'; +import {Columns} from '../types/Columns'; + +const IS_PROXIED = Symbol('IS_PROXIED'); + +export function columns( + tableName: string, + schema?: { + columnName: string; + postgresTypeQuery?: SQLQuery; + postgresType?: string; + }[], + isAlias: boolean = false, +): Columns { + if (schema) { + return Object.fromEntries( + schema.map(({columnName, postgresTypeQuery, postgresType}) => [ + columnName, + columnReference( + tableName, + columnName, + isAlias, + postgresTypeQuery, + postgresType, + ), + ]), + ) as Columns; + } else { + return new Proxy( + {}, + { + get: (_target, columnName, _receiver) => { + if (columnName === IS_PROXIED) return true; + if (columnName === 'then' || typeof columnName !== 'string') { + return undefined; + } + return columnReference(tableName, columnName, isAlias); + }, + }, + ) as any; + } +} + +const cache = new Map, Columns>>(); +export function aliasColumns( + tableAlias: string, + columns: Columns, +): Columns { + let cachedAlias = cache.get(tableAlias); + if (!cachedAlias) { + cachedAlias = new WeakMap(); + cache.set(tableAlias, cachedAlias); + } + const cached = cachedAlias.get(columns); + if (cached) return cached; + const aliasedColumns = (columns as any)[IS_PROXIED] + ? aliasColumnsByProxy(tableAlias, columns) + : aliasColumnsWithPlainObject(tableAlias, columns); + cachedAlias.set(columns, aliasedColumns); + return aliasedColumns; +} + +function aliasColumnsByProxy( + tableAlias: string, + columns: Columns, +): Columns { + return new Proxy( + {}, + { + get: (_target, columnName, _receiver) => { + if (columnName === IS_PROXIED) return true; + const column = columns[columnName as keyof typeof columns]; + if (column === undefined) return column; + return column.setAlias(tableAlias); + }, + }, + ) as Columns; +} + +function aliasColumnsWithPlainObject( + tableAlias: string, + columns: Columns, +): Columns { + return Object.fromEntries( + Object.entries(columns).map(([columnName, column]) => [ + columnName, + (column as Columns[keyof Columns]).setAlias(tableAlias), + ]), + ) as Columns; +} diff --git a/packages/pg-typed/src/v2/implementation/Operators.ts b/packages/pg-typed/src/v2/implementation/Operators.ts new file mode 100644 index 00000000..f5948806 --- /dev/null +++ b/packages/pg-typed/src/v2/implementation/Operators.ts @@ -0,0 +1,797 @@ +import {SQLQuery, sql} from '@databases/pg'; +import Value, { + AggregatedTypedValue, + FieldCondition, + ComputedFieldCondition, + isSpecialValue, + NonAggregatedTypedValue, + FieldConditionToSqlContext, + RawValue, + AnyOf, + ValueToSqlContext, + BaseAggregatedTypedValue, + isComputedFieldQuery, + ComputedValue, + AggregatedValue, + isAnyOfCondition, +} from '../types/SpecialValues'; +import { + BinaryInput, + OperatorDefinition, + OperatorDefinitions, +} from '../PostgresOperators'; +import {IOperators, List} from '../types/Operators'; +import {ColumnReference} from '../types/Columns'; + +export function columnReference( + tableName: string, + columnName: string, + isAlias: boolean, + postgresTypeQuery?: SQLQuery, + postgresType?: string, +): Value { + return new ColumnReferenceImplementation( + tableName, + columnName, + isAlias, + postgresTypeQuery, + postgresType, + ); +} + +export function fieldConditionToPredicateValue( + column: ColumnReference, + f: FieldCondition, +): Value { + const constantValue = fieldConditionToConstant(f); + if (constantValue !== null) return constantValue; + return new FieldConditionValue(column, f); +} + +export function valueToSelect( + alias: string, + value: Value | AggregatedValue, +): SQLQuery { + if ( + value instanceof ColumnReferenceImplementation && + value.columnName === alias + ) { + return valueToSql(value, { + parentOperatorPrecedence: null, + toValue: (v) => v, + tableAlias: () => null, + }); + } + return sql`${valueToSql(value, { + parentOperatorPrecedence: null, + toValue: (v) => v, + tableAlias: () => null, + })} AS ${sql.ident(alias)}`; +} + +export function aliasTableInValue( + tableName: string, + tableAlias: string, + value: Value, +): Value { + if (!isSpecialValue(value)) { + return value; + } + return new AliasTableInValue(tableName, tableAlias, value); +} + +export function valueToSql( + value: Value | BaseAggregatedTypedValue, + ctx: ValueToSqlContext, +): SQLQuery { + if (isSpecialValue(value)) return value.toSql(ctx); + if (sql.isSqlQuery(value)) { + if (ctx.parentOperatorPrecedence !== null) { + return sql`(${value})`; + } + return value; + } + return sql.value(ctx.toValue(value)); +} + +const STAR = {}; // marker value for the "*" in COUNT(*) +const AGGREGATE_FUNCTIONS = { + MAX: sql`MAX`, + MIN: sql`MIN`, + SUM: sql`SUM`, + COUNT: sql`COUNT`, +}; + +const NON_AGGREGATE_FUNCTIONS = { + LOWER: sql`LOWER`, + UPPER: sql`UPPER`, +}; + +const ORDER_BY_DIRECTION = { + ASC: sql`ASC`, + DESC: sql`DESC`, +}; + +abstract class BaseExpression + implements BaseAggregatedTypedValue, NonAggregatedTypedValue +{ + public readonly __isSpecialValue = true; + public readonly __isAggregatedValue = true; + public readonly __isNonAggregatedComputedValue = true; + + public abstract toSql(ctx: ValueToSqlContext): SQLQuery; + public __getType(): T { + throw new Error( + `The "getType" function should not be called. It is only there to help TypeScript infer the correct type.`, + ); + } +} + +abstract class BaseFieldQuery implements ComputedFieldCondition { + public readonly __isSpecialValue = true; + public readonly __isFieldQuery = true; + public __getType(): T { + throw new Error( + `The "getType" function should not be called. It is only there to help TypeScript infer the correct type.`, + ); + } + public abstract getStaticValue(): boolean | null; + public abstract toSqlCondition(ctx: FieldConditionToSqlContext): SQLQuery; +} + +class ColumnReferenceImplementation + extends BaseExpression + implements ColumnReference +{ + public readonly tableName: string; + public readonly columnName: string; + public readonly isAlias: boolean; + + // TODO: make use of schema info + public readonly postgresTypeQuery: SQLQuery | undefined; + public readonly postgresType: string | undefined; + + constructor( + tableName: string, + columnName: string, + isAlias: boolean, + postgresTypeQuery: SQLQuery | undefined, + postgresType: string | undefined, + ) { + super(); + this.tableName = tableName; + this.columnName = columnName; + this.isAlias = isAlias; + this.postgresTypeQuery = postgresTypeQuery; + this.postgresType = postgresType; + } + public toSql(ctx: ValueToSqlContext): SQLQuery { + if (this.isAlias) { + return sql.ident(this.tableName, this.columnName); + } + const tableAlias = ctx.tableAlias(this.tableName); + if (tableAlias) { + return sql.ident(tableAlias, this.columnName); + } + return sql.ident(this.columnName); + } + public setAlias(tableAlias: string): ColumnReference { + return new ColumnReferenceImplementation( + tableAlias, + this.columnName, + true, + this.postgresTypeQuery, + this.postgresType, + ); + } +} + +class OperatorExpression< + TPreparedInput, + TInput, + TResult, +> extends BaseExpression { + public readonly op: OperatorDefinition; + public readonly input: TInput; + public readonly prepareInput: ( + input: TInput, + ctx: ValueToSqlContext, + ) => TPreparedInput; + + constructor( + op: OperatorDefinition, + input: TInput, + prepareInput: (input: TInput, ctx: ValueToSqlContext) => TPreparedInput, + ) { + super(); + this.op = op; + this.input = input; + this.prepareInput = prepareInput; + } + + public toSql(ctx: ValueToSqlContext): SQLQuery { + return this.op.toSql( + this.prepareInput(this.input, { + ...ctx, + parentOperatorPrecedence: this.op.precedence, + }), + ctx, + ); + } +} + +class OperatorFieldQuery< + TPreparedInput, + TInput, + TLeft, +> extends BaseFieldQuery { + public readonly op: OperatorDefinition; + public readonly input: TInput; + public readonly prepareInput: ( + input: TInput, + ctx: FieldConditionToSqlContext, + ) => TPreparedInput; + public readonly staticValue: boolean | null; + + constructor( + op: OperatorDefinition, + input: TInput, + prepareInput: ( + input: TInput, + ctx: FieldConditionToSqlContext, + ) => TPreparedInput, + staticValue: boolean | null, + ) { + super(); + this.op = op; + this.input = input; + this.prepareInput = prepareInput; + this.staticValue = staticValue; + } + + public getStaticValue(): boolean | null { + return this.staticValue; + } + + public toSqlCondition(ctx: FieldConditionToSqlContext): SQLQuery { + return this.op.toSql( + this.prepareInput(this.input, { + ...ctx, + parentOperatorPrecedence: this.op.precedence, + }), + ctx, + ); + } +} + +class NonAggregateFunction extends BaseExpression { + public readonly fn: keyof typeof NON_AGGREGATE_FUNCTIONS; + public readonly values: Value[]; + constructor(fn: keyof typeof NON_AGGREGATE_FUNCTIONS, values: Value[]) { + super(); + this.fn = fn; + this.values = values; + } + public toSql(ctx: ValueToSqlContext): SQLQuery { + return sql`${NON_AGGREGATE_FUNCTIONS[this.fn]}(${sql.join( + this.values.map((v) => + valueToSql(v, {...ctx, parentOperatorPrecedence: null}), + ), + `, `, + )})`; + } +} + +class AggregateFunction + extends BaseExpression + implements AggregatedTypedValue +{ + public readonly fn: keyof typeof AGGREGATE_FUNCTIONS; + public readonly values: Value[]; + public readonly condition: undefined | Value; + public readonly orderByClauses: { + direction: keyof typeof ORDER_BY_DIRECTION; + value: Value; + }[]; + public readonly isDistinct: boolean; + constructor( + fn: keyof typeof AGGREGATE_FUNCTIONS, + values: Value[], + condition?: Value, + orderBy?: Value[], + distinct?: boolean, + ) { + super(); + this.fn = fn; + this.values = values; + this.condition = condition; + this.orderByClauses = orderBy ?? []; + this.isDistinct = distinct ?? false; + } + public toSql(ctx: ValueToSqlContext): SQLQuery { + const fn = AGGREGATE_FUNCTIONS[this.fn]; + let args = sql.join( + this.values.map((v) => (v === STAR ? sql`*` : valueToSql(v, ctx))), + `, `, + ); + if (this.isDistinct) { + args = sql`DISTINCT ${args}`; + } + if (this.orderByClauses.length) { + args = sql`${args} ORDER BY ${sql.join( + this.orderByClauses.map( + (c) => + sql`${valueToSql(c.value, ctx)} ${ORDER_BY_DIRECTION[c.direction]}`, + ), + `, `, + )}`; + } + if (this.condition !== undefined) { + return sql`${fn}(${args}) FILTER (WHERE ${valueToSql( + this.condition, + ctx, + )})`; + } + return sql`${fn}(${args})`; + } + + public distinct(): AggregatedTypedValue { + return new AggregateFunction( + this.fn, + this.values, + this.condition, + this.orderByClauses, + true, + ); + } + + public orderByAsc( + value: Value, + ): AggregatedTypedValue { + return new AggregateFunction( + this.fn, + this.values, + this.condition, + [...this.orderByClauses, {direction: 'ASC', value}], + this.isDistinct, + ); + } + + public orderByDesc( + value: Value, + ): AggregatedTypedValue { + return new AggregateFunction( + this.fn, + this.values, + this.condition, + [...this.orderByClauses, {direction: 'DESC', value}], + this.isDistinct, + ); + } + + public filter(condition: Value): AggregatedTypedValue { + return new AggregateFunction( + this.fn, + this.values, + condition, + this.orderByClauses, + this.isDistinct, + ); + } +} + +class AllOf extends BaseFieldQuery { + public readonly values: List>; + constructor(values: List>) { + super(); + this.values = values; + } + public getStaticValue(): boolean | null { + const staticValues = [...this.values].map(fieldConditionToConstant); + if (staticValues.every((v) => v === true)) return true; + if (staticValues.some((v) => v === false)) return false; + return null; + } + + public toSqlCondition(ctx: FieldConditionToSqlContext): SQLQuery { + const parts: FieldCondition[] = [...this.values].filter( + (part) => fieldConditionToConstant(part) !== true, + ); + const partCount = parts.length; + if (partCount === 0) return sql`TRUE`; + + const childCtx: FieldConditionToSqlContext = { + ...ctx, + parentOperatorPrecedence: + partCount === 1 + ? ctx.parentOperatorPrecedence + : OperatorDefinitions.AND.precedence, + }; + + const sqlParts = parts.map((p) => fieldConditionToSql(p, childCtx)); + + if (sqlParts.length === 1) return sqlParts[0]; + + return OperatorDefinitions.AND.toSql(sqlParts, ctx); + } +} + +class AnyOfImplementation extends BaseFieldQuery implements AnyOf { + public readonly __isAnyOf = true; + public readonly values: Value>>; + constructor(values: Value>>) { + super(); + this.values = values; + } + public getStaticValue(): boolean | null { + if (isSpecialValue(this.values) || sql.isSqlQuery(this.values)) return null; + const staticValues = [...this.values].map(fieldConditionToConstant); + if (staticValues.some((v) => v === true)) return true; + if (staticValues.every((v) => v === false)) return false; + return null; + } + + public toSqlCondition(ctx: FieldConditionToSqlContext): SQLQuery { + if (isSpecialValue(this.values) || sql.isSqlQuery(this.values)) { + return OperatorDefinitions.EQ.toSql( + { + left: ctx.left, + right: sql`ALL(${valueToSql(this.values, { + ...ctx, + parentOperatorPrecedence: OperatorDefinitions.EQ.precedence, + })})`, + }, + ctx, + ); + } + + const values = new Set>(); + const parts: FieldCondition[] = []; + for (const value of this.values) { + if (fieldConditionToConstant(value) !== false) { + if (isSpecialValue(value)) { + parts.push(value); + } else { + values.add(value); + } + } + } + const partCount = parts.length + (values.size ? 1 : 0); + if (partCount === 0) return sql`FALSE`; + + const childCtx: FieldConditionToSqlContext = { + ...ctx, + parentOperatorPrecedence: + partCount === 1 + ? ctx.parentOperatorPrecedence + : OperatorDefinitions.OR.precedence, + }; + + const sqlParts = [ + ...parts.map((p) => fieldConditionToSql(p, childCtx)), + ...(values.size + ? [ + OperatorDefinitions.EQ.toSql( + { + left: ctx.left, + right: + values.size > 1 + ? sql`ANY(${[...values].map(ctx.toValue)})` + : sql.value(ctx.toValue([...values][0])), + }, + childCtx, + ), + ] + : []), + ]; + + if (sqlParts.length === 1) return sqlParts[0]; + + return OperatorDefinitions.OR.toSql(sqlParts, ctx); + } +} + +class EqualsAnyOf extends BaseExpression { + public readonly left: AggregatedValue | Value; + public readonly right: AnyOf; + + constructor( + op: OperatorDefinition, + left: AggregatedValue | Value, + right: AnyOf, + ) { + super(); + this.left = left; + this.right = right; + } + + public toSql(ctx: ValueToSqlContext): SQLQuery { + return this.right.toSqlCondition({ + ...ctx, + left: valueToSql(this.left, { + ...ctx, + parentOperatorPrecedence: OperatorDefinitions.EQ.precedence, + }), + }); + } +} + +class CaseInsensitive extends BaseFieldQuery { + public readonly value: FieldCondition; + constructor(value: FieldCondition) { + super(); + this.value = value; + } + public getStaticValue(): boolean | null { + return null; + } + + public toSqlCondition(ctx: FieldConditionToSqlContext): SQLQuery { + return fieldConditionToSql(this.value, { + ...ctx, + left: sql`LOWER(${ctx.left})`, + toValue: (value) => { + const v = ctx.toValue(value); + if (typeof v === 'string') return v.toLowerCase(); + return v; + }, + }); + } +} + +function prepareBinaryOperatorExpression( + { + left, + right, + }: { + left: AggregatedValue | Value; + right: AggregatedValue | Value; + }, + ctx: ValueToSqlContext, +) { + return {left: valueToSql(left, ctx), right: valueToSql(right, ctx)}; +} + +function prepareBinaryOperatorFieldQuery( + right: RawValue, + ctx: FieldConditionToSqlContext, +) { + return {left: ctx.left, right: sql.value(ctx.toValue(right))}; +} + +function binaryOperator( + operator: OperatorDefinition<{left: SQLQuery; right: SQLQuery}>, +): (leftOrOnly: any, right?: any) => any { + return ( + leftOrOnly: AggregatedValue | Value | RawValue, + right?: AggregatedValue | Value | AnyOf, + ): + | (Value & NonAggregatedTypedValue) + | FieldCondition => { + if (right === undefined) { + return new OperatorFieldQuery( + operator, + leftOrOnly as RawValue, + prepareBinaryOperatorFieldQuery, + null, + ); + } else if (isAnyOfCondition(right)) { + if (operator !== OperatorDefinitions.EQ) { + throw new Error( + `The only operator that can be used with "anyOf" is "eq".`, + ); + } + if ( + right instanceof AnyOfImplementation && + !isSpecialValue(right.values) && + !sql.isSqlQuery(right.values) && + [...right.values].length === 0 + ) { + // @ts-expect-error + return false; + } + return new EqualsAnyOf( + operator, + leftOrOnly as AggregatedValue | Value, + right, + ) as any; + } else { + return new OperatorExpression( + operator, + {left: leftOrOnly as AggregatedValue | Value, right}, + prepareBinaryOperatorExpression, + ); + } + }; +} + +function prepareVariadicOperatorInput( + inputs: (Value | AggregatedTypedValue)[], + ctx: ValueToSqlContext, +) { + return inputs.map((input) => valueToSql(input, ctx)); +} + +function variadicOperator( + operator: OperatorDefinition, + getConstantValue?: ( + ...params: (Value | AggregatedTypedValue)[] + ) => TStaticValue | undefined, +) { + return ( + ...params: (Value | AggregatedTypedValue)[] + ): + | TStaticValue + | (BaseAggregatedTypedValue & + NonAggregatedTypedValue) => { + const flatParams = params.flatMap((p) => { + if (p instanceof OperatorExpression && p.op === operator) { + return p.input as (Value | AggregatedTypedValue)[]; + } + return [p]; + }); + const constantValue = getConstantValue && getConstantValue(...flatParams); + if (constantValue !== undefined) return constantValue; + + return new OperatorExpression( + operator, + flatParams, + prepareVariadicOperatorInput, + ); + }; +} + +function nonAggregateFunction(fn: keyof typeof NON_AGGREGATE_FUNCTIONS) { + return ( + ...args: TArgs + ): ComputedValue => { + return new NonAggregateFunction(fn, args); + }; +} + +function aggregateFunction(fn: keyof typeof AGGREGATE_FUNCTIONS) { + return ( + ...args: TArgs + ): AggregatedTypedValue => { + return new AggregateFunction(fn, args); + }; +} + +function fieldConditionToSql( + value: FieldCondition, + ctx: FieldConditionToSqlContext, +) { + if (isSpecialValue(value)) return value.toSqlCondition(ctx); + return OperatorDefinitions.EQ.toSql( + { + left: ctx.left, + right: sql.value(ctx.toValue(value)), + }, + ctx, + ); +} + +function fieldConditionToConstant(q: FieldCondition): boolean | null { + if (isSpecialValue(q)) return q.getStaticValue(); + return null; +} + +function overload( + overloads: Record any>, + chooseOverload: (...args: any[]) => TKey, +) { + return (...args: any[]): any => { + return overloads[chooseOverload(...args)](...args); + }; +} + +class FieldConditionValue extends BaseExpression { + public readonly left: ColumnReference; + public readonly right: FieldCondition; + constructor(left: ColumnReference, right: FieldCondition) { + super(); + this.left = left; + this.right = right; + } + public toSql(ctx: ValueToSqlContext): SQLQuery { + return fieldConditionToSql(this.right, { + ...ctx, + left: valueToSql(this.left, ctx), + }); + } +} + +class AliasTableInValue extends BaseExpression { + public readonly tableName: string; + public readonly tableAlias: string; + public readonly value: Value; + constructor(tableName: string, tableAlias: string, value: Value) { + super(); + this.tableName = tableName; + this.tableAlias = tableAlias; + this.value = value; + } + public toSql(ctx: ValueToSqlContext): SQLQuery { + return valueToSql(this.value, { + ...ctx, + tableAlias: (tableName: string) => + tableName === this.tableName + ? this.tableAlias + : ctx.tableAlias(tableName), + }); + } +} + +const Operators: IOperators = { + allOf: (values) => new AllOf(values), + and: variadicOperator(OperatorDefinitions.AND, (...params) => { + if (params.every((p) => p === true)) return true; + if (params.some((p) => p === false)) return false; + return undefined; + }), + anyOf: (values) => new AnyOfImplementation(values), + caseInsensitive: (value) => new CaseInsensitive(value), + count(expression) { + return new AggregateFunction(`COUNT`, [expression ?? STAR]); + }, + eq: binaryOperator(OperatorDefinitions.EQ), + gt: binaryOperator(OperatorDefinitions.GT), + gte: binaryOperator(OperatorDefinitions.GTE), + ilike: binaryOperator(OperatorDefinitions.ILIKE), + like: binaryOperator(OperatorDefinitions.LIKE), + lower: nonAggregateFunction(`LOWER`), + lt: binaryOperator(OperatorDefinitions.LT), + lte: binaryOperator(OperatorDefinitions.LTE), + max: aggregateFunction(`MAX`), + min: aggregateFunction(`MIN`), + neq: binaryOperator(OperatorDefinitions.NEQ), + not: overload( + { + expression(value: Value): Value { + if (typeof value === 'boolean') return !value; + return new OperatorExpression( + OperatorDefinitions.NOT, + value, + valueToSql, + ); + }, + fieldQuery(value: FieldCondition): FieldCondition { + if (isSpecialValue(value)) { + const constantValueOfExpression = fieldConditionToConstant(value); + const constantValueOfNot = + constantValueOfExpression !== null + ? !constantValueOfExpression + : null; + return new OperatorFieldQuery( + OperatorDefinitions.NOT, + value, + fieldConditionToSql, + constantValueOfNot, + ); + } else { + return new OperatorFieldQuery( + OperatorDefinitions.NEQ, + value, + prepareBinaryOperatorFieldQuery, + null, + ); + } + }, + }, + (value: any) => { + return (isSpecialValue(value) && !isComputedFieldQuery(value)) || + sql.isSqlQuery(value) || + typeof value === 'boolean' + ? `expression` + : `fieldQuery`; + }, + ), + or: variadicOperator(OperatorDefinitions.OR, (...params) => { + if (params.some((p) => p === true)) return true; + if (params.every((p) => p === false)) return false; + return undefined; + }), + sum: aggregateFunction(`SUM`), +}; + +export default Operators; diff --git a/packages/pg-typed/src/v2/index.ts b/packages/pg-typed/src/v2/index.ts new file mode 100644 index 00000000..5f0b0065 --- /dev/null +++ b/packages/pg-typed/src/v2/index.ts @@ -0,0 +1,10 @@ +import Operators from './implementation/Operators'; +import Value from './types/SpecialValues'; +import {JoinQueryBuilder, JoinQuery} from './types/Join'; +import AliasedQuery from './AliasedQuery'; +import {Table} from './Table'; +import {IOperators} from './types/Operators'; + +export const q: IOperators = Operators; + +export type {AliasedQuery, JoinQueryBuilder, JoinQuery, Table, Value}; diff --git a/packages/pg-typed/src/v2/types/Columns.ts b/packages/pg-typed/src/v2/types/Columns.ts new file mode 100644 index 00000000..8ad1eee0 --- /dev/null +++ b/packages/pg-typed/src/v2/types/Columns.ts @@ -0,0 +1,42 @@ +import {SQLQuery} from '@databases/pg'; +import {NonAggregatedTypedValue} from './SpecialValues'; + +export interface ColumnReference extends NonAggregatedTypedValue { + readonly postgresTypeQuery?: SQLQuery; + readonly postgresType?: string; + setAlias(tableAlias: string): ColumnReference; +} + +export type Columns = { + [TColumnName in keyof TRecord]: ColumnReference; +}; + +export type JoinedColumns< + TLeftTables, + TRightAlias extends string, + TRightRecordColumns, +> = { + [TChildAlias in + | keyof TLeftTables + | TRightAlias]: TChildAlias extends keyof TLeftTables + ? TLeftTables[TChildAlias] + : TRightRecordColumns; +}; + +export type InnerJoinedColumns< + TLeftTables, + TRightAlias extends string, + TRightRecord, +> = JoinedColumns>; + +export type LeftOuterJoinedColumns< + TLeftTables, + TRightAlias extends string, + TRightRecord, +> = JoinedColumns< + TLeftTables, + TRightAlias, + Columns<{ + [TColumnName in keyof TRightRecord]: TRightRecord[TColumnName] | null; + }> +>; diff --git a/packages/pg-typed/src/v2/types/Join.ts b/packages/pg-typed/src/v2/types/Join.ts new file mode 100644 index 00000000..c2afbedf --- /dev/null +++ b/packages/pg-typed/src/v2/types/Join.ts @@ -0,0 +1,23 @@ +import GroupByQuery from '../GroupByQuery'; +import {JoinableQueryLeft} from './JoinableQuery'; +import {ProjectedQuery} from './Queries'; +import {AggregatedSelectionSet, SelectionSet} from './SelectionSet'; +import Value from './SpecialValues'; + +export interface JoinQueryBuilder { + on(predicate: (column: TColumns) => Value): JoinQuery; +} + +export interface JoinQuery extends JoinableQueryLeft { + select( + selection: (column: TColumns) => SelectionSet, + ): ProjectedQuery; + groupBy( + selection: (column: TColumns) => SelectionSet, + ): GroupByQuery; + selectAggregate( + aggregation: (column: TColumns) => AggregatedSelectionSet, + ): ProjectedQuery; + + where(predicate: (column: TColumns) => Value): JoinQuery; +} diff --git a/packages/pg-typed/src/v2/types/JoinableQuery.ts b/packages/pg-typed/src/v2/types/JoinableQuery.ts new file mode 100644 index 00000000..86cd8e2f --- /dev/null +++ b/packages/pg-typed/src/v2/types/JoinableQuery.ts @@ -0,0 +1,51 @@ +import {Columns} from './Columns'; +import {JoinQueryBuilder} from './Join'; +import ProjectedLimitQuery from '../ProjectedLimitQuery'; + +export type JoinedColumns< + TLeftTables, + TRightAlias extends string, + TRightRecordColumns, +> = { + [TChildAlias in + | keyof TLeftTables + | TRightAlias]: TChildAlias extends keyof TLeftTables + ? TLeftTables[TChildAlias] + : TRightRecordColumns; +}; + +export type InnerJoinedColumns< + TLeftTables, + TRightAlias extends string, + TRightRecord, +> = JoinedColumns>; + +export type LeftOuterJoinedColumns< + TLeftTables, + TRightAlias extends string, + TRightRecord, +> = JoinedColumns< + TLeftTables, + TRightAlias, + Columns<{ + [TColumnName in keyof TRightRecord]: TRightRecord[TColumnName] | null; + }> +>; + +export interface JoinableQueryLeft { + innerJoin( + otherQuery: JoinableQueryRight, + ): JoinQueryBuilder< + InnerJoinedColumns + >; + leftOuterJoin( + otherQuery: JoinableQueryRight, + ): JoinQueryBuilder< + LeftOuterJoinedColumns + >; +} + +export interface JoinableQueryRight + extends ProjectedLimitQuery { + alias: TAlias; +} diff --git a/packages/pg-typed/src/v2/types/Operators.ts b/packages/pg-typed/src/v2/types/Operators.ts new file mode 100644 index 00000000..5721627b --- /dev/null +++ b/packages/pg-typed/src/v2/types/Operators.ts @@ -0,0 +1,42 @@ +import Value, { + AggregatedTypedValue, + FieldCondition, + RawValue, + AnyOf, +} from './SpecialValues'; + +export interface List { + [Symbol.iterator](): IterableIterator; +} + +// prettier-ignore +export interface IOperators { + allOf(values: List>): FieldCondition; + and(...values: Value[]): Value; + anyOf(values: Value>>): AnyOf; + caseInsensitive: (value: FieldCondition) => FieldCondition; + count(expression?: Value): AggregatedTypedValue; + eq(left: Value, right: Value | AnyOf): Value; + gt(right: RawValue): FieldCondition; + gt(left: Value, right: Value): Value; + gte(right: RawValue): FieldCondition; + gte(left: Value, right: Value): Value; + neq(left: Value, right: Value): Value; + ilike(right: string): FieldCondition; + ilike(left: Value, right: Value): Value; + // TODO: IN should probably have an SQL query as the right hand side + // in(left: Value, right: Value): Value; + like(right: string): FieldCondition; + like(left: Value, right: Value): Value; + lower(value: Value): Value; + lt(right: RawValue): FieldCondition; + lt(left: Value, right: Value): Value; + lte(right: RawValue): FieldCondition; + lte(left: Value, right: Value): Value; + max(value: Value): AggregatedTypedValue; + min(value: Value): AggregatedTypedValue; + not(value: Value): Value; + not(value: FieldCondition): FieldCondition; + or(...values: Value[]): Value; + sum(value: Value): AggregatedTypedValue; +} diff --git a/packages/pg-typed/src/v2/types/Queries.ts b/packages/pg-typed/src/v2/types/Queries.ts new file mode 100644 index 00000000..b8b7c535 --- /dev/null +++ b/packages/pg-typed/src/v2/types/Queries.ts @@ -0,0 +1,70 @@ +import {SQLQuery} from '@databases/pg'; +import AliasedQuery from '../AliasedQuery'; +import {TypedDatabaseQuery} from './TypedDatabaseQuery'; + +export interface ProjectedLimitQuery + extends TypedDatabaseQuery { + /** + * Get the SQL query that would be executed. This is useful if you want to use this query as a sub-query in a query that is not type safe. + */ + toSql(): SQLQuery; + + /** + * If this is a complex query: + * Wrap the entire query in parentheses, and give it an alias. This lets you use joins, group by, etc. as sub-queries. + * + * If this is a simple query: + * Give the table an alias. This lets you use it in a join. + */ + as( + alias: TAliasTableName, + ): AliasedQuery; +} + +export interface ProjectedDistinctQuery + extends ProjectedLimitQuery { + /** + * If the query returns exactly one row, it is returned. + * Throws an error if multiple rows are returned by the query. + * Returns undefined if no rows are returned by the query. + */ + one(): TypedDatabaseQuery; + + /** + * If the query returns exactly one row, it is returned. + * Throws an error if multiple rows are returned by the query. + * Throws an error if no rows are returned by the query. + */ + oneRequired(): TypedDatabaseQuery; + + /** + * Returns the first row, or undefined if there are no rows. + * This will automatically add `LIMIT 1` to the query. + * + * If you want the raw SQL query, you should call `.limit(1).toSql()` instead. + */ + first(): TypedDatabaseQuery; + + limit(count: number): ProjectedLimitQuery; +} + +export interface ProjectedSortedQuery + extends ProjectedDistinctQuery { + orderByAsc(columnName: keyof TRecord): ProjectedSortedQuery; + orderByDesc(columnName: keyof TRecord): ProjectedSortedQuery; +} + +export interface ProjectedDistinctColumnsQuery + extends ProjectedSortedQuery { + orderByAscDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery; + orderByDescDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery; +} + +export interface ProjectedQuery + extends ProjectedDistinctColumnsQuery { + distinct(): ProjectedDistinctQuery; +} diff --git a/packages/pg-typed/src/v2/types/SelectionSet.ts b/packages/pg-typed/src/v2/types/SelectionSet.ts new file mode 100644 index 00000000..a88fc719 --- /dev/null +++ b/packages/pg-typed/src/v2/types/SelectionSet.ts @@ -0,0 +1,9 @@ +import Value, {AggregatedValue} from './SpecialValues'; + +export type SelectionSet = { + [key in keyof TSelection]: Value; +}; + +export type AggregatedSelectionSet = { + [key in keyof TSelection]: AggregatedValue; +}; diff --git a/packages/pg-typed/src/v2/types/SpecialValues.ts b/packages/pg-typed/src/v2/types/SpecialValues.ts new file mode 100644 index 00000000..4e874a53 --- /dev/null +++ b/packages/pg-typed/src/v2/types/SpecialValues.ts @@ -0,0 +1,94 @@ +import {SQLQuery} from '@databases/pg'; + +export type BooleanCombiner = 'AND' | 'OR'; + +export type RawValue = T extends SQLQuery + ? never + : T extends {readonly __isSpecialValue: true} + ? never + : T; + +export interface ValueToSqlContext { + readonly tableAlias: (tableName: string) => string | null; + readonly toValue: (value: RawValue) => unknown; + readonly parentOperatorPrecedence: number | null; +} + +export interface NonAggregatedTypedValue { + readonly __isSpecialValue: true; + readonly __isNonAggregatedComputedValue: true; + __getType(): T; + toSql(ctx: ValueToSqlContext): SQLQuery; +} + +export interface BaseAggregatedTypedValue { + readonly __isSpecialValue: true; + readonly __isAggregatedValue: true; + __getType(): T; + toSql(ctx: ValueToSqlContext): SQLQuery; +} + +export interface FieldConditionToSqlContext extends ValueToSqlContext { + readonly left: SQLQuery; +} + +export interface ComputedFieldCondition { + readonly __isSpecialValue: true; + readonly __isFieldQuery: true; + __getType(): T; + getStaticValue(): boolean | null; + toSqlCondition(ctx: FieldConditionToSqlContext): SQLQuery; +} + +export interface AggregatedTypedValue extends BaseAggregatedTypedValue { + distinct(): AggregatedTypedValue; + orderByAsc(value: Value): AggregatedTypedValue; + orderByDesc(value: Value): AggregatedTypedValue; + filter(condition: Value): BaseAggregatedTypedValue; +} + +export interface AnyOf extends ComputedFieldCondition { + readonly __isSpecialValue: true; + readonly __isAnyOf: true; + __getType(): T; +} + +export type ComputedValue = SQLQuery | NonAggregatedTypedValue; +export type AggregatedValue = SQLQuery | BaseAggregatedTypedValue; + +type Value = RawValue | ComputedValue; +export default Value; + +export function isSpecialValue( + value: unknown, +): value is {readonly __isSpecialValue: true} { + return ( + typeof value === 'object' && + value !== null && + '__isSpecialValue' in value && + value.__isSpecialValue === true + ); +} +export function isComputedFieldQuery(value: unknown): value is { + readonly __isSpecialValue: true; + readonly __isFieldQuery: true; +} { + return isSpecialValue(value) && (value as any).__isFieldQuery === true; +} +export function isNonAggregatedComputedValue(value: unknown): value is { + readonly __isSpecialValue: true; + readonly __isNonAggregatedComputedValue: true; +} { + return ( + isSpecialValue(value) && + (value as any).__isNonAggregatedComputedValue === true + ); +} +export function isAnyOfCondition(value: unknown): value is { + readonly __isSpecialValue: true; + readonly __isAnyOf: true; +} { + return isSpecialValue(value) && (value as any).__isAnyOf === true; +} + +export type FieldCondition = RawValue | ComputedFieldCondition; diff --git a/packages/pg-typed/src/v2/types/TypedDatabaseQuery.ts b/packages/pg-typed/src/v2/types/TypedDatabaseQuery.ts new file mode 100644 index 00000000..5be12207 --- /dev/null +++ b/packages/pg-typed/src/v2/types/TypedDatabaseQuery.ts @@ -0,0 +1,9 @@ +import {SQLQuery} from '@databases/pg'; + +export interface Queryable { + query(query: SQLQuery): Promise; +} + +export interface TypedDatabaseQuery { + executeQuery(database: Queryable): Promise; +} diff --git a/packages/sql/src/web.ts b/packages/sql/src/web.ts index 69ae049a..8ce48a43 100644 --- a/packages/sql/src/web.ts +++ b/packages/sql/src/web.ts @@ -10,7 +10,7 @@ export enum SQLItemType { export type SQLItem = | {type: SQLItemType.RAW; text: string} | {type: SQLItemType.VALUE; value: any} - | {type: SQLItemType.IDENTIFIER; names: Array}; + | {type: SQLItemType.IDENTIFIER; names: readonly any[]}; export interface FormatConfig { escapeIdentifier: (str: string) => string; @@ -57,9 +57,9 @@ class SQLQuery { */ public static query( strings: TemplateStringsArray, - ...values: Array + ...values: readonly any[] ): SQLQuery { - const items: Array = []; + const items: SQLItem[] = []; // Add all of the strings as raw items and values as placeholder values. for (let i = 0; i < strings.length; i++) { @@ -114,7 +114,7 @@ class SQLQuery { * separator was defined. */ public static join( - queries: Array, + queries: readonly SQLQuery[], separator?: LiteralSeparator | SQLQuery, ): SQLQuery { if (typeof separator === 'string' && !literalSeparators.has(separator)) { @@ -126,7 +126,7 @@ class SQLQuery { .join(', ')}`, ); } - const items: Array = []; + const items: SQLItem[] = []; const separatorItems: readonly SQLItem[] | undefined = separator ? typeof separator === 'string' ? [{type: SQLItemType.RAW, text: separator}] @@ -176,7 +176,7 @@ class SQLQuery { * Creates an identifier query. Each name will be escaped, and the * names will be concatenated with a period (`.`). */ - public static ident(...names: Array): SQLQuery { + public static ident(...names: readonly any[]): SQLQuery { return new SQLQuery([{type: SQLItemType.IDENTIFIER, names}]); } From b63ecb475adc0bd185f31f70e61aed9f10f6fb3c Mon Sep 17 00:00:00 2001 From: Forbes Lindesay Date: Thu, 18 May 2023 15:40:44 +0100 Subject: [PATCH 3/5] feat: more functionality --- .../pg-typed/src/v2/QueryImplementation.ts | 219 ++++-- packages/pg-typed/src/v2/WhereCondition.ts | 27 +- .../pg-typed/src/v2/__tests__/index.test.ts | 233 +++++- .../pg-typed/src/v2/implementation/Columns.ts | 16 +- .../src/v2/implementation/Operators.ts | 669 +++++++++++++----- packages/pg-typed/src/v2/index.ts | 10 +- packages/pg-typed/src/v2/types/Columns.ts | 4 +- packages/pg-typed/src/v2/types/Join.ts | 10 +- packages/pg-typed/src/v2/types/Operators.ts | 66 +- .../pg-typed/src/v2/types/SelectionSet.ts | 4 +- .../pg-typed/src/v2/types/SpecialValues.ts | 45 +- 11 files changed, 970 insertions(+), 333 deletions(-) diff --git a/packages/pg-typed/src/v2/QueryImplementation.ts b/packages/pg-typed/src/v2/QueryImplementation.ts index 78e51d5d..701084ce 100644 --- a/packages/pg-typed/src/v2/QueryImplementation.ts +++ b/packages/pg-typed/src/v2/QueryImplementation.ts @@ -1,7 +1,11 @@ import {Queryable, SQLQuery, sql} from '@databases/pg'; import {aliasColumns, columns} from './implementation/Columns'; import WhereCondition from './WhereCondition'; -import Value from './types/SpecialValues'; +import { + FieldCondition, + NonAggregatedValue, + isSpecialValue, +} from './types/SpecialValues'; import {Columns} from './types/Columns'; import Operators, { aliasTableInValue, @@ -96,7 +100,7 @@ interface QueryConfig { projection: Projection | null; tableId: SQLQuery; tableName: TAlias; - where: readonly Value[]; + where: readonly NonAggregatedValue[]; } interface FinalQueryConfig { @@ -336,23 +340,36 @@ class SelectQueryImplementation< } where(condition: WhereCondition): any { + const where = [ + ...this._config.where, + ...(sql.isSqlQuery(condition) || + isSpecialValue(condition) || + typeof condition === 'boolean' + ? [condition] + : typeof condition === 'function' + ? [condition(this._config.columns)] + : Object.entries(condition).map(([columnName, value]) => + fieldConditionToPredicateValue( + (this._config.columns as Columns)[ + columnName as keyof Columns + ], + value as FieldCondition, + ), + )), + ]; return new SelectQueryImplementation({ - ...this._config, - where: [ - ...this._config.where, - ...(sql.isSqlQuery(condition) - ? [condition] - : typeof condition === 'function' - ? [condition(this._config.columns)] - : Object.entries(condition).map(([columnName, value]) => - fieldConditionToPredicateValue( - (this._config.columns as Columns)[ - columnName as keyof Columns - ], - value, - ), - )), - ], + columns: this._config.columns, + distinct: this._config.distinct, + distinctColumns: this._config.distinctColumns, + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, + limit: this._config.limit, + orderBy: this._config.orderBy, + projection: this._config.projection, + tableId: this._config.tableId, + tableName: this._config.tableName, + where, }); } @@ -403,15 +420,26 @@ class SelectQueryImplementation< distinct: boolean, direction: SQLQuery, ): ProjectedDistinctColumnsQuery { + const distinctColumns = distinct + ? [...this._config.distinctColumns, sql.ident(columnName)] + : this._config.distinctColumns; + const orderBy = [ + ...this._config.orderBy, + sql`${this._orderByColumn(columnName)} ${direction}`, + ]; return new SelectQueryImplementation({ - ...this._config, - distinctColumns: distinct - ? [...this._config.distinctColumns, sql.ident(columnName)] - : this._config.distinctColumns, - orderBy: [ - ...this._config.orderBy, - sql`${this._orderByColumn(columnName)} ${direction}`, - ], + columns: this._config.columns, + distinct: this._config.distinct, + distinctColumns, + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, + limit: this._config.limit, + orderBy, + projection: this._config.projection, + tableId: this._config.tableId, + tableName: this._config.tableName, + where: this._config.where, }); } @@ -439,15 +467,34 @@ class SelectQueryImplementation< ); } return new SelectQueryImplementation({ - ...this._config, + columns: this._config.columns, distinct: true, distinctColumns: [], + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, + limit: this._config.limit, + orderBy: this._config.orderBy, + projection: this._config.projection, + tableId: this._config.tableId, + tableName: this._config.tableName, + where: this._config.where, }); } limit(n: number): ProjectedLimitQuery { return new SelectQueryImplementation({ - ...this._config, + columns: this._config.columns, + distinct: this._config.distinct, + distinctColumns: this._config.distinctColumns, + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, limit: n, + orderBy: this._config.orderBy, + projection: this._config.projection, + tableId: this._config.tableId, + tableName: this._config.tableName, + where: this._config.where, }); } @@ -463,19 +510,32 @@ class SelectQueryImplementation< ) { throw new Error(`Right hand side of join is not valid.`); } + const tableId = sql`${this._config.tableId} INNER JOIN ${otherQuery._config.tableId}`; + const columns: any = Object.assign( + {[otherQuery._config.tableName]: otherQuery._config.columns}, + this._config.isJoin + ? this._config.columns + : {[this._config.tableName]: this._config.columns}, + ); return new JoinImplementation< InnerJoinedColumns - >({ - ...this._config, - tableId: sql`${this._config.tableId} INNER JOIN ${otherQuery._config.tableId}`, - where: [...this._config.where, ...otherQuery._config.where], - columns: Object.assign( - {[otherQuery._config.tableName]: otherQuery._config.columns}, - this._config.isJoin - ? this._config.columns - : {[this._config.tableName]: this._config.columns}, - ) as any, - }); + >( + { + columns, + distinct: this._config.distinct, + distinctColumns: this._config.distinctColumns, + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, + limit: this._config.limit, + orderBy: this._config.orderBy, + projection: this._config.projection, + tableId, + tableName: this._config.tableName, + where: this._config.where, + }, + otherQuery._config.where, + ); } leftOuterJoin( @@ -490,23 +550,32 @@ class SelectQueryImplementation< ) { throw new Error(`Right hand side of join is not valid.`); } - if (otherQuery._config.where.length) { - throw new Error( - `Right hand side of a LEFT OUTER JOIN cannot have a WHERE clause.`, - ); - } + const tableId = sql`${this._config.tableId} LEFT OUTER JOIN ${otherQuery._config.tableId}`; + const columns: any = Object.assign( + {[otherQuery._config.tableName]: otherQuery._config.columns}, + this._config.isJoin + ? this._config.columns + : {[this._config.tableName]: this._config.columns}, + ); return new JoinImplementation< InnerJoinedColumns - >({ - ...this._config, - tableId: sql`${this._config.tableId} LEFT OUTER JOIN ${otherQuery._config.tableId}`, - columns: Object.assign( - {[otherQuery._config.tableName]: otherQuery._config.columns}, - this._config.isJoin - ? this._config.columns - : {[this._config.tableName]: this._config.columns}, - ) as any, - }); + >( + { + columns, + distinct: this._config.distinct, + distinctColumns: this._config.distinctColumns, + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, + limit: this._config.limit, + orderBy: this._config.orderBy, + projection: this._config.projection, + tableId, + tableName: this._config.tableName, + where: this._config.where, + }, + otherQuery._config.where, + ); } one(): TypedDatabaseQuery { @@ -593,20 +662,38 @@ class GroupByQueryImplementation class JoinImplementation implements JoinQueryBuilder { private readonly _config: QueryConfig; - constructor(config: QueryConfig) { + private readonly _rightWhere: readonly NonAggregatedValue[]; + constructor( + config: QueryConfig, + rightWhere: readonly NonAggregatedValue[], + ) { this._config = config; - } - on(predicate: (column: TColumns) => Value): JoinQuery { + this._rightWhere = rightWhere; + } + on( + predicate: (column: TColumns) => NonAggregatedValue, + ): JoinQuery { + const tableId = sql`${this._config.tableId} ON (${valueToSql( + Operators.and(predicate(this._config.columns), ...this._rightWhere), + { + toValue: (v) => v, + tableAlias: () => null, + parentOperatorPrecedence: null, + }, + )})`; return new SelectQueryImplementation({ - ...this._config, - tableId: sql`${this._config.tableId} ON (${valueToSql( - predicate(this._config.columns), - { - toValue: (v) => v, - tableAlias: () => null, - parentOperatorPrecedence: null, - }, - )})`, + columns: this._config.columns, + distinct: this._config.distinct, + distinctColumns: this._config.distinctColumns, + groupBy: this._config.groupBy, + isAliased: this._config.isAliased, + isJoin: this._config.isJoin, + limit: this._config.limit, + orderBy: this._config.orderBy, + projection: this._config.projection, + tableId, + tableName: this._config.tableName, + where: this._config.where, }); } } diff --git a/packages/pg-typed/src/v2/WhereCondition.ts b/packages/pg-typed/src/v2/WhereCondition.ts index d8bcf674..9e64423e 100644 --- a/packages/pg-typed/src/v2/WhereCondition.ts +++ b/packages/pg-typed/src/v2/WhereCondition.ts @@ -1,27 +1,16 @@ -import {SQLQuery} from '@databases/pg'; -import Value, {FieldCondition} from './types/SpecialValues'; +import {NonAggregatedValue, FieldCondition} from './types/SpecialValues'; import {Columns} from './types/Columns'; -// TODO: this was for doing AND/OR with the simplified API - not sure if we should/can still support this -export interface WhereCombinedCondition { - readonly __isSpecialValue: true; - readonly __isWhereCombinedCondition: true; - readonly conditions: readonly WhereConditionObject[]; - readonly combiner: 'AND' | 'OR'; -} +export type WhereConditionObject = { + readonly [key in keyof TRecord]?: FieldCondition; +}; -export type WhereConditionObject = - | Partial<{ - readonly [key in keyof TRecord]: - | TRecord[key] - | FieldCondition; - }> - | WhereCombinedCondition - | SQLQuery; - -export type WhereConditionFunction = (c: TColumns) => Value; +export type WhereConditionFunction = ( + c: TColumns, +) => NonAggregatedValue; type WhereCondition> = + | NonAggregatedValue | WhereConditionObject | WhereConditionFunction; export default WhereCondition; diff --git a/packages/pg-typed/src/v2/__tests__/index.test.ts b/packages/pg-typed/src/v2/__tests__/index.test.ts index 6f816867..e19f0b7e 100644 --- a/packages/pg-typed/src/v2/__tests__/index.test.ts +++ b/packages/pg-typed/src/v2/__tests__/index.test.ts @@ -1,12 +1,14 @@ -import {SQLQuery, sql} from '@databases/pg'; +import {sql} from '@databases/pg'; import {columns} from '../implementation/Columns'; import createQuery from '../QueryImplementation'; import {q} from '..'; import {escapePostgresIdentifier} from '@databases/escape-identifier'; +import {ProjectedLimitQuery} from '../types/Queries'; interface DbUser { id: number; username: string; + profile_image_url: string | null; } interface DbPost { author_id: number; @@ -17,22 +19,60 @@ interface DbPost { const users = createQuery('users', sql`users`, columns(`users`)); const posts = createQuery('posts', sql`posts`, columns(`posts`)); -test(`q`, () => { - const join = users +const testFormat = { + escapeIdentifier: escapePostgresIdentifier, + formatValue: (value: unknown) => ({ + placeholder: '${ ' + JSON.stringify(value) + ' }', + value: undefined, + }), +}; + +test(`INNER JOIN`, () => { + const joinWithWhereBeforeJoin = users .as(`u`) .where((u) => q.eq(u.id, 10)) .innerJoin(posts.as(`p`)) .on(({u, p}) => q.eq(u.id, p.author_id)) - // .where(({u}) => q.eq(u.id, 10)) .select(({u, p}) => ({ id: u.id, username: u.username, title: p.title, })); - expect(printQueryForTest(join)).toEqual( + const joinWithWhereAfterJoin = users + .as(`u`) + .innerJoin(posts.as(`p`)) + .on(({u, p}) => q.eq(u.id, p.author_id)) + .where(({u}) => q.eq(u.id, 10)) + .select(({u, p}) => ({ + id: u.id, + username: u.username, + title: p.title, + })); + expect( + printQueryForTest<{ + id: number; + username: string; + title: string; + }>(joinWithWhereBeforeJoin), + ).toEqual( `SELECT "u"."id","u"."username","p"."title" FROM users AS "u" INNER JOIN posts AS "p" ON ("u"."id"="p"."author_id") WHERE "u"."id"=\${ 10 }`, ); + expect( + printQueryForTest<{ + id: number; + username: string; + title: string; + }>(joinWithWhereAfterJoin), + ).toEqual( + printQueryForTest<{ + id: number; + username: string; + title: string; + }>(joinWithWhereBeforeJoin), + ); +}); +test(`group by`, () => { const groupBy = users .as(`u`) .innerJoin(posts.as(`p`)) @@ -47,10 +87,19 @@ test(`q`, () => { })) .orderByDesc(`last_posted_at`); - expect(printQueryForTest(groupBy)).toEqual( - `SELECT "u"."id","u"."username",MAX("p"."created_at") AS "last_posted_at",COUNT(*) AS "total_count" FROM users AS "u" INNER JOIN posts AS "p" ON ("u"."id"="p"."author_id") GROUP BY 1,2 ORDER BY 3 DESC`, + expect( + printQueryForTest<{ + id: number; + username: string; + last_posted_at: Date; + total_count: number; + }>(groupBy), + ).toEqual( + `SELECT "u"."id","u"."username",MAX("p"."created_at") AS "last_posted_at",(COUNT(*))::INT AS "total_count" FROM users AS "u" INNER JOIN posts AS "p" ON ("u"."id"="p"."author_id") GROUP BY 1,2 ORDER BY 3 DESC`, ); +}); +test(`arbitrary SQL conditions`, () => { const conditions = users .where({id: 10}) .where(sql`username='x' OR username='y'`) @@ -58,17 +107,169 @@ test(`q`, () => { q.and(q.gt(u.id, 1), q.lt(u.id, 20), sql`id % 2 = 0 OR id % 3 = 0`), ); - expect(printQueryForTest(conditions)).toEqual( + expect(printQueryForTest(conditions)).toEqual( `SELECT * FROM users WHERE "id"=\${ 10 } AND (username='x' OR username='y') AND "id">\${ 1 } AND "id"<\${ 20 } AND (id % 2 = 0 OR id % 3 = 0)`, ); }); -function printQueryForTest(query: {toSql(): SQLQuery}) { - return query.toSql().format({ - escapeIdentifier: escapePostgresIdentifier, - formatValue: (value: unknown) => ({ - placeholder: '${ ' + JSON.stringify(value) + ' }', - value: undefined, - }), - }).text; +test(`condition nesting`, () => { + const or = users.where((u) => q.or(q.eq(u.id, 1), q.eq(u.id, 10))); + const and = users.where((u) => q.and(q.gte(u.id, 1), q.lte(u.id, 10))); + const and_or = users.where((u) => + q.and( + q.or(q.eq(u.id, 1), q.eq(u.id, 10)), + q.neq(u.profile_image_url, null), + ), + ); + const or_and = users.where((u) => + q.or( + q.and(q.eq(u.id, 1), q.eq(u.username, 'ForbesLindesay')), + q.eq(u.id, 10), + ), + ); + + expect(printQueryForTest(or)).toEqual( + `SELECT * FROM users WHERE "id"=\${ 1 } OR "id"=\${ 10 }`, + ); + expect(printQueryForTest(and)).toEqual( + `SELECT * FROM users WHERE "id">=\${ 1 } AND "id"<=\${ 10 }`, + ); + expect(printQueryForTest(and_or)).toEqual( + `SELECT * FROM users WHERE ("id"=\${ 1 } OR "id"=\${ 10 }) AND "profile_image_url" IS NOT NULL`, + ); + // AND already binds more tightly than OR so no extra parentheses are needed here - although extra parentheses + // would potentially make it easier for humans to read. + expect(printQueryForTest(or_and)).toEqual( + `SELECT * FROM users WHERE "id"=\${ 1 } AND "username"=\${ "ForbesLindesay" } OR "id"=\${ 10 }`, + ); +}); + +test(`condition nesting - objects`, () => { + const or = users.where(q.or({id: 1}, {id: 10})); + const and = users.where(q.and({id: q.gte(1)}, {id: q.lte(10)})); + const and_or = users.where( + q.and(q.or({id: 1}, {id: 10}), {profile_image_url: q.not(null)}), + ); + const or_and = users.where( + q.or(q.and({id: 1}, {username: 'ForbesLindesay'}), {id: 10}), + ); + + expect(printQueryForTest(or)).toEqual( + `SELECT * FROM users WHERE "id"=\${ 1 } OR "id"=\${ 10 }`, + ); + expect(printQueryForTest(and)).toEqual( + `SELECT * FROM users WHERE "id">=\${ 1 } AND "id"<=\${ 10 }`, + ); + expect(printQueryForTest(and_or)).toEqual( + `SELECT * FROM users WHERE ("id"=\${ 1 } OR "id"=\${ 10 }) AND "profile_image_url" IS NOT NULL`, + ); + // AND already binds more tightly than OR so no extra parentheses are needed here - although extra parentheses + // would potentially make it easier for humans to read. + expect(printQueryForTest(or_and)).toEqual( + `SELECT * FROM users WHERE "id"=\${ 1 } AND "username"=\${ "ForbesLindesay" } OR "id"=\${ 10 }`, + ); +}); + +test(`operators - lower`, () => { + const lowerUsernames = users.select((u) => ({ + username: q.lower(u.username), + greeting: q.lower('HELLO WORLD'), + })); + expect( + printQueryForTest<{ + username: string; + greeting: string; + }>(lowerUsernames), + ).toEqual( + `SELECT LOWER("username") AS "username",LOWER(\${ "HELLO WORLD" }) AS "greeting" FROM users`, + ); + + const lowerMaxUsername = users.selectAggregate((u) => ({ + username: q.lower(q.max(u.username)), + greeting: q.lower('HELLO WORLD'), + })); + expect( + printQueryForTest<{ + username: string; + greeting: string; + }>(lowerMaxUsername), + ).toEqual( + `SELECT LOWER(MAX("username")) AS "username",LOWER(\${ "HELLO WORLD" }) AS "greeting" FROM users`, + ); + + const maxLowerUsername = users.selectAggregate((u) => ({ + username: q.max(q.lower(u.username)), + greeting: q.lower('HELLO WORLD'), + })); + expect( + printQueryForTest<{ + username: string; + greeting: string; + }>(maxLowerUsername), + ).toEqual( + `SELECT MAX(LOWER("username")) AS "username",LOWER(\${ "HELLO WORLD" }) AS "greeting" FROM users`, + ); +}); + +test(`operators - json`, () => { + interface DbRecord { + id: number; + data: {foo: string; bar: number}; + } + const records = createQuery( + `records`, + sql`records`, + columns(`records`, [ + {columnName: 'id', type: 'INT'}, + {columnName: 'data', type: 'JSONB'}, + ]), + ); + const query = records + .where((r) => + q.or( + q.eq(q.json(r.data).prop(`foo`).asJson(), 'hello'), + q.eq(q.json(r.data).prop(`foo`).asString(), 'world'), + ), + ) + .select((r) => ({ + foo: q.json(r.data).prop(`foo`).asString(), + bar: q.json(r.data).prop(`bar`).asJson(), + })); + expect( + printQueryForTest<{ + foo: string; + bar: number; + }>(query), + ).toEqual( + `SELECT "data"#>>\${ ["foo"] } AS "foo","data"#>\${ ["bar"] } AS "bar" FROM records WHERE "data"#>\${ ["foo"] }=\${ "\\"hello\\"" } OR "data"#>>\${ ["foo"] }=\${ "world" }`, + ); + + const lowerMaxUsername = users.selectAggregate((u) => ({ + username: q.lower(q.max(u.username)), + greeting: q.lower('HELLO WORLD'), + })); + expect( + printQueryForTest<{ + username: string; + greeting: string; + }>(lowerMaxUsername), + ).toEqual( + `SELECT LOWER(MAX("username")) AS "username",LOWER(\${ "HELLO WORLD" }) AS "greeting" FROM users`, + ); + + const maxLowerUsername = users.selectAggregate((u) => ({ + username: q.max(q.lower(u.username)), + greeting: q.lower('HELLO WORLD'), + })); + expect( + printQueryForTest<{ + username: string; + greeting: string; + }>(maxLowerUsername), + ).toEqual( + `SELECT MAX(LOWER("username")) AS "username",LOWER(\${ "HELLO WORLD" }) AS "greeting" FROM users`, + ); +}); +function printQueryForTest(query: ProjectedLimitQuery) { + return query.toSql().format(testFormat).text; } diff --git a/packages/pg-typed/src/v2/implementation/Columns.ts b/packages/pg-typed/src/v2/implementation/Columns.ts index 971303dd..62a02936 100644 --- a/packages/pg-typed/src/v2/implementation/Columns.ts +++ b/packages/pg-typed/src/v2/implementation/Columns.ts @@ -1,4 +1,3 @@ -import {SQLQuery} from '@databases/pg'; import {columnReference} from './Operators'; import {Columns} from '../types/Columns'; @@ -8,22 +7,15 @@ export function columns( tableName: string, schema?: { columnName: string; - postgresTypeQuery?: SQLQuery; - postgresType?: string; + type?: string; }[], isAlias: boolean = false, ): Columns { if (schema) { return Object.fromEntries( - schema.map(({columnName, postgresTypeQuery, postgresType}) => [ + schema.map(({columnName, type}) => [ columnName, - columnReference( - tableName, - columnName, - isAlias, - postgresTypeQuery, - postgresType, - ), + columnReference(tableName, columnName, isAlias, type ?? null), ]), ) as Columns; } else { @@ -35,7 +27,7 @@ export function columns( if (columnName === 'then' || typeof columnName !== 'string') { return undefined; } - return columnReference(tableName, columnName, isAlias); + return columnReference(tableName, columnName, isAlias, null); }, }, ) as any; diff --git a/packages/pg-typed/src/v2/implementation/Operators.ts b/packages/pg-typed/src/v2/implementation/Operators.ts index f5948806..d352c3da 100644 --- a/packages/pg-typed/src/v2/implementation/Operators.ts +++ b/packages/pg-typed/src/v2/implementation/Operators.ts @@ -1,48 +1,52 @@ import {SQLQuery, sql} from '@databases/pg'; -import Value, { +import { AggregatedTypedValue, - FieldCondition, + AnyOf, ComputedFieldCondition, - isSpecialValue, - NonAggregatedTypedValue, + FieldCondition, FieldConditionToSqlContext, + isAnyOfCondition, + isComputedFieldQuery, + isSpecialValue, + NonAggregatedValue, RawValue, - AnyOf, + TypedValue, + UnknownValue, + Value, ValueToSqlContext, - BaseAggregatedTypedValue, - isComputedFieldQuery, - ComputedValue, - AggregatedValue, - isAnyOfCondition, } from '../types/SpecialValues'; import { BinaryInput, OperatorDefinition, OperatorDefinitions, } from '../PostgresOperators'; -import {IOperators, List} from '../types/Operators'; +import { + AggregatedJsonValue, + IOperators, + List, + NonAggregatedJsonValue, +} from '../types/Operators'; import {ColumnReference} from '../types/Columns'; +import WhereCondition from '../WhereCondition'; export function columnReference( tableName: string, columnName: string, isAlias: boolean, - postgresTypeQuery?: SQLQuery, - postgresType?: string, -): Value { + sqlType: string | null, +): NonAggregatedValue { return new ColumnReferenceImplementation( tableName, columnName, isAlias, - postgresTypeQuery, - postgresType, + sqlType, ); } export function fieldConditionToPredicateValue( column: ColumnReference, f: FieldCondition, -): Value { +): NonAggregatedValue { const constantValue = fieldConditionToConstant(f); if (constantValue !== null) return constantValue; return new FieldConditionValue(column, f); @@ -50,7 +54,7 @@ export function fieldConditionToPredicateValue( export function valueToSelect( alias: string, - value: Value | AggregatedValue, + value: UnknownValue, ): SQLQuery { if ( value instanceof ColumnReferenceImplementation && @@ -72,8 +76,8 @@ export function valueToSelect( export function aliasTableInValue( tableName: string, tableAlias: string, - value: Value, -): Value { + value: NonAggregatedValue, +): NonAggregatedValue { if (!isSpecialValue(value)) { return value; } @@ -81,7 +85,7 @@ export function aliasTableInValue( } export function valueToSql( - value: Value | BaseAggregatedTypedValue, + value: UnknownValue, ctx: ValueToSqlContext, ): SQLQuery { if (isSpecialValue(value)) return value.toSql(ctx); @@ -112,13 +116,14 @@ const ORDER_BY_DIRECTION = { DESC: sql`DESC`, }; -abstract class BaseExpression - implements BaseAggregatedTypedValue, NonAggregatedTypedValue -{ +abstract class BaseExpression implements TypedValue { public readonly __isSpecialValue = true; public readonly __isAggregatedValue = true; public readonly __isNonAggregatedComputedValue = true; - + public readonly sqlType: string | null | undefined; + constructor(sqlType: string | null | undefined) { + this.sqlType = sqlType; + } public abstract toSql(ctx: ValueToSqlContext): SQLQuery; public __getType(): T { throw new Error( @@ -147,23 +152,19 @@ class ColumnReferenceImplementation public readonly columnName: string; public readonly isAlias: boolean; - // TODO: make use of schema info - public readonly postgresTypeQuery: SQLQuery | undefined; - public readonly postgresType: string | undefined; + // This property is assigned in the base class + public readonly sqlType!: string | null; constructor( tableName: string, columnName: string, isAlias: boolean, - postgresTypeQuery: SQLQuery | undefined, - postgresType: string | undefined, + sqlType: string | null, ) { - super(); + super(sqlType); this.tableName = tableName; this.columnName = columnName; this.isAlias = isAlias; - this.postgresTypeQuery = postgresTypeQuery; - this.postgresType = postgresType; } public toSql(ctx: ValueToSqlContext): SQLQuery { if (this.isAlias) { @@ -180,8 +181,7 @@ class ColumnReferenceImplementation tableAlias, this.columnName, true, - this.postgresTypeQuery, - this.postgresType, + this.sqlType, ); } } @@ -193,7 +193,8 @@ class OperatorExpression< > extends BaseExpression { public readonly op: OperatorDefinition; public readonly input: TInput; - public readonly prepareInput: ( + + private readonly _prepareInput: ( input: TInput, ctx: ValueToSqlContext, ) => TPreparedInput; @@ -202,16 +203,17 @@ class OperatorExpression< op: OperatorDefinition, input: TInput, prepareInput: (input: TInput, ctx: ValueToSqlContext) => TPreparedInput, + sqlType: string | null | undefined, ) { - super(); + super(sqlType); this.op = op; this.input = input; - this.prepareInput = prepareInput; + this._prepareInput = prepareInput; } public toSql(ctx: ValueToSqlContext): SQLQuery { return this.op.toSql( - this.prepareInput(this.input, { + this._prepareInput(this.input, { ...ctx, parentOperatorPrecedence: this.op.precedence, }), @@ -266,9 +268,13 @@ class OperatorFieldQuery< class NonAggregateFunction extends BaseExpression { public readonly fn: keyof typeof NON_AGGREGATE_FUNCTIONS; - public readonly values: Value[]; - constructor(fn: keyof typeof NON_AGGREGATE_FUNCTIONS, values: Value[]) { - super(); + public readonly values: UnknownValue[]; + constructor( + fn: keyof typeof NON_AGGREGATE_FUNCTIONS, + values: UnknownValue[], + sqlType: string | null | undefined, + ) { + super(sqlType); this.fn = fn; this.values = values; } @@ -287,23 +293,27 @@ class AggregateFunction implements AggregatedTypedValue { public readonly fn: keyof typeof AGGREGATE_FUNCTIONS; - public readonly values: Value[]; - public readonly condition: undefined | Value; + public readonly values: NonAggregatedValue[]; + public readonly typeCast: SQLQuery | undefined; + public readonly condition: undefined | NonAggregatedValue; public readonly orderByClauses: { direction: keyof typeof ORDER_BY_DIRECTION; - value: Value; + value: NonAggregatedValue; }[]; public readonly isDistinct: boolean; constructor( fn: keyof typeof AGGREGATE_FUNCTIONS, - values: Value[], - condition?: Value, - orderBy?: Value[], + values: NonAggregatedValue[], + sqlType: string | null | undefined, + typeCast?: SQLQuery, + condition?: NonAggregatedValue, + orderBy?: NonAggregatedValue[], distinct?: boolean, ) { - super(); + super(sqlType); this.fn = fn; this.values = values; + this.typeCast = typeCast; this.condition = condition; this.orderByClauses = orderBy ?? []; this.isDistinct = distinct ?? false; @@ -326,19 +336,22 @@ class AggregateFunction `, `, )}`; } - if (this.condition !== undefined) { - return sql`${fn}(${args}) FILTER (WHERE ${valueToSql( - this.condition, - ctx, - )})`; + let result = sql`${fn}(${args})`; + if (this.condition) { + result = sql`${result} FILTER (WHERE ${valueToSql(this.condition, ctx)})`; + } + if (this.typeCast) { + result = sql`(${result})::${this.typeCast}`; } - return sql`${fn}(${args})`; + return result; } public distinct(): AggregatedTypedValue { return new AggregateFunction( this.fn, this.values, + this.sqlType, + this.typeCast, this.condition, this.orderByClauses, true, @@ -346,11 +359,13 @@ class AggregateFunction } public orderByAsc( - value: Value, + value: NonAggregatedValue, ): AggregatedTypedValue { return new AggregateFunction( this.fn, this.values, + this.sqlType, + this.typeCast, this.condition, [...this.orderByClauses, {direction: 'ASC', value}], this.isDistinct, @@ -358,21 +373,27 @@ class AggregateFunction } public orderByDesc( - value: Value, + value: NonAggregatedValue, ): AggregatedTypedValue { return new AggregateFunction( this.fn, this.values, + this.sqlType, + this.typeCast, this.condition, [...this.orderByClauses, {direction: 'DESC', value}], this.isDistinct, ); } - public filter(condition: Value): AggregatedTypedValue { + public filter( + condition: NonAggregatedValue, + ): AggregatedTypedValue { return new AggregateFunction( this.fn, this.values, + this.sqlType, + this.typeCast, condition, this.orderByClauses, this.isDistinct, @@ -418,8 +439,8 @@ class AllOf extends BaseFieldQuery { class AnyOfImplementation extends BaseFieldQuery implements AnyOf { public readonly __isAnyOf = true; - public readonly values: Value>>; - constructor(values: Value>>) { + public readonly values: NonAggregatedValue>>; + constructor(values: NonAggregatedValue>>) { super(); this.values = values; } @@ -492,15 +513,11 @@ class AnyOfImplementation extends BaseFieldQuery implements AnyOf { } class EqualsAnyOf extends BaseExpression { - public readonly left: AggregatedValue | Value; + public readonly left: UnknownValue; public readonly right: AnyOf; - constructor( - op: OperatorDefinition, - left: AggregatedValue | Value, - right: AnyOf, - ) { - super(); + constructor(left: UnknownValue, right: AnyOf) { + super(`BOOLEAN`); this.left = left; this.right = right; } @@ -539,121 +556,240 @@ class CaseInsensitive extends BaseFieldQuery { } } +function prepareUnaryOperatorExpression( + expression: UnknownValue, + ctx: ValueToSqlContext, +): SQLQuery { + return valueToSql(expression, ctx); +} + function prepareBinaryOperatorExpression( { left, right, }: { - left: AggregatedValue | Value; - right: AggregatedValue | Value; + left: UnknownValue; + right: UnknownValue; }, ctx: ValueToSqlContext, -) { - return {left: valueToSql(left, ctx), right: valueToSql(right, ctx)}; +): BinaryInput { + const leftSqlType = isSpecialValue(left) ? left.sqlType : null; + const rightSqlType = isSpecialValue(right) ? right.sqlType : null; + + const ctxWithType = setSqlType( + ctx, + leftSqlType === rightSqlType || rightSqlType === null + ? leftSqlType + : leftSqlType === null + ? rightSqlType + : undefined, + ); + return { + left: valueToSql(left, ctxWithType), + right: valueToSql(right, ctxWithType), + }; +} + +function prepareUnaryOperatorFieldQuery( + _input: null, + ctx: FieldConditionToSqlContext, +): SQLQuery { + return ctx.left; } function prepareBinaryOperatorFieldQuery( right: RawValue, ctx: FieldConditionToSqlContext, -) { +): BinaryInput { return {left: ctx.left, right: sql.value(ctx.toValue(right))}; } function binaryOperator( operator: OperatorDefinition<{left: SQLQuery; right: SQLQuery}>, + { + getType, + getUnaryOperator, + handleAnyOf, + }: { + getType?: ( + left: UnknownValue, + right: UnknownValue, + ) => string | null | undefined; + getUnaryOperator?: (value: unknown) => null | OperatorDefinition; + handleAnyOf?: ( + left: NonAggregatedValue, + right: AnyOf, + ) => NonAggregatedValue; + } = {}, ): (leftOrOnly: any, right?: any) => any { return ( - leftOrOnly: AggregatedValue | Value | RawValue, - right?: AggregatedValue | Value | AnyOf, - ): - | (Value & NonAggregatedTypedValue) - | FieldCondition => { + leftOrOnly: UnknownValue | RawValue, + right?: UnknownValue | AnyOf, + ): Value | FieldCondition => { if (right === undefined) { - return new OperatorFieldQuery( + const unaryOperator = getUnaryOperator?.(leftOrOnly); + if (unaryOperator) { + return new OperatorFieldQuery( + unaryOperator, + null, + prepareUnaryOperatorFieldQuery, + null, + ); + } + return new OperatorFieldQuery, TLeft>( operator, leftOrOnly as RawValue, prepareBinaryOperatorFieldQuery, null, ); } else if (isAnyOfCondition(right)) { - if (operator !== OperatorDefinitions.EQ) { - throw new Error( - `The only operator that can be used with "anyOf" is "eq".`, - ); + if (handleAnyOf) { + return handleAnyOf( + leftOrOnly as NonAggregatedValue, + right, + ) as any; } - if ( - right instanceof AnyOfImplementation && - !isSpecialValue(right.values) && - !sql.isSqlQuery(right.values) && - [...right.values].length === 0 - ) { - // @ts-expect-error - return false; - } - return new EqualsAnyOf( - operator, - leftOrOnly as AggregatedValue | Value, - right, - ) as any; + throw new Error(`"anyOf" cannot be used with this operator.`); } else { + const unaryOperator = getUnaryOperator?.(right); + if (unaryOperator) { + return new OperatorExpression, TResult>( + unaryOperator, + leftOrOnly as UnknownValue, + prepareUnaryOperatorExpression, + getType + ? getType(leftOrOnly as UnknownValue, right) + : `BOOLEAN`, + ); + } return new OperatorExpression( operator, - {left: leftOrOnly as AggregatedValue | Value, right}, + { + left: leftOrOnly as UnknownValue, + right, + }, prepareBinaryOperatorExpression, + getType ? getType(leftOrOnly as UnknownValue, right) : `BOOLEAN`, ); } }; } function prepareVariadicOperatorInput( - inputs: (Value | AggregatedTypedValue)[], + inputs: readonly UnknownValue[], ctx: ValueToSqlContext, ) { return inputs.map((input) => valueToSql(input, ctx)); } -function variadicOperator( - operator: OperatorDefinition, - getConstantValue?: ( - ...params: (Value | AggregatedTypedValue)[] - ) => TStaticValue | undefined, +function booleanOperator( + operatorName: 'AND' | 'OR', + { + getConstantValue, + }: { + getConstantValue?: ( + ...params: UnknownValue[] + ) => boolean | undefined; + }, ) { - return ( - ...params: (Value | AggregatedTypedValue)[] - ): - | TStaticValue - | (BaseAggregatedTypedValue & - NonAggregatedTypedValue) => { - const flatParams = params.flatMap((p) => { + const operator: OperatorDefinition = + OperatorDefinitions[operatorName]; + const fn = ( + ...params: readonly (UnknownValue | WhereCondition)[] + ): any => { + if ( + params.some( + (p) => + typeof p === 'function' || + (typeof p === 'object' && + p !== null && + !isSpecialValue(p) && + !sql.isSqlQuery(p)), + ) + ) { + return (columns: any) => { + return fn( + ...params.map((condition) => { + if (typeof condition === 'function') return condition(columns); + if ( + typeof condition === 'object' && + condition !== null && + !isSpecialValue(condition) && + !sql.isSqlQuery(condition) + ) { + return Operators.and( + ...Object.entries(condition).map(([columnName, value]) => + fieldConditionToPredicateValue(columns[columnName], value), + ), + ); + } + return condition; + }), + ); + }; + } + + const flatParams: readonly UnknownValue[] = ( + params as UnknownValue[] + ).flatMap((p) => { if (p instanceof OperatorExpression && p.op === operator) { - return p.input as (Value | AggregatedTypedValue)[]; + return p.input as UnknownValue[]; } return [p]; }); const constantValue = getConstantValue && getConstantValue(...flatParams); if (constantValue !== undefined) return constantValue; + if (flatParams.length === 1) { + return flatParams[0] as Value; + } - return new OperatorExpression( - operator, - flatParams, - prepareVariadicOperatorInput, - ); + return new OperatorExpression< + SQLQuery[], + readonly UnknownValue[], + boolean + >(operator, flatParams, prepareVariadicOperatorInput, `BOOLEAN`); }; + return fn; } -function nonAggregateFunction(fn: keyof typeof NON_AGGREGATE_FUNCTIONS) { - return ( - ...args: TArgs - ): ComputedValue => { - return new NonAggregateFunction(fn, args); +function sqlTypeFromArgs(args: any[]) { + return args.reduce((t, arg) => { + if (t === undefined) return undefined; + if (isSpecialValue(arg) && (arg as any).sqlType !== null) { + if (t === null || t === (args as any).sqlType) { + return (arg as any).sqlType; + } else { + return undefined; + } + } + return t; + }, null); +} +function nonAggregateFunction( + fn: keyof typeof NON_AGGREGATE_FUNCTIONS, + sqlType?: string, +) { + return (...args: TArgs): Value => { + return new NonAggregateFunction( + fn, + args, + sqlType === undefined ? sqlTypeFromArgs(args) : sqlType, + ); }; } -function aggregateFunction(fn: keyof typeof AGGREGATE_FUNCTIONS) { +function aggregateFunction( + fn: keyof typeof AGGREGATE_FUNCTIONS, + sqlType?: string, +) { return ( ...args: TArgs ): AggregatedTypedValue => { - return new AggregateFunction(fn, args); + return new AggregateFunction( + fn, + args, + sqlType === undefined ? sqlTypeFromArgs(args) : sqlType, + ); }; } @@ -662,10 +798,14 @@ function fieldConditionToSql( ctx: FieldConditionToSqlContext, ) { if (isSpecialValue(value)) return value.toSqlCondition(ctx); + const v = ctx.toValue(value); + if (v === null) { + return OperatorDefinitions.IS_NULL.toSql(ctx.left, ctx); + } return OperatorDefinitions.EQ.toSql( { left: ctx.left, - right: sql.value(ctx.toValue(value)), + right: sql.value(v), }, ctx, ); @@ -676,26 +816,19 @@ function fieldConditionToConstant(q: FieldCondition): boolean | null { return null; } -function overload( - overloads: Record any>, - chooseOverload: (...args: any[]) => TKey, -) { - return (...args: any[]): any => { - return overloads[chooseOverload(...args)](...args); - }; -} - class FieldConditionValue extends BaseExpression { public readonly left: ColumnReference; public readonly right: FieldCondition; constructor(left: ColumnReference, right: FieldCondition) { - super(); + super(`BOOLEAN`); this.left = left; this.right = right; } public toSql(ctx: ValueToSqlContext): SQLQuery { + if (this.left.sqlType) { + } return fieldConditionToSql(this.right, { - ...ctx, + ...setSqlType(ctx, this.left.sqlType), left: valueToSql(this.left, ctx), }); } @@ -704,9 +837,9 @@ class FieldConditionValue extends BaseExpression { class AliasTableInValue extends BaseExpression { public readonly tableName: string; public readonly tableAlias: string; - public readonly value: Value; - constructor(tableName: string, tableAlias: string, value: Value) { - super(); + public readonly value: UnknownValue; + constructor(tableName: string, tableAlias: string, value: UnknownValue) { + super(isSpecialValue(value) ? value.sqlType : null); this.tableName = tableName; this.tableAlias = tableAlias; this.value = value; @@ -722,76 +855,236 @@ class AliasTableInValue extends BaseExpression { } } +// export interface NonAggregatedJsonValue { +// asString(): NonAggregatedValue; +// asJson(): NonAggregatedValue; +// prop( +// key: TKey, +// ): NonAggregatedJsonValue; +// } + +// export interface AggregatedJsonValue { +// asString(): AggregatedValue; +// asJson(): AggregatedValue; +// prop(key: TKey): AggregatedJsonValue; +// } + +class JsonValue + implements + NonAggregatedJsonValue, + AggregatedJsonValue +{ + readonly __isSpecialValue = true; + private readonly _value: UnknownValue; + private readonly _path: string[]; + constructor(value: UnknownValue, path: string[]) { + this._value = value; + this._path = path; + } + prop(key: TKey) { + return new JsonValue(this._value, [ + ...this._path, + `${key as string | number}`, + ]); + } + asString(): Value { + return new JsonValueResult(this._value, this._path, true, `TEXT`); + } + asJson(): Value { + return new JsonValueResult( + this._value, + this._path, + false, + isSpecialValue(this._value) ? this._value.sqlType : null, + ); + } +} + +class JsonValueResult< + TBaseValue, + TValueAtPath, +> extends BaseExpression { + private readonly _value: UnknownValue; + private readonly _path: string[]; + private readonly _asText: boolean; + constructor( + value: UnknownValue, + path: string[], + asText: boolean, + sqlType: string | null | undefined, + ) { + super(sqlType); + this._value = value; + this._path = path; + this._asText = asText; + } + public toSql(ctx: ValueToSqlContext): SQLQuery { + if (this._asText) { + return sql`${valueToSql(this._value, ctx)}#>>${this._path}`; + } else { + return sql`${valueToSql(this._value, ctx)}#>${this._path}`; + } + } +} + +// function prepareExpressionAndType( +// {expression, type}: {expression: UnknownValue; type: string}, +// ctx: ValueToSqlContext, +// ) { +// if (!/^[a-z]([a-z0-9_]*[a-z0-9])?(\[\])?$/i.test(type)) { +// throw new Error(`Invalid type: ${type}`); +// } +// return { +// expression: valueToSql(expression, ctx), +// type: sql.__dangerous__rawValue(type), +// }; +// } +// function typeCast( +// expression: UnknownValue, +// sqlType: string, +// ): TypedValue { +// return new OperatorExpression< +// {expression: SQLQuery; type: SQLQuery}, +// {expression: UnknownValue; type: string}, +// TResult +// >( +// OperatorDefinitions.TYPECAST, +// {expression, type: sqlType}, +// prepareExpressionAndType, +// sqlType, +// ); +// } + +function setSqlType( + ctx: ValueToSqlContext, + sqlType: string | null | undefined, +): ValueToSqlContext { + if (sqlType === 'JSON' || sqlType === 'JSONB') { + return { + ...ctx, + toValue: (v) => JSON.stringify(v), + }; + } + if (sqlType === 'JSON[]' || sqlType === 'JSONB[]') { + return { + ...ctx, + toValue: (v) => + Array.isArray(v) ? v.map((v) => JSON.stringify(v)) : ctx.toValue(v), + }; + } + return ctx; +} + const Operators: IOperators = { allOf: (values) => new AllOf(values), - and: variadicOperator(OperatorDefinitions.AND, (...params) => { - if (params.every((p) => p === true)) return true; - if (params.some((p) => p === false)) return false; - return undefined; + and: booleanOperator(`AND`, { + getConstantValue: (...params) => { + if (params.every((p) => p === true)) return true; + if (params.some((p) => p === false)) return false; + return undefined; + }, }), anyOf: (values) => new AnyOfImplementation(values), caseInsensitive: (value) => new CaseInsensitive(value), count(expression) { - return new AggregateFunction(`COUNT`, [expression ?? STAR]); + return new AggregateFunction( + `COUNT`, + [expression ?? STAR], + `INT`, + sql`INT`, + ); }, - eq: binaryOperator(OperatorDefinitions.EQ), + eq: binaryOperator(OperatorDefinitions.EQ, { + getUnaryOperator(right) { + if (right === null) { + return OperatorDefinitions.IS_NULL; + } + return null; + }, + handleAnyOf( + left: NonAggregatedValue, + right: AnyOf, + ): NonAggregatedValue { + if ( + right instanceof AnyOfImplementation && + !isSpecialValue(right.values) && + !sql.isSqlQuery(right.values) && + [...right.values].length === 0 + ) { + return false; + } + return new EqualsAnyOf(left, right) as any; + }, + }), gt: binaryOperator(OperatorDefinitions.GT), gte: binaryOperator(OperatorDefinitions.GTE), ilike: binaryOperator(OperatorDefinitions.ILIKE), + json: (value: UnknownValue): JsonValue => { + return new JsonValue(value, []); + }, like: binaryOperator(OperatorDefinitions.LIKE), lower: nonAggregateFunction(`LOWER`), lt: binaryOperator(OperatorDefinitions.LT), lte: binaryOperator(OperatorDefinitions.LTE), max: aggregateFunction(`MAX`), min: aggregateFunction(`MIN`), - neq: binaryOperator(OperatorDefinitions.NEQ), - not: overload( - { - expression(value: Value): Value { - if (typeof value === 'boolean') return !value; - return new OperatorExpression( - OperatorDefinitions.NOT, - value, - valueToSql, - ); - }, - fieldQuery(value: FieldCondition): FieldCondition { - if (isSpecialValue(value)) { - const constantValueOfExpression = fieldConditionToConstant(value); - const constantValueOfNot = - constantValueOfExpression !== null - ? !constantValueOfExpression - : null; - return new OperatorFieldQuery( - OperatorDefinitions.NOT, - value, - fieldConditionToSql, - constantValueOfNot, - ); - } else { - return new OperatorFieldQuery( - OperatorDefinitions.NEQ, - value, - prepareBinaryOperatorFieldQuery, - null, - ); - } - }, + neq: binaryOperator(OperatorDefinitions.NEQ, { + getUnaryOperator(right) { + if (right === null) { + return OperatorDefinitions.IS_NOT_NULL; + } + return null; }, - (value: any) => { - return (isSpecialValue(value) && !isComputedFieldQuery(value)) || - sql.isSqlQuery(value) || - typeof value === 'boolean' - ? `expression` - : `fieldQuery`; + }), + not: (value: UnknownValue | FieldCondition): any => { + if (typeof value === 'boolean') return !value; + + if (isSpecialValue(value) && isComputedFieldQuery(value)) { + const constantValueOfExpression = fieldConditionToConstant(value); + const constantValueOfNot = + constantValueOfExpression !== null ? !constantValueOfExpression : null; + return new OperatorFieldQuery, T>( + OperatorDefinitions.NOT, + value, + fieldConditionToSql, + constantValueOfNot, + ); + } + + if (sql.isSqlQuery(value) || isSpecialValue(value)) { + return new OperatorExpression, boolean>( + OperatorDefinitions.NOT, + value, + valueToSql, + `BOOLEAN`, + ); + } + + if (value === null) { + return new OperatorFieldQuery( + OperatorDefinitions.IS_NOT_NULL, + null, + prepareUnaryOperatorFieldQuery, + null, + ); + } else { + return new OperatorFieldQuery, T>( + OperatorDefinitions.NEQ, + value, + prepareBinaryOperatorFieldQuery, + null, + ); + } + }, + or: booleanOperator(`OR`, { + getConstantValue: (...params) => { + if (params.some((p) => p === true)) return true; + if (params.every((p) => p === false)) return false; + return undefined; }, - ), - or: variadicOperator(OperatorDefinitions.OR, (...params) => { - if (params.some((p) => p === true)) return true; - if (params.every((p) => p === false)) return false; - return undefined; }), sum: aggregateFunction(`SUM`), + upper: nonAggregateFunction(`UPPER`), }; export default Operators; diff --git a/packages/pg-typed/src/v2/index.ts b/packages/pg-typed/src/v2/index.ts index 5f0b0065..c21c5cf7 100644 --- a/packages/pg-typed/src/v2/index.ts +++ b/packages/pg-typed/src/v2/index.ts @@ -1,5 +1,5 @@ import Operators from './implementation/Operators'; -import Value from './types/SpecialValues'; +import NonAggregatedValue from './types/SpecialValues'; import {JoinQueryBuilder, JoinQuery} from './types/Join'; import AliasedQuery from './AliasedQuery'; import {Table} from './Table'; @@ -7,4 +7,10 @@ import {IOperators} from './types/Operators'; export const q: IOperators = Operators; -export type {AliasedQuery, JoinQueryBuilder, JoinQuery, Table, Value}; +export type { + AliasedQuery, + JoinQueryBuilder, + JoinQuery, + Table, + NonAggregatedValue as Value, +}; diff --git a/packages/pg-typed/src/v2/types/Columns.ts b/packages/pg-typed/src/v2/types/Columns.ts index 8ad1eee0..1db2a3d5 100644 --- a/packages/pg-typed/src/v2/types/Columns.ts +++ b/packages/pg-typed/src/v2/types/Columns.ts @@ -1,9 +1,7 @@ -import {SQLQuery} from '@databases/pg'; import {NonAggregatedTypedValue} from './SpecialValues'; export interface ColumnReference extends NonAggregatedTypedValue { - readonly postgresTypeQuery?: SQLQuery; - readonly postgresType?: string; + readonly sqlType: string | null; setAlias(tableAlias: string): ColumnReference; } diff --git a/packages/pg-typed/src/v2/types/Join.ts b/packages/pg-typed/src/v2/types/Join.ts index c2afbedf..31860402 100644 --- a/packages/pg-typed/src/v2/types/Join.ts +++ b/packages/pg-typed/src/v2/types/Join.ts @@ -2,10 +2,12 @@ import GroupByQuery from '../GroupByQuery'; import {JoinableQueryLeft} from './JoinableQuery'; import {ProjectedQuery} from './Queries'; import {AggregatedSelectionSet, SelectionSet} from './SelectionSet'; -import Value from './SpecialValues'; +import {NonAggregatedValue} from './SpecialValues'; export interface JoinQueryBuilder { - on(predicate: (column: TColumns) => Value): JoinQuery; + on( + predicate: (column: TColumns) => NonAggregatedValue, + ): JoinQuery; } export interface JoinQuery extends JoinableQueryLeft { @@ -19,5 +21,7 @@ export interface JoinQuery extends JoinableQueryLeft { aggregation: (column: TColumns) => AggregatedSelectionSet, ): ProjectedQuery; - where(predicate: (column: TColumns) => Value): JoinQuery; + where( + predicate: (column: TColumns) => NonAggregatedValue, + ): JoinQuery; } diff --git a/packages/pg-typed/src/v2/types/Operators.ts b/packages/pg-typed/src/v2/types/Operators.ts index 5721627b..3b09b5d8 100644 --- a/packages/pg-typed/src/v2/types/Operators.ts +++ b/packages/pg-typed/src/v2/types/Operators.ts @@ -1,42 +1,90 @@ -import Value, { +import WhereCondition from '../WhereCondition'; +import { AggregatedTypedValue, + NonAggregatedValue, FieldCondition, RawValue, AnyOf, + AggregatedValue, + Value, } from './SpecialValues'; export interface List { [Symbol.iterator](): IterableIterator; } +export interface NonAggregatedJsonValue { + readonly __isSpecialValue: true; + asString(): NonAggregatedValue; + asJson(): NonAggregatedValue; + prop( + key: TKey, + ): NonAggregatedJsonValue; +} + +export interface AggregatedJsonValue { + readonly __isSpecialValue: true; + asString(): AggregatedValue; + asJson(): AggregatedValue; + prop(key: TKey): AggregatedJsonValue; +} + // prettier-ignore export interface IOperators { allOf(values: List>): FieldCondition; and(...values: Value[]): Value; - anyOf(values: Value>>): AnyOf; + and(...values: NonAggregatedValue[]): NonAggregatedValue; + and(...values: AggregatedValue[]): AggregatedValue; + and(...conditions: WhereCondition[]): WhereCondition; + anyOf(values: NonAggregatedValue>>): AnyOf; caseInsensitive: (value: FieldCondition) => FieldCondition; - count(expression?: Value): AggregatedTypedValue; - eq(left: Value, right: Value | AnyOf): Value; + count(expression?: NonAggregatedValue): AggregatedTypedValue; + eq(left: Value, right: Value): Value; + eq(left: NonAggregatedValue, right: NonAggregatedValue | AnyOf): NonAggregatedValue; + eq(left: AggregatedValue, right: AggregatedValue): AggregatedValue; gt(right: RawValue): FieldCondition; gt(left: Value, right: Value): Value; + gt(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; + gt(left: AggregatedValue, right: AggregatedValue): AggregatedValue; gte(right: RawValue): FieldCondition; gte(left: Value, right: Value): Value; + gte(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; + gte(left: AggregatedValue, right: AggregatedValue): AggregatedValue; neq(left: Value, right: Value): Value; + neq(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; + neq(left: AggregatedValue, right: AggregatedValue): AggregatedValue; ilike(right: string): FieldCondition; - ilike(left: Value, right: Value): Value; + ilike(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; // TODO: IN should probably have an SQL query as the right hand side // in(left: Value, right: Value): Value; + json(value: NonAggregatedValue): NonAggregatedJsonValue; + json(value: AggregatedValue): AggregatedJsonValue; like(right: string): FieldCondition; - like(left: Value, right: Value): Value; + like(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; lower(value: Value): Value; + lower(value: NonAggregatedValue): NonAggregatedValue; + lower(value: AggregatedValue): AggregatedValue; lt(right: RawValue): FieldCondition; lt(left: Value, right: Value): Value; + lt(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; + lt(left: AggregatedValue, right: AggregatedValue): AggregatedValue; lte(right: RawValue): FieldCondition; lte(left: Value, right: Value): Value; - max(value: Value): AggregatedTypedValue; - min(value: Value): AggregatedTypedValue; + lte(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; + lte(left: AggregatedValue, right: AggregatedValue): AggregatedValue; + max(value: NonAggregatedValue): AggregatedTypedValue; + min(value: NonAggregatedValue): AggregatedTypedValue; + not(value: boolean): boolean; not(value: Value): Value; + not(value: NonAggregatedValue): NonAggregatedValue; + not(value: AggregatedValue): AggregatedValue; not(value: FieldCondition): FieldCondition; or(...values: Value[]): Value; - sum(value: Value): AggregatedTypedValue; + or(...values: NonAggregatedValue[]): NonAggregatedValue; + or(...values: AggregatedValue[]): AggregatedValue; + or(...conditions: WhereCondition[]): WhereCondition; + sum(value: NonAggregatedValue): AggregatedTypedValue; + upper(value: Value): Value; + upper(value: NonAggregatedValue): NonAggregatedValue; + upper(value: AggregatedValue): AggregatedValue; } diff --git a/packages/pg-typed/src/v2/types/SelectionSet.ts b/packages/pg-typed/src/v2/types/SelectionSet.ts index a88fc719..066540be 100644 --- a/packages/pg-typed/src/v2/types/SelectionSet.ts +++ b/packages/pg-typed/src/v2/types/SelectionSet.ts @@ -1,7 +1,7 @@ -import Value, {AggregatedValue} from './SpecialValues'; +import {AggregatedValue, NonAggregatedValue} from './SpecialValues'; export type SelectionSet = { - [key in keyof TSelection]: Value; + [key in keyof TSelection]: NonAggregatedValue; }; export type AggregatedSelectionSet = { diff --git a/packages/pg-typed/src/v2/types/SpecialValues.ts b/packages/pg-typed/src/v2/types/SpecialValues.ts index 4e874a53..bd1ef8d2 100644 --- a/packages/pg-typed/src/v2/types/SpecialValues.ts +++ b/packages/pg-typed/src/v2/types/SpecialValues.ts @@ -14,18 +14,19 @@ export interface ValueToSqlContext { readonly parentOperatorPrecedence: number | null; } -export interface NonAggregatedTypedValue { +export interface SpecialTypedValue { readonly __isSpecialValue: true; - readonly __isNonAggregatedComputedValue: true; __getType(): T; + readonly sqlType: string | null | undefined; toSql(ctx: ValueToSqlContext): SQLQuery; } -export interface BaseAggregatedTypedValue { - readonly __isSpecialValue: true; +export interface NonAggregatedTypedValue extends SpecialTypedValue { + readonly __isNonAggregatedComputedValue: true; +} + +export interface BaseAggregatedTypedValue extends SpecialTypedValue { readonly __isAggregatedValue: true; - __getType(): T; - toSql(ctx: ValueToSqlContext): SQLQuery; } export interface FieldConditionToSqlContext extends ValueToSqlContext { @@ -42,9 +43,13 @@ export interface ComputedFieldCondition { export interface AggregatedTypedValue extends BaseAggregatedTypedValue { distinct(): AggregatedTypedValue; - orderByAsc(value: Value): AggregatedTypedValue; - orderByDesc(value: Value): AggregatedTypedValue; - filter(condition: Value): BaseAggregatedTypedValue; + orderByAsc( + value: NonAggregatedValue, + ): AggregatedTypedValue; + orderByDesc( + value: NonAggregatedValue, + ): AggregatedTypedValue; + filter(condition: NonAggregatedValue): BaseAggregatedTypedValue; } export interface AnyOf extends ComputedFieldCondition { @@ -53,11 +58,25 @@ export interface AnyOf extends ComputedFieldCondition { __getType(): T; } -export type ComputedValue = SQLQuery | NonAggregatedTypedValue; -export type AggregatedValue = SQLQuery | BaseAggregatedTypedValue; +export type AggregatedValue = + | RawValue + | SQLQuery + | BaseAggregatedTypedValue; + +export type NonAggregatedValue = + | RawValue + | SQLQuery + | NonAggregatedTypedValue; -type Value = RawValue | ComputedValue; -export default Value; +export interface TypedValue + extends BaseAggregatedTypedValue, + NonAggregatedTypedValue {} +export type Value = RawValue | SQLQuery | TypedValue; +export type UnknownValue = + | RawValue + | SQLQuery + | BaseAggregatedTypedValue + | NonAggregatedTypedValue; export function isSpecialValue( value: unknown, From d7f49d55fee12a92d7c08d5cfc7824060645f440 Mon Sep 17 00:00:00 2001 From: Forbes Lindesay Date: Fri, 19 May 2023 17:25:21 +0100 Subject: [PATCH 4/5] more functionality --- packages/pg-typed/src/v2/AliasedQuery.ts | 9 - packages/pg-typed/src/v2/GroupByQuery.ts | 8 - packages/pg-typed/src/v2/InsertQuery.ts | 1 - .../pg-typed/src/v2/ProjectedLimitQuery.ts | 22 - packages/pg-typed/src/v2/SelectQuery.ts | 29 -- packages/pg-typed/src/v2/Table.ts | 61 --- .../pg-typed/src/v2/__tests__/insert.test.ts | 182 ++++++++ .../{index.test.ts => queries.test.ts} | 117 ++++-- .../pg-typed/src/v2/implementation/Columns.ts | 71 +++- .../OperatorDefinitions.ts} | 1 - .../src/v2/implementation/Operators.ts | 50 ++- .../Queries.ts} | 397 +++++++++++++++--- .../src/v2/implementation/Statements.ts | 157 +++++++ .../pg-typed/src/v2/implementation/Table.ts | 169 ++++++++ packages/pg-typed/src/v2/index.ts | 4 +- packages/pg-typed/src/v2/types/Columns.ts | 3 +- packages/pg-typed/src/v2/types/Join.ts | 27 -- .../pg-typed/src/v2/types/JoinableQuery.ts | 51 --- packages/pg-typed/src/v2/types/Operators.ts | 17 +- packages/pg-typed/src/v2/types/Queries.ts | 111 ++++- .../pg-typed/src/v2/types/SelectionSet.ts | 25 +- packages/pg-typed/src/v2/types/Statements.ts | 44 ++ packages/pg-typed/src/v2/types/Table.ts | 19 + .../src/v2/{ => types}/TableSchema.ts | 4 +- .../src/v2/{ => types}/WhereCondition.ts | 4 +- 25 files changed, 1234 insertions(+), 349 deletions(-) delete mode 100644 packages/pg-typed/src/v2/AliasedQuery.ts delete mode 100644 packages/pg-typed/src/v2/GroupByQuery.ts delete mode 100644 packages/pg-typed/src/v2/InsertQuery.ts delete mode 100644 packages/pg-typed/src/v2/ProjectedLimitQuery.ts delete mode 100644 packages/pg-typed/src/v2/SelectQuery.ts delete mode 100644 packages/pg-typed/src/v2/Table.ts create mode 100644 packages/pg-typed/src/v2/__tests__/insert.test.ts rename packages/pg-typed/src/v2/__tests__/{index.test.ts => queries.test.ts} (76%) rename packages/pg-typed/src/v2/{PostgresOperators.ts => implementation/OperatorDefinitions.ts} (98%) rename packages/pg-typed/src/v2/{QueryImplementation.ts => implementation/Queries.ts} (67%) create mode 100644 packages/pg-typed/src/v2/implementation/Statements.ts create mode 100644 packages/pg-typed/src/v2/implementation/Table.ts delete mode 100644 packages/pg-typed/src/v2/types/Join.ts delete mode 100644 packages/pg-typed/src/v2/types/JoinableQuery.ts create mode 100644 packages/pg-typed/src/v2/types/Statements.ts create mode 100644 packages/pg-typed/src/v2/types/Table.ts rename packages/pg-typed/src/v2/{ => types}/TableSchema.ts (68%) rename packages/pg-typed/src/v2/{ => types}/WhereCondition.ts (78%) diff --git a/packages/pg-typed/src/v2/AliasedQuery.ts b/packages/pg-typed/src/v2/AliasedQuery.ts deleted file mode 100644 index 232add8a..00000000 --- a/packages/pg-typed/src/v2/AliasedQuery.ts +++ /dev/null @@ -1,9 +0,0 @@ -import {Columns} from './types/Columns'; - -import SelectQuery from './SelectQuery'; -import {JoinableQueryLeft, JoinableQueryRight} from './types/JoinableQuery'; - -export default interface AliasedQuery - extends SelectQuery, - JoinableQueryRight, - JoinableQueryLeft<{[TKey in TAlias]: Columns}> {} diff --git a/packages/pg-typed/src/v2/GroupByQuery.ts b/packages/pg-typed/src/v2/GroupByQuery.ts deleted file mode 100644 index 6145456c..00000000 --- a/packages/pg-typed/src/v2/GroupByQuery.ts +++ /dev/null @@ -1,8 +0,0 @@ -import {ProjectedSortedQuery} from './types/Queries'; -import {AggregatedSelectionSet} from './types/SelectionSet'; - -export default interface GroupByQuery { - selectAggregate( - aggregation: (column: TColumns) => AggregatedSelectionSet, - ): ProjectedSortedQuery; -} diff --git a/packages/pg-typed/src/v2/InsertQuery.ts b/packages/pg-typed/src/v2/InsertQuery.ts deleted file mode 100644 index d933b680..00000000 --- a/packages/pg-typed/src/v2/InsertQuery.ts +++ /dev/null @@ -1 +0,0 @@ -export default interface InsertQuery {} diff --git a/packages/pg-typed/src/v2/ProjectedLimitQuery.ts b/packages/pg-typed/src/v2/ProjectedLimitQuery.ts deleted file mode 100644 index 43040c34..00000000 --- a/packages/pg-typed/src/v2/ProjectedLimitQuery.ts +++ /dev/null @@ -1,22 +0,0 @@ -import {SQLQuery} from '@databases/pg'; -import AliasedQuery from './AliasedQuery'; -import {TypedDatabaseQuery} from './types/TypedDatabaseQuery'; - -export default interface ProjectedLimitQuery - extends TypedDatabaseQuery { - /** - * Get the SQL query that would be executed. This is useful if you want to use this query as a sub-query in a query that is not type safe. - */ - toSql(): SQLQuery; - - /** - * If this is a complex query: - * Wrap the entire query in parentheses, and give it an alias. This lets you use joins, group by, etc. as sub-queries. - * - * If this is a simple query: - * Give the table an alias. This lets you use it in a join. - */ - as( - alias: TAliasTableName, - ): AliasedQuery; -} diff --git a/packages/pg-typed/src/v2/SelectQuery.ts b/packages/pg-typed/src/v2/SelectQuery.ts deleted file mode 100644 index c7cc1e4b..00000000 --- a/packages/pg-typed/src/v2/SelectQuery.ts +++ /dev/null @@ -1,29 +0,0 @@ -import {Columns} from './types/Columns'; -import {ProjectedQuery} from './types/Queries'; -import GroupByQuery from './GroupByQuery'; -import WhereCondition from './WhereCondition'; -import {AggregatedSelectionSet, SelectionSet} from './types/SelectionSet'; - -export default interface SelectQuery extends ProjectedQuery { - where(condition: WhereCondition>): this; - - select( - ...columnNames: TColumnNames - ): ProjectedQuery>; - select( - selection: (column: Columns) => SelectionSet, - ): ProjectedQuery; - - groupBy( - ...columnNames: TColumnNames - ): GroupByQuery, Columns>; - groupBy( - selection: (column: Columns) => SelectionSet, - ): GroupByQuery>; - - selectAggregate( - aggregation: ( - column: Columns, - ) => AggregatedSelectionSet, - ): ProjectedQuery; -} diff --git a/packages/pg-typed/src/v2/Table.ts b/packages/pg-typed/src/v2/Table.ts deleted file mode 100644 index 924caae0..00000000 --- a/packages/pg-typed/src/v2/Table.ts +++ /dev/null @@ -1,61 +0,0 @@ -import AliasedQuery from './AliasedQuery'; -import {Columns} from './types/Columns'; -import {ProjectedQuery} from './types/Queries'; -import GroupByQuery from './GroupByQuery'; -import InsertQuery from './InsertQuery'; -import SelectQuery, {WhereCondition, selectQuery} from './SelectQuery'; -import {AggregatedSelectionSet, SelectionSet} from './types/SelectionSet'; -import TableSchema from './TableSchema'; -import {TypedDatabaseQuery} from './types/TypedDatabaseQuery'; - -export interface Table - extends SelectQuery { - insert(...records: TInsertParameters[]): InsertQuery; -} - -class TableImplementation - implements Table -{ - private _table: TableSchema; - constructor(table: TableSchema) { - this._table = table; - } - - as( - alias: TAliasTableName, - ): AliasedQuery<{[TKey in TAliasTableName]: TRecord}> { - return selectQuery(this._table).as(alias); - } - - where(condition: WhereCondition): SelectQuery { - return selectQuery(this._table).where(condition); - } - - select( - ...columnNames: TColumnNames - ): ProjectedQuery>; - select( - selection: (column: Columns) => SelectionSet, - ): ProjectedQuery; - select(...args: any): any { - return selectQuery(this._table).select(...args); - } - - groupBy( - ...columnNames: TColumnNames - ): GroupByQuery, TRecord>; - groupBy( - selection: (column: Columns) => SelectionSet, - ): GroupByQuery; - groupBy(...args: any): any { - return selectQuery(this._table).groupBy(...args); - } - - selectAggregate( - aggregation: ( - column: Columns, - ) => AggregatedSelectionSet, - ): TypedDatabaseQuery { - return selectQuery(this._table).selectAggregate(aggregation); - } -} diff --git a/packages/pg-typed/src/v2/__tests__/insert.test.ts b/packages/pg-typed/src/v2/__tests__/insert.test.ts new file mode 100644 index 00000000..ce2c9d64 --- /dev/null +++ b/packages/pg-typed/src/v2/__tests__/insert.test.ts @@ -0,0 +1,182 @@ +import {SQLQuery, sql} from '@databases/pg'; +import {columns} from '../implementation/Columns'; +import {q} from '..'; +import {escapePostgresIdentifier} from '@databases/escape-identifier'; +import createTableApi from '../implementation/Table'; +import {TypedDatabaseQuery} from '../types/TypedDatabaseQuery'; + +interface DbUser { + id: number; + username: string; + profile_image_url: string | null; +} +interface DbPost { + author_id: number; + title: string; + created_at: Date; +} + +const users = createTableApi('users', sql`users`, columns(`users`)); +const posts = createTableApi('posts', sql`posts`, columns(`posts`)); + +const testFormat = { + escapeIdentifier: escapePostgresIdentifier, + formatValue: (value: unknown) => ({ + placeholder: '${ ' + JSON.stringify(value) + ' }', + value: undefined, + }), +}; + +test(`Basic Insert`, async () => { + const insertNothing = users.insert(); + const mock = {query: jest.fn()}; + await insertNothing.executeQuery(mock); + expect(mock.query).not.toBeCalled(); + expect(insertNothing.toSql()).toBe(null); + expect(await mockResult(insertNothing)).toBe(undefined); + + const insertOne = users.insert({ + id: 1, + username: 'test', + profile_image_url: null, + }); + + expect( + await mockResult( + insertOne, + `INSERT INTO users ("id","profile_image_url","username") VALUES (\${ 1 },\${ null },\${ "test" })`, + [], + ), + ).toBe(undefined); + + const insertReturningStar = insertOne.returning(`*`); + expect( + await mockResult( + insertReturningStar.one(), + `INSERT INTO users ("id","profile_image_url","username") VALUES (\${ 1 },\${ null },\${ "test" }) RETURNING *`, + [{id: 1, username: 'test', profile_image_url: null}], + ), + ).toEqual({id: 1, username: 'test', profile_image_url: null}); + + const insertReturningId = insertOne.returning(`id`); + expect( + await mockResult<{id: number}[]>( + insertReturningId, + `INSERT INTO users ("id","profile_image_url","username") VALUES (\${ 1 },\${ null },\${ "test" }) RETURNING "id"`, + [{id: 1}], + ), + ).toEqual([{id: 1}]); + + const insertReturningCount = insertOne.returningCount(); + expect( + await mockResult( + insertReturningCount, + `INSERT INTO users ("id","profile_image_url","username") VALUES (\${ 1 },\${ null },\${ "test" }) RETURNING (COUNT(*))::INT AS row_count`, + [{row_count: 1}], + ), + ).toBe(1); +}); + +test(`Max/Min Inserted ID`, async () => { + const insertAndSelectMaxAndMinId = users + .insert( + {id: 1, username: 'test1', profile_image_url: null}, + {id: 2, username: 'test2', profile_image_url: null}, + {id: 3, username: 'test3', profile_image_url: null}, + ) + .returning(`id`) + .selectAggregate((c) => ({min_id: q.min(c.id), max_id: q.max(c.id)})); + + const expectedQueryForId = `SELECT MIN("id") AS "min_id",MAX("id") AS "max_id" FROM (INSERT INTO users ("id","profile_image_url","username") VALUES (\${ 1 },\${ null },\${ "test1" }),(\${ 2 },\${ null },\${ "test2" }),(\${ 3 },\${ null },\${ "test3" }) RETURNING "id") AS "users"`; + expect( + await mockResult<{max_id: number}>( + insertAndSelectMaxAndMinId, + expectedQueryForId, + [{min_id: 1, max_id: 3}], + ), + ).toEqual({min_id: 1, max_id: 3}); + + insertAndSelectMaxAndMinId.as(`i`); + const insertAndSelectMaxAndMinRecord = users + .as(`u`) + .innerJoin(insertAndSelectMaxAndMinId.as(`i`)) + .on(({u, i}) => q.or(q.eq(u.id, i.min_id), q.eq(u.id, i.max_id))) + .select(({u}) => q.star(u)); + + expect( + await mockResult( + insertAndSelectMaxAndMinRecord, + `SELECT "u".* FROM users AS "u" INNER JOIN (${expectedQueryForId}) AS "i" ON ("u"."id"="i"."min_id" OR "u"."id"="i"."max_id")`, + [ + {id: 1, username: 'test1', profile_image_url: null}, + {id: 3, username: 'test3', profile_image_url: null}, + ], + ), + ).toEqual([ + {id: 1, username: 'test1', profile_image_url: null}, + {id: 3, username: 'test3', profile_image_url: null}, + ]); +}); + +test(`INNER JOIN`, async () => { + const insertAsRightOfJoin = users + .as(`u`) + .innerJoin( + posts + .insert({author_id: 1, title: 'test', created_at: new Date(0)}) + .returning(`*`) + .as(`p`), + ) + .on(({u, p}) => q.eq(u.id, p.author_id)) + .select(({u, p}) => ({ + username: u.username, + title: p.title, + })); + + expect( + await mockResult< + { + username: string; + title: string; + }[] + >( + insertAsRightOfJoin, + `SELECT "u"."username","p"."title" FROM users AS "u" INNER JOIN (INSERT INTO posts ("author_id","created_at","title") VALUES (\${ 1 },\${ "1970-01-01T00:00:00.000Z" },\${ "test" }) RETURNING *) AS "p" ON ("u"."id"="p"."author_id")`, + [{username: `ForbesLindesay`, title: `test`}], + ), + ).toEqual([{username: `ForbesLindesay`, title: `test`}]); +}); + +// function printQueryForTest( +// query: TypedDatabaseQuery & {toSql(): SQLQuery | null}, +// ) { +// const q = query.toSql(); +// if (q === null) return null; +// return q.format(testFormat).text; +// } +async function mockResult( + query: TypedDatabaseQuery, + expectedQuery?: string, + results?: any[], +): Promise { + if ((expectedQuery === undefined) !== (results === undefined)) { + throw new Error( + `Mock results should have either an expected query and results, or neither.`, + ); + } + let called = false; + const result = await query.executeQuery({ + query: async (q: SQLQuery) => { + if (expectedQuery === undefined || results === undefined) { + throw new Error(`Did not expect query to be called`); + } + called = true; + expect(q.format(testFormat).text).toEqual(expectedQuery); + return results; + }, + }); + if (expectedQuery) { + expect(called).toBe(true); + } + return result; +} diff --git a/packages/pg-typed/src/v2/__tests__/index.test.ts b/packages/pg-typed/src/v2/__tests__/queries.test.ts similarity index 76% rename from packages/pg-typed/src/v2/__tests__/index.test.ts rename to packages/pg-typed/src/v2/__tests__/queries.test.ts index e19f0b7e..3ae6437f 100644 --- a/packages/pg-typed/src/v2/__tests__/index.test.ts +++ b/packages/pg-typed/src/v2/__tests__/queries.test.ts @@ -1,9 +1,9 @@ -import {sql} from '@databases/pg'; +import {SQLQuery, sql} from '@databases/pg'; import {columns} from '../implementation/Columns'; -import createQuery from '../QueryImplementation'; import {q} from '..'; import {escapePostgresIdentifier} from '@databases/escape-identifier'; -import {ProjectedLimitQuery} from '../types/Queries'; +import createTableApi from '../implementation/Table'; +import {TypedDatabaseQuery} from '../types/TypedDatabaseQuery'; interface DbUser { id: number; @@ -16,8 +16,8 @@ interface DbPost { created_at: Date; } -const users = createQuery('users', sql`users`, columns(`users`)); -const posts = createQuery('posts', sql`posts`, columns(`posts`)); +const users = createTableApi('users', sql`users`, columns(`users`)); +const posts = createTableApi('posts', sql`posts`, columns(`posts`)); const testFormat = { escapeIdentifier: escapePostgresIdentifier, @@ -49,26 +49,44 @@ test(`INNER JOIN`, () => { title: p.title, })); expect( - printQueryForTest<{ - id: number; - username: string; - title: string; - }>(joinWithWhereBeforeJoin), + printQueryForTest< + { + id: number; + username: string; + title: string; + }[] + >(joinWithWhereBeforeJoin), ).toEqual( `SELECT "u"."id","u"."username","p"."title" FROM users AS "u" INNER JOIN posts AS "p" ON ("u"."id"="p"."author_id") WHERE "u"."id"=\${ 10 }`, ); expect( - printQueryForTest<{ - id: number; - username: string; - title: string; - }>(joinWithWhereAfterJoin), + printQueryForTest< + { + id: number; + username: string; + title: string; + }[] + >(joinWithWhereAfterJoin), ).toEqual( - printQueryForTest<{ - id: number; - username: string; - title: string; - }>(joinWithWhereBeforeJoin), + printQueryForTest< + { + id: number; + username: string; + title: string; + }[] + >(joinWithWhereBeforeJoin), + ); + + const postsWithUsernames = posts + .as(`p`) + .innerJoin(users.as(`u`)) + .on(({u, p}) => q.eq(u.id, p.author_id)) + .select(({u, p}) => q.mergeColumns(q.star(p), {username: u.username})); + + expect( + printQueryForTest<(DbPost & {username: string})[]>(postsWithUsernames), + ).toEqual( + `SELECT "p".*,"u"."username" FROM posts AS "p" INNER JOIN users AS "u" ON ("u"."id"="p"."author_id")`, ); }); @@ -88,12 +106,14 @@ test(`group by`, () => { .orderByDesc(`last_posted_at`); expect( - printQueryForTest<{ - id: number; - username: string; - last_posted_at: Date; - total_count: number; - }>(groupBy), + printQueryForTest< + { + id: number; + username: string; + last_posted_at: Date; + total_count: number; + }[] + >(groupBy), ).toEqual( `SELECT "u"."id","u"."username",MAX("p"."created_at") AS "last_posted_at",(COUNT(*))::INT AS "total_count" FROM users AS "u" INNER JOIN posts AS "p" ON ("u"."id"="p"."author_id") GROUP BY 1,2 ORDER BY 3 DESC`, ); @@ -107,7 +127,7 @@ test(`arbitrary SQL conditions`, () => { q.and(q.gt(u.id, 1), q.lt(u.id, 20), sql`id % 2 = 0 OR id % 3 = 0`), ); - expect(printQueryForTest(conditions)).toEqual( + expect(printQueryForTest(conditions)).toEqual( `SELECT * FROM users WHERE "id"=\${ 10 } AND (username='x' OR username='y') AND "id">\${ 1 } AND "id"<\${ 20 } AND (id % 2 = 0 OR id % 3 = 0)`, ); }); @@ -128,18 +148,18 @@ test(`condition nesting`, () => { ), ); - expect(printQueryForTest(or)).toEqual( + expect(printQueryForTest(or)).toEqual( `SELECT * FROM users WHERE "id"=\${ 1 } OR "id"=\${ 10 }`, ); - expect(printQueryForTest(and)).toEqual( + expect(printQueryForTest(and)).toEqual( `SELECT * FROM users WHERE "id">=\${ 1 } AND "id"<=\${ 10 }`, ); - expect(printQueryForTest(and_or)).toEqual( + expect(printQueryForTest(and_or)).toEqual( `SELECT * FROM users WHERE ("id"=\${ 1 } OR "id"=\${ 10 }) AND "profile_image_url" IS NOT NULL`, ); // AND already binds more tightly than OR so no extra parentheses are needed here - although extra parentheses // would potentially make it easier for humans to read. - expect(printQueryForTest(or_and)).toEqual( + expect(printQueryForTest(or_and)).toEqual( `SELECT * FROM users WHERE "id"=\${ 1 } AND "username"=\${ "ForbesLindesay" } OR "id"=\${ 10 }`, ); }); @@ -154,18 +174,18 @@ test(`condition nesting - objects`, () => { q.or(q.and({id: 1}, {username: 'ForbesLindesay'}), {id: 10}), ); - expect(printQueryForTest(or)).toEqual( + expect(printQueryForTest(or)).toEqual( `SELECT * FROM users WHERE "id"=\${ 1 } OR "id"=\${ 10 }`, ); - expect(printQueryForTest(and)).toEqual( + expect(printQueryForTest(and)).toEqual( `SELECT * FROM users WHERE "id">=\${ 1 } AND "id"<=\${ 10 }`, ); - expect(printQueryForTest(and_or)).toEqual( + expect(printQueryForTest(and_or)).toEqual( `SELECT * FROM users WHERE ("id"=\${ 1 } OR "id"=\${ 10 }) AND "profile_image_url" IS NOT NULL`, ); // AND already binds more tightly than OR so no extra parentheses are needed here - although extra parentheses // would potentially make it easier for humans to read. - expect(printQueryForTest(or_and)).toEqual( + expect(printQueryForTest(or_and)).toEqual( `SELECT * FROM users WHERE "id"=\${ 1 } AND "username"=\${ "ForbesLindesay" } OR "id"=\${ 10 }`, ); }); @@ -176,10 +196,12 @@ test(`operators - lower`, () => { greeting: q.lower('HELLO WORLD'), })); expect( - printQueryForTest<{ - username: string; - greeting: string; - }>(lowerUsernames), + printQueryForTest< + { + username: string; + greeting: string; + }[] + >(lowerUsernames), ).toEqual( `SELECT LOWER("username") AS "username",LOWER(\${ "HELLO WORLD" }) AS "greeting" FROM users`, ); @@ -216,7 +238,7 @@ test(`operators - json`, () => { id: number; data: {foo: string; bar: number}; } - const records = createQuery( + const records = createTableApi( `records`, sql`records`, columns(`records`, [ @@ -236,10 +258,12 @@ test(`operators - json`, () => { bar: q.json(r.data).prop(`bar`).asJson(), })); expect( - printQueryForTest<{ - foo: string; - bar: number; - }>(query), + printQueryForTest< + { + foo: string; + bar: number; + }[] + >(query), ).toEqual( `SELECT "data"#>>\${ ["foo"] } AS "foo","data"#>\${ ["bar"] } AS "bar" FROM records WHERE "data"#>\${ ["foo"] }=\${ "\\"hello\\"" } OR "data"#>>\${ ["foo"] }=\${ "world" }`, ); @@ -270,6 +294,9 @@ test(`operators - json`, () => { `SELECT MAX(LOWER("username")) AS "username",LOWER(\${ "HELLO WORLD" }) AS "greeting" FROM users`, ); }); -function printQueryForTest(query: ProjectedLimitQuery) { + +function printQueryForTest( + query: TypedDatabaseQuery & {toSql(): SQLQuery}, +) { return query.toSql().format(testFormat).text; } diff --git a/packages/pg-typed/src/v2/implementation/Columns.ts b/packages/pg-typed/src/v2/implementation/Columns.ts index 62a02936..6eee30d3 100644 --- a/packages/pg-typed/src/v2/implementation/Columns.ts +++ b/packages/pg-typed/src/v2/implementation/Columns.ts @@ -1,8 +1,19 @@ import {columnReference} from './Operators'; -import {Columns} from '../types/Columns'; +import {ColumnReference, Columns} from '../types/Columns'; const IS_PROXIED = Symbol('IS_PROXIED'); +function baseToValue(value: unknown): unknown { + return value; +} +function jsonToValue(value: unknown): unknown { + return JSON.stringify(value); +} +function jsonArrayToValue(value: unknown): unknown { + if (!Array.isArray(value)) return value; + return value.map((v) => JSON.stringify(v)); +} + export function columns( tableName: string, schema?: { @@ -12,11 +23,24 @@ export function columns( isAlias: boolean = false, ): Columns { if (schema) { - return Object.fromEntries( - schema.map(({columnName, type}) => [ - columnName, - columnReference(tableName, columnName, isAlias, type ?? null), - ]), + return Object.assign( + {__isSpecialValue: true, __tableName: tableName}, + Object.fromEntries( + schema.map(({columnName, type}) => [ + columnName, + columnReference( + tableName, + columnName, + isAlias, + type ?? null, + type === `JSON` || type === `JSONB` + ? jsonToValue + : type === `JSON[]` || type === `JSONB[]` + ? jsonArrayToValue + : baseToValue, + ), + ]), + ), ) as Columns; } else { return new Proxy( @@ -24,10 +48,18 @@ export function columns( { get: (_target, columnName, _receiver) => { if (columnName === IS_PROXIED) return true; + if (columnName === `__isSpecialValue`) return true; + if (columnName === `__tableName`) return tableName; if (columnName === 'then' || typeof columnName !== 'string') { return undefined; } - return columnReference(tableName, columnName, isAlias, null); + return columnReference( + tableName, + columnName, + isAlias, + null, + baseToValue, + ); }, }, ) as any; @@ -44,12 +76,12 @@ export function aliasColumns( cachedAlias = new WeakMap(); cache.set(tableAlias, cachedAlias); } - const cached = cachedAlias.get(columns); - if (cached) return cached; + const cached = cachedAlias.get(columns as any); + if (cached) return cached as any; const aliasedColumns = (columns as any)[IS_PROXIED] ? aliasColumnsByProxy(tableAlias, columns) : aliasColumnsWithPlainObject(tableAlias, columns); - cachedAlias.set(columns, aliasedColumns); + cachedAlias.set(columns as any, aliasedColumns as any); return aliasedColumns; } @@ -62,9 +94,11 @@ function aliasColumnsByProxy( { get: (_target, columnName, _receiver) => { if (columnName === IS_PROXIED) return true; + if (columnName === `__isSpecialValue`) return true; + if (columnName === `__tableName`) return tableAlias; const column = columns[columnName as keyof typeof columns]; if (column === undefined) return column; - return column.setAlias(tableAlias); + return (column as ColumnReference).setAlias(tableAlias); }, }, ) as Columns; @@ -74,10 +108,15 @@ function aliasColumnsWithPlainObject( tableAlias: string, columns: Columns, ): Columns { - return Object.fromEntries( - Object.entries(columns).map(([columnName, column]) => [ - columnName, - (column as Columns[keyof Columns]).setAlias(tableAlias), - ]), + return Object.assign( + {__isSpecialValue: true, __tableName: tableAlias}, + Object.fromEntries( + Object.entries(columns) + .filter(([n]) => n !== '__isSpecialValue' && n !== '__tableName') + .map(([columnName, column]) => [ + columnName, + (column as any as ColumnReference).setAlias(tableAlias), + ]), + ), ) as Columns; } diff --git a/packages/pg-typed/src/v2/PostgresOperators.ts b/packages/pg-typed/src/v2/implementation/OperatorDefinitions.ts similarity index 98% rename from packages/pg-typed/src/v2/PostgresOperators.ts rename to packages/pg-typed/src/v2/implementation/OperatorDefinitions.ts index 7ec6725a..fb26f1f1 100644 --- a/packages/pg-typed/src/v2/PostgresOperators.ts +++ b/packages/pg-typed/src/v2/implementation/OperatorDefinitions.ts @@ -13,7 +13,6 @@ export interface OperatorDefinition { ctx: {parentOperatorPrecedence: number | null}, ) => SQLQuery; readonly precedence: number; - // readonly staticValue?: (input: TStaticValueInput) => boolean | null; } function operatorDefinition( diff --git a/packages/pg-typed/src/v2/implementation/Operators.ts b/packages/pg-typed/src/v2/implementation/Operators.ts index d352c3da..1e396a50 100644 --- a/packages/pg-typed/src/v2/implementation/Operators.ts +++ b/packages/pg-typed/src/v2/implementation/Operators.ts @@ -19,7 +19,7 @@ import { BinaryInput, OperatorDefinition, OperatorDefinitions, -} from '../PostgresOperators'; +} from './OperatorDefinitions'; import { AggregatedJsonValue, IOperators, @@ -27,19 +27,26 @@ import { NonAggregatedJsonValue, } from '../types/Operators'; import {ColumnReference} from '../types/Columns'; -import WhereCondition from '../WhereCondition'; +import WhereCondition from '../types/WhereCondition'; +import { + SelectionSet, + SelectionSetMerged, + SelectionSetStar, +} from '../types/SelectionSet'; export function columnReference( tableName: string, columnName: string, isAlias: boolean, sqlType: string | null, + serializeValue: (value: T) => unknown, ): NonAggregatedValue { return new ColumnReferenceImplementation( tableName, columnName, isAlias, sqlType, + serializeValue, ); } @@ -154,17 +161,20 @@ class ColumnReferenceImplementation // This property is assigned in the base class public readonly sqlType!: string | null; + public readonly serializeValue: (value: T) => unknown; constructor( tableName: string, columnName: string, isAlias: boolean, sqlType: string | null, + serializeValue: (value: T) => unknown, ) { super(sqlType); this.tableName = tableName; this.columnName = columnName; this.isAlias = isAlias; + this.serializeValue = serializeValue; } public toSql(ctx: ValueToSqlContext): SQLQuery { if (this.isAlias) { @@ -182,6 +192,7 @@ class ColumnReferenceImplementation this.columnName, true, this.sqlType, + this.serializeValue, ); } } @@ -975,6 +986,38 @@ function setSqlType( return ctx; } +class SelectionSetStarImplementation + implements SelectionSetStar +{ + public readonly __isSpecialValue = true; + public readonly __selectionSetType = 'STAR'; + public readonly tableName: string; + constructor(tableName: string) { + this.tableName = tableName; + } + public __getType(): TSelection { + throw new Error( + `The "getType" function should not be called. It is only there to help TypeScript infer the correct type.`, + ); + } +} + +class SelectionSetMergedImplementation + implements SelectionSetMerged +{ + public readonly __isSpecialValue = true; + public readonly __selectionSetType = 'MERGED'; + public readonly selections: SelectionSet>[]; + constructor(selections: SelectionSet>[]) { + this.selections = selections; + } + public __getType(): TSelection { + throw new Error( + `The "getType" function should not be called. It is only there to help TypeScript infer the correct type.`, + ); + } +} + const Operators: IOperators = { allOf: (values) => new AllOf(values), and: booleanOperator(`AND`, { @@ -1027,6 +1070,8 @@ const Operators: IOperators = { lt: binaryOperator(OperatorDefinitions.LT), lte: binaryOperator(OperatorDefinitions.LTE), max: aggregateFunction(`MAX`), + mergeColumns: (...selections) => + new SelectionSetMergedImplementation(selections), min: aggregateFunction(`MIN`), neq: binaryOperator(OperatorDefinitions.NEQ, { getUnaryOperator(right) { @@ -1083,6 +1128,7 @@ const Operators: IOperators = { return undefined; }, }), + star: (columns) => new SelectionSetStarImplementation(columns.__tableName), sum: aggregateFunction(`SUM`), upper: nonAggregateFunction(`UPPER`), }; diff --git a/packages/pg-typed/src/v2/QueryImplementation.ts b/packages/pg-typed/src/v2/implementation/Queries.ts similarity index 67% rename from packages/pg-typed/src/v2/QueryImplementation.ts rename to packages/pg-typed/src/v2/implementation/Queries.ts index 701084ce..cec5e79a 100644 --- a/packages/pg-typed/src/v2/QueryImplementation.ts +++ b/packages/pg-typed/src/v2/implementation/Queries.ts @@ -1,63 +1,105 @@ -import {Queryable, SQLQuery, sql} from '@databases/pg'; -import {aliasColumns, columns} from './implementation/Columns'; -import WhereCondition from './WhereCondition'; +import {SQLQuery, sql} from '@databases/pg'; +import {escapePostgresIdentifier} from '@databases/escape-identifier'; + +import {aliasColumns, columns} from './Columns'; +import WhereCondition from '../types/WhereCondition'; import { FieldCondition, NonAggregatedValue, isSpecialValue, -} from './types/SpecialValues'; -import {Columns} from './types/Columns'; +} from '../types/SpecialValues'; +import { + ColumnReference, + Columns, + InnerJoinedColumns, + LeftOuterJoinedColumns, +} from '../types/Columns'; import Operators, { aliasTableInValue, fieldConditionToPredicateValue, valueToSelect, valueToSql, -} from './implementation/Operators'; -import {AggregatedSelectionSet, SelectionSet} from './types/SelectionSet'; -import AliasedQuery from './AliasedQuery'; -import {TypedDatabaseQuery} from './types/TypedDatabaseQuery'; -import {escapePostgresIdentifier} from '@databases/escape-identifier'; - -import SelectQuery from './SelectQuery'; -import GroupByQuery from './GroupByQuery'; +} from './Operators'; +import { + AggregatedSelectionSet, + SelectionSet, + SelectionSetObject, +} from '../types/SelectionSet'; +import {Queryable, TypedDatabaseQuery} from '../types/TypedDatabaseQuery'; import { + AggregatedQuery, + AliasedQuery, + GroupByQuery, + JoinableQueryLeft, + JoinableQueryRight, + JoinQuery, + JoinQueryBuilder, ProjectedDistinctColumnsQuery, ProjectedDistinctQuery, ProjectedLimitQuery, - ProjectedSortedQuery, ProjectedQuery, -} from './types/Queries'; -import {JoinQueryBuilder, JoinQuery} from './types/Join'; -import { - InnerJoinedColumns, - JoinableQueryLeft, - JoinableQueryRight, - LeftOuterJoinedColumns, -} from './types/JoinableQuery'; + ProjectedSortedQuery, + SelectQuery, +} from '../types/Queries'; +import TableSchema from '../types/TableSchema'; const NO_RESULT_FOUND = `NO_RESULT_FOUND`; const MULTIPLE_RESULTS_FOUND = `MULTIPLE_RESULTS_FOUND`; export default function createQuery( - tableName: string, - tableId: SQLQuery, - columns: Columns, + table: TableSchema, ): SelectQuery { return new SelectQueryImplementation({ - columns, + columns: table.columns, distinct: false, distinctColumns: [], groupBy: 0, + hasSideEffects: false, isAliased: false, + isEmpty: false, isJoin: false, limit: null, orderBy: [], projection: null, - tableId, - tableName, + tableId: table.tableId, + tableName: table.tableName, where: [], }); } +export function createStatementReturn( + table: TableSchema, + query: SQLQuery | null, + selection: + | readonly string[] + | readonly [(columns: Columns) => SelectionSetObject], +): SelectQuery { + return new StatementReturning( + table as any, + query, + typeof selection[0] === 'function' + ? selectionSetToProjection(selection[0](table.columns)) + : selection[0] === '*' + ? null + : columnNamesToProjection(selection as string[]), + ); +} + +interface Projection { + /** + * The SQL for the selection set. + * + * e.g. u.name AS user_name, u.email AS user_email, COUNT(*) AS post_count + */ + readonly query: SQLQuery; + /** + * The column names (in order) returned by this projection. + * + * e.g. ['user_name', 'user_email', 'post_count'] + * + * Star is used if the number of columns is unknown + */ + readonly columnNames: readonly string[]; +} interface CompleteQuery< TAlias extends string, @@ -73,19 +115,19 @@ interface CompleteQuery< ...columnNames: TColumnNames ): ProjectedQuery>; select( - selection: (column: TColumns) => SelectionSet, + selection: (column: TColumns) => SelectionSetObject, ): ProjectedQuery; groupBy( ...columnNames: TColumnNames ): GroupByQuery, TColumns>; groupBy( - selection: (column: TColumns) => SelectionSet, + selection: (column: TColumns) => SelectionSetObject, ): GroupByQuery; selectAggregate( aggregation: (column: TColumns) => AggregatedSelectionSet, - ): ProjectedQuery; + ): AggregatedQuery; } interface QueryConfig { @@ -94,7 +136,9 @@ interface QueryConfig { distinctColumns: readonly SQLQuery[]; groupBy: number; isAliased: boolean; + isEmpty: boolean; isJoin: boolean; + hasSideEffects: boolean; limit: number | null; orderBy: readonly SQLQuery[]; projection: Projection | null; @@ -232,7 +276,7 @@ class SelectQueryImplementation< return { query: sql.join(parts, sql` `), - isEmpty: whereCondition === false, + isEmpty: whereCondition === false && !this._config.hasSideEffects, }; } @@ -256,7 +300,9 @@ class SelectQueryImplementation< distinct: this._config.distinct, distinctColumns: this._config.distinctColumns, groupBy: this._config.groupBy, + hasSideEffects: this._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: this._config.isEmpty, isJoin: this._config.isJoin, limit: this._config.limit, orderBy: this._config.orderBy, @@ -278,22 +324,20 @@ class SelectQueryImplementation< `Table aliases must start with a lower case letter and only contain letters, numbers and underscores`, ); } - if (this._config.isAliased) { - throw new Error(`Cannot alias a query that has already been aliased`); - } const { columns, distinct, distinctColumns, groupBy, + isAliased, + isJoin, limit, + orderBy, projection, tableId, tableName, where, - orderBy, - isJoin, } = this._config; const aliasedColumns = aliasColumns(alias, columns as Columns); @@ -305,14 +349,17 @@ class SelectQueryImplementation< isJoin || limit || orderBy.length || - projection + projection || + isAliased ) { return new SelectQueryImplementation({ columns: aliasedColumns, distinct: false, distinctColumns: [], groupBy: 0, + hasSideEffects: this._config.hasSideEffects, isAliased: true, + isEmpty: this._config.isEmpty || Operators.and(...where) === false, isJoin: false, limit: null, orderBy: [], @@ -328,7 +375,9 @@ class SelectQueryImplementation< distinct: false, distinctColumns: [], groupBy: 0, + hasSideEffects: this._config.hasSideEffects, isAliased: true, + isEmpty: this._config.isEmpty, isJoin: false, limit: null, orderBy: [], @@ -352,7 +401,7 @@ class SelectQueryImplementation< fieldConditionToPredicateValue( (this._config.columns as Columns)[ columnName as keyof Columns - ], + ] as ColumnReference, value as FieldCondition, ), )), @@ -362,7 +411,9 @@ class SelectQueryImplementation< distinct: this._config.distinct, distinctColumns: this._config.distinctColumns, groupBy: this._config.groupBy, + hasSideEffects: this._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: this._config.isEmpty, isJoin: this._config.isJoin, limit: this._config.limit, orderBy: this._config.orderBy, @@ -376,7 +427,7 @@ class SelectQueryImplementation< select(...selection: any[]): ProjectedQuery { return this._projectedQuery( selection.length === 1 && typeof selection[0] === 'function' - ? selectionSetToProjection( + ? selectionSetSpecialToProjection( selection[0](this._config.columns) as SelectionSet, ) : columnNamesToProjection(selection), @@ -385,9 +436,11 @@ class SelectQueryImplementation< selectAggregate( aggregation: (column: TColumns) => AggregatedSelectionSet, - ): ProjectedQuery { - return this._projectedQuery( - selectionSetToProjection(aggregation(this._config.columns)), + ): AggregatedQuery { + return new AggregatedQueryImplementation( + this._projectedQuery( + selectionSetToProjection(aggregation(this._config.columns)), + ), ); } @@ -395,7 +448,9 @@ class SelectQueryImplementation< return new GroupByQueryImplementation( selection.length === 1 && typeof selection[0] === 'function' ? selectionSetToProjection( - selection[0](this._config.columns) as SelectionSet, + selection[0]( + this._config.columns, + ) as SelectionSetObject, ) : columnNamesToProjection(selection), this._config, @@ -403,23 +458,36 @@ class SelectQueryImplementation< } private _orderByColumn(columnName: keyof TRecord): SQLQuery { - if (this._config.projection) { + if ( + this._config.projection && + !this._config.projection.columnNames.includes(`*`) + ) { const index = this._config.projection.columnNames.indexOf( columnName as string, ); - if (index === -1) { - throw new Error(`Cannot find column: "${columnName as string}"`); + if (index !== -1) { + return sql.__dangerous__rawValue((index + 1).toString(10)); } - return sql.__dangerous__rawValue((index + 1).toString(10)); - } else { - return sql.ident(columnName); } + return sql.ident(columnName); } private _orderByInternal( columnName: keyof TRecord, distinct: boolean, direction: SQLQuery, ): ProjectedDistinctColumnsQuery { + if (distinct) { + if (this._config.distinct) { + throw new Error( + `Cannot call orderByAscDistinct() or orderByDescDistinct() after distinct()`, + ); + } + if (this._config.orderBy.length > this._config.distinctColumns.length) { + throw new Error( + `Cannot call orderByAscDistinct() or orderByDescDistinct() after orderByAsc() or orderByDesc()`, + ); + } + } const distinctColumns = distinct ? [...this._config.distinctColumns, sql.ident(columnName)] : this._config.distinctColumns; @@ -432,7 +500,9 @@ class SelectQueryImplementation< distinct: this._config.distinct, distinctColumns, groupBy: this._config.groupBy, + hasSideEffects: this._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: this._config.isEmpty, isJoin: this._config.isJoin, limit: this._config.limit, orderBy, @@ -471,7 +541,9 @@ class SelectQueryImplementation< distinct: true, distinctColumns: [], groupBy: this._config.groupBy, + hasSideEffects: this._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: this._config.isEmpty, isJoin: this._config.isJoin, limit: this._config.limit, orderBy: this._config.orderBy, @@ -487,7 +559,9 @@ class SelectQueryImplementation< distinct: this._config.distinct, distinctColumns: this._config.distinctColumns, groupBy: this._config.groupBy, + hasSideEffects: this._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: this._config.isEmpty, isJoin: this._config.isJoin, limit: n, orderBy: this._config.orderBy, @@ -525,7 +599,13 @@ class SelectQueryImplementation< distinct: this._config.distinct, distinctColumns: this._config.distinctColumns, groupBy: this._config.groupBy, + hasSideEffects: + this._config.hasSideEffects || otherQuery._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: + this._config.isEmpty || + otherQuery._config.isEmpty || + Operators.and(...otherQuery._config.where) === false, isJoin: this._config.isJoin, limit: this._config.limit, orderBy: this._config.orderBy, @@ -565,7 +645,10 @@ class SelectQueryImplementation< distinct: this._config.distinct, distinctColumns: this._config.distinctColumns, groupBy: this._config.groupBy, + hasSideEffects: + this._config.hasSideEffects || otherQuery._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: this._config.isEmpty, isJoin: this._config.isJoin, limit: this._config.limit, orderBy: this._config.orderBy, @@ -648,7 +731,9 @@ class GroupByQueryImplementation distinct: false, distinctColumns: [], groupBy: groupByProjection.columnNames.length, + hasSideEffects: this._config.hasSideEffects, isAliased: false, + isEmpty: this._config.isEmpty, isJoin: false, limit: this._config.limit, orderBy: [], @@ -686,7 +771,9 @@ class JoinImplementation implements JoinQueryBuilder { distinct: this._config.distinct, distinctColumns: this._config.distinctColumns, groupBy: this._config.groupBy, + hasSideEffects: this._config.hasSideEffects, isAliased: this._config.isAliased, + isEmpty: this._config.isEmpty, isJoin: this._config.isJoin, limit: this._config.limit, orderBy: this._config.orderBy, @@ -698,25 +785,203 @@ class JoinImplementation implements JoinQueryBuilder { } } -export interface Projection { - /** - * The SQL for the selection set. - * - * e.g. u.name AS user_name, u.email AS user_email, COUNT(*) AS post_count - */ - readonly query: SQLQuery; - /** - * The column names (in order) returned by this projection. - * - * e.g. ['user_name', 'user_email', 'post_count'] - */ - readonly columnNames: readonly string[]; +class StatementReturning implements SelectQuery { + private readonly _table: TableSchema; + private readonly _statement: SQLQuery | null; + private readonly _projection: Projection | null; + constructor( + table: TableSchema, + statement: SQLQuery | null, + projection: Projection | null, + ) { + this._table = table; + this._statement = statement; + this._projection = projection; + } + + private _query(): FinalQueryConfig { + const selection = this._projection?.query ?? sql`*`; + + return { + query: this._statement + ? sql`${this._statement} RETURNING ${selection}` + : sql`SELECT ${selection} WHERE FALSE`, + isEmpty: !this._statement, + }; + } + + toSql(): SQLQuery { + const {query} = this._query(); + return query; + } + + private _getQuery(): SelectQuery { + const {query, isEmpty} = this._query(); + return new SelectQueryImplementation({ + columns: this._table.columns, + distinct: false, + distinctColumns: [], + groupBy: 0, + hasSideEffects: !isEmpty, + isAliased: false, + isEmpty, + isJoin: false, + limit: null, + orderBy: [], + projection: null, + tableId: sql`(${query}) AS ${sql.ident(this._table.tableName)}`, + tableName: this._table.tableName, + where: [], + }); + } + as( + alias: TAliasTableName, + ): AliasedQuery { + if (!/^[a-z][a-z0-9_]*$/.test(alias)) { + throw new Error( + `Table aliases must start with a lower case letter and only contain letters, numbers and underscores`, + ); + } + const aliasedColumns = aliasColumns(alias, this._table.columns); + const {query, isEmpty} = this._query(); + return new SelectQueryImplementation({ + columns: aliasedColumns, + distinct: false, + distinctColumns: [], + groupBy: 0, + hasSideEffects: !isEmpty, + isAliased: true, + isEmpty, + isJoin: false, + limit: null, + orderBy: [], + projection: null, + tableId: sql`(${query}) AS ${sql.ident(alias)}`, + tableName: alias, + where: [], + }); + } + + where(condition: WhereCondition): SelectQuery { + return this._getQuery().where(condition); + } + select(...selection: any[]): ProjectedQuery { + return this._getQuery().select(...selection) as any; + } + selectAggregate( + aggregation: ( + column: Columns, + ) => AggregatedSelectionSet, + ): AggregatedQuery { + return this._getQuery().selectAggregate(aggregation); + } + groupBy( + ...selection: any[] + ): GroupByQuery> { + return this._getQuery().groupBy(...selection); + } + orderByAscDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery { + return this._getQuery().orderByAscDistinct(columnName); + } + orderByDescDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery { + return this._getQuery().orderByDescDistinct(columnName); + } + orderByAsc(columnName: keyof TRecord): ProjectedSortedQuery { + return this._getQuery().orderByAsc(columnName); + } + orderByDesc(columnName: keyof TRecord): ProjectedSortedQuery { + return this._getQuery().orderByDesc(columnName); + } + distinct(): ProjectedDistinctQuery { + return this._getQuery().distinct(); + } + limit(n: number): ProjectedLimitQuery { + return this._getQuery().limit(n); + } + + one(): TypedDatabaseQuery { + return new OneQuery(this._query()); + } + + oneRequired(): TypedDatabaseQuery { + return new OneRequiredQuery(this._query()); + } + + first(): TypedDatabaseQuery { + return new FirstQuery(this._query()); + } + + async executeQuery(database: Queryable): Promise { + const {query, isEmpty} = this._query(); + if (isEmpty) return []; + return await database.query(query); + } +} + +class AggregatedQueryImplementation + implements AggregatedQuery +{ + private readonly _query: ProjectedLimitQuery; + constructor(query: ProjectedLimitQuery) { + this._query = query; + } + toSql(): SQLQuery { + return this._query.toSql(); + } + as( + alias: TAliasTableName, + ): AliasedQuery { + return this._query.as(alias); + } + + async executeQuery(database: Queryable): Promise { + const results = await database.query(this._query.toSql()); + if (results.length !== 1) { + throw new Error( + `Expected exactly one row to be returned by this query because it is "aggregated"`, + ); + } + return results[0]; + } +} + +function selectionSetSpecialToProjection( + selection: SelectionSet, +): Projection { + if (isSpecialValue(selection)) { + switch (selection.__selectionSetType) { + case 'STAR': + return { + query: sql`${sql.ident(selection.tableName)}.*`, + columnNames: [`*`], + }; + case 'MERGED': { + const parts: SQLQuery[] = []; + const columnNames = []; + for (const part of selection.selections) { + const projection = selectionSetSpecialToProjection(part); + if (projection.columnNames.length) { + parts.push(projection.query); + columnNames.push(...projection.columnNames); + } + } + return {query: sql.join(parts, `,`), columnNames}; + } + } + } + return selectionSetToProjection(selection); } -function selectionSetToProjection( - ...selections: (SelectionSet | AggregatedSelectionSet)[] +function selectionSetToProjection( + selection: + | SelectionSetObject + | AggregatedSelectionSet, ): Projection { - const entries = selections.flatMap((selection) => Object.entries(selection)); + const entries = Object.entries(selection); return { query: sql.join( entries.map(([alias, value]) => valueToSelect(alias, value)), diff --git a/packages/pg-typed/src/v2/implementation/Statements.ts b/packages/pg-typed/src/v2/implementation/Statements.ts new file mode 100644 index 00000000..ad99414d --- /dev/null +++ b/packages/pg-typed/src/v2/implementation/Statements.ts @@ -0,0 +1,157 @@ +import {SQLQuery, sql} from '@databases/pg'; +import { + BaseStatement, + InsertStatement, + InsertStatementOnConflictBuilder, + StatementCount, + UpdateStatement, +} from '../types/Statements'; +import {Queryable} from '../types/TypedDatabaseQuery'; +import TableSchema from '../types/TableSchema'; +import {ColumnReference, Columns} from '../types/Columns'; +import {SelectionSetObject} from '../types/SelectionSet'; +import {valueToSql} from './Operators'; +import {SelectQuery} from '../types/Queries'; +import {createStatementReturn} from './Queries'; + +interface AnyStatement + extends InsertStatementOnConflictBuilder, + InsertStatement, + UpdateStatement {} + +class StatementImplementation implements AnyStatement { + private readonly _table: TableSchema; + private readonly _statement: SQLQuery | null; + constructor(table: TableSchema, statement: SQLQuery | null) { + this._table = table; + this._statement = statement; + } + + returningCount(): StatementCount { + return new StatementCountImplementation(this._statement); + } + returning(star: '*'): SelectQuery; + returning( + ...columnNames: TColumnNames + ): SelectQuery>; + returning( + selection: (column: Columns) => SelectionSetObject, + ): SelectQuery; + returning(...selection: any[]): SelectQuery { + return createStatementReturn( + this._table, + this._statement, + selection, + ); + } + + doUpdate(...columns: (keyof TRecord)[]): BaseStatement; + doUpdate( + updates: ( + columns: Columns, + excluded: Columns, + ) => Partial>, + ): BaseStatement; + doUpdate(...updates: any[]): BaseStatement { + if (this._statement === null) return this; + const updatesSql = + typeof updates[0] === 'function' + ? selectionSetToUpdate( + this._table.columns, + updates[0]( + this._table.columns, + // TODO: make these references to EXCLUDED.column_name not table.column_name + this._table.columns, + ), + ) + : sql.join( + (updates as string[]).map( + (key) => sql`${sql.ident(key)}=EXCLUDED.${sql.ident(key)}`, + ), + sql`, `, + ); + return new StatementImplementation( + this._table, + sql`${this._statement} DO UPDATE SET ${updatesSql}`, + ); + } + + onConflict( + ...columns: readonly (keyof TRecord)[] + ): InsertStatementOnConflictBuilder { + if (this._statement === null) return this; + return new StatementImplementation( + this._table, + sql`${this._statement} ON CONFLICT (${sql.join( + columns.map((columnName) => sql.ident(columnName)), + `,`, + )})`, + ); + } + + onConflictDoNothing(): BaseStatement { + if (this._statement === null) return this; + return new StatementImplementation( + this._table, + sql`${this._statement} ON CONFLICT DO NOTHING`, + ); + } + + toSql(): SQLQuery | null { + return this._statement; + } + + async executeQuery(database: Queryable): Promise { + if (this._statement) { + await database.query(this._statement); + } + } +} + +class StatementCountImplementation implements StatementCount { + private readonly _statement: SQLQuery | null; + constructor(statement: SQLQuery | null) { + this._statement = statement; + } + + toSql(): SQLQuery { + return this._statement + ? sql`${this._statement} RETURNING (COUNT(*))::INT AS row_count` + : sql`SELECT 0 AS row_count`; + } + + async executeQuery(database: Queryable): Promise { + if (!this._statement) return 0; + const results = await database.query( + sql`${this._statement} RETURNING (COUNT(*))::INT AS row_count`, + ); + return results.length ? results[0].row_count : 0; + } +} + +export default function createInsertStatement( + table: TableSchema, + statement: SQLQuery | null, +): InsertStatement { + return new StatementImplementation(table, statement); +} + +function selectionSetToUpdate( + columns: Columns, + ...selections: Partial>[] +): SQLQuery { + const entries = selections.flatMap((selection) => Object.entries(selection)); + return sql.join( + entries.map(([columnName, value]) => { + const column = columns[ + columnName as keyof TRecord + ] as ColumnReference; + return sql`${sql.ident(columnName)}=${valueToSql(value, { + parentOperatorPrecedence: null, + toValue: (v) => column.serializeValue(v as any), + tableAlias: () => null, + })}`; + }), + `,`, + ); +} diff --git a/packages/pg-typed/src/v2/implementation/Table.ts b/packages/pg-typed/src/v2/implementation/Table.ts new file mode 100644 index 00000000..046c071f --- /dev/null +++ b/packages/pg-typed/src/v2/implementation/Table.ts @@ -0,0 +1,169 @@ +import {SQLQuery, sql} from '@databases/pg'; +import TableSchema from '../types/TableSchema'; +import Table from '../types/Table'; +import { + AggregatedQuery, + AliasedQuery, + GroupByQuery, + ProjectedDistinctColumnsQuery, + ProjectedDistinctQuery, + ProjectedLimitQuery, + ProjectedQuery, + ProjectedSortedQuery, + SelectQuery, +} from '../types/Queries'; +import createQuery from './Queries'; +import WhereCondition from '../types/WhereCondition'; +import {ColumnReference, Columns} from '../types/Columns'; +import {AggregatedSelectionSet} from '../types/SelectionSet'; +import {InsertStatement} from '../types/Statements'; +import {TypedDatabaseQuery, Queryable} from '../types/TypedDatabaseQuery'; +import createInsertStatement from './Statements'; + +class TableImplementation + implements Table +{ + private readonly _table: TableSchema; + private readonly _query: SelectQuery; + constructor(table: TableSchema) { + this._table = table; + this._query = createQuery(table); + } + + insert(...records: TInsertParameters[]): InsertStatement { + if (records.length === 0) return createInsertStatement(this._table, null); + + const columnNamesSet = new Set(); + for (const record of records) { + for (const columnName of Object.keys(record as any)) { + columnNamesSet.add( + columnName as keyof TRecord & keyof TInsertParameters, + ); + } + } + const columnNames = [...columnNamesSet].sort(); + + const columnNamesSql = sql.join( + columnNames.map((columnName) => sql.ident(columnName)), + `,`, + ); + + const values = records.map( + (record) => + sql`(${sql.join( + columnNames.map((columnName): SQLQuery => { + const column = this._table.columns[ + columnName + ] as ColumnReference; + const value: any = record[columnName]; + if (!column) { + throw new Error(`Unexpected column: ${columnName as string}`); + } + if (value === undefined) { + return sql`DEFAULT`; + } else { + return sql.value(column.serializeValue(value)); + } + }), + `,`, + )})`, + ); + + return createInsertStatement( + this._table, + sql`INSERT INTO ${ + this._table.tableId + } (${columnNamesSql}) VALUES ${sql.join(values, `,`)}`, + ); + } + + // == Methods of SelectQuery == + + toSql(): SQLQuery { + return this._query.toSql(); + } + as( + alias: TAliasTableName, + ): AliasedQuery { + return this._query.as(alias); + } + where(condition: WhereCondition): SelectQuery { + return this._query.where(condition); + } + select(...selection: any[]): ProjectedQuery { + return this._query.select(...selection) as any; + } + selectAggregate( + aggregation: ( + column: Columns, + ) => AggregatedSelectionSet, + ): AggregatedQuery { + return this._query.selectAggregate(aggregation); + } + groupBy( + ...selection: any[] + ): GroupByQuery> { + return this._query.groupBy(...selection); + } + orderByAscDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery { + return this._query.orderByAscDistinct(columnName); + } + orderByDescDistinct( + columnName: keyof TRecord, + ): ProjectedDistinctColumnsQuery { + return this._query.orderByDescDistinct(columnName); + } + orderByAsc(columnName: keyof TRecord): ProjectedSortedQuery { + return this._query.orderByAsc(columnName); + } + orderByDesc(columnName: keyof TRecord): ProjectedSortedQuery { + return this._query.orderByDesc(columnName); + } + distinct(): ProjectedDistinctQuery { + return this._query.distinct(); + } + limit(n: number): ProjectedLimitQuery { + return this._query.limit(n); + } + + private _whereOptional( + whereCondition?: WhereCondition, + ): SelectQuery { + return whereCondition ? this._query.where(whereCondition) : this._query; + } + one( + whereCondition?: WhereCondition, + ): TypedDatabaseQuery { + return this._whereOptional(whereCondition).one(); + } + + oneRequired( + whereCondition?: WhereCondition, + ): TypedDatabaseQuery { + return this._whereOptional(whereCondition).oneRequired(); + } + + first( + whereCondition?: WhereCondition, + ): TypedDatabaseQuery { + return this._whereOptional(whereCondition).first(); + } + + async executeQuery(database: Queryable): Promise { + return this._query.executeQuery(database); + } +} + +export default function createTableApi( + tableName: string, + tableId: SQLQuery, + columns: Columns, +): Table { + return new TableImplementation({ + columns, + tableId, + tableName, + }); +} diff --git a/packages/pg-typed/src/v2/index.ts b/packages/pg-typed/src/v2/index.ts index c21c5cf7..c06637f5 100644 --- a/packages/pg-typed/src/v2/index.ts +++ b/packages/pg-typed/src/v2/index.ts @@ -1,8 +1,8 @@ import Operators from './implementation/Operators'; import NonAggregatedValue from './types/SpecialValues'; import {JoinQueryBuilder, JoinQuery} from './types/Join'; -import AliasedQuery from './AliasedQuery'; -import {Table} from './Table'; +import AliasedQuery from './types/AliasedQuery.1'; +import {Table} from './implementation/Table'; import {IOperators} from './types/Operators'; export const q: IOperators = Operators; diff --git a/packages/pg-typed/src/v2/types/Columns.ts b/packages/pg-typed/src/v2/types/Columns.ts index 1db2a3d5..60bb7b3f 100644 --- a/packages/pg-typed/src/v2/types/Columns.ts +++ b/packages/pg-typed/src/v2/types/Columns.ts @@ -2,10 +2,11 @@ import {NonAggregatedTypedValue} from './SpecialValues'; export interface ColumnReference extends NonAggregatedTypedValue { readonly sqlType: string | null; + serializeValue(value: T): unknown; setAlias(tableAlias: string): ColumnReference; } -export type Columns = { +export type Columns = {__isSpecialValue: true; __tableName: string} & { [TColumnName in keyof TRecord]: ColumnReference; }; diff --git a/packages/pg-typed/src/v2/types/Join.ts b/packages/pg-typed/src/v2/types/Join.ts deleted file mode 100644 index 31860402..00000000 --- a/packages/pg-typed/src/v2/types/Join.ts +++ /dev/null @@ -1,27 +0,0 @@ -import GroupByQuery from '../GroupByQuery'; -import {JoinableQueryLeft} from './JoinableQuery'; -import {ProjectedQuery} from './Queries'; -import {AggregatedSelectionSet, SelectionSet} from './SelectionSet'; -import {NonAggregatedValue} from './SpecialValues'; - -export interface JoinQueryBuilder { - on( - predicate: (column: TColumns) => NonAggregatedValue, - ): JoinQuery; -} - -export interface JoinQuery extends JoinableQueryLeft { - select( - selection: (column: TColumns) => SelectionSet, - ): ProjectedQuery; - groupBy( - selection: (column: TColumns) => SelectionSet, - ): GroupByQuery; - selectAggregate( - aggregation: (column: TColumns) => AggregatedSelectionSet, - ): ProjectedQuery; - - where( - predicate: (column: TColumns) => NonAggregatedValue, - ): JoinQuery; -} diff --git a/packages/pg-typed/src/v2/types/JoinableQuery.ts b/packages/pg-typed/src/v2/types/JoinableQuery.ts deleted file mode 100644 index 86cd8e2f..00000000 --- a/packages/pg-typed/src/v2/types/JoinableQuery.ts +++ /dev/null @@ -1,51 +0,0 @@ -import {Columns} from './Columns'; -import {JoinQueryBuilder} from './Join'; -import ProjectedLimitQuery from '../ProjectedLimitQuery'; - -export type JoinedColumns< - TLeftTables, - TRightAlias extends string, - TRightRecordColumns, -> = { - [TChildAlias in - | keyof TLeftTables - | TRightAlias]: TChildAlias extends keyof TLeftTables - ? TLeftTables[TChildAlias] - : TRightRecordColumns; -}; - -export type InnerJoinedColumns< - TLeftTables, - TRightAlias extends string, - TRightRecord, -> = JoinedColumns>; - -export type LeftOuterJoinedColumns< - TLeftTables, - TRightAlias extends string, - TRightRecord, -> = JoinedColumns< - TLeftTables, - TRightAlias, - Columns<{ - [TColumnName in keyof TRightRecord]: TRightRecord[TColumnName] | null; - }> ->; - -export interface JoinableQueryLeft { - innerJoin( - otherQuery: JoinableQueryRight, - ): JoinQueryBuilder< - InnerJoinedColumns - >; - leftOuterJoin( - otherQuery: JoinableQueryRight, - ): JoinQueryBuilder< - LeftOuterJoinedColumns - >; -} - -export interface JoinableQueryRight - extends ProjectedLimitQuery { - alias: TAlias; -} diff --git a/packages/pg-typed/src/v2/types/Operators.ts b/packages/pg-typed/src/v2/types/Operators.ts index 3b09b5d8..aaaf3823 100644 --- a/packages/pg-typed/src/v2/types/Operators.ts +++ b/packages/pg-typed/src/v2/types/Operators.ts @@ -1,4 +1,4 @@ -import WhereCondition from '../WhereCondition'; +import WhereCondition from './WhereCondition'; import { AggregatedTypedValue, NonAggregatedValue, @@ -8,6 +8,8 @@ import { AggregatedValue, Value, } from './SpecialValues'; +import {Columns} from './Columns'; +import {SelectionSet, SelectionSetStar} from './SelectionSet'; export interface List { [Symbol.iterator](): IterableIterator; @@ -29,6 +31,16 @@ export interface AggregatedJsonValue { prop(key: TKey): AggregatedJsonValue; } +export type MergedSelectionSet = SelectionSet< + { + [TKey in keyof T]: T[TKey] extends SelectionSet + ? (v: TSelection) => void + : never; + }[number & keyof T] extends (v: infer TResult) => void + ? TResult + : never +>; + // prettier-ignore export interface IOperators { allOf(values: List>): FieldCondition; @@ -50,6 +62,8 @@ export interface IOperators { gte(left: Value, right: Value): Value; gte(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; gte(left: AggregatedValue, right: AggregatedValue): AggregatedValue; + + mergeColumns[]>(...selections: {[key in keyof TSelections]: TSelections[key]}): MergedSelectionSet neq(left: Value, right: Value): Value; neq(left: NonAggregatedValue, right: NonAggregatedValue): NonAggregatedValue; neq(left: AggregatedValue, right: AggregatedValue): AggregatedValue; @@ -83,6 +97,7 @@ export interface IOperators { or(...values: NonAggregatedValue[]): NonAggregatedValue; or(...values: AggregatedValue[]): AggregatedValue; or(...conditions: WhereCondition[]): WhereCondition; + star(columns: Columns): SelectionSetStar; sum(value: NonAggregatedValue): AggregatedTypedValue; upper(value: Value): Value; upper(value: NonAggregatedValue): NonAggregatedValue; diff --git a/packages/pg-typed/src/v2/types/Queries.ts b/packages/pg-typed/src/v2/types/Queries.ts index b8b7c535..a1ffc5ae 100644 --- a/packages/pg-typed/src/v2/types/Queries.ts +++ b/packages/pg-typed/src/v2/types/Queries.ts @@ -1,6 +1,14 @@ import {SQLQuery} from '@databases/pg'; -import AliasedQuery from '../AliasedQuery'; + +import {Columns, InnerJoinedColumns, LeftOuterJoinedColumns} from './Columns'; import {TypedDatabaseQuery} from './TypedDatabaseQuery'; +import {NonAggregatedValue} from './SpecialValues'; +import { + AggregatedSelectionSet, + SelectionSet, + SelectionSetObject, +} from './SelectionSet'; +import WhereCondition from './WhereCondition'; export interface ProjectedLimitQuery extends TypedDatabaseQuery { @@ -68,3 +76,104 @@ export interface ProjectedQuery extends ProjectedDistinctColumnsQuery { distinct(): ProjectedDistinctQuery; } + +export interface AliasedQuery + extends SelectQuery, + JoinableQueryRight, + JoinableQueryLeft<{ + [TKey in TAlias]: Columns; + }> { + where( + condition: WhereCondition>, + ): AliasedQuery; +} + +export interface JoinableQueryLeft { + innerJoin( + otherQuery: JoinableQueryRight, + ): JoinQueryBuilder< + InnerJoinedColumns + >; + leftOuterJoin( + otherQuery: JoinableQueryRight, + ): JoinQueryBuilder< + LeftOuterJoinedColumns + >; +} + +export interface JoinableQueryRight + extends ProjectedLimitQuery { + alias: TAlias; +} + +export interface JoinQueryBuilder { + on( + predicate: (column: TColumns) => NonAggregatedValue, + ): JoinQuery; +} + +export interface JoinQuery extends JoinableQueryLeft { + select( + selection: (column: TColumns) => SelectionSet, + ): ProjectedQuery; + groupBy( + selection: (column: TColumns) => SelectionSetObject, + ): GroupByQuery; + selectAggregate( + aggregation: (column: TColumns) => AggregatedSelectionSet, + ): AggregatedQuery; + + where( + predicate: (column: TColumns) => NonAggregatedValue, + ): JoinQuery; +} + +export interface GroupByQuery { + selectAggregate( + aggregation: (column: TColumns) => AggregatedSelectionSet, + ): ProjectedSortedQuery; +} + +export interface SelectQuery extends ProjectedQuery { + where( + condition: WhereCondition>, + ): SelectQuery; + + select( + ...columnNames: TColumnNames + ): ProjectedQuery>; + select( + selection: (column: Columns) => SelectionSet, + ): ProjectedQuery; + + groupBy( + ...columnNames: TColumnNames + ): GroupByQuery, Columns>; + groupBy( + selection: (column: Columns) => SelectionSetObject, + ): GroupByQuery>; + + selectAggregate( + aggregation: ( + column: Columns, + ) => AggregatedSelectionSet, + ): AggregatedQuery; +} + +export interface AggregatedQuery extends TypedDatabaseQuery { + /** + * Get the SQL query that would be executed. This is useful if you want to use this query as a sub-query in a query that is not type safe. + */ + toSql(): SQLQuery; + + /** + * If this is a complex query: + * Wrap the entire query in parentheses, and give it an alias. This lets you use joins, group by, etc. as sub-queries. + * + * If this is a simple query: + * Give the table an alias. This lets you use it in a join. + */ + as( + alias: TAliasTableName, + ): AliasedQuery; +} diff --git a/packages/pg-typed/src/v2/types/SelectionSet.ts b/packages/pg-typed/src/v2/types/SelectionSet.ts index 066540be..e022c8fb 100644 --- a/packages/pg-typed/src/v2/types/SelectionSet.ts +++ b/packages/pg-typed/src/v2/types/SelectionSet.ts @@ -1,9 +1,30 @@ import {AggregatedValue, NonAggregatedValue} from './SpecialValues'; -export type SelectionSet = { - [key in keyof TSelection]: NonAggregatedValue; +export type SelectionSetStar = { + readonly __isSpecialValue: true; + readonly __selectionSetType: 'STAR'; + readonly tableName: string; + __getType(): TSelection; }; +export type SelectionSetMerged = { + readonly __isSpecialValue: true; + readonly __selectionSetType: 'MERGED'; + readonly selections: SelectionSet>[]; + __getType(): TSelection; +}; + +export type SelectionSetObject = { + [key in keyof TSelection]: key extends '__isSpecialValue' + ? never + : NonAggregatedValue; +}; + +export type SelectionSet = + | SelectionSetStar + | SelectionSetMerged + | SelectionSetObject; + export type AggregatedSelectionSet = { [key in keyof TSelection]: AggregatedValue; }; diff --git a/packages/pg-typed/src/v2/types/Statements.ts b/packages/pg-typed/src/v2/types/Statements.ts new file mode 100644 index 00000000..c6d43d9a --- /dev/null +++ b/packages/pg-typed/src/v2/types/Statements.ts @@ -0,0 +1,44 @@ +import {SQLQuery} from '@databases/pg'; +import {TypedDatabaseQuery} from './TypedDatabaseQuery'; +import {SelectionSetObject} from './SelectionSet'; +import {Columns} from './Columns'; +import {SelectQuery} from './Queries'; + +export interface StatementCount extends TypedDatabaseQuery { + toSql(): SQLQuery; +} +export interface BaseStatement extends TypedDatabaseQuery { + /** + * Get the SQL query that would be executed. This is useful if you want to use this query as a sub-query in a query that is not type safe. + * + * Returns "null" if the query has no side effects (e.g. an insert with no records to insert) + */ + toSql(): SQLQuery | null; + + returningCount(): StatementCount; + returning(star: '*'): SelectQuery; + returning( + ...columnNames: TColumnNames + ): SelectQuery>; + returning( + selection: (column: Columns) => SelectionSetObject, + ): SelectQuery; +} + +export interface InsertStatementOnConflictBuilder { + doUpdate(...columns: (keyof TRecord)[]): BaseStatement; + doUpdate( + updates: ( + columns: Columns, + excluded: Columns, + ) => Partial>, + ): BaseStatement; +} +export interface InsertStatement extends BaseStatement { + onConflict( + ...columns: readonly (keyof TRecord)[] + ): InsertStatementOnConflictBuilder; + onConflictDoNothing(): BaseStatement; +} + +export interface UpdateStatement extends BaseStatement {} diff --git a/packages/pg-typed/src/v2/types/Table.ts b/packages/pg-typed/src/v2/types/Table.ts new file mode 100644 index 00000000..382fa1ff --- /dev/null +++ b/packages/pg-typed/src/v2/types/Table.ts @@ -0,0 +1,19 @@ +import {InsertStatement} from './Statements'; +import {SelectQuery} from './Queries'; +import {TypedDatabaseQuery} from './TypedDatabaseQuery'; +import WhereCondition from './WhereCondition'; + +export default interface Table + extends SelectQuery { + one( + whereCondition?: WhereCondition, + ): TypedDatabaseQuery; + oneRequired( + whereCondition?: WhereCondition, + ): TypedDatabaseQuery; + first( + whereCondition?: WhereCondition, + ): TypedDatabaseQuery; + + insert(...records: TInsertParameters[]): InsertStatement; +} diff --git a/packages/pg-typed/src/v2/TableSchema.ts b/packages/pg-typed/src/v2/types/TableSchema.ts similarity index 68% rename from packages/pg-typed/src/v2/TableSchema.ts rename to packages/pg-typed/src/v2/types/TableSchema.ts index 9598c44a..e437c40d 100644 --- a/packages/pg-typed/src/v2/TableSchema.ts +++ b/packages/pg-typed/src/v2/types/TableSchema.ts @@ -1,8 +1,8 @@ import {SQLQuery} from '@databases/pg'; -import {Columns} from './types/Columns'; +import {Columns} from './Columns'; export default interface TableSchema { - __getType(): TRecord; + readonly __getType?: () => TRecord; tableName: string; tableId: SQLQuery; columns: Columns; diff --git a/packages/pg-typed/src/v2/WhereCondition.ts b/packages/pg-typed/src/v2/types/WhereCondition.ts similarity index 78% rename from packages/pg-typed/src/v2/WhereCondition.ts rename to packages/pg-typed/src/v2/types/WhereCondition.ts index 9e64423e..162abea9 100644 --- a/packages/pg-typed/src/v2/WhereCondition.ts +++ b/packages/pg-typed/src/v2/types/WhereCondition.ts @@ -1,5 +1,5 @@ -import {NonAggregatedValue, FieldCondition} from './types/SpecialValues'; -import {Columns} from './types/Columns'; +import {NonAggregatedValue, FieldCondition} from './SpecialValues'; +import {Columns} from './Columns'; export type WhereConditionObject = { readonly [key in keyof TRecord]?: FieldCondition; From 1c8f61b6779602e2e81f04b87b9f12831f257670 Mon Sep 17 00:00:00 2001 From: Forbes Lindesay Date: Fri, 19 May 2023 18:11:43 +0100 Subject: [PATCH 5/5] Implement insert/update/delete TODO: - Referencing EXCLUDED in `ON CONFLICT ... DO UPDATE ...` - Bulk operations using `UNNEST` - Script to migrate existing code - Integration tests --- .../pg-typed/src/v2/__tests__/delete.test.ts | 95 ++++++++++++ .../pg-typed/src/v2/__tests__/insert.test.ts | 4 +- .../pg-typed/src/v2/__tests__/update.test.ts | 101 +++++++++++++ .../pg-typed/src/v2/implementation/Queries.ts | 45 +++--- .../src/v2/implementation/Statements.ts | 135 +++++++++++++++--- .../pg-typed/src/v2/implementation/Table.ts | 36 ++++- packages/pg-typed/src/v2/types/Statements.ts | 20 ++- packages/pg-typed/src/v2/types/Table.ts | 11 +- 8 files changed, 401 insertions(+), 46 deletions(-) create mode 100644 packages/pg-typed/src/v2/__tests__/delete.test.ts create mode 100644 packages/pg-typed/src/v2/__tests__/update.test.ts diff --git a/packages/pg-typed/src/v2/__tests__/delete.test.ts b/packages/pg-typed/src/v2/__tests__/delete.test.ts new file mode 100644 index 00000000..34e20c33 --- /dev/null +++ b/packages/pg-typed/src/v2/__tests__/delete.test.ts @@ -0,0 +1,95 @@ +import {escapePostgresIdentifier} from '@databases/escape-identifier'; +import {SQLQuery, sql} from '@databases/pg'; +import {columns} from '../implementation/Columns'; +import createTableApi from '../implementation/Table'; +import {TypedDatabaseQuery} from '../types/TypedDatabaseQuery'; + +interface DbUser { + id: number; + username: string; + profile_image_url: string | null; +} + +const users = createTableApi('users', sql`users`, columns(`users`)); + +const testFormat = { + escapeIdentifier: escapePostgresIdentifier, + formatValue: (value: unknown) => ({ + placeholder: '${ ' + JSON.stringify(value) + ' }', + value: undefined, + }), +}; + +test(`Basic Delete`, async () => { + const mock = {query: jest.fn()}; + + const deleteNoRecords = users.delete(false); + await deleteNoRecords.executeQuery(mock); + expect(mock.query).not.toBeCalled(); + expect(deleteNoRecords.toSql()).toBe(null); + expect(await mockResult(deleteNoRecords)).toBe(undefined); + + const deleteOne = users.delete({id: 1}); + + expect( + await mockResult( + deleteOne, + `DELETE FROM users WHERE "id"=\${ 1 }`, + [], + ), + ).toBe(undefined); + + const deleteReturningStar = deleteOne.returning(); + expect( + await mockResult( + deleteReturningStar.one(), + `DELETE FROM users WHERE "id"=\${ 1 } RETURNING *`, + [{id: 1, username: 'deleted_username', profile_image_url: null}], + ), + ).toEqual({id: 1, username: 'deleted_username', profile_image_url: null}); + + const deleteReturningId = deleteOne.returning(`id`); + expect( + await mockResult<{id: number}[]>( + deleteReturningId, + `DELETE FROM users WHERE "id"=\${ 1 } RETURNING "id"`, + [{id: 1}], + ), + ).toEqual([{id: 1}]); + + const deleteReturningCount = deleteOne.returningCount(); + expect( + await mockResult( + deleteReturningCount, + `DELETE FROM users WHERE "id"=\${ 1 } RETURNING (COUNT(*))::INT AS row_count`, + [{row_count: 1}], + ), + ).toBe(1); +}); + +async function mockResult( + query: TypedDatabaseQuery, + expectedQuery?: string, + results?: any[], +): Promise { + if ((expectedQuery === undefined) !== (results === undefined)) { + throw new Error( + `Mock results should have either an expected query and results, or neither.`, + ); + } + let called = false; + const result = await query.executeQuery({ + query: async (q: SQLQuery) => { + if (expectedQuery === undefined || results === undefined) { + throw new Error(`Did not expect query to be called`); + } + called = true; + expect(q.format(testFormat).text).toEqual(expectedQuery); + return results; + }, + }); + if (expectedQuery) { + expect(called).toBe(true); + } + return result; +} diff --git a/packages/pg-typed/src/v2/__tests__/insert.test.ts b/packages/pg-typed/src/v2/__tests__/insert.test.ts index ce2c9d64..7b90404c 100644 --- a/packages/pg-typed/src/v2/__tests__/insert.test.ts +++ b/packages/pg-typed/src/v2/__tests__/insert.test.ts @@ -49,7 +49,7 @@ test(`Basic Insert`, async () => { ), ).toBe(undefined); - const insertReturningStar = insertOne.returning(`*`); + const insertReturningStar = insertOne.returning(); expect( await mockResult( insertReturningStar.one(), @@ -124,7 +124,7 @@ test(`INNER JOIN`, async () => { .innerJoin( posts .insert({author_id: 1, title: 'test', created_at: new Date(0)}) - .returning(`*`) + .returning() .as(`p`), ) .on(({u, p}) => q.eq(u.id, p.author_id)) diff --git a/packages/pg-typed/src/v2/__tests__/update.test.ts b/packages/pg-typed/src/v2/__tests__/update.test.ts new file mode 100644 index 00000000..dd3ae6d4 --- /dev/null +++ b/packages/pg-typed/src/v2/__tests__/update.test.ts @@ -0,0 +1,101 @@ +import {escapePostgresIdentifier} from '@databases/escape-identifier'; +import {SQLQuery, sql} from '@databases/pg'; +import {columns} from '../implementation/Columns'; +import createTableApi from '../implementation/Table'; +import {TypedDatabaseQuery} from '../types/TypedDatabaseQuery'; + +interface DbUser { + id: number; + username: string; + profile_image_url: string | null; +} + +const users = createTableApi('users', sql`users`, columns(`users`)); + +const testFormat = { + escapeIdentifier: escapePostgresIdentifier, + formatValue: (value: unknown) => ({ + placeholder: '${ ' + JSON.stringify(value) + ' }', + value: undefined, + }), +}; + +test(`Basic Update`, async () => { + const mock = {query: jest.fn()}; + + const updateNoColumns = users.update(true, {}); + await updateNoColumns.executeQuery(mock); + expect(mock.query).not.toBeCalled(); + expect(updateNoColumns.toSql()).toBe(null); + expect(await mockResult(updateNoColumns)).toBe(undefined); + + const updateNoRecords = users.update(false, {username: 'updated_username'}); + await updateNoRecords.executeQuery(mock); + expect(mock.query).not.toBeCalled(); + expect(updateNoRecords.toSql()).toBe(null); + expect(await mockResult(updateNoRecords)).toBe(undefined); + + const updateOne = users.update({id: 1}, {username: 'updated_username'}); + + expect( + await mockResult( + updateOne, + `UPDATE users SET "username"=\${ "updated_username" } WHERE "id"=\${ 1 }`, + [], + ), + ).toBe(undefined); + + const updateReturningStar = updateOne.returning(); + expect( + await mockResult( + updateReturningStar.one(), + `UPDATE users SET "username"=\${ "updated_username" } WHERE "id"=\${ 1 } RETURNING *`, + [{id: 1, username: 'updated_username', profile_image_url: null}], + ), + ).toEqual({id: 1, username: 'updated_username', profile_image_url: null}); + + const updateReturningId = updateOne.returning(`id`); + expect( + await mockResult<{id: number}[]>( + updateReturningId, + `UPDATE users SET "username"=\${ "updated_username" } WHERE "id"=\${ 1 } RETURNING "id"`, + [{id: 1}], + ), + ).toEqual([{id: 1}]); + + const updateReturningCount = updateOne.returningCount(); + expect( + await mockResult( + updateReturningCount, + `UPDATE users SET "username"=\${ "updated_username" } WHERE "id"=\${ 1 } RETURNING (COUNT(*))::INT AS row_count`, + [{row_count: 1}], + ), + ).toBe(1); +}); + +async function mockResult( + query: TypedDatabaseQuery, + expectedQuery?: string, + results?: any[], +): Promise { + if ((expectedQuery === undefined) !== (results === undefined)) { + throw new Error( + `Mock results should have either an expected query and results, or neither.`, + ); + } + let called = false; + const result = await query.executeQuery({ + query: async (q: SQLQuery) => { + if (expectedQuery === undefined || results === undefined) { + throw new Error(`Did not expect query to be called`); + } + called = true; + expect(q.format(testFormat).text).toEqual(expectedQuery); + return results; + }, + }); + if (expectedQuery) { + expect(called).toBe(true); + } + return result; +} diff --git a/packages/pg-typed/src/v2/implementation/Queries.ts b/packages/pg-typed/src/v2/implementation/Queries.ts index cec5e79a..eb7371d5 100644 --- a/packages/pg-typed/src/v2/implementation/Queries.ts +++ b/packages/pg-typed/src/v2/implementation/Queries.ts @@ -66,6 +66,7 @@ export default function createQuery( where: [], }); } + export function createStatementReturn( table: TableSchema, query: SQLQuery | null, @@ -76,14 +77,37 @@ export function createStatementReturn( return new StatementReturning( table as any, query, - typeof selection[0] === 'function' - ? selectionSetToProjection(selection[0](table.columns)) - : selection[0] === '*' + selection.length === 0 ? null + : typeof selection[0] === 'function' + ? selectionSetToProjection(selection[0](table.columns)) : columnNamesToProjection(selection as string[]), ); } +export function whereConditionToPredicates< + TRecord, + TColumns = Columns, +>( + columns: TColumns, + condition: WhereCondition, +): NonAggregatedValue[] { + return sql.isSqlQuery(condition) || + isSpecialValue(condition) || + typeof condition === 'boolean' + ? [condition] + : typeof condition === 'function' + ? [condition(columns)] + : Object.entries(condition).map(([columnName, value]) => + fieldConditionToPredicateValue( + (columns as Columns)[ + columnName as keyof TRecord + ] as ColumnReference, + value as FieldCondition, + ), + ); +} + interface Projection { /** * The SQL for the selection set. @@ -391,20 +415,7 @@ class SelectQueryImplementation< where(condition: WhereCondition): any { const where = [ ...this._config.where, - ...(sql.isSqlQuery(condition) || - isSpecialValue(condition) || - typeof condition === 'boolean' - ? [condition] - : typeof condition === 'function' - ? [condition(this._config.columns)] - : Object.entries(condition).map(([columnName, value]) => - fieldConditionToPredicateValue( - (this._config.columns as Columns)[ - columnName as keyof Columns - ] as ColumnReference, - value as FieldCondition, - ), - )), + ...whereConditionToPredicates(this._config.columns, condition), ]; return new SelectQueryImplementation({ columns: this._config.columns, diff --git a/packages/pg-typed/src/v2/implementation/Statements.ts b/packages/pg-typed/src/v2/implementation/Statements.ts index ad99414d..543ce12f 100644 --- a/packages/pg-typed/src/v2/implementation/Statements.ts +++ b/packages/pg-typed/src/v2/implementation/Statements.ts @@ -1,23 +1,26 @@ import {SQLQuery, sql} from '@databases/pg'; import { BaseStatement, + DeleteStatement, InsertStatement, InsertStatementOnConflictBuilder, StatementCount, UpdateStatement, } from '../types/Statements'; -import {Queryable} from '../types/TypedDatabaseQuery'; +import {Queryable, TypedDatabaseQuery} from '../types/TypedDatabaseQuery'; import TableSchema from '../types/TableSchema'; import {ColumnReference, Columns} from '../types/Columns'; import {SelectionSetObject} from '../types/SelectionSet'; -import {valueToSql} from './Operators'; +import Operators, {valueToSql} from './Operators'; import {SelectQuery} from '../types/Queries'; -import {createStatementReturn} from './Queries'; +import {createStatementReturn, whereConditionToPredicates} from './Queries'; +import WhereCondition from '../types/WhereCondition'; interface AnyStatement extends InsertStatementOnConflictBuilder, InsertStatement, - UpdateStatement {} + UpdateStatement, + DeleteStatement {} class StatementImplementation implements AnyStatement { private readonly _table: TableSchema; @@ -30,7 +33,8 @@ class StatementImplementation implements AnyStatement { returningCount(): StatementCount { return new StatementCountImplementation(this._statement); } - returning(star: '*'): SelectQuery; + + returning(): SelectQuery; returning( ...columnNames: TColumnNames ): SelectQuery>; @@ -45,6 +49,40 @@ class StatementImplementation implements AnyStatement { ); } + returningOne(): TypedDatabaseQuery; + returningOne( + ...columnNames: TColumnNames + ): TypedDatabaseQuery | undefined>; + returningOne( + selection: (column: Columns) => SelectionSetObject, + ): TypedDatabaseQuery; + returningOne( + ...selection: any[] + ): TypedDatabaseQuery { + return createStatementReturn( + this._table, + this._statement, + selection, + ).one(); + } + + returningOneRequired(): TypedDatabaseQuery; + returningOneRequired( + ...columnNames: TColumnNames + ): TypedDatabaseQuery>; + returningOneRequired( + selection: (column: Columns) => SelectionSetObject, + ): TypedDatabaseQuery; + returningOneRequired( + ...selection: any[] + ): TypedDatabaseQuery { + return createStatementReturn( + this._table, + this._statement, + selection, + ).oneRequired(); + } + doUpdate(...columns: (keyof TRecord)[]): BaseStatement; doUpdate( updates: ( @@ -63,7 +101,7 @@ class StatementImplementation implements AnyStatement { // TODO: make these references to EXCLUDED.column_name not table.column_name this._table.columns, ), - ) + ).query : sql.join( (updates as string[]).map( (key) => sql`${sql.ident(key)}=EXCLUDED.${sql.ident(key)}`, @@ -129,29 +167,82 @@ class StatementCountImplementation implements StatementCount { } } -export default function createInsertStatement( +export function createInsertStatement( table: TableSchema, statement: SQLQuery | null, ): InsertStatement { return new StatementImplementation(table, statement); } +export function createUpdateStatement( + table: TableSchema, + condition: WhereCondition, + updateValues: Partial>, +): UpdateStatement { + const predicate = Operators.and( + ...whereConditionToPredicates(table.columns, condition), + ); + const {query: update, columnCount} = selectionSetToUpdate( + table.columns, + updateValues, + ); + + return new StatementImplementation( + table, + predicate === false || columnCount === 0 + ? null + : sql`UPDATE ${table.tableId} SET ${update} WHERE ${valueToSql( + predicate, + { + parentOperatorPrecedence: null, + toValue: (v) => v, + tableAlias: () => null, + }, + )}`, + ); +} + +export function createDeleteStatement( + table: TableSchema, + condition: WhereCondition, +): DeleteStatement { + const predicate = Operators.and( + ...whereConditionToPredicates(table.columns, condition), + ); + + return new StatementImplementation( + table, + predicate === false + ? null + : sql`DELETE FROM ${table.tableId} WHERE ${valueToSql(predicate, { + parentOperatorPrecedence: null, + toValue: (v) => v, + tableAlias: () => null, + })}`, + ); +} + function selectionSetToUpdate( columns: Columns, - ...selections: Partial>[] -): SQLQuery { - const entries = selections.flatMap((selection) => Object.entries(selection)); - return sql.join( - entries.map(([columnName, value]) => { - const column = columns[ - columnName as keyof TRecord - ] as ColumnReference; - return sql`${sql.ident(columnName)}=${valueToSql(value, { - parentOperatorPrecedence: null, - toValue: (v) => column.serializeValue(v as any), - tableAlias: () => null, - })}`; - }), - `,`, + selection: Partial>, +): {query: SQLQuery; columnCount: number} { + const entries = Object.entries(selection).filter( + ([, value]) => value !== undefined, ); + return { + query: sql.join( + entries.map(([columnName, value]) => { + const column = columns[ + columnName as keyof TRecord + ] as ColumnReference; + return sql`${sql.ident(columnName)}=${valueToSql(value, { + parentOperatorPrecedence: null, + toValue: (v) => column.serializeValue(v as any), + tableAlias: () => null, + })}`; + }), + `,`, + ), + columnCount: entries.length, + }; } diff --git a/packages/pg-typed/src/v2/implementation/Table.ts b/packages/pg-typed/src/v2/implementation/Table.ts index 046c071f..854bbd90 100644 --- a/packages/pg-typed/src/v2/implementation/Table.ts +++ b/packages/pg-typed/src/v2/implementation/Table.ts @@ -15,10 +15,21 @@ import { import createQuery from './Queries'; import WhereCondition from '../types/WhereCondition'; import {ColumnReference, Columns} from '../types/Columns'; -import {AggregatedSelectionSet} from '../types/SelectionSet'; -import {InsertStatement} from '../types/Statements'; +import { + AggregatedSelectionSet, + SelectionSetObject, +} from '../types/SelectionSet'; +import { + DeleteStatement, + InsertStatement, + UpdateStatement, +} from '../types/Statements'; import {TypedDatabaseQuery, Queryable} from '../types/TypedDatabaseQuery'; -import createInsertStatement from './Statements'; +import { + createDeleteStatement, + createInsertStatement, + createUpdateStatement, +} from './Statements'; class TableImplementation implements Table @@ -77,6 +88,25 @@ class TableImplementation ); } + update( + condition: WhereCondition, + updateValues: + | Partial + | ((column: Columns) => Partial>), + ): UpdateStatement { + return createUpdateStatement( + this._table, + condition, + typeof updateValues === 'function' + ? updateValues(this._table.columns) + : (updateValues as Partial>), + ); + } + + delete(condition: WhereCondition): DeleteStatement { + return createDeleteStatement(this._table, condition); + } + // == Methods of SelectQuery == toSql(): SQLQuery { diff --git a/packages/pg-typed/src/v2/types/Statements.ts b/packages/pg-typed/src/v2/types/Statements.ts index c6d43d9a..a2bab197 100644 --- a/packages/pg-typed/src/v2/types/Statements.ts +++ b/packages/pg-typed/src/v2/types/Statements.ts @@ -16,13 +16,30 @@ export interface BaseStatement extends TypedDatabaseQuery { toSql(): SQLQuery | null; returningCount(): StatementCount; - returning(star: '*'): SelectQuery; + + returning(): SelectQuery; returning( ...columnNames: TColumnNames ): SelectQuery>; returning( selection: (column: Columns) => SelectionSetObject, ): SelectQuery; + + returningOne(): TypedDatabaseQuery; + returningOne( + ...columnNames: TColumnNames + ): TypedDatabaseQuery | undefined>; + returningOne( + selection: (column: Columns) => SelectionSetObject, + ): TypedDatabaseQuery; + + returningOneRequired(): TypedDatabaseQuery; + returningOneRequired( + ...columnNames: TColumnNames + ): TypedDatabaseQuery>; + returningOneRequired( + selection: (column: Columns) => SelectionSetObject, + ): TypedDatabaseQuery; } export interface InsertStatementOnConflictBuilder { @@ -42,3 +59,4 @@ export interface InsertStatement extends BaseStatement { } export interface UpdateStatement extends BaseStatement {} +export interface DeleteStatement extends BaseStatement {} diff --git a/packages/pg-typed/src/v2/types/Table.ts b/packages/pg-typed/src/v2/types/Table.ts index 382fa1ff..3dac8e3d 100644 --- a/packages/pg-typed/src/v2/types/Table.ts +++ b/packages/pg-typed/src/v2/types/Table.ts @@ -1,7 +1,9 @@ -import {InsertStatement} from './Statements'; +import {DeleteStatement, InsertStatement, UpdateStatement} from './Statements'; import {SelectQuery} from './Queries'; import {TypedDatabaseQuery} from './TypedDatabaseQuery'; import WhereCondition from './WhereCondition'; +import {SelectionSetObject} from './SelectionSet'; +import {Columns} from './Columns'; export default interface Table extends SelectQuery { @@ -16,4 +18,11 @@ export default interface Table ): TypedDatabaseQuery; insert(...records: TInsertParameters[]): InsertStatement; + update( + whereValues: WhereCondition, + updateValues: + | Partial + | ((column: Columns) => Partial>), + ): UpdateStatement; + delete(whereValues: WhereCondition): DeleteStatement; }