Skip to content

Commit cd5d729

Browse files
committed
Reapply "[Fix] Query Hangs if Connection is Closed (vapor#487)" (vapor#501)
This reverts commit cd5318a.
1 parent 9f84290 commit cd5d729

File tree

3 files changed

+197
-12
lines changed

3 files changed

+197
-12
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

+28-11
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ public final class PostgresConnection: @unchecked Sendable {
222222
promise: promise
223223
)
224224

225-
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
225+
self.write(.extendedQuery(context), cascadingFailureTo: promise)
226226

227227
return promise.futureResult
228228
}
@@ -239,7 +239,8 @@ public final class PostgresConnection: @unchecked Sendable {
239239
promise: promise
240240
)
241241

242-
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
242+
self.write(.extendedQuery(context), cascadingFailureTo: promise)
243+
243244
return promise.futureResult.map { rowDescription in
244245
PSQLPreparedStatement(name: name, query: query, connection: self, rowDescription: rowDescription)
245246
}
@@ -255,15 +256,17 @@ public final class PostgresConnection: @unchecked Sendable {
255256
logger: logger,
256257
promise: promise)
257258

258-
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
259+
self.write(.extendedQuery(context), cascadingFailureTo: promise)
260+
259261
return promise.futureResult
260262
}
261263

262264
func close(_ target: CloseTarget, logger: Logger) -> EventLoopFuture<Void> {
263265
let promise = self.channel.eventLoop.makePromise(of: Void.self)
264266
let context = CloseCommandContext(target: target, logger: logger, promise: promise)
265267

266-
self.channel.write(HandlerTask.closeCommand(context), promise: nil)
268+
self.write(.closeCommand(context), cascadingFailureTo: promise)
269+
267270
return promise.futureResult
268271
}
269272

@@ -426,7 +429,7 @@ extension PostgresConnection {
426429
promise: promise
427430
)
428431

429-
self.channel.write(HandlerTask.extendedQuery(context), promise: nil)
432+
self.write(.extendedQuery(context), cascadingFailureTo: promise)
430433

431434
do {
432435
return try await promise.futureResult.map({ $0.asyncSequence() }).get()
@@ -455,7 +458,11 @@ extension PostgresConnection {
455458

456459
let task = HandlerTask.startListening(listener)
457460

458-
self.channel.write(task, promise: nil)
461+
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
462+
self.channel.write(task, promise: writePromise)
463+
writePromise.futureResult.whenFailure { error in
464+
listener.failed(error)
465+
}
459466
}
460467
} onCancel: {
461468
let task = HandlerTask.cancelListening(channel, id)
@@ -480,7 +487,9 @@ extension PostgresConnection {
480487
logger: logger,
481488
promise: promise
482489
))
483-
self.channel.write(task, promise: nil)
490+
491+
self.write(task, cascadingFailureTo: promise)
492+
484493
do {
485494
return try await promise.futureResult
486495
.map { $0.asyncSequence() }
@@ -515,7 +524,9 @@ extension PostgresConnection {
515524
logger: logger,
516525
promise: promise
517526
))
518-
self.channel.write(task, promise: nil)
527+
528+
self.write(task, cascadingFailureTo: promise)
529+
519530
do {
520531
return try await promise.futureResult
521532
.map { $0.commandTag }
@@ -530,6 +541,12 @@ extension PostgresConnection {
530541
throw error // rethrow with more metadata
531542
}
532543
}
544+
545+
private func write<T>(_ task: HandlerTask, cascadingFailureTo promise: EventLoopPromise<T>) {
546+
let writePromise = self.channel.eventLoop.makePromise(of: Void.self)
547+
self.channel.write(task, promise: writePromise)
548+
writePromise.futureResult.cascadeFailure(to: promise)
549+
}
533550
}
534551

535552
// MARK: EventLoopFuture interface
@@ -674,7 +691,7 @@ internal enum PostgresCommands: PostgresRequest {
674691

675692
/// Context for receiving NotificationResponse messages on a connection, used for PostgreSQL's `LISTEN`/`NOTIFY` support.
676693
public final class PostgresListenContext: Sendable {
677-
private let promise: EventLoopPromise<Void>
694+
let promise: EventLoopPromise<Void>
678695

679696
var future: EventLoopFuture<Void> {
680697
self.promise.futureResult
@@ -713,8 +730,7 @@ extension PostgresConnection {
713730
closure: notificationHandler
714731
)
715732

716-
let task = HandlerTask.startListening(listener)
717-
self.channel.write(task, promise: nil)
733+
self.write(.startListening(listener), cascadingFailureTo: listenContext.promise)
718734

719735
listenContext.future.whenComplete { _ in
720736
let task = HandlerTask.cancelListening(channel, id)
@@ -761,3 +777,4 @@ extension PostgresConnection {
761777
#endif
762778
}
763779
}
780+

Tests/IntegrationTests/PSQLIntegrationTests.swift

-1
Original file line numberDiff line numberDiff line change
@@ -378,5 +378,4 @@ final class IntegrationTests: XCTestCase {
378378
XCTAssertEqual(obj?.bar, 2)
379379
}
380380
}
381-
382381
}

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

+169
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,63 @@ class PostgresConnectionTests: XCTestCase {
224224
}
225225
}
226226

