Skip to content

Commit 06e5e25

Browse files
authored
Don't let client close the connection wait for server to initiate close (#10)
* Don't let client close the connection wait for server * Add timeout for close * Change Closing websocket debug to trace * Disable autobahn tests in CI
1 parent e9763cf commit 06e5e25

File tree

5 files changed

+126
-59
lines changed

5 files changed

+126
-59
lines changed

Sources/WSClient/WebSocketClientChannel.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ struct WebSocketClientChannel: ClientConnectionChannel {
110110
configuration: .init(
111111
extensions: extensions,
112112
autoPing: self.configuration.autoPing,
113+
closeTimeout: self.configuration.closeTimeout,
113114
validateUTF8: self.configuration.validateUTF8
114115
),
115116
asyncChannel: webSocketChannel,

Sources/WSClient/WebSocketClientConfiguration.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ public struct WebSocketClientConfiguration: Sendable {
2323
public var additionalHeaders: HTTPFields
2424
/// WebSocket extensions
2525
public var extensions: [any WebSocketExtensionBuilder]
26+
/// Close timeout
27+
public var closeTimeout: Duration
2628
/// Automatic ping setup
2729
public var autoPing: AutoPingSetup
2830
/// Should text be validated to be UTF8
@@ -39,12 +41,14 @@ public struct WebSocketClientConfiguration: Sendable {
3941
maxFrameSize: Int = (1 << 14),
4042
additionalHeaders: HTTPFields = .init(),
4143
extensions: [WebSocketExtensionFactory] = [],
44+
closeTimeout: Duration = .seconds(15),
4245
autoPing: AutoPingSetup = .disabled,
4346
validateUTF8: Bool = false
4447
) {
4548
self.maxFrameSize = maxFrameSize
4649
self.additionalHeaders = additionalHeaders
4750
self.extensions = extensions.map { $0.build() }
51+
self.closeTimeout = closeTimeout
4852
self.autoPing = autoPing
4953
self.validateUTF8 = validateUTF8
5054
}

Sources/WSCore/WebSocketHandler.swift

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,17 @@ public struct WebSocketCloseFrame: Sendable {
7272
let autoPing: AutoPingSetup
7373
let validateUTF8: Bool
7474
let reservedBits: WebSocketFrame.ReservedBits
75+
let closeTimeout: Duration
7576

76-
@_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup, validateUTF8: Bool) {
77+
@_spi(WSInternal) public init(
78+
extensions: [any WebSocketExtension],
79+
autoPing: AutoPingSetup,
80+
closeTimeout: Duration = .seconds(15),
81+
validateUTF8: Bool
82+
) {
7783
self.extensions = extensions
7884
self.autoPing = autoPing
85+
self.closeTimeout = closeTimeout
7986
self.validateUTF8 = validateUTF8
8087
// store reserved bits used by this handler
8188
self.reservedBits = extensions.reduce(.init()) { partialResult, `extension` in
@@ -118,32 +125,24 @@ public struct WebSocketCloseFrame: Sendable {
118125
}
119126
do {
120127
let rt = try await asyncChannel.executeThenClose { inbound, outbound in
121-
try await withTaskCancellationHandler {
122-
try await withThrowingTaskGroup(of: WebSocketCloseFrame.self) { group in
123-
let webSocketHandler = Self(
124-
channel: asyncChannel.channel,
125-
outbound: outbound,
126-
type: type,
127-
configuration: configuration,
128-
context: context
129-
)
130-
if case .enabled = configuration.autoPing.value {
131-
/// Add task sending ping frames every so often and verifying a pong frame was sent back
132-
group.addTask {
133-
try await webSocketHandler.runAutoPingLoop()
134-
return .init(closeCode: .goingAway, reason: "Ping timeout")
135-
}
136-
}
137-
let rt = try await webSocketHandler.handle(
138-
type: type,
139-
inbound: inbound,
140-
outbound: outbound,
141-
handler: handler,
142-
context: context
143-
)
144-
group.cancelAll()
145-
return rt
146-
}
128+
defer {
129+
context.logger.trace("Closing WebSocket")
130+
}
131+
return try await withTaskCancellationHandler {
132+
let webSocketHandler = Self(
133+
channel: asyncChannel.channel,
134+
outbound: outbound,
135+
type: type,
136+
configuration: configuration,
137+
context: context
138+
)
139+
return try await webSocketHandler.handle(
140+
type: type,
141+
inbound: inbound,
142+
outbound: outbound,
143+
handler: handler,
144+
context: context
145+
)
147146
} onCancel: {
148147
Task {
149148
try await asyncChannel.channel.close(mode: .input)
@@ -166,39 +165,57 @@ public struct WebSocketCloseFrame: Sendable {
166165
context: Context
167166
) async throws -> WebSocketCloseFrame? {
168167
try await withGracefulShutdownHandler {
169-
let webSocketOutbound = WebSocketOutboundWriter(handler: self)
170-
var inboundIterator = inbound.makeAsyncIterator()
171-
let webSocketInbound = WebSocketInboundStream(
172-
iterator: inboundIterator,
173-
handler: self
174-
)
175-
let closeCode: WebSocketErrorCode
176-
var clientError: Error?
177-
do {
178-
// handle websocket data and text
179-
try await handler(webSocketInbound, webSocketOutbound, context)
180-
closeCode = .normalClosure
181-
} catch InternalError.close(let code) {
182-
closeCode = code
183-
} catch {
184-
clientError = error
185-
closeCode = .unexpectedServerError
186-
}
187-
do {
188-
try await self.close(code: closeCode)
189-
if case .closing = self.stateMachine.state {
190-
// Close handshake. Wait for responding close or until inbound ends
191-
while let frame = try await inboundIterator.next() {
192-
if case .connectionClose = frame.opcode {
193-
try await self.receivedClose(frame)
194-
break
168+
try await withThrowingTaskGroup(of: Void.self) { group in
169+
if case .enabled = configuration.autoPing.value {
170+
/// Add task sending ping frames every so often and verifying a pong frame was sent back
171+
group.addTask {
172+
try await self.runAutoPingLoop()
173+
}
174+
}
175+
let webSocketOutbound = WebSocketOutboundWriter(handler: self)
176+
var inboundIterator = inbound.makeAsyncIterator()
177+
let webSocketInbound = WebSocketInboundStream(
178+
iterator: inboundIterator,
179+
handler: self
180+
)
181+
let closeCode: WebSocketErrorCode
182+
var clientError: Error?
183+
do {
184+
// handle websocket data and text
185+
try await handler(webSocketInbound, webSocketOutbound, context)
186+
closeCode = .normalClosure
187+
} catch InternalError.close(let code) {
188+
closeCode = code
189+
} catch {
190+
clientError = error
191+
closeCode = .unexpectedServerError
192+
}
193+
do {
194+
try await self.close(code: closeCode)
195+
if case .closing = self.stateMachine.state {
196+
group.addTask {
197+
try await Task.sleep(for: self.configuration.closeTimeout)
198+
try await self.channel.close(mode: .input)
199+
}
200+
// Close handshake. Wait for responding close or until inbound ends
201+
while let frame = try await inboundIterator.next() {
202+
if case .connectionClose = frame.opcode {
203+
try await self.receivedClose(frame)
204+
// only the server can close the connection, so clients
205+
// should continue reading from inbound until it is closed
206+
if type == .server {
207+
break
208+
}
209+
}
195210
}
196211
}
212+
// don't propagate error if channel is already closed
213+
} catch ChannelError.ioOnClosedChannel {}
214+
if type == .client, let clientError {
215+
throw clientError
197216
}
198-
// don't propagate error if channel is already closed
199-
} catch ChannelError.ioOnClosedChannel {}
200-
if type == .client, let clientError {
201-
throw clientError
217+
218+
group.cancelAll()
202219
}
203220
} onGracefulShutdown: {
204221
Task {

Tests/WebSocketTests/AutobahnTests.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import XCTest
2323
/// The Autobahn|Testsuite provides a fully automated test suite to verify client and server
2424
/// implementations of The WebSocket Protocol for specification conformance and implementation robustness.
2525
/// You can find out more at https://github.com/crossbario/autobahn-testsuite
26+
///
27+
/// Before running these tests run `./scripts/autobahn-server.sh` to running the test server.
2628
final class AutobahnTests: XCTestCase {
2729
/// To run all the autobahn compression tests takes a long time. By default we only run a selection.
2830
/// The `AUTOBAHN_ALL_TESTS` environment flag triggers running all of them.
@@ -58,6 +60,9 @@ final class AutobahnTests: XCTestCase {
5860
cases: Set<Int>,
5961
extensions: [WebSocketExtensionFactory] = [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)]
6062
) async throws {
63+
// These are broken in CI currently
64+
try XCTSkipIf(ProcessInfo.processInfo.environment["CI"] != nil)
65+
6166
struct CaseInfo: Decodable {
6267
let id: String
6368
let description: String
@@ -121,8 +126,6 @@ final class AutobahnTests: XCTestCase {
121126
}
122127

123128
func test_3_ReservedBits() async throws {
124-
// Reserved bits tests fail
125-
try XCTSkipIf(true)
126129
try await self.autobahnTests(cases: .init(28..<35))
127130
}
128131

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Hummingbird server framework project
4+
//
5+
// Copyright (c) 2024 the Hummingbird authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import Logging
16+
import NIOCore
17+
import NIOSSL
18+
import NIOWebSocket
19+
import WSClient
20+
import XCTest
21+
22+
final class WebSocketClientTests: XCTestCase {
23+
24+
func testEchoServer() async throws {
25+
let clientLogger = {
26+
var logger = Logger(label: "client")
27+
logger.logLevel = .trace
28+
return logger
29+
}()
30+
try await WebSocketClient.connect(
31+
url: "wss://echo.websocket.org/",
32+
tlsConfiguration: TLSConfiguration.makeClientConfiguration(),
33+
logger: clientLogger
34+
) { inbound, outbound, _ in
35+
var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator()
36+
try await outbound.write(.text("hello"))
37+
if let msg = try await inboundIterator.next() {
38+
print(msg)
39+
}
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)