Skip to content

Commit fee8d4f

Browse files
adam-fowlerJoannis
andauthored
Add UTF8 validation (#2)
* Add UTF8 validation * Add UTF8 validation * Wrap validation code in compiler(>=6) * Disable some tests for swift 5.10 * Duplicate API being submitted to SwiftNIO * Update docs to use WSCore * Update Sources/WSCore/ByteBuffer+validatingString.swift Co-authored-by: Joannis Orlandos <[email protected]> * Minor updates * Fix CI --------- Co-authored-by: Joannis Orlandos <[email protected]>
1 parent 7c53f09 commit fee8d4f

11 files changed

+195
-29
lines changed

Sources/WSClient/WebSocketClient.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ import WSCore
4040
/// }
4141
/// ```
4242
public struct WebSocketClient {
43-
/// Basic context implementation of ``/HummingbirdWSCore/WebSocketContext``.
44-
/// Used by non-router web socket handle function
43+
/// Client implementation of ``/WSCore/WebSocketContext``.
4544
public struct Context: WebSocketContext {
4645
public let logger: Logger
4746

Sources/WSClient/WebSocketClientChannel.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ struct WebSocketClientChannel: ClientConnectionChannel {
101101
type: .client,
102102
configuration: .init(
103103
extensions: extensions,
104-
autoPing: self.configuration.autoPing
104+
autoPing: self.configuration.autoPing,
105+
validateUTF8: self.configuration.validateUTF8
105106
),
106107
asyncChannel: webSocketChannel,
107108
context: WebSocketClient.Context(logger: logger),

Sources/WSClient/WebSocketClientConfiguration.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ public struct WebSocketClientConfiguration: Sendable {
2525
public var extensions: [any WebSocketExtensionBuilder]
2626
/// Automatic ping setup
2727
public var autoPing: AutoPingSetup
28+
/// Should text be validated to be UTF8
29+
public var validateUTF8: Bool
2830

2931
/// Initialize WebSocketClient configuration
3032
/// - Paramters
@@ -34,11 +36,13 @@ public struct WebSocketClientConfiguration: Sendable {
3436
maxFrameSize: Int = (1 << 14),
3537
additionalHeaders: HTTPFields = .init(),
3638
extensions: [WebSocketExtensionFactory] = [],
37-
autoPing: AutoPingSetup = .disabled
39+
autoPing: AutoPingSetup = .disabled,
40+
validateUTF8: Bool = false
3841
) {
3942
self.maxFrameSize = maxFrameSize
4043
self.additionalHeaders = additionalHeaders
4144
self.extensions = extensions.map { $0.build() }
4245
self.autoPing = autoPing
46+
self.validateUTF8 = validateUTF8
4347
}
4448
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Hummingbird server framework project
4+
//
5+
// Copyright (c) 2024 the Hummingbird authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import NIOCore
16+
17+
#if compiler(>=6)
18+
extension ByteBuffer {
19+
/// Get the string at `index` from this `ByteBuffer` decoding using the UTF-8 encoding. Does not move the reader index.
20+
/// The selected bytes must be readable or else `nil` will be returned.
21+
///
22+
/// This is an alternative to `ByteBuffer.getString(at:length:)` which ensures the returned string is valid UTF8. If the
23+
/// string is not valid UTF8 then a `ReadUTF8ValidationError` error is thrown.
24+
///
25+
/// - Parameters:
26+
/// - index: The starting index into `ByteBuffer` containing the string of interest.
27+
/// - length: The number of bytes making up the string.
28+
/// - Returns: A `String` value containing the UTF-8 decoded selected bytes from this `ByteBuffer` or `nil` if
29+
/// the requested bytes are not readable.
30+
@inlinable
31+
@available(macOS 15, iOS 18, tvOS 18, watchOS 11, *)
32+
public func getUTF8ValidatedString(at index: Int, length: Int) throws -> String? {
33+
guard let range = self.rangeWithinReadableBytes(index: index, length: length) else {
34+
return nil
35+
}
36+
return try self.withUnsafeReadableBytes { pointer in
37+
assert(range.lowerBound >= 0 && (range.upperBound - range.lowerBound) <= pointer.count)
38+
guard
39+
let string = String(
40+
validating: UnsafeRawBufferPointer(fastRebase: pointer[range]),
41+
as: Unicode.UTF8.self
42+
)
43+
else {
44+
throw ReadUTF8ValidationError.invalidUTF8
45+
}
46+
return string
47+
}
48+
}
49+
50+
/// Read `length` bytes off this `ByteBuffer`, decoding it as `String` using the UTF-8 encoding. Move the reader index
51+
/// forward by `length`.
52+
///
53+
/// This is an alternative to `ByteBuffer.readString(length:)` which ensures the returned string is valid UTF8. If the
54+
/// string is not valid UTF8 then a `ReadUTF8ValidationError` error is thrown and the reader index is not advanced.
55+
///
56+
/// - Parameters:
57+
/// - length: The number of bytes making up the string.
58+
/// - Returns: A `String` value deserialized from this `ByteBuffer` or `nil` if there aren't at least `length` bytes readable.
59+
@inlinable
60+
@available(macOS 15, iOS 18, tvOS 18, watchOS 11, *)
61+
public mutating func readUTF8ValidatedString(length: Int) throws -> String? {
62+
guard let result = try self.getUTF8ValidatedString(at: self.readerIndex, length: length) else {
63+
return nil
64+
}
65+
self.moveReaderIndex(forwardBy: length)
66+
return result
67+
}
68+
69+
/// Errors thrown when calling `readUTF8ValidatedString` or `getUTF8ValidatedString`.
70+
public struct ReadUTF8ValidationError: Error, Equatable {
71+
private enum BaseError: Hashable {
72+
case invalidUTF8
73+
}
74+
75+
private var baseError: BaseError
76+
77+
/// The length of the bytes to copy was negative.
78+
public static let invalidUTF8: ReadUTF8ValidationError = .init(baseError: .invalidUTF8)
79+
}
80+
81+
@inlinable
82+
func rangeWithinReadableBytes(index: Int, length: Int) -> Range<Int>? {
83+
guard index >= self.readerIndex, length >= 0 else {
84+
return nil
85+
}
86+
87+
// both these &-s are safe, they can't underflow because both left & right side are >= 0 (and index >= readerIndex)
88+
let indexFromReaderIndex = index &- self.readerIndex
89+
assert(indexFromReaderIndex >= 0)
90+
guard indexFromReaderIndex <= self.readableBytes &- length else {
91+
return nil
92+
}
93+
94+
let upperBound = indexFromReaderIndex &+ length // safe, can't overflow, we checked it above.
95+
96+
// uncheckedBounds is safe because `length` is >= 0, so the lower bound will always be lower/equal to upper
97+
return Range<Int>(uncheckedBounds: (lower: indexFromReaderIndex, upper: upperBound))
98+
}
99+
}
100+
101+
extension UnsafeRawBufferPointer {
102+
@inlinable
103+
init(fastRebase slice: Slice<UnsafeRawBufferPointer>) {
104+
let base = slice.base.baseAddress?.advanced(by: slice.startIndex)
105+
self.init(start: base, count: slice.endIndex &- slice.startIndex)
106+
}
107+
}
108+
109+
#endif // compiler(>=6)
110+
111+
extension String {
112+
init?(buffer: ByteBuffer, validateUTF8: Bool) {
113+
#if compiler(>=6)
114+
if #available(macOS 15, iOS 18, tvOS 18, watchOS 11, *), validateUTF8 {
115+
do {
116+
var buffer = buffer
117+
self = try buffer.readUTF8ValidatedString(length: buffer.readableBytes)!
118+
} catch {
119+
return nil
120+
}
121+
} else {
122+
self = .init(buffer: buffer)
123+
}
124+
#else
125+
self = .init(buffer: buffer)
126+
#endif // compiler(>=6)
127+
}
128+
}

Sources/WSCore/WebSocketFrameSequence.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ struct WebSocketFrameSequence {
4646
}
4747
}
4848

49-
var message: WebSocketMessage {
50-
.init(frame: self.collated)!
49+
func getMessage(validateUTF8: Bool) -> WebSocketMessage? {
50+
.init(frame: self.collated, validate: validateUTF8)
5151
}
5252

5353
var collated: WebSocketDataFrame {

Sources/WSCore/WebSocketHandler.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ public struct WebSocketCloseFrame: Sendable {
7070
@_spi(WSInternal) public struct Configuration: Sendable {
7171
let extensions: [any WebSocketExtension]
7272
let autoPing: AutoPingSetup
73+
let validateUTF8: Bool
7374

74-
@_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup) {
75+
@_spi(WSInternal) public init(extensions: [any WebSocketExtension], autoPing: AutoPingSetup, validateUTF8: Bool) {
7576
self.extensions = extensions
7677
self.autoPing = autoPing
78+
self.validateUTF8 = validateUTF8
7779
}
7880
}
7981

@@ -287,7 +289,7 @@ public struct WebSocketCloseFrame: Sendable {
287289
}
288290

289291
func receivedClose(_ frame: WebSocketFrame) async throws {
290-
switch self.stateMachine.receivedClose(frameData: frame.unmaskedData) {
292+
switch self.stateMachine.receivedClose(frameData: frame.unmaskedData, validateUTF8: self.configuration.validateUTF8) {
291293
case .sendClose(let errorCode):
292294
try await self.sendClose(code: errorCode, reason: nil)
293295
// Only server should initiate a connection close. Clients should wait for the

Sources/WSCore/WebSocketInboundStream.swift

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,25 +108,28 @@ public final class WebSocketInboundStream: AsyncSequence, Sendable {
108108
case .text, .binary:
109109
frameSequence = .init(frame: frame)
110110
if frame.fin {
111-
return frameSequence.message
111+
guard let message = frameSequence.getMessage(validateUTF8: self.handler.configuration.validateUTF8) else {
112+
throw WebSocketHandler.InternalError.close(.dataInconsistentWithMessage)
113+
}
114+
return message
112115
}
113116
default:
114-
try await self.handler.close(code: .protocolError)
115-
return nil
117+
throw WebSocketHandler.InternalError.close(.protocolError)
116118
}
117119
// parse continuation frames until we get a frame with a FIN flag
118120
while let frame = try await self.next() {
119121
guard frame.opcode == .continuation else {
120-
try await self.handler.close(code: .protocolError)
121-
return nil
122+
throw WebSocketHandler.InternalError.close(.protocolError)
122123
}
123124
guard frameSequence.size + frame.data.readableBytes <= maxSize else {
124-
try await self.handler.close(code: .messageTooLarge)
125-
return nil
125+
throw WebSocketHandler.InternalError.close(.messageTooLarge)
126126
}
127127
frameSequence.append(frame)
128128
if frame.fin {
129-
return frameSequence.message
129+
guard let message = frameSequence.getMessage(validateUTF8: self.handler.configuration.validateUTF8) else {
130+
throw WebSocketHandler.InternalError.close(.dataInconsistentWithMessage)
131+
}
132+
return message
130133
}
131134
}
132135
return nil

Sources/WSCore/WebSocketMessage.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ public enum WebSocketMessage: Equatable, Sendable, CustomStringConvertible, Cust
2020
case text(String)
2121
case binary(ByteBuffer)
2222

23-
init?(frame: WebSocketDataFrame) {
23+
init?(frame: WebSocketDataFrame, validate: Bool) {
2424
switch frame.opcode {
2525
case .text:
26-
self = .text(String(buffer: frame.data))
26+
guard let string = String(buffer: frame.data, validateUTF8: validate) else {
27+
return nil
28+
}
29+
self = .text(string)
2730
case .binary:
2831
self = .binary(frame.data)
2932
default:

Sources/WSCore/WebSocketStateMachine.swift

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,24 @@ struct WebSocketStateMachine {
5454

5555
// we received a connection close.
5656
// send a close back if it hasn't already been send and exit
57-
mutating func receivedClose(frameData: ByteBuffer) -> ReceivedCloseResult {
57+
mutating func receivedClose(frameData: ByteBuffer, validateUTF8: Bool) -> ReceivedCloseResult {
5858
var frameData = frameData
5959
let dataSize = frameData.readableBytes
6060
// read close code and close reason
6161
let closeCode = frameData.readWebSocketErrorCode()
62-
let reason = frameData.readableBytes > 0
63-
? frameData.readString(length: frameData.readableBytes)
64-
: nil
62+
let hasReason = frameData.readableBytes > 0
63+
let reason: String? = if hasReason {
64+
String(buffer: frameData, validateUTF8: validateUTF8)
65+
} else {
66+
nil
67+
}
6568

6669
switch self.state {
6770
case .open:
71+
if hasReason, reason == nil {
72+
self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) })
73+
return .sendClose(.protocolError)
74+
}
6875
self.state = .closed(closeCode.map { .init(closeCode: $0, reason: reason) })
6976
let code: WebSocketErrorCode = if dataSize == 0 || closeCode != nil {
7077
// codes 3000 - 3999 are reserved for use by libraries, frameworks

Tests/WebSocketTests/AutobahnTests.swift

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ import WSClient
2020
import WSCompression
2121
import XCTest
2222

23+
/// The Autobahn|Testsuite provides a fully automated test suite to verify client and server
24+
/// implementations of The WebSocket Protocol for specification conformance and implementation robustness.
25+
/// You can find out more at https://github.com/crossbario/autobahn-testsuite
2326
final class AutobahnTests: XCTestCase {
24-
/// To run all the autobahn tests takes a long time. By default we only run a selection.
27+
/// To run all the autobahn compression tests takes a long time. By default we only run a selection.
2528
/// The `AUTOBAHN_ALL_TESTS` environment flag triggers running all of them.
2629
var runAllTests: Bool { ProcessInfo.processInfo.environment["AUTOBAHN_ALL_TESTS"] == "true" }
2730
var autobahnServer: String { ProcessInfo.processInfo.environment["FUZZING_SERVER"] ?? "localhost" }
@@ -30,6 +33,7 @@ final class AutobahnTests: XCTestCase {
3033
let result: NIOLockedValueBox<T?> = .init(nil)
3134
try await WebSocketClient.connect(
3235
url: .init("ws://\(self.autobahnServer):9001/\(path)"),
36+
configuration: .init(validateUTF8: true),
3337
logger: Logger(label: "Autobahn")
3438
) { inbound, _, _ in
3539
var inboundIterator = inbound.messages(maxSize: .max).makeAsyncIterator()
@@ -49,6 +53,7 @@ final class AutobahnTests: XCTestCase {
4953
return try result.withLockedValue { try XCTUnwrap($0) }
5054
}
5155

56+
/// Run a number of autobahn tests
5257
func autobahnTests(
5358
cases: Set<Int>,
5459
extensions: [WebSocketExtensionFactory] = [.perMessageDeflate(maxDecompressedFrameSize: 16_777_216)]
@@ -73,7 +78,11 @@ final class AutobahnTests: XCTestCase {
7378
// run case
7479
try await WebSocketClient.connect(
7580
url: .init("ws://\(self.autobahnServer):9001/runCase?case=\(index)&agent=swift-websocket"),
76-
configuration: .init(maxFrameSize: 16_777_216, extensions: extensions),
81+
configuration: .init(
82+
maxFrameSize: 16_777_216,
83+
extensions: extensions,
84+
validateUTF8: true
85+
),
7786
logger: logger
7887
) { inbound, outbound, _ in
7988
for try await msg in inbound.messages(maxSize: .max) {
@@ -88,7 +97,11 @@ final class AutobahnTests: XCTestCase {
8897

8998
// get case status
9099
let status = try await getValue("getCaseStatus?case=\(index)&agent=swift-websocket", as: CaseStatus.self)
91-
XCTAssert(status.behavior == "OK" || status.behavior == "INFORMATIONAL")
100+
XCTAssert(status.behavior == "OK" || status.behavior == "INFORMATIONAL" || status.behavior == "NON-STRICT")
101+
}
102+
103+
try await WebSocketClient.connect(url: .init("ws://\(self.autobahnServer):9001/updateReports?agent=HB"), logger: logger) { inbound, _, _ in
104+
for try await _ in inbound {}
92105
}
93106
} catch let error as NIOConnectionError {
94107
logger.error("Autobahn tests require a running Autobahn fuzzing server. Run ./scripts/autobahn-server.sh")
@@ -119,15 +132,21 @@ final class AutobahnTests: XCTestCase {
119132
}
120133

121134
func test_6_UTF8Handling() async throws {
122-
// UTF8 validation fails
135+
// UTF8 validation is available on swift 5.10 or earlier
136+
#if compiler(<6)
123137
try XCTSkipIf(true)
138+
#endif
124139
try await self.autobahnTests(cases: .init(65..<210))
125140
}
126141

127142
func test_7_CloseHandling() async throws {
143+
// UTF8 validation is available on swift 5.10 or earlier
144+
#if compiler(<6)
128145
try await self.autobahnTests(cases: .init(210..<222))
129-
// UTF8 validation fails so skip 222
130146
try await self.autobahnTests(cases: .init(223..<247))
147+
#else
148+
try await self.autobahnTests(cases: .init(210..<247))
149+
#endif
131150
}
132151

133152
func test_9_Performance() async throws {

0 commit comments

Comments
 (0)