227+
func testSimpleListenFailsIfConnectionIsClosed() async throws {
228+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
229+
230+
try await connection.closeGracefully()
231+
232+
XCTAssertEqual(channel.isActive, false)
233+
234+
do {
235+
_ = try await connection.listen("test_channel")
236+
XCTFail("Expected to fail")
237+
} catch let error as ChannelError {
238+
XCTAssertEqual(error, .ioOnClosedChannel)
239+
}
240+
}
241+
242+
func testSimpleListenFailsIfConnectionIsClosedWhileListening() async throws {
243+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
244+
245+
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
246+
taskGroup.addTask {
247+
let events = try await connection.listen("foo")
248+
var iterator = events.makeAsyncIterator()
249+
let first = try await iterator.next()
250+
XCTAssertEqual(first?.payload, "wooohooo")
251+
do {
252+
_ = try await iterator.next()
253+
XCTFail("Did not expect to not throw")
254+
} catch let error as PSQLError {
255+
XCTAssertEqual(error.code, .clientClosedConnection)
256+
}
257+
}
258+
259+
let listenMessage = try await channel.waitForUnpreparedRequest()
260+
XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#)
261+
262+
try await channel.writeInbound(PostgresBackendMessage.parseComplete)
263+
try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [])))
264+
try await channel.writeInbound(PostgresBackendMessage.noData)
265+
try await channel.writeInbound(PostgresBackendMessage.bindComplete)
266+
try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN"))
267+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
268+
269+
try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo")))
270+
271+
try await connection.close()
272+
273+
XCTAssertEqual(channel.isActive, false)
274+
275+
switch await taskGroup.nextResult()! {
276+
case .success:
277+
break
278+
case .failure(let failure):
279+
XCTFail("Unexpected error: \(failure)")
280+
}
281+
}
282+
}
283+
227284
func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
228285
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
229286
try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in
@@ -638,6 +695,118 @@ class PostgresConnectionTests: XCTestCase {
638695
}
639696
}
640697

698+
func testQueryFailsIfConnectionIsClosed() async throws {
699+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
700+
701+
try await connection.closeGracefully()
702+
703+
XCTAssertEqual(channel.isActive, false)
704+
705+
do {
706+
_ = try await connection.query("SELECT version;", logger: self.logger)
707+
XCTFail("Expected to fail")
708+
} catch let error as ChannelError {
709+
XCTAssertEqual(error, .ioOnClosedChannel)
710+
}
711+
}
712+
713+
func testPrepareStatementFailsIfConnectionIsClosed() async throws {
714+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
715+
716+
try await connection.closeGracefully()
717+
718+
XCTAssertEqual(channel.isActive, false)
719+
720+
do {
721+
_ = try await connection.prepareStatement("SELECT version;", with: "test_query", logger: .psqlTest).get()
722+
XCTFail("Expected to fail")
723+
} catch let error as ChannelError {
724+
XCTAssertEqual(error, .ioOnClosedChannel)
725+
}
726+
}
727+
728+
func testExecuteFailsIfConnectionIsClosed() async throws {
729+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
730+
731+
try await connection.closeGracefully()
732+
733+
XCTAssertEqual(channel.isActive, false)
734+
735+
do {
736+
let statement = PSQLExecuteStatement(name: "SELECT version;", binds: .init(), rowDescription: nil)
737+
_ = try await connection.execute(statement, logger: .psqlTest).get()
738+
XCTFail("Expected to fail")
739+
} catch let error as ChannelError {
740+
XCTAssertEqual(error, .ioOnClosedChannel)
741+
}
742+
}
743+
744+
func testExecutePreparedStatementFailsIfConnectionIsClosed() async throws {
745+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
746+
747+
try await connection.closeGracefully()
748+
749+
XCTAssertEqual(channel.isActive, false)
750+
751+
struct TestPreparedStatement: PostgresPreparedStatement {
752+
static let sql = "SELECT pid, datname FROM pg_stat_activity WHERE state = $1"
753+
typealias Row = (Int, String)
754+
755+
var state: String
756+
757+
func makeBindings() -> PostgresBindings {
758+
var bindings = PostgresBindings()
759+
bindings.append(self.state)
760+
return bindings
761+
}
762+
763+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
764+
try row.decode(Row.self)
765+
}
766+
}
767+
768+
do {
769+
let preparedStatement = TestPreparedStatement(state: "active")
770+
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
771+
XCTFail("Expected to fail")
772+
} catch let error as ChannelError {
773+
XCTAssertEqual(error, .ioOnClosedChannel)
774+
}
775+
}
776+
777+
func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed() async throws {
778+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
779+
780+
try await connection.closeGracefully()
781+
782+
XCTAssertEqual(channel.isActive, false)
783+
784+
struct TestPreparedStatement: PostgresPreparedStatement {
785+
static let sql = "SELECT * FROM pg_stat_activity WHERE state = $1"
786+
typealias Row = ()
787+
788+
var state: String
789+
790+
func makeBindings() -> PostgresBindings {
791+
var bindings = PostgresBindings()
792+
bindings.append(self.state)
793+
return bindings
794+
}
795+
796+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
797+
()
798+
}
799+
}
800+
801+
do {
802+
let preparedStatement = TestPreparedStatement(state: "active")
803+
_ = try await connection.execute(preparedStatement, logger: .psqlTest)
804+
XCTFail("Expected to fail")
805+
} catch let error as ChannelError {
806+
XCTAssertEqual(error, .ioOnClosedChannel)
807+
}
808+
}
809+
641810
func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
642811
let eventLoop = NIOAsyncTestingEventLoop()
643812
let channel = await NIOAsyncTestingChannel(handlers: [

0 commit comments

Comments
 (0)