Skip to content

Commit 7aaae6e

Browse files
committed
Properly fulfill write promises
1 parent d18b137 commit 7aaae6e

13 files changed

+300
-231
lines changed

Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift

+10-10
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ struct AuthenticationStateMachine {
1616
}
1717

1818
enum Action {
19-
case sendStartupMessage(AuthContext)
20-
case sendPassword(PasswordAuthencationMode, AuthContext)
21-
case sendSaslInitialResponse(name: String, initialResponse: [UInt8])
22-
case sendSaslResponse([UInt8])
19+
case sendStartupMessage(AuthContext, promise: EventLoopPromise<Void>?)
20+
case sendPassword(PasswordAuthencationMode, AuthContext, promise: EventLoopPromise<Void>?)
21+
case sendSaslInitialResponse(name: String, initialResponse: [UInt8], promise: EventLoopPromise<Void>?)
22+
case sendSaslResponse([UInt8], promise: EventLoopPromise<Void>?)
2323
case wait
2424
case authenticated
2525

@@ -34,12 +34,12 @@ struct AuthenticationStateMachine {
3434
self.state = .initialized
3535
}
3636

37-
mutating func start() -> Action {
37+
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
3838
guard case .initialized = self.state else {
3939
preconditionFailure("Unexpected state")
4040
}
4141
self.state = .startupMessageSent
42-
return .sendStartupMessage(self.authContext)
42+
return .sendStartupMessage(self.authContext, promise: promise)
4343
}
4444

4545
mutating func authenticationMessageReceived(_ message: PostgresBackendMessage.Authentication) -> Action {
@@ -54,10 +54,10 @@ struct AuthenticationStateMachine {
5454
return self.setAndFireError(PSQLError(code: .authMechanismRequiresPassword))
5555
}
5656
self.state = .passwordAuthenticationSent
57-
return .sendPassword(.md5(salt: salt), self.authContext)
57+
return .sendPassword(.md5(salt: salt), self.authContext, promise: nil)
5858
case .plaintext:
5959
self.state = .passwordAuthenticationSent
60-
return .sendPassword(.cleartext, authContext)
60+
return .sendPassword(.cleartext, authContext, promise: nil)
6161
case .kerberosV5:
6262
return self.setAndFireError(.unsupportedAuthMechanism(.kerberosV5))
6363
case .scmCredential:
@@ -89,7 +89,7 @@ struct AuthenticationStateMachine {
8989
}
9090

9191
self.state = .saslInitialResponseSent(saslManager)
92-
return .sendSaslInitialResponse(name: SASLMechanism.SCRAM.SHA256.name, initialResponse: output)
92+
return .sendSaslInitialResponse(name: SASLMechanism.SCRAM.SHA256.name, initialResponse: output, promise: nil)
9393
} catch {
9494
return self.setAndFireError(.sasl(underlying: error))
9595
}
@@ -122,7 +122,7 @@ struct AuthenticationStateMachine {
122122
}
123123

124124
self.state = .saslChallengeResponseSent(saslManager)
125-
return .sendSaslResponse(output)
125+
return .sendSaslResponse(output, promise: nil)
126126
} catch {
127127
return self.setAndFireError(.sasl(underlying: error))
128128
}

Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ struct CloseStateMachine {
1010
}
1111

