Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,17 @@ extension PostgresConnection {
try await self.close().get()
}

/// Closes the connection to the server, _after all queries_ that have been created on this connection have been run.
public func closeGracefully() async throws {
try await withTaskCancellationHandler { () async throws -> () in
let promise = self.eventLoop.makePromise(of: Void.self)
self.channel.triggerUserOutboundEvent(PSQLOutgoingEvent.gracefulShutdown, promise: promise)
return try await promise.futureResult.get()
} onCancel: {
_ = self.close()
}
}

/// Run a query on the Postgres server the connection is connected to.
///
/// - Parameters:
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ struct ListenStateMachine {
}

mutating func stopListeningSucceeded(channel: String) -> StopListeningSuccessAction {
return self.channels[channel, default: .init()].stopListeningSucceeded()
switch self.channels[channel]!.stopListeningSucceeded() {
case .none:
self.channels.removeValue(forKey: channel)
return .none

case .startListening:
return .startListening
}
}

enum CancelAction {
Expand All @@ -46,7 +53,7 @@ struct ListenStateMachine {
}

mutating func cancelNotificationListener(channel: String, id: Int) -> CancelAction {
return self.channels[channel, default: .init()].cancelListening(id: id)
return self.channels[channel]?.cancelListening(id: id) ?? .none
}

mutating func fail(_ error: Error) -> [NotificationListener] {
Expand Down
54 changes: 39 additions & 15 deletions Sources/PostgresNIO/New/PSQLError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ public struct PSQLError: Error {

case queryCancelled
case tooManyParameters
case connectionQuiescing
case connectionClosed
case clientClosesConnection
case clientClosedConnection
case serverClosedConnection
case connectionError
case uncleanShutdown

Expand All @@ -45,13 +46,20 @@ public struct PSQLError: Error {
public static let invalidCommandTag = Self(.invalidCommandTag)
public static let queryCancelled = Self(.queryCancelled)
public static let tooManyParameters = Self(.tooManyParameters)
public static let connectionQuiescing = Self(.connectionQuiescing)
public static let connectionClosed = Self(.connectionClosed)
public static let clientClosesConnection = Self(.clientClosesConnection)
public static let clientClosedConnection = Self(.clientClosedConnection)
public static let serverClosedConnection = Self(.serverClosedConnection)
public static let connectionError = Self(.connectionError)
public static let uncleanShutdown = Self.init(.uncleanShutdown)
public static let listenFailed = Self.init(.listenFailed)
public static let unlistenFailed = Self.init(.unlistenFailed)

@available(*, deprecated, renamed: "clientClosesConnection")
public static let connectionQuiescing = Self.clientClosesConnection

@available(*, deprecated, message: "Use the more specific `serverClosedConnection` or `clientClosedConnection` instead")
Comment on lines +57 to +60
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 @gwynne

public static let connectionClosed = Self.serverClosedConnection

public var description: String {
switch self.base {
case .sslUnsupported:
Expand All @@ -78,10 +86,12 @@ public struct PSQLError: Error {
return "queryCancelled"
case .tooManyParameters:
return "tooManyParameters"
case .connectionQuiescing:
return "connectionQuiescing"
case .connectionClosed:
return "connectionClosed"
case .clientClosesConnection:
return "clientClosesConnection"
case .clientClosedConnection:
return "clientClosedConnection"
case .serverClosedConnection:
return "serverClosedConnection"
case .connectionError:
return "connectionError"
case .uncleanShutdown:
Expand Down Expand Up @@ -377,19 +387,33 @@ public struct PSQLError: Error {
return new
}

static var connectionQuiescing: PSQLError { PSQLError(code: .connectionQuiescing) }
static func clientClosesConnection(underlying: Error?) -> PSQLError {
var error = PSQLError(code: .clientClosesConnection)
error.underlying = underlying
return error
}

static func clientClosedConnection(underlying: Error?) -> PSQLError {
var error = PSQLError(code: .clientClosedConnection)
error.underlying = underlying
return error
}

static var connectionClosed: PSQLError { PSQLError(code: .connectionClosed) }
static func serverClosedConnection(underlying: Error?) -> PSQLError {
var error = PSQLError(code: .serverClosedConnection)
error.underlying = underlying
return error
}

static var authMechanismRequiresPassword: PSQLError { PSQLError(code: .authMechanismRequiresPassword) }
static let authMechanismRequiresPassword = PSQLError(code: .authMechanismRequiresPassword)

static var sslUnsupported: PSQLError { PSQLError(code: .sslUnsupported) }
static let sslUnsupported = PSQLError(code: .sslUnsupported)

static var queryCancelled: PSQLError { PSQLError(code: .queryCancelled) }
static let queryCancelled = PSQLError(code: .queryCancelled)

static var uncleanShutdown: PSQLError { PSQLError(code: .uncleanShutdown) }
static let uncleanShutdown = PSQLError(code: .uncleanShutdown)

static var receivedUnencryptedDataAfterSSLRequest: PSQLError { PSQLError(code: .receivedUnencryptedDataAfterSSLRequest) }
static let receivedUnencryptedDataAfterSSLRequest = PSQLError(code: .receivedUnencryptedDataAfterSSLRequest)

static func server(_ response: PostgresBackendMessage.ErrorResponse) -> PSQLError {
var error = PSQLError(code: .server)
Expand Down
2 changes: 2 additions & 0 deletions Sources/PostgresNIO/New/PSQLEventsHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ enum PSQLOutgoingEvent {
///
/// this shall be removed with the next breaking change and always supplied with `PSQLConnection.Configuration`
case authenticate(AuthContext)

case gracefulShutdown
}

enum PSQLEvent {
Expand Down
7 changes: 6 additions & 1 deletion Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
return
}

let action = self.state.close(promise)
let action = self.state.close(promise: promise)
self.run(action, with: context)
}

Expand All @@ -258,6 +258,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
case PSQLOutgoingEvent.authenticate(let authContext):
let action = self.state.provideAuthenticationContext(authContext)
self.run(action, with: context)

case PSQLOutgoingEvent.gracefulShutdown:
let action = self.state.gracefulClose(promise)
self.run(action, with: context)

default:
context.triggerUserOutboundEvent(event, promise: promise)
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/PostgresNIO/Postgres+PSQLCompat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ extension PSQLError {
return self.underlying ?? self
case .tooManyParameters, .invalidCommandTag:
return self
case .connectionQuiescing:
return PostgresError.connectionClosed
case .connectionClosed:
case .clientClosesConnection,
.clientClosedConnection,
.serverClosedConnection:
return PostgresError.connectionClosed
case .connectionError:
return self.underlying ?? self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ class ConnectionStateMachineTests: XCTestCase {

func testErrorIsIgnoredWhenClosingConnection() {
// test ignore unclean shutdown when closing connection
var stateIgnoreChannelError = ConnectionStateMachine(.closing)
var stateIgnoreChannelError = ConnectionStateMachine(.closing(nil))

XCTAssertEqual(stateIgnoreChannelError.errorHappened(.connectionError(underlying: NIOSSLError.uncleanShutdown)), .wait)
XCTAssertEqual(stateIgnoreChannelError.closed(), .fireChannelInactive)

// test ignore any other error when closing connection

var stateIgnoreErrorMessage = ConnectionStateMachine(.closing)
var stateIgnoreErrorMessage = ConnectionStateMachine(.closing(nil))
XCTAssertEqual(stateIgnoreErrorMessage.errorReceived(.init(fields: [:])), .wait)
XCTAssertEqual(stateIgnoreErrorMessage.closed(), .fireChannelInactive)
}
Expand Down
4 changes: 2 additions & 2 deletions Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ final class PSQLRowStreamTests: XCTestCase {

func testFailedStream() {
let stream = PSQLRowStream(
source: .noRows(.failure(PSQLError.connectionClosed)),
source: .noRows(.failure(PSQLError.serverClosedConnection(underlying: nil))),
eventLoop: self.eventLoop,
logger: self.logger
)

XCTAssertThrowsError(try stream.all().wait()) {
XCTAssertEqual($0 as? PSQLError, .connectionClosed)
XCTAssertEqual($0 as? PSQLError, .serverClosedConnection(underlying: nil))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ class PostgresChannelHandlerTests: XCTestCase {
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
handler
], loop: self.eventLoop)
defer { XCTAssertNoThrow(try embedded.finish()) }

defer {
do { try embedded.finish() }
catch { print("\(String(reflecting: error))") }
}

var maybeMessage: PostgresFrontendMessage?
XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil))
XCTAssertNoThrow(maybeMessage = try embedded.readOutbound(as: PostgresFrontendMessage.self))
Expand Down
92 changes: 92 additions & 0 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,98 @@ class PostgresConnectionTests: XCTestCase {
}
}

func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
for _ in 1...2 {
taskGroup.addTask {
let rows = try await connection.query("SELECT 1;", logger: self.logger)
var iterator = rows.decode(Int.self).makeAsyncIterator()
let first = try await iterator.next()
XCTAssertEqual(first, 1)
let second = try await iterator.next()
XCTAssertNil(second)
}
}

for i in 0...1 {
let listenMessage = try await channel.waitForUnpreparedRequest()
XCTAssertEqual(listenMessage.parse.query, "SELECT 1;")

if i == 0 {
taskGroup.addTask {
try await connection.closeGracefully()
}
}

try await channel.writeInbound(PostgresBackendMessage.parseComplete)
try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [])))
let intDescription = RowDescription.Column(
name: "",
tableOID: 0,
columnAttributeNumber: 0,
dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary
)
try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription])))
try await channel.testingEventLoop.executeInContext { channel.read() }
try await channel.writeInbound(PostgresBackendMessage.bindComplete)
try await channel.testingEventLoop.executeInContext { channel.read() }
try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)]))
try await channel.testingEventLoop.executeInContext { channel.read() }
try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1"))
try await channel.testingEventLoop.executeInContext { channel.read() }
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
}

