Skip to content

Commit d5c5258

Browse files
async/await prepared statement API (#390)
This patch adds a new `PreparedStatement` protocol to represent prepared statements and an `execute` function on `PostgresConnection` to prepare and execute statements. To implement the features the patch also introduces a `PreparedStatementStateMachine` that keeps track of the state of a prepared statement at the connection level. This ensures that, for each connection, each statement is prepared once at time of first use and then subsequent uses are going to skip the preparation step and just execute it. ## Example usage First define the struct to represent the prepared statement: ```swift struct ExamplePreparedStatement: PreparedStatement { static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1" typealias Row = (Int, String) var state: String func makeBindings() -> PostgresBindings { var bindings = PostgresBindings() bindings.append(self.state) return bindings } func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { try row.decode(Row.self) } } ``` then, assuming you already have a `PostgresConnection` you can execute it: ```swift let preparedStatement = ExamplePreparedStatement(state: "active") let results = try await connection.execute(preparedStatement, logger: logger) for (pid, database) in results { print("PID: \(pid), database: \(database)") } ``` --------- Co-authored-by: Fabian Fett <[email protected]>
1 parent 5217ba7 commit d5c5258

File tree

10 files changed

+898
-4
lines changed

10 files changed

+898
-4
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,72 @@ extension PostgresConnection {
460460
self.channel.write(task, promise: nil)
461461
}
462462
}
463+
464+
/// Execute a prepared statement, taking care of the preparation when necessary
465+
public func execute<Statement: PostgresPreparedStatement, Row>(
466+
_ preparedStatement: Statement,
467+
logger: Logger,
468+
file: String = #fileID,
469+
line: Int = #line
470+
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, Row> where Row == Statement.Row {
471+
let bindings = try preparedStatement.makeBindings()
472+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
473+
let task = HandlerTask.executePreparedStatement(.init(
474+
name: String(reflecting: Statement.self),
475+
sql: Statement.sql,
476+
bindings: bindings,
477+
logger: logger,
478+
promise: promise
479+
))
480+
self.channel.write(task, promise: nil)
481+
do {
482+
return try await promise.futureResult
483+
.map { $0.asyncSequence() }
484+
.get()
485+
.map { try preparedStatement.decodeRow($0) }
486+
} catch var error as PSQLError {
487+
error.file = file
488+
error.line = line
489+
error.query = .init(
490+
unsafeSQL: Statement.sql,
491+
binds: bindings
492+
)
493+
throw error // rethrow with more metadata
494+
}
495+
496+
}
497+
498+
/// Execute a prepared statement, taking care of the preparation when necessary
499+
public func execute<Statement: PostgresPreparedStatement>(
500+
_ preparedStatement: Statement,
501+
logger: Logger,
502+
file: String = #fileID,
503+
line: Int = #line
504+
) async throws -> String where Statement.Row == () {
505+
let bindings = try preparedStatement.makeBindings()
506+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
507+
let task = HandlerTask.executePreparedStatement(.init(
508+
name: String(reflecting: Statement.self),
509+
sql: Statement.sql,
510+
bindings: bindings,
511+
logger: logger,
512+
promise: promise
513+
))
514+
self.channel.write(task, promise: nil)
515+
do {
516+
return try await promise.futureResult
517+
.map { $0.commandTag }
518+
.get()
519+
} catch var error as PSQLError {
520+
error.file = file
521+
error.line = line
522+
error.query = .init(
523+
unsafeSQL: Statement.sql,
524+
binds: bindings
525+
)
526+
throw error // rethrow with more metadata
527+
}
528+
}
463529
}
464530

