Skip to content

Commit abb7e08

Browse files
ivashogDev K0teigalklebanov
committed
Add within group clause support for aggregate function builder (#1024)
Co-authored-by: Ivashkin Olexiy <[email protected]> Co-authored-by: Dev K0te <[email protected]> Co-authored-by: igalklebanov <[email protected]>
1 parent 099b1a7 commit abb7e08

File tree

5 files changed

+98
-8
lines changed

5 files changed

+98
-8
lines changed

src/operation-node/aggregate-function-node.ts

+6-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export interface AggregateFunctionNode extends OperationNode {
1111
readonly aggregated: readonly OperationNode[]
1212
readonly distinct?: boolean
1313
readonly orderBy?: OrderByNode
14+
readonly withinGroup?: OrderByNode
1415
readonly filter?: WhereNode
1516
readonly over?: OverNode
1617
}
@@ -46,11 +47,14 @@ export const AggregateFunctionNode = freeze({
4647
cloneWithOrderBy(
4748
aggregateFunctionNode: AggregateFunctionNode,
4849
orderItems: ReadonlyArray<OrderByItemNode>,
50+
withinGroup = false,
4951
): AggregateFunctionNode {
52+
const prop = withinGroup ? 'withinGroup' : 'orderBy'
53+
5054
return freeze({
5155
...aggregateFunctionNode,
52-
orderBy: aggregateFunctionNode.orderBy
53-
? OrderByNode.cloneWithItems(aggregateFunctionNode.orderBy, orderItems)
56+
[prop]: aggregateFunctionNode[prop]
57+
? OrderByNode.cloneWithItems(aggregateFunctionNode[prop], orderItems)
5458
: OrderByNode.create(orderItems),
5559
})
5660
},

src/operation-node/operation-node-transformer.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -904,11 +904,12 @@ export class OperationNodeTransformer {
904904
): AggregateFunctionNode {
905905
return requireAllProps({
906906
kind: 'AggregateFunctionNode',
907+
func: node.func,
907908
aggregated: this.transformNodeList(node.aggregated),
908909
distinct: node.distinct,
909910
orderBy: this.transformNode(node.orderBy),
911+
withinGroup: this.transformNode(node.withinGroup),
910912
filter: this.transformNode(node.filter),
911-
func: node.func,
912913
over: this.transformNode(node.over),
913914
})
914915
}

src/query-builder/aggregate-function-builder.ts

+44-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import {
1111
} from '../expression/expression.js'
1212
import {
1313
ReferenceExpression,
14-
StringReference,
14+
SimpleReferenceExpression,
1515
} from '../parser/reference-parser.js'
1616
import {
1717
ComparisonOperatorExpression,
@@ -21,7 +21,6 @@ import {
2121
} from '../parser/binary-operation-parser.js'
2222
import { SqlBool } from '../util/type-utils.js'
2323
import { ExpressionOrFactory } from '../parser/expression-parser.js'
24-
import { DynamicReferenceBuilder } from '../dynamic/dynamic-reference-builder.js'
2524
import {
2625
OrderByDirectionExpression,
2726
parseOrderBy,
@@ -125,7 +124,7 @@ export class AggregateFunctionBuilder<DB, TB extends keyof DB, O = unknown>
125124
* inner join "pet" ON "pet"."owner_id" = "person"."id"
126125
* ```
127126
*/
128-
orderBy<OE extends StringReference<DB, TB> | DynamicReferenceBuilder<any>>(
127+
orderBy<OE extends SimpleReferenceExpression<DB, TB>>(
129128
orderBy: OE,
130129
direction?: OrderByDirectionExpression,
131130
): AggregateFunctionBuilder<DB, TB, O> {
@@ -138,6 +137,48 @@ export class AggregateFunctionBuilder<DB, TB extends keyof DB, O = unknown>
138137
})
139138
}
140139

140+
/**
141+
* Adds a `withing group` clause with a nested `order by` clause after the function.
142+
*
143+
* This is only supported by some dialects like PostgreSQL or MS SQL Server.
144+
*
145+
* ### Examples
146+
*
147+
* Most frequent person name:
148+
*
149+
* ```ts
150+
* const result = await db
151+
* .selectFrom('person')
152+
* .select((eb) => [
153+
* eb.fn
154+
* .agg<string>('mode')
155+
* .withinGroupOrderBy('person.first_name')
156+
* .as('most_frequent_name')
157+
* ])
158+
* .executeTakeFirstOrThrow()
159+
* ```
160+
*
161+
* The generated SQL (PostgreSQL):
162+
*
163+
* ```sql
164+
* select mode() within group (order by "person"."first_name") as "most_frequent_name"
165+
* from "person"
166+
* ```
167+
*/
168+
withinGroupOrderBy<OE extends SimpleReferenceExpression<DB, TB>>(
169+
orderBy: OE,
170+
direction?: OrderByDirectionExpression,
171+
): AggregateFunctionBuilder<DB, TB, O> {
172+
return new AggregateFunctionBuilder({
173+
...this.#props,
174+
aggregateFunctionNode: AggregateFunctionNode.cloneWithOrderBy(
175+
this.#props.aggregateFunctionNode,
176+
parseOrderBy([orderBy, direction]),
177+
true,
178+
),
179+
})
180+
}
181+
141182
/**
142183
* Adds a `filter` clause with a nested `where` clause after the function.
143184
*

src/query-compiler/default-query-compiler.ts

+6
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,12 @@ export class DefaultQueryCompiler
14181418

14191419
this.append(')')
14201420

1421+
if (node.withinGroup) {
1422+
this.append(' within group (')
1423+
this.visitNode(node.withinGroup)
1424+
this.append(')')
1425+
}
1426+
14211427
if (node.filter) {
14221428
this.append(' filter(')
14231429
this.visitNode(node.filter)

test/node/src/aggregate-function.test.ts

+40-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
SimpleReferenceExpression,
55
ReferenceExpression,
66
sql,
7+
expressionBuilder,
78
} from '../../../'
89
import {
910
Database,
@@ -1108,8 +1109,12 @@ for (const dialect of DIALECTS) {
11081109
await query.execute()
11091110
})
11101111

1111-
describe(`should execute order-sensitive aggregate functions`, () => {
1112-
if (dialect === 'postgres' || dialect === 'mysql' || dialect === 'sqlite') {
1112+
describe('should execute order-sensitive aggregate functions', () => {
1113+
if (
1114+
dialect === 'postgres' ||
1115+
dialect === 'mysql' ||
1116+
dialect === 'sqlite'
1117+
) {
11131118
const isMySql = dialect === 'mysql'
11141119
const funcName = isMySql ? 'group_concat' : 'string_agg'
11151120
const funcArgs: Array<ReferenceExpression<Database, 'person'>> = [
@@ -1157,6 +1162,39 @@ for (const dialect of DIALECTS) {
11571162
await query.execute()
11581163
})
11591164
}
1165+
1166+
if (dialect === 'postgres' || dialect === 'mssql') {
1167+
it(`should execute a query with within group (order by column) in select clause`, async () => {
1168+
const query = ctx.db.selectFrom('toy').select((eb) =>
1169+
eb.fn
1170+
.agg('percentile_cont', [sql.lit(0.5)])
1171+
.withinGroupOrderBy('toy.price')
1172+
.$call((ab) => (dialect === 'mssql' ? ab.over() : ab))
1173+
.as('median_price'),
1174+
)
1175+
1176+
testSql(query, dialect, {
1177+
postgres: {
1178+
sql: [
1179+
`select percentile_cont(0.5) within group (order by "toy"."price") as "median_price"`,
1180+
`from "toy"`,
1181+
],
1182+
parameters: [],
1183+
},
1184+
mysql: NOT_SUPPORTED,
1185+
mssql: {
1186+
sql: [
1187+
`select percentile_cont(0.5) within group (order by "toy"."price") over() as "median_price"`,
1188+
`from "toy"`,
1189+
],
1190+
parameters: [],
1191+
},
1192+
sqlite: NOT_SUPPORTED,
1193+
})
1194+
1195+
await query.execute()
1196+
})
1197+
}
11601198
})
11611199
})
11621200
}

0 commit comments

Comments
 (0)