1212
enum Action {
13-
case sendCloseSync(CloseTarget)
13+
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
1414
case succeedClose(CloseCommandContext)
1515
case failClose(CloseCommandContext, with: PSQLError)
1616

@@ -24,14 +24,14 @@ struct CloseStateMachine {
2424
self.state = .initialized(closeContext)
2525
}
2626

27-
mutating func start() -> Action {
27+
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
2828
guard case .initialized(let closeContext) = self.state else {
2929
preconditionFailure("Start should only be called, if the query has been initialized")
3030
}
3131

3232
self.state = .closeSyncSent(closeContext)
3333

34-
return .sendCloseSync(closeContext.target)
34+
return .sendCloseSync(closeContext.target, promise: promise)
3535
}
3636

3737
mutating func closeCompletedReceived() -> Action {

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

+49-41
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ struct ConnectionStateMachine {
6464

6565
case read
6666
case wait
67-
case sendSSLRequest
68-
case establishSSLConnection
67+
case sendSSLRequest(EventLoopPromise<Void>?)
68+
case establishSSLConnection(EventLoopPromise<Void>?)
6969
case provideAuthenticationContext
7070
case forwardNotificationToListeners(PostgresBackendMessage.NotificationResponse)
7171
case fireEventReadyForQuery
@@ -77,16 +77,16 @@ struct ConnectionStateMachine {
7777
case closeConnectionAndCleanup(CleanUpContext)
7878

7979
// Auth Actions
80-
case sendStartupMessage(AuthContext)
81-
case sendPasswordMessage(PasswordAuthencationMode, AuthContext)
82-
case sendSaslInitialResponse(name: String, initialResponse: [UInt8])
83-
case sendSaslResponse([UInt8])
80+
case sendStartupMessage(AuthContext, promise: EventLoopPromise<Void>?)
81+
case sendPasswordMessage(PasswordAuthencationMode, AuthContext, promise: EventLoopPromise<Void>?)
82+
case sendSaslInitialResponse(name: String, initialResponse: [UInt8], promise: EventLoopPromise<Void>?)
83+
case sendSaslResponse([UInt8], promise: EventLoopPromise<Void>?)
8484

8585
// Connection Actions
8686

8787
// --- general actions
88-
case sendParseDescribeBindExecuteSync(PostgresQuery)
89-
case sendBindExecuteSync(PSQLExecuteStatement)
88+
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
89+
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)
9090
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
9191
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
9292

@@ -97,12 +97,12 @@ struct ConnectionStateMachine {
9797
case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?)
9898

9999
// Prepare statement actions
100-
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
100+
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
101101
case succeedPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: RowDescription?)
102102
case failPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: PSQLError, cleanupContext: CleanUpContext?)
103103

104104
// Close actions
105-
case sendCloseSync(CloseTarget)
105+
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
106106
case succeedClose(CloseCommandContext)
107107
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
108108
}
@@ -131,7 +131,7 @@ struct ConnectionStateMachine {
131131
case require
132132
}
133133