465531
// MARK: EventLoopFuture interface
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import NIOCore
2+
3+
struct PreparedStatementStateMachine {
4+
enum State {
5+
case preparing([PreparedStatementContext])
6+
case prepared(RowDescription?)
7+
case error(PSQLError)
8+
}
9+
10+
var preparedStatements: [String: State] = [:]
11+
12+
enum LookupAction {
13+
case prepareStatement
14+
case waitForAlreadyInFlightPreparation
15+
case executeStatement(RowDescription?)
16+
case returnError(PSQLError)
17+
}
18+
19+
mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction {
20+
if let state = self.preparedStatements[preparedStatement.name] {
21+
switch state {
22+
case .preparing(var statements):
23+
statements.append(preparedStatement)
24+
self.preparedStatements[preparedStatement.name] = .preparing(statements)
25+
return .waitForAlreadyInFlightPreparation
26+
case .prepared(let rowDescription):
27+
return .executeStatement(rowDescription)
28+
case .error(let error):
29+
return .returnError(error)
30+
}
31+
} else {
32+
self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement])
33+
return .prepareStatement
34+
}
35+
}
36+
37+
struct PreparationCompleteAction {
38+
var statements: [PreparedStatementContext]
39+
var rowDescription: RowDescription?
40+
}
41+
42+
mutating func preparationComplete(
43+
name: String,
44+
rowDescription: RowDescription?
45+
) -> PreparationCompleteAction {
46+
guard let state = self.preparedStatements[name] else {
47+
fatalError("Unknown prepared statement \(name)")
48+
}
49+
switch state {
50+
case .preparing(let statements):
51+
// When sending the bindings we are going to ask for binary data.
52+
if var rowDescription = rowDescription {
53+
for i in 0..<rowDescription.columns.count {
54+
rowDescription.columns[i].format = .binary
55+
}
56+
self.preparedStatements[name] = .prepared(rowDescription)
57+
return PreparationCompleteAction(
58+
statements: statements,
59+
rowDescription: rowDescription
60+
)
61+
} else {
62+
self.preparedStatements[name] = .prepared(nil)
63+
return PreparationCompleteAction(
64+
statements: statements,
65+
rowDescription: nil
66+
)
67+
}
68+
case .prepared, .error:
69+
preconditionFailure("Preparation completed happened in an unexpected state \(state)")
70+
}
71+
}
72+
73+
struct ErrorHappenedAction {
74+
var statements: [PreparedStatementContext]
75+
var error: PSQLError
76+
}
77+
78+
mutating func errorHappened(name: String, error: PSQLError) -> ErrorHappenedAction {
79+
guard let state = self.preparedStatements[name] else {
80+
fatalError("Unknown prepared statement \(name)")
81+
}
82+
switch state {
83+
case .preparing(let statements):
84+
self.preparedStatements[name] = .error(error)
85+
return ErrorHappenedAction(
86+
statements: statements,
87+
error: error
88+
)
89+
case .prepared, .error:
90+
preconditionFailure("Error happened in an unexpected state \(state)")
91+
}
92+
}
93+
}

Sources/PostgresNIO/New/PSQLTask.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ enum HandlerTask {
66
case closeCommand(CloseCommandContext)
77
case startListening(NotificationListener)
88
case cancelListening(String, Int)
9+
case executePreparedStatement(PreparedStatementContext)
910
}
1011

1112
enum PSQLTask {
@@ -69,6 +70,28 @@ final class ExtendedQueryContext {
6970
}
7071
}
7172

73+
final class PreparedStatementContext{
74+
let name: String
75+
let sql: String
76+
let bindings: PostgresBindings
77+
let logger: Logger
78+
let promise: EventLoopPromise<PSQLRowStream>
79+
80+
init(
81+
name: String,
82+
sql: String,
83+
bindings: PostgresBindings,
84+
logger: Logger,
85+
promise: EventLoopPromise<PSQLRowStream>
86+
) {
87+
self.name = name
88+
self.sql = sql
89+
self.bindings = bindings
90+
self.logger = logger
91+
self.promise = promise
92+
}
93+
}
94+
7295
final class CloseCommandContext {
7396
let target: CloseTarget
7497
let logger: Logger

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
2222
private let configuration: PostgresConnection.InternalConfiguration
2323
private let configureSSLCallback: ((Channel) throws -> Void)?
2424

25-
private var listenState: ListenStateMachine
25+
private var listenState = ListenStateMachine()
26+
private var preparedStatementState = PreparedStatementStateMachine()
2627

2728
init(
2829
configuration: PostgresConnection.InternalConfiguration,
@@ -32,7 +33,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
3233
) {
3334
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData)
3435
self.eventLoop = eventLoop
35-
self.listenState = ListenStateMachine()
3636
self.configuration = configuration
3737
self.configureSSLCallback = configureSSLCallback
3838
self.logger = logger
@@ -50,7 +50,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
5050
) {
5151
self.state = state
5252
self.eventLoop = eventLoop
53-
self.listenState = ListenStateMachine()
5453
self.configuration = configuration
5554
self.configureSSLCallback = configureSSLCallback
5655
self.logger = logger
@@ -233,6 +232,29 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
233232
listener.failed(CancellationError())
234233
return
235234
}
235+
case .executePreparedStatement(let preparedStatement):
236+
let action = self.preparedStatementState.lookup(
237+
preparedStatement: preparedStatement
238+
)
239+
switch action {
240+
case .prepareStatement:
241+
psqlTask = self.makePrepareStatementTask(
242+
preparedStatement: preparedStatement,
243+
context: context
244+
)
245+
case .waitForAlreadyInFlightPreparation:
246+
// The state machine already keeps track of this
247+
// and will execute the statement as soon as it's prepared
248+
return
249+
case .executeStatement(let rowDescription):
250+
psqlTask = self.makeExecutePreparedStatementTask(
251+
preparedStatement: preparedStatement,
252+
rowDescription: rowDescription
253+
)
254+
case .returnError(let error):
255+
preparedStatement.promise.fail(error)
256+
return
257+
}
236258
}
237259