let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
XCTAssertEqual(terminate, .terminate)
try await channel.closeFuture.get()
XCTAssertEqual(channel.isActive, false)

while let taskResult = await taskGroup.nextResult() {
switch taskResult {
case .success:
break
case .failure(let failure):
XCTFail("Unexpected error: \(failure)")
}
}
}
}

func testCloseClosesImmediatly() async throws {
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()

try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
for _ in 1...2 {
taskGroup.addTask {
try await connection.query("SELECT 1;", logger: self.logger)
}
}

let listenMessage = try await channel.waitForUnpreparedRequest()
XCTAssertEqual(listenMessage.parse.query, "SELECT 1;")

async let close: () = connection.close()

try await channel.closeFuture.get()
XCTAssertEqual(channel.isActive, false)

try await close

while let taskResult = await taskGroup.nextResult() {
switch taskResult {
case .success:
XCTFail("Expected queries to fail")
case .failure(let failure):
guard let error = failure as? PSQLError else {
return XCTFail("Unexpected error type: \(failure)")
}
XCTAssertEqual(error.code, .clientClosedConnection)
}
}
}
}

func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
let eventLoop = NIOAsyncTestingEventLoop()
Expand Down
8 changes: 4 additions & 4 deletions Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ final class PostgresRowSequenceTests: XCTestCase {
logger: self.logger
)

stream.receive(completion: .failure(PSQLError.connectionClosed))
stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil)))

let rowSequence = stream.asyncSequence()

Expand All @@ -194,7 +194,7 @@ final class PostgresRowSequenceTests: XCTestCase {
}
XCTFail("Expected that an error was thrown before.")
} catch {
XCTAssertEqual(error as? PSQLError, .connectionClosed)
XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil))
}
}

Expand Down Expand Up @@ -255,14 +255,14 @@ final class PostgresRowSequenceTests: XCTestCase {
XCTAssertEqual(try row1?.decode(Int.self, context: .default), 0)

DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) {
stream.receive(completion: .failure(PSQLError.connectionClosed))
stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil)))
}

do {
_ = try await rowIterator.next()
XCTFail("Expected that an error was thrown before.")
} catch {
XCTAssertEqual(error as? PSQLError, .connectionClosed)
XCTAssertEqual(error as? PSQLError, .serverClosedConnection(underlying: nil))
}
}

Expand Down