Skip to content

Commit 58b4703

Browse files
Initial implementation
1 parent 329ce83 commit 58b4703

File tree

6 files changed

+259
-0
lines changed

6 files changed

+259
-0
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,27 @@ extension PostgresConnection {
448448
self.channel.write(task, promise: nil)
449449
}
450450
}
451+
452+
/// Execute a prepared statement, taking care of the preparation when necessary
453+
public func execute<P: PreparedStatement>(
454+
_ preparedStatement: P,
455+
logger: Logger
456+
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, P.Row>
457+
{
458+
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
459+
let task = HandlerTask.executePreparedStatement(.init(
460+
name: String(reflecting: P.self),
461+
sql: P.sql,
462+
bindings: preparedStatement.makeBindings(),
463+
logger: logger,
464+
promise: promise
465+
))
466+
self.channel.write(task, promise: nil)
467+
return try await promise.futureResult
468+
.map { $0.asyncSequence() }
469+
.get()
470+
.map { try preparedStatement.decodeRow($0) }
471+
}
451472
}
452473

453474
// MARK: EventLoopFuture interface
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import NIOCore
2+
3+
struct PreparedStatementStateMachine {
4+
enum State {
5+
case preparing([PreparedStatementContext])
6+
case prepared(RowDescription?)
7+
// TODO: store errors
8+
// case error(PostgresBackendMessage.ErrorResponse)
9+
}
10+
11+
enum Action {
12+
case none
13+
case prepareStatement
14+
case waitForAlreadyInFlightPreparation
15+
case executePendingStatements([PreparedStatementContext], RowDescription?)
16+
}
17+
18+
var preparedStatements: [String: State]
19+
20+
init() {
21+
self.preparedStatements = [:]
22+
}
23+
24+
mutating func lookup(name: String, context: PreparedStatementContext) -> Action {
25+
if let state = self.preparedStatements[name] {
26+
switch state {
27+
case .preparing(var statements):
28+
statements.append(context)
29+
self.preparedStatements[name] = .preparing(statements)
30+
return .waitForAlreadyInFlightPreparation
31+
case .prepared(let rowDescription):
32+
return .executePendingStatements([context], rowDescription)
33+
}
34+
} else {
35+
self.preparedStatements[name] = .preparing([context])
36+
return .prepareStatement
37+
}
38+
}
39+
40+
mutating func preparationComplete(
41+
name: String,
42+
rowDescription: RowDescription?
43+
) -> Action {
44+
guard let state = self.preparedStatements[name] else {
45+
preconditionFailure("Preparation completed for an unexpected statement")
46+
}
47+
switch state {
48+
case .preparing(let statements):
49+
// When sending the bindings we are going to ask for binary data.
50+
if var rowDescription {
51+
for i in 0..<rowDescription.columns.count {
52+
rowDescription.columns[i].format = .binary
53+
}
54+
self.preparedStatements[name] = .prepared(rowDescription)
55+
return .executePendingStatements(statements, rowDescription)
56+
} else {
57+
self.preparedStatements[name] = .prepared(nil)
58+
return .executePendingStatements(statements, nil)
59+
}
60+
case .prepared(_):
61+
return .none
62+
}
63+
}
64+
}

Sources/PostgresNIO/New/PSQLTask.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import NIOCore
44
enum HandlerTask {
55
case extendedQuery(ExtendedQueryContext)
66
case preparedStatement(PrepareStatementContext)
7+
case executePreparedStatement(PreparedStatementContext)
78
case closeCommand(CloseCommandContext)
89
case startListening(NotificationListener)
910
case cancelListening(String, Int)
@@ -76,6 +77,28 @@ final class PrepareStatementContext {
7677
}
7778
}
7879

80+
final class PreparedStatementContext{
81+
let name: String
82+
let sql: String
83+
let bindings: PostgresBindings
84+
let logger: Logger
85+
let promise: EventLoopPromise<PSQLRowStream>
86+
87+
init(
88+
name: String,
89+
sql: String,
90+
bindings: PostgresBindings,
91+
logger: Logger,
92+
promise: EventLoopPromise<PSQLRowStream>
93+
) {
94+
self.name = name
95+
self.sql = sql
96+
self.bindings = bindings
97+
self.logger = logger
98+
self.promise = promise
99+
}
100+
}
101+
79102
final class CloseCommandContext {
80103
let target: CloseTarget
81104
let logger: Logger

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
2323
private let configureSSLCallback: ((Channel) throws -> Void)?
2424

2525
private var listenState: ListenStateMachine
26+
private var preparedStatementState: PreparedStatementStateMachine
2627

2728
init(
2829
configuration: PostgresConnection.InternalConfiguration,
@@ -33,6 +34,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
3334
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData)
3435
self.eventLoop = eventLoop
3536
self.listenState = ListenStateMachine()
37+
self.preparedStatementState = PreparedStatementStateMachine()
3638
self.configuration = configuration
3739
self.configureSSLCallback = configureSSLCallback
3840
self.logger = logger
@@ -51,6 +53,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
5153
self.state = state
5254
self.eventLoop = eventLoop
5355
self.listenState = ListenStateMachine()
56+
self.preparedStatementState = PreparedStatementStateMachine()
5457
self.configuration = configuration
5558
self.configureSSLCallback = configureSSLCallback
5659
self.logger = logger
@@ -235,6 +238,46 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
235238
listener.failed(CancellationError())
236239
return
237240
}
241+
case .executePreparedStatement(let preparedStatement):
242+
switch self.preparedStatementState.lookup(
243+
name: preparedStatement.name,
244+
context: preparedStatement
245+
) {
246+
case .none:
247+
return
248+
case .prepareStatement:
249+
let promise = self.eventLoop.makePromise(of: RowDescription?.self)
250+
promise.futureResult.whenSuccess { rowDescription in
251+
self.prepareStatementComplete(
252+
name: preparedStatement.name,
253+
rowDescription: rowDescription,
254+
context: context
255+
)
256+
}
257+
psqlTask = .preparedStatement(.init(
258+
name: preparedStatement.name,
259+
query: preparedStatement.sql,
260+
logger: preparedStatement.logger,
261+
promise: promise
262+
))
263+
case .waitForAlreadyInFlightPreparation:
264+
// The state machine already keeps track of this
265+
// and will execute the statement as soon as it's prepared
266+
return
267+
case .executePendingStatements(let pendingStatements, let rowDescription):
268+
for statement in pendingStatements {
269+
let action = self.state.enqueue(task: .extendedQuery(.init(
270+
executeStatement: .init(
271+
name: statement.name,
272+
binds: statement.bindings,
273+
rowDescription: rowDescription),
274+
logger: statement.logger,
275+
promise: statement.promise
276+
)))
277+
self.run(action, with: context)
278+
}
279+
return
280+
}
238281
}
239282

