@@ -224,6 +224,63 @@ class PostgresConnectionTests: XCTestCase {
224
224
}
225
225
}
226
226
227
+ func testSimpleListenFailsIfConnectionIsClosed( ) async throws {
228
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
229
+
230
+ try await connection. closeGracefully ( )
231
+
232
+ XCTAssertEqual ( channel. isActive, false )
233
+
234
+ do {
235
+ _ = try await connection. listen ( " test_channel " )
236
+ XCTFail ( " Expected to fail " )
237
+ } catch let error as ChannelError {
238
+ XCTAssertEqual ( error, . ioOnClosedChannel)
239
+ }
240
+ }
241
+
242
+ func testSimpleListenFailsIfConnectionIsClosedWhileListening( ) async throws {
243
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
244
+
245
+ try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
246
+ taskGroup. addTask {
247
+ let events = try await connection. listen ( " foo " )
248
+ var iterator = events. makeAsyncIterator ( )
249
+ let first = try await iterator. next ( )
250
+ XCTAssertEqual ( first? . payload, " wooohooo " )
251
+ do {
252
+ _ = try await iterator. next ( )
253
+ XCTFail ( " Did not expect to not throw " )
254
+ } catch let error as PSQLError {
255
+ XCTAssertEqual ( error. code, . clientClosedConnection)
256
+ }
257
+ }
258
+
259
+ let listenMessage = try await channel. waitForUnpreparedRequest ( )
260
+ XCTAssertEqual ( listenMessage. parse. query, #"LISTEN "foo";"# )
261
+
262
+ try await channel. writeInbound ( PostgresBackendMessage . parseComplete)
263
+ try await channel. writeInbound ( PostgresBackendMessage . parameterDescription ( . init( dataTypes: [ ] ) ) )
264
+ try await channel. writeInbound ( PostgresBackendMessage . noData)
265
+ try await channel. writeInbound ( PostgresBackendMessage . bindComplete)
266
+ try await channel. writeInbound ( PostgresBackendMessage . commandComplete ( " LISTEN " ) )
267
+ try await channel. writeInbound ( PostgresBackendMessage . readyForQuery ( . idle) )
268
+
269
+ try await channel. writeInbound ( PostgresBackendMessage . notification ( . init( backendPID: 12 , channel: " foo " , payload: " wooohooo " ) ) )
270
+
271
+ try await connection. close ( )
272
+
273
+ XCTAssertEqual ( channel. isActive, false )
274
+
275
+ switch await taskGroup. nextResult ( ) ! {
276
+ case . success:
277
+ break
278
+ case . failure( let failure) :
279
+ XCTFail ( " Unexpected error: \( failure) " )
280
+ }
281
+ }
282
+ }
283
+
227
284
func testCloseGracefullyClosesWhenInternalQueueIsEmpty( ) async throws {
228
285
let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
229
286
try await withThrowingTaskGroup ( of: Void . self) { [ logger] taskGroup async throws -> ( ) in
@@ -638,6 +695,118 @@ class PostgresConnectionTests: XCTestCase {
638
695
}
639
696
}
640
697
698
+ func testQueryFailsIfConnectionIsClosed( ) async throws {
699
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
700
+
701
+ try await connection. closeGracefully ( )
702
+
703
+ XCTAssertEqual ( channel. isActive, false )
704
+
705
+ do {
706
+ _ = try await connection. query ( " SELECT version; " , logger: self . logger)
707
+ XCTFail ( " Expected to fail " )
708
+ } catch let error as ChannelError {
709
+ XCTAssertEqual ( error, . ioOnClosedChannel)
710
+ }
711
+ }
712
+
713
+ func testPrepareStatementFailsIfConnectionIsClosed( ) async throws {
714
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
715
+
716
+ try await connection. closeGracefully ( )
717
+
718
+ XCTAssertEqual ( channel. isActive, false )
719
+
720
+ do {
721
+ _ = try await connection. prepareStatement ( " SELECT version; " , with: " test_query " , logger: . psqlTest) . get ( )
722
+ XCTFail ( " Expected to fail " )
723
+ } catch let error as ChannelError {
724
+ XCTAssertEqual ( error, . ioOnClosedChannel)
725
+ }
726
+ }
727
+
728
+ func testExecuteFailsIfConnectionIsClosed( ) async throws {
729
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
730
+
731
+ try await connection. closeGracefully ( )
732
+
733
+ XCTAssertEqual ( channel. isActive, false )
734
+
735
+ do {
736
+ let statement = PSQLExecuteStatement ( name: " SELECT version; " , binds: . init( ) , rowDescription: nil )
737
+ _ = try await connection. execute ( statement, logger: . psqlTest) . get ( )
738
+ XCTFail ( " Expected to fail " )
739
+ } catch let error as ChannelError {
740
+ XCTAssertEqual ( error, . ioOnClosedChannel)
741
+ }
742
+ }
743
+
744
+ func testExecutePreparedStatementFailsIfConnectionIsClosed( ) async throws {
745
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
746
+
747
+ try await connection. closeGracefully ( )
748
+
749
+ XCTAssertEqual ( channel. isActive, false )
750
+
751
+ struct TestPreparedStatement : PostgresPreparedStatement {
752
+ static let sql = " SELECT pid, datname FROM pg_stat_activity WHERE state = $1 "
753
+ typealias Row = ( Int , String )
754
+
755
+ var state : String
756
+
757
+ func makeBindings( ) -> PostgresBindings {
758
+ var bindings = PostgresBindings ( )
759
+ bindings. append ( self . state)
760
+ return bindings
761
+ }
762
+
763
+ func decodeRow( _ row: PostgresNIO . PostgresRow ) throws -> Row {
764
+ try row. decode ( Row . self)
765
+ }
766
+ }
767
+
768
+ do {
769
+ let preparedStatement = TestPreparedStatement ( state: " active " )
770
+ _ = try await connection. execute ( preparedStatement, logger: . psqlTest)
771
+ XCTFail ( " Expected to fail " )
772
+ } catch let error as ChannelError {
773
+ XCTAssertEqual ( error, . ioOnClosedChannel)
774
+ }
775
+ }
776
+
777
+ func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed( ) async throws {
778
+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
779
+
780
+ try await connection. closeGracefully ( )
781
+
782
+ XCTAssertEqual ( channel. isActive, false )
783
+
784
+ struct TestPreparedStatement : PostgresPreparedStatement {
785
+ static let sql = " SELECT * FROM pg_stat_activity WHERE state = $1 "
786
+ typealias Row = ( )
787
+
788
+ var state : String
789
+
790
+ func makeBindings( ) -> PostgresBindings {
791
+ var bindings = PostgresBindings ( )
792
+ bindings. append ( self . state)
793
+ return bindings
794
+ }
795
+
796
+ func decodeRow( _ row: PostgresNIO . PostgresRow ) throws -> Row {
797
+ ( )
798
+ }
799
+ }
800
+
801
+ do {
802
+ let preparedStatement = TestPreparedStatement ( state: " active " )
803
+ _ = try await connection. execute ( preparedStatement, logger: . psqlTest)
804
+ XCTFail ( " Expected to fail " )
805
+ } catch let error as ChannelError {
806
+ XCTAssertEqual ( error, . ioOnClosedChannel)
807
+ }
808
+ }
809
+
641
810
func makeTestConnectionWithAsyncTestingChannel( ) async throws -> ( PostgresConnection , NIOAsyncTestingChannel ) {
642
811
let eventLoop = NIOAsyncTestingEventLoop ( )
643
812
let channel = await NIOAsyncTestingChannel ( handlers: [
0 commit comments