Skip to content

Commit 261a531

Browse files
PR feedback and state machine tests
1 parent dfacd3e commit 261a531

File tree

4 files changed

+284
-91
lines changed

4 files changed

+284
-91
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -468,19 +468,31 @@ extension PostgresConnection {
468468
file: String = #fileID,
469469
line: Int = #line
470470
) async throws -> AsyncThrowingMapSequence<PostgresRowSequence, Statement.Row> {
471+
let bindings = preparedStatement.makeBindings()
471472
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
472473
let task = HandlerTask.executePreparedStatement(.init(
473474
name: String(reflecting: Statement.self),
474475
sql: Statement.sql,
475-
bindings: preparedStatement.makeBindings(),
476+
bindings: bindings,
476477
logger: logger,
477478
promise: promise
478479
))
479480
self.channel.write(task, promise: nil)
480-
return try await promise.futureResult
481-
.map { $0.asyncSequence() }
482-
.get()
483-
.map { try preparedStatement.decodeRow($0) }
481+
do {
482+
return try await promise.futureResult
483+
.map { $0.asyncSequence() }
484+
.get()
485+
.map { try preparedStatement.decodeRow($0) }
486+
} catch var error as PSQLError {
487+
error.file = file
488+
error.line = line
489+
error.query = .init(
490+
unsafeSQL: Statement.sql,
491+
binds: bindings
492+
)
493+
throw error // rethrow with more metadata
494+
}
495+
484496
}
485497

486498
/// Execute a prepared statement, taking care of the preparation when necessary
@@ -490,18 +502,30 @@ extension PostgresConnection {
490502
file: String = #fileID,
491503
line: Int = #line
492504
) async throws -> String where Statement.Row == () {
505+
let bindings = preparedStatement.makeBindings()
493506
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
494507
let task = HandlerTask.executePreparedStatement(.init(
495508
name: String(reflecting: Statement.self),
496509
sql: Statement.sql,
497-
bindings: preparedStatement.makeBindings(),
510+
bindings: bindings,
498511
logger: logger,
499512
promise: promise
500513
))
501514
self.channel.write(task, promise: nil)
502-
return try await promise.futureResult
503-
.map { $0.commandTag }
504-
.get()
515+
do {
516+
return try await promise.futureResult
517+
.map { $0.commandTag }
518+
.get()
519+
} catch var error as PSQLError {
520+
error.file = file
521+
error.line = line
522+
error.query = .init(
523+
unsafeSQL: Statement.sql,
524+
binds: bindings
525+
)
526+
throw error // rethrow with more metadata
527+
}
528+
505529
}
506530
}
507531

Sources/PostgresNIO/New/Connection State Machine/PreparedStatementStateMachine.swift

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,36 @@ struct PreparedStatementStateMachine {
77
case error(PSQLError)
88
}
99

10-
var preparedStatements: [String: State]
10+
var preparedStatements: [String: State] = [:]
1111

12-
init() {
13-
self.preparedStatements = [:]
14-
}
15-
1612
enum LookupAction {
1713
case prepareStatement
1814
case waitForAlreadyInFlightPreparation
1915
case executeStatement(RowDescription?)
20-
case executePendingStatements([PreparedStatementContext], RowDescription?)
21-
case returnError([PreparedStatementContext], PSQLError)
16+
case returnError(PSQLError)
2217
}
2318