240283
let action = self.state.enqueue(task: psqlTask)
@@ -666,6 +709,32 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
666709
}
667710
}
668711

712+
private func prepareStatementComplete(
713+
name: String,
714+
rowDescription: RowDescription?,
715+
context: ChannelHandlerContext
716+
) {
717+
let action = self.preparedStatementState.preparationComplete(
718+
name: name,
719+
rowDescription: rowDescription
720+
)
721+
guard case .executePendingStatements(let statements, let rowDescription) = action else {
722+
preconditionFailure("Expected to have pending statements to execute")
723+
}
724+
for preparedStatement in statements {
725+
let action = self.state.enqueue(task: .extendedQuery(.init(
726+
executeStatement: .init(
727+
name: preparedStatement.name,
728+
binds: preparedStatement.bindings,
729+
rowDescription: rowDescription
730+
),
731+
logger: preparedStatement.logger,
732+
promise: preparedStatement.promise
733+
))
734+
)
735+
self.run(action, with: context)
736+
}
737+
}
669738
}
670739

671740
extension PostgresChannelHandler: PSQLRowsDataSource {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/// A prepared statement.
2+
///
3+
/// Structs conforming to this protocol will need to provide the SQL statement to
4+
/// send to the server and a way of creating bindings are decoding the result.
5+
///
6+
/// As an example, consider this struct:
7+
/// ```swift
8+
/// struct Example: PreparedStatement {
9+
/// static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
10+
/// typealias Row = (Int, String)
11+
///
12+
/// var state: String
13+
///
14+
/// func makeBindings() -> PostgresBindings {
15+
/// var bindings = PostgresBindings()
16+
/// bindings.append(.init(string: self.state))
17+
/// return bindings
18+
/// }
19+
///
20+
/// func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
21+
/// try row.decode(Row.self)
22+
/// }
23+
/// }
24+
/// ```
25+
///
26+
/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`,
27+
/// which will take care of preparing the statement on the server side and executing it.
28+
public protocol PreparedStatement {
29+
/// The type rows returned by the statement will be decoded into
30+
associatedtype Row
31+
32+
/// The SQL statement to prepare on the database server.
33+
static var sql: String { get }
34+
35+
/// Make the bindings to provided concrete values to use when executing the prepared SQL statement
36+
func makeBindings() -> PostgresBindings
37+
38+
/// Decode a row returned by the database into an instance of `Row`
39+
func decodeRow(_ row: PostgresRow) throws -> Row
40+
}

Tests/IntegrationTests/AsyncTests.swift

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,48 @@ final class AsyncPostgresConnectionTests: XCTestCase {
315315
try await connection.query("SELECT 1;", logger: .psqlTest)
316316
}
317317
}
318+
319+
func testPreparedStatement() async throws {
320+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
321+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
322+
let eventLoop = eventLoopGroup.next()
323+
324+
struct TestPreparedStatement: PreparedStatement {
325+
static var sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
326+
typealias Row = (Int, String)
327+
328+
var state: String
329+
330+
func makeBindings() -> PostgresBindings {
331+
var bindings = PostgresBindings()
332+
bindings.append(.init(string: self.state))
333+
return bindings
334+
}
335+
336+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
337+
try row.decode(Row.self)
338+
}
339+
}
340+
let preparedStatement = TestPreparedStatement(state: "active")
341+
try await withTestConnection(on: eventLoop) { connection in
342+
var results = try await connection.execute(preparedStatement, logger: .psqlTest)
343+
var counter = 0
344+
345+
for try await element in results {
346+
XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database")
347+
counter += 1
348+
}
349+
350+
XCTAssertGreaterThanOrEqual(counter, 1)
351+
352+
// Second execution, which reuses the existing prepared statement
353+
results = try await connection.execute(preparedStatement, logger: .psqlTest)
354+
for try await element in results {
355+
XCTAssertEqual(element.1, env("POSTGRES_DB") ?? "test_database")
356+
counter += 1
357+
}
358+
}
359+
}
318360
}
319361

320362
extension XCTestCase {

0 commit comments

Comments
 (0)