238260
let action = self.state.enqueue(task: psqlTask)
@@ -664,6 +686,93 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
664686
}
665687
}
666688

689+
private func makePrepareStatementTask(
690+
preparedStatement: PreparedStatementContext,
691+
context: ChannelHandlerContext
692+
) -> PSQLTask {
693+
let promise = self.eventLoop.makePromise(of: RowDescription?.self)
694+
promise.futureResult.whenComplete { result in
695+
switch result {
696+
case .success(let rowDescription):
697+
self.prepareStatementComplete(
698+
name: preparedStatement.name,
699+
rowDescription: rowDescription,
700+
context: context
701+
)
702+
case .failure(let error):
703+
let psqlError: PSQLError
704+
if let error = error as? PSQLError {
705+
psqlError = error
706+
} else {
707+
psqlError = .connectionError(underlying: error)
708+
}
709+
self.prepareStatementFailed(
710+
name: preparedStatement.name,
711+
error: psqlError,
712+
context: context
713+
)
714+
}
715+
}
716+
return .extendedQuery(.init(
717+
name: preparedStatement.name,
718+
query: preparedStatement.sql,
719+
logger: preparedStatement.logger,
720+
promise: promise
721+
))
722+
}
723+
724+
private func makeExecutePreparedStatementTask(
725+
preparedStatement: PreparedStatementContext,
726+
rowDescription: RowDescription?
727+
) -> PSQLTask {
728+
return .extendedQuery(.init(
729+
executeStatement: .init(
730+
name: preparedStatement.name,
731+
binds: preparedStatement.bindings,
732+
rowDescription: rowDescription
733+
),
734+
logger: preparedStatement.logger,
735+
promise: preparedStatement.promise
736+
))
737+
}
738+
739+
private func prepareStatementComplete(
740+
name: String,
741+
rowDescription: RowDescription?,
742+
context: ChannelHandlerContext
743+
) {
744+
let action = self.preparedStatementState.preparationComplete(
745+
name: name,
746+
rowDescription: rowDescription
747+
)
748+
for preparedStatement in action.statements {
749+
let action = self.state.enqueue(task: .extendedQuery(.init(
750+
executeStatement: .init(
751+
name: preparedStatement.name,
752+
binds: preparedStatement.bindings,
753+
rowDescription: action.rowDescription
754+
),
755+
logger: preparedStatement.logger,
756+
promise: preparedStatement.promise
757+
))
758+
)
759+
self.run(action, with: context)
760+
}
761+
}
762+
763+
private func prepareStatementFailed(
764+
name: String,
765+
error: PSQLError,
766+
context: ChannelHandlerContext
767+
) {
768+
let action = self.preparedStatementState.errorHappened(
769+
name: name,
770+
error: error
771+
)
772+
for statement in action.statements {
773+
statement.promise.fail(action.error)
774+
}
775+
}
667776
}
668777

669778
extension PostgresChannelHandler: PSQLRowsDataSource {

Sources/PostgresNIO/New/PostgresQuery.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ public struct PostgresBindings: Sendable, Hashable {
167167
self.metadata.append(.init(dataType: .null, format: .binary, protected: true))
168168
}
169169

170+
@inlinable
171+
public mutating func append<Value: PostgresEncodable>(_ value: Value) throws {
172+
try self.append(value, context: .default)
173+
}
174+
170175
@inlinable
171176
public mutating func append<Value: PostgresEncodable, JSONEncoder: PostgresJSONEncoder>(
172177
_ value: Value,
@@ -176,6 +181,11 @@ public struct PostgresBindings: Sendable, Hashable {
176181
self.metadata.append(.init(value: value, protected: true))
177182
}
178183

184+
@inlinable
185+
public mutating func append<Value: PostgresNonThrowingEncodable>(_ value: Value) {
186+
self.append(value, context: .default)
187+
}
188+
179189
@inlinable
180190
public mutating func append<Value: PostgresNonThrowingEncodable, JSONEncoder: PostgresJSONEncoder>(
181191
_ value: Value,

0 commit comments

Comments
 (0)