24-
mutating func lookup(name: String, context: PreparedStatementContext) -> LookupAction {
25-
if let state = self.preparedStatements[name] {
19+
mutating func lookup(preparedStatement: PreparedStatementContext) -> LookupAction {
20+
if let state = self.preparedStatements[preparedStatement.name] {
2621
switch state {
2722
case .preparing(var statements):
28-
statements.append(context)
29-
self.preparedStatements[name] = .preparing(statements)
23+
statements.append(preparedStatement)
24+
self.preparedStatements[preparedStatement.name] = .preparing(statements)
3025
return .waitForAlreadyInFlightPreparation
3126
case .prepared(let rowDescription):
3227
return .executeStatement(rowDescription)
3328
case .error(let error):
34-
return .returnError([context], error)
29+
return .returnError(error)
3530
}
3631
} else {
37-
self.preparedStatements[name] = .preparing([context])
32+
self.preparedStatements[preparedStatement.name] = .preparing([preparedStatement])
3833
return .prepareStatement
3934
}
4035
}
4136

42-
enum PreparationCompleteAction {
43-
case executePendingStatements([PreparedStatementContext], RowDescription?)
37+
struct PreparationCompleteAction {
38+
var statements: [PreparedStatementContext]
39+
var rowDescription: RowDescription?
4440
}
4541

4642
mutating func preparationComplete(
@@ -56,22 +52,32 @@ struct PreparedStatementStateMachine {
5652
rowDescription.columns[i].format = .binary
5753
}
5854
self.preparedStatements[name] = .prepared(rowDescription)
59-
return .executePendingStatements(statements, rowDescription)
55+
return PreparationCompleteAction(
56+
statements: statements,
57+
rowDescription: rowDescription
58+
)
6059
} else {
6160
self.preparedStatements[name] = .prepared(nil)
62-
return .executePendingStatements(statements, nil)
61+
return PreparationCompleteAction(
62+
statements: statements,
63+
rowDescription: nil
64+
)
6365
}
6466
}
6567

66-
enum ErrorHappenedAction {
67-
case returnError([PreparedStatementContext], PSQLError)
68+
struct ErrorHappenedAction {
69+
var statements: [PreparedStatementContext]
70+
var error: PSQLError
6871
}
6972

7073
mutating func errorHappened(name: String, error: PSQLError) -> ErrorHappenedAction {
7174
guard case .preparing(let statements) = self.preparedStatements[name] else {
7275
preconditionFailure("Preparation completed for an unexpected statement")
7376
}
7477
self.preparedStatements[name] = .error(error)
75-
return .returnError(statements, error)
78+
return ErrorHappenedAction(
79+
statements: statements,
80+
error: error
81+
)
7682
}
7783
}

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
2222
private let configuration: PostgresConnection.InternalConfiguration
2323
private let configureSSLCallback: ((Channel) throws -> Void)?
2424

25-
private var listenState: ListenStateMachine
26-
private var preparedStatementState: PreparedStatementStateMachine
25+
private var listenState = ListenStateMachine()
26+
private var preparedStatementState = PreparedStatementStateMachine()
2727

2828
init(
2929
configuration: PostgresConnection.InternalConfiguration,
@@ -33,8 +33,6 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
3333
) {
3434
self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData)
3535
self.eventLoop = eventLoop
36-
self.listenState = ListenStateMachine()
37-
self.preparedStatementState = PreparedStatementStateMachine()
3836
self.configuration = configuration
3937
self.configureSSLCallback = configureSSLCallback
4038
self.logger = logger
@@ -238,62 +236,25 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
238236
}
239237
case .executePreparedStatement(let preparedStatement):
240238
let action = self.preparedStatementState.lookup(
241-
name: preparedStatement.name,
242-
context: preparedStatement
239+
preparedStatement: preparedStatement
243240
)
244241
switch action {
245242
case .prepareStatement:
246-
let promise = self.eventLoop.makePromise(of: RowDescription?.self)
247-
promise.futureResult.whenSuccess { rowDescription in
248-
self.prepareStatementComplete(
249-
name: preparedStatement.name,
250-
rowDescription: rowDescription,
251-
context: context
252-
)
253-
}
254-
promise.futureResult.whenFailure { error in
255-
self.prepareStatementFailed(
256-
name: preparedStatement.name,
257-
error: error as! PSQLError,
258-
context: context
259-
)
260-
}
261-
psqlTask = .extendedQuery(.init(
262-
name: preparedStatement.name,
263-
query: preparedStatement.sql,
264-
logger: preparedStatement.logger,
265-
promise: promise
266-
))
243+
psqlTask = self.makePrepareStatementAction(
244+
preparedStatement: preparedStatement,
245+
context: context
246+
)
267247
case .waitForAlreadyInFlightPreparation:
268248
// The state machine already keeps track of this
269249
// and will execute the statement as soon as it's prepared
270250
return
271251
case .executeStatement(let rowDescription):
272-
psqlTask = .extendedQuery(.init(
273-
executeStatement: .init(
274-
name: preparedStatement.name,
275-
binds: preparedStatement.bindings,
276-
rowDescription: rowDescription),
277-
logger: preparedStatement.logger,
278-
promise: preparedStatement.promise
279-
))
280-
case .executePendingStatements(let pendingStatements, let rowDescription):
281-
for statement in pendingStatements {
282-
let action = self.state.enqueue(task: .extendedQuery(.init(
283-
executeStatement: .init(
284-
name: statement.name,
285-
binds: statement.bindings,
286-
rowDescription: rowDescription),
287-
logger: statement.logger,
288-
promise: statement.promise
289-
)))
290-
self.run(action, with: context)
291-
}
292-
return
293-
case .returnError(let pendingStatements, let error):
294-
for statement in pendingStatements {
295-
statement.promise.fail(error)
296-
}
252+
psqlTask = self.makeExecutPreparedStatementAction(
253+
preparedStatement: preparedStatement,
254+
rowDescription: rowDescription
255+
)
256+
case .returnError(let error):
257+
preparedStatement.promise.fail(error)
297258
return
298259
}
299260
}
@@ -727,6 +688,55 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
727688
}
728689
}
729690

