@@ -23,6 +23,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
23
23
private let configureSSLCallback : ( ( Channel ) throws -> Void ) ?
24
24
25
25
private var listenState : ListenStateMachine
26
+ private var preparedStatementState : PreparedStatementStateMachine
26
27
27
28
init (
28
29
configuration: PostgresConnection . InternalConfiguration ,
@@ -33,6 +34,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
33
34
self . state = ConnectionStateMachine ( requireBackendKeyData: configuration. options. requireBackendKeyData)
34
35
self . eventLoop = eventLoop
35
36
self . listenState = ListenStateMachine ( )
37
+ self . preparedStatementState = PreparedStatementStateMachine ( )
36
38
self . configuration = configuration
37
39
self . configureSSLCallback = configureSSLCallback
38
40
self . logger = logger
@@ -51,6 +53,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
51
53
self . state = state
52
54
self . eventLoop = eventLoop
53
55
self . listenState = ListenStateMachine ( )
56
+ self . preparedStatementState = PreparedStatementStateMachine ( )
54
57
self . configuration = configuration
55
58
self . configureSSLCallback = configureSSLCallback
56
59
self . logger = logger
@@ -233,6 +236,56 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
233
236
listener. failed ( CancellationError ( ) )
234
237
return
235
238
}
239
+ case . executePreparedStatement( let preparedStatement) :
240
+ switch self . preparedStatementState. lookup (
241
+ name: preparedStatement. name,
242
+ context: preparedStatement
243
+ ) {
244
+ case . prepareStatement:
245
+ let promise = self . eventLoop. makePromise ( of: RowDescription ? . self)
246
+ promise. futureResult. whenSuccess { rowDescription in
247
+ self . prepareStatementComplete (
248
+ name: preparedStatement. name,
249
+ rowDescription: rowDescription,
250
+ context: context
251
+ )
252
+ }
253
+ promise. futureResult. whenFailure { error in
254
+ self . prepareStatementFailed (
255
+ name: preparedStatement. name,
256
+ error: error as! PSQLError ,
257
+ context: context
258
+ )
259
+ }
260
+ psqlTask = . extendedQuery( . init(
261
+ name: preparedStatement. name,
262
+ query: preparedStatement. sql,
263
+ logger: preparedStatement. logger,
264
+ promise: promise
265
+ ) )
266
+ case . waitForAlreadyInFlightPreparation:
267
+ // The state machine already keeps track of this
268
+ // and will execute the statement as soon as it's prepared
269
+ return
270
+ case . executePendingStatements( let pendingStatements, let rowDescription) :
271
+ for statement in pendingStatements {
272
+ let action = self . state. enqueue ( task: . extendedQuery( . init(
273
+ executeStatement: . init(
274
+ name: statement. name,
275
+ binds: statement. bindings,
276
+ rowDescription: rowDescription) ,
277
+ logger: statement. logger,
278
+ promise: statement. promise
279
+ ) ) )
280
+ self . run ( action, with: context)
281
+ }
282
+ return
283
+ case . returnError( let pendingStatements, let error) :
284
+ for statement in pendingStatements {
285
+ statement. promise. fail ( error)
286
+ }
287
+ return
288
+ }
236
289
}
237
290
238
291
let action = self . state. enqueue ( task: psqlTask)
@@ -664,6 +717,49 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
664
717
}
665
718
}
666
719
720
+ private func prepareStatementComplete(
721
+ name: String ,
722
+ rowDescription: RowDescription ? ,
723
+ context: ChannelHandlerContext
724
+ ) {
725
+ let action = self . preparedStatementState. preparationComplete (
726
+ name: name,
727
+ rowDescription: rowDescription
728
+ )
729
+ guard case . executePendingStatements( let statements, let rowDescription) = action else {
730
+ preconditionFailure ( " Expected to have pending statements to execute " )
731
+ }
732
+ for preparedStatement in statements {
733
+ let action = self . state. enqueue ( task: . extendedQuery( . init(
734
+ executeStatement: . init(
735
+ name: preparedStatement. name,
736
+ binds: preparedStatement. bindings,
737
+ rowDescription: rowDescription
738
+ ) ,
739
+ logger: preparedStatement. logger,
740
+ promise: preparedStatement. promise
741
+ ) )
742
+ )
743
+ self . run ( action, with: context)
744
+ }
745
+ }
746
+
747
+ private func prepareStatementFailed(
748
+ name: String ,
749
+ error: PSQLError ,
750
+ context: ChannelHandlerContext
751
+ ) {
752
+ let action = self . preparedStatementState. errorHappened (
753
+ name: name,
754
+ error: error
755
+ )
756
+ guard case . returnError( let statements, let error) = action else {
757
+ preconditionFailure ( " Expected to have pending statements to execute " )
758
+ }
759
+ for statement in statements {
760
+ statement. promise. fail ( error)
761
+ }
762
+ }
667
763
}
668
764
669
765
extension PostgresChannelHandler : PSQLRowsDataSource {
0 commit comments