134-
mutating func connected(tls: TLSConfiguration) -> ConnectionAction {
134+
mutating func connected(tls: TLSConfiguration, promise: EventLoopPromise<Void>?) -> ConnectionAction {
135135
switch self.state {
136136
case .initialized:
137137
switch tls {
@@ -141,11 +141,11 @@ struct ConnectionStateMachine {
141141

142142
case .prefer:
143143
self.state = .sslRequestSent(.prefer)
144-
return .sendSSLRequest
144+
return .sendSSLRequest(promise)
145145

146146
case .require:
147147
self.state = .sslRequestSent(.require)
148-
return .sendSSLRequest
148+
return .sendSSLRequest(promise)
149149
}
150150

151151
case .sslRequestSent,
@@ -164,8 +164,11 @@ struct ConnectionStateMachine {
164164
}
165165
}
166166

167-
mutating func provideAuthenticationContext(_ authContext: AuthContext) -> ConnectionAction {
168-
self.startAuthentication(authContext)
167+
mutating func provideAuthenticationContext(
168+
_ authContext: AuthContext,
169+
promise: EventLoopPromise<Void>?
170+
) -> ConnectionAction {
171+
self.startAuthentication(authContext, promise: promise)
169172
}
170173

171174
mutating func gracefulClose(_ promise: EventLoopPromise<Void>?) -> ConnectionAction {
@@ -233,8 +236,8 @@ struct ConnectionStateMachine {
233236
return self.closeConnectionAndCleanup(.receivedUnencryptedDataAfterSSLRequest)
234237
}
235238
self.state = .sslNegotiated
236-
return .establishSSLConnection
237-
239+
return .establishSSLConnection(nil)
240+
238241
case .initialized,
239242
.sslNegotiated,
240243
.sslHandlerAdded,
@@ -583,14 +586,16 @@ struct ConnectionStateMachine {
583586
}
584587

585588
switch task {
586-
case .extendedQuery(let queryContext):
589+
case .extendedQuery(let queryContext, let writePromise):
590+
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
587591
switch queryContext.query {
588592
case .executeStatement(_, let promise), .unnamed(_, let promise):
589593
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
590594
case .prepareStatement(_, _, _, let promise):
591595
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
592596
}
593-
case .closeCommand(let closeContext):
597+
case .closeCommand(let closeContext, let writePromise):
598+
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
594599
return .failClose(closeContext, with: psqlErrror, cleanupContext: nil)
595600
}
596601
}
@@ -800,14 +805,17 @@ struct ConnectionStateMachine {
800805

801806
// MARK: - Private Methods -
802807

803-
private mutating func startAuthentication(_ authContext: AuthContext) -> ConnectionAction {
808+
private mutating func startAuthentication(
809+
_ authContext: AuthContext,
810+
promise: EventLoopPromise<Void>?
811+
) -> ConnectionAction {
804812
guard case .waitingToStartAuthentication = self.state else {
805813
preconditionFailure("Can only start authentication after connect or ssl establish")
806814
}
807815

808816
self.state = .modifying // avoid CoW
809817
var authState = AuthenticationStateMachine(authContext: authContext)
810-
let action = authState.start()
818+
let action = authState.start(promise)
811819
self.state = .authenticating(authState)
812820
return self.modify(with: action)
813821
}
@@ -934,17 +942,17 @@ struct ConnectionStateMachine {
934942
}
935943

936944
switch task {
937-
case .extendedQuery(let queryContext):
945+
case .extendedQuery(let queryContext, let promise):
938946
self.state = .modifying // avoid CoW
939947
var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext)
940-
let action = extendedQuery.start()
948+
let action = extendedQuery.start(promise)
941949
self.state = .extendedQuery(extendedQuery, connectionContext)
942950
return self.modify(with: action)
943951

944-
case .closeCommand(let closeContext):
952+
case .closeCommand(let closeContext, let promise):
945953
self.state = .modifying // avoid CoW
946954
var closeStateMachine = CloseStateMachine(closeContext: closeContext)
947-
let action = closeStateMachine.start()
955+
let action = closeStateMachine.start(promise)
948956
self.state = .closeCommand(closeStateMachine, connectionContext)
949957
return self.modify(with: action)
950958
}
@@ -1031,10 +1039,10 @@ extension ConnectionStateMachine {
10311039
extension ConnectionStateMachine {
10321040
mutating func modify(with action: ExtendedQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
10331041
switch action {
1034-
case .sendParseDescribeBindExecuteSync(let query):
1035-
return .sendParseDescribeBindExecuteSync(query)
1036-
case .sendBindExecuteSync(let executeStatement):
1037-
return .sendBindExecuteSync(executeStatement)
1042+
case .sendParseDescribeBindExecuteSync(let query, let promise):
1043+
return .sendParseDescribeBindExecuteSync(query, promise: promise)
1044+
case .sendBindExecuteSync(let executeStatement, let promise):
1045+
return .sendBindExecuteSync(executeStatement, promise: promise)
10381046
case .failQuery(let requestContext, with: let error):
10391047
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
10401048
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
@@ -1057,8 +1065,8 @@ extension ConnectionStateMachine {
10571065
return .read
10581066
case .wait:
10591067
return .wait
1060-
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes):
1061-
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
1068+
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes, let promise):
1069+
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
10621070
case .succeedPreparedStatementCreation(let promise, with: let rowDescription):
10631071
return .succeedPreparedStatementCreation(promise, with: rowDescription)
10641072
case .failPreparedStatementCreation(let promise, with: let error):
@@ -1071,14 +1079,14 @@ extension ConnectionStateMachine {
10711079
extension ConnectionStateMachine {
10721080
mutating func modify(with action: AuthenticationStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
10731081
switch action {
1074-
case .sendStartupMessage(let authContext):
1075-
return .sendStartupMessage(authContext)
1076-
case .sendPassword(let mode, let authContext):
1077-
return .sendPasswordMessage(mode, authContext)
1078-
case .sendSaslInitialResponse(let name, let initialResponse):
1079-
return .sendSaslInitialResponse(name: name, initialResponse: initialResponse)
1080-
case .sendSaslResponse(let bytes):
1081-
return .sendSaslResponse(bytes)
1082+
case .sendStartupMessage(let authContext, let promise):
1083+
return .sendStartupMessage(authContext, promise: promise)
1084+
case .sendPassword(let mode, let authContext, let promise):
1085+
return .sendPasswordMessage(mode, authContext, promise: promise)
1086+
case .sendSaslInitialResponse(let name, let initialResponse, let promise):
1087+
return .sendSaslInitialResponse(name: name, initialResponse: initialResponse, promise: promise)
1088+
case .sendSaslResponse(let bytes, let promise):
1089+
return .sendSaslResponse(bytes, promise: promise)
10821090
case .authenticated:
10831091
self.state = .authenticated(nil, [:])
10841092
return .wait
@@ -1094,8 +1102,8 @@ extension ConnectionStateMachine {
10941102
extension ConnectionStateMachine {
10951103
mutating func modify(with action: CloseStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
10961104
switch action {
1097-
case .sendCloseSync(let sendClose):
1098-
return .sendCloseSync(sendClose)
1105+
case .sendCloseSync(let sendClose, let promise):
1106+
return .sendCloseSync(sendClose, promise: promise)
10991107
case .succeedClose(let closeContext):
11001108
return .succeedClose(closeContext)
11011109
case .failClose(let closeContext, with: let error):

Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift

+9-8
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ struct ExtendedQueryStateMachine {
2525
}
2626

2727
enum Action {
28-
case sendParseDescribeBindExecuteSync(PostgresQuery)
29-
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
30-
case sendBindExecuteSync(PSQLExecuteStatement)
31-
28+
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
29+
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
30+
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)
31+
3232
// --- general actions
3333
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
3434
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
@@ -56,7 +56,7 @@ struct ExtendedQueryStateMachine {
5656
self.state = .initialized(queryContext)
5757
}
5858

59-
mutating func start() -> Action {
59+
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
6060
guard case .initialized(let queryContext) = self.state else {
6161
preconditionFailure("Start should only be called, if the query has been initialized")
6262
}
@@ -65,7 +65,7 @@ struct ExtendedQueryStateMachine {
6565
case .unnamed(let query, _):
6666
return self.avoidingStateMachineCoW { state -> Action in
6767
state = .messagesSent(queryContext)
68-
return .sendParseDescribeBindExecuteSync(query)
68+
return .sendParseDescribeBindExecuteSync(query, promise: promise)
6969
}
7070

7171
case .executeStatement(let prepared, _):
@@ -76,13 +76,14 @@ struct ExtendedQueryStateMachine {
7676
case .none:
7777
state = .noDataMessageReceived(queryContext)
7878
}
79-
return .sendBindExecuteSync(prepared)
79+
return .sendBindExecuteSync(prepared, promise: promise)
8080
}
8181

82+
/// Not my code, but this is ignoring the last argument which is a promise? is that fine?
8283
case .prepareStatement(let name, let query, let bindingDataTypes, _):
8384
return self.avoidingStateMachineCoW { state -> Action in
8485
state = .messagesSent(queryContext)
85-
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
86+
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
8687
}
8788
}
8889
}

Sources/PostgresNIO/New/NotificationListener.swift

+10-5
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ final class NotificationListener: @unchecked Sendable {
4242
self.state = .closure(context, closure)
4343
}
4444

45-
func startListeningSucceeded(handler: PostgresChannelHandler) {
45+
func startListeningSucceeded(
46+
handler: PostgresChannelHandler,
47+
writePromise: EventLoopPromise<Void>?
48+
) {
4649
self.eventLoop.preconditionInEventLoop()
4750
let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop)
4851

@@ -56,26 +59,28 @@ final class NotificationListener: @unchecked Sendable {
5659
switch reason {
5760
case .cancelled:
5861
eventLoop.execute {
59-
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID)
62+
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID, writePromise: nil)
6063
}
6164

6265
case .finished:
63-
break
66+
writePromise?.succeed()
6467

6568
@unknown default:
66-
break
69+
writePromise?.succeed()
6770
}
6871
}
6972
self.state = .streamListening(continuation)
7073

7174
let notificationSequence = PostgresNotificationSequence(base: stream)
7275
checkedContinuation.resume(returning: notificationSequence)
76+
writePromise?.succeed(())
7377

7478
case .streamListening, .done:
7579
fatalError("Invalid state: \(self.state)")
7680

7781
case .closure:
78-
break // ignore
82+
writePromise?.succeed(())
83+
// ignore
7984
}
8085
}
8186

0 commit comments

Comments
 (0)