691+
private func makePrepareStatementAction(
692+
preparedStatement: PreparedStatementContext,
693+
context: ChannelHandlerContext
694+
) -> PSQLTask {
695+
let promise = self.eventLoop.makePromise(of: RowDescription?.self)
696+
promise.futureResult.whenComplete { result in
697+
switch result {
698+
case .success(let rowDescription):
699+
self.prepareStatementComplete(
700+
name: preparedStatement.name,
701+
rowDescription: rowDescription,
702+
context: context
703+
)
704+
case .failure(let error):
705+
let psqlError: PSQLError
706+
if let error = error as? PSQLError {
707+
psqlError = error
708+
} else {
709+
psqlError = .connectionError(underlying: error)
710+
}
711+
self.prepareStatementFailed(
712+
name: preparedStatement.name,
713+
error: psqlError,
714+
context: context
715+
)
716+
}
717+
}
718+
return .extendedQuery(.init(
719+
name: preparedStatement.name,
720+
query: preparedStatement.sql,
721+
logger: preparedStatement.logger,
722+
promise: promise
723+
))
724+
}
725+
726+
private func makeExecutPreparedStatementAction(
727+
preparedStatement: PreparedStatementContext,
728+
rowDescription: RowDescription?
729+
) -> PSQLTask {
730+
return .extendedQuery(.init(
731+
executeStatement: .init(
732+
name: preparedStatement.name,
733+
binds: preparedStatement.bindings,
734+
rowDescription: rowDescription),
735+
logger: preparedStatement.logger,
736+
promise: preparedStatement.promise
737+
))
738+
}
739+
730740
private func prepareStatementComplete(
731741
name: String,
732742
rowDescription: RowDescription?,
@@ -736,15 +746,12 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
736746
name: name,
737747
rowDescription: rowDescription
738748
)
739-
guard case .executePendingStatements(let statements, let rowDescription) = action else {
740-
preconditionFailure("Expected to have pending statements to execute")
741-
}
742-
for preparedStatement in statements {
749+
for preparedStatement in action.statements {
743750
let action = self.state.enqueue(task: .extendedQuery(.init(
744751
executeStatement: .init(
745752
name: preparedStatement.name,
746753
binds: preparedStatement.bindings,
747-
rowDescription: rowDescription
754+
rowDescription: action.rowDescription
748755
),
749756
logger: preparedStatement.logger,
750757
promise: preparedStatement.promise
@@ -763,11 +770,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
763770
name: name,
764771
error: error
765772
)
766-
guard case .returnError(let statements, let error) = action else {
767-
preconditionFailure("Expected to have pending statements to execute")
768-
}
769-
for statement in statements {
770-
statement.promise.fail(error)
773+
for statement in action.statements {
774+
statement.promise.fail(action.error)
771775
}
772776
}
773777
}

0 commit comments

Comments
 (0)