diff --git a/Sources/HTTPServer/HTTPRequestConcludingAsyncReader.swift b/Sources/HTTPServer/HTTPRequestConcludingAsyncReader.swift index dbb6646..984a42e 100644 --- a/Sources/HTTPServer/HTTPRequestConcludingAsyncReader.swift +++ b/Sources/HTTPServer/HTTPRequestConcludingAsyncReader.swift @@ -41,8 +41,93 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable /// The HTTP trailer fields captured at the end of the request. fileprivate var state: ReaderState - /// The iterator that provides HTTP request parts from the underlying channel. - private var iterator: NIOAsyncChannelInboundStream.AsyncIterator + struct RequestBodyStateMachine { + enum State { + // The request body is currently being read: expecting more request body parts or a request end part. + case readingBody(ReadingBodyState) + + // The request end part was received. We have finished. + case finished + + enum ReadingBodyState { + // Not yet received any request body parts + case initial + + // `read` was called with a `maximumCount` value that was lower than the bytes available. The excess + // bytes are stored here so they can be dispensed in future calls to `read`. + case excess(ByteBuffer) + + // No excess bytes currently needing to be stored + case noExcess + } + } + + private var state: State + + /// The iterator that provides HTTP request parts from the underlying channel. + private var iterator: NIOAsyncChannelInboundStream.AsyncIterator + + init(iterator: NIOAsyncChannelInboundStream.AsyncIterator) { + self.state = .readingBody(.initial) + self.iterator = iterator + } + + enum ReadResult { + case readBody(ByteBuffer) + case readEnd(HTTPFields?) + case streamFinished + } + + mutating func read(limit: Int?) async throws -> ReadResult { + switch self.state { + case .readingBody(let readingBodyState): + var bodyElement: ByteBuffer + + switch readingBodyState { + case .excess(let excessElement): + // There was an excess of bytes from the previous call to `read`. We read directly from this + // excess and don't advance the iterator. + bodyElement = excessElement + + case .initial, .noExcess: + // There is no excess from previous reads. We obtain the next element from the stream. + let requestPart = try await self.iterator.next(isolation: #isolation) + + switch requestPart { + case .head: + fatalError("Unexpectedly received a request head.") + + case .none: + fatalError("The stream unexpectedly ended before receiving a request end.") + + case .body(let element): + bodyElement = element + + case .end(let trailers): + self.state = .finished + return .readEnd(trailers) + } + } + + if let limit, limit < bodyElement.readableBytes, + let truncated = bodyElement.readSlice(length: limit) + { + // There are more bytes available than `limit`. We must store the excess in a buffer for it to + // be consumed in the next call to `read`. + self.state = .readingBody(.excess(bodyElement)) + return .readBody(truncated) + } + + self.state = .readingBody(.noExcess) + return .readBody(bodyElement) + + case .finished: + return .streamFinished + } + } + } + + var requestBodyStateMachine: RequestBodyStateMachine /// Initializes a new request body reader with the given NIO async channel iterator. /// @@ -51,7 +136,7 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable iterator: consuming sending NIOAsyncChannelInboundStream.AsyncIterator, readerState: ReaderState ) { - self.iterator = iterator + self.requestBodyStateMachine = .init(iterator: iterator) self.state = readerState } @@ -65,26 +150,26 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable maximumCount: Int?, body: nonisolated(nonsending) (consuming Span) async throws(Failure) -> Return ) async throws(EitherError) -> Return { - let requestPart: HTTPRequestPart? + let readResult: RequestBodyStateMachine.ReadResult do { - requestPart = try await self.iterator.next(isolation: #isolation) + readResult = try await self.requestBodyStateMachine.read(limit: maximumCount) } catch { throw .first(error) } do { - switch requestPart { - case .head: - fatalError() - case .body(let element): - return try await body(Array(buffer: element).span) - case .end(let trailers): + switch readResult { + case .readBody(let readElement): + return try await body(Array(buffer: readElement).span) + + case .readEnd(let trailers): self.state.wrapped.withLock { state in state.trailers = trailers state.finishedReading = true } return try await body(.init()) - case .none: + + case .streamFinished: return try await body(.init()) } } catch { diff --git a/Tests/HTTPServerTests/HTTPRequestConcludingAsyncReaderTests.swift b/Tests/HTTPServerTests/HTTPRequestConcludingAsyncReaderTests.swift index 82baa80..75d5570 100644 --- a/Tests/HTTPServerTests/HTTPRequestConcludingAsyncReaderTests.swift +++ b/Tests/HTTPServerTests/HTTPRequestConcludingAsyncReaderTests.swift @@ -42,7 +42,32 @@ struct HTTPRequestConcludingAsyncReaderTests { _ = try await requestReader.consumeAndConclude { bodyReader in var bodyReader = bodyReader - try await bodyReader.read(maximumCount: nil) { element in () } + try await bodyReader.read(maximumCount: nil) { _ in } + } + } + } + + @Test("Stream cannot be finished before writing request end part") + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + func testNotWritingRequestEndPartFatalError() async throws { + await #expect(processExitsWith: .failure) { + let (stream, source) = NIOAsyncChannelInboundStream.makeTestingStream() + + // Only write a request body part; do not write an end part. + source.yield(.body(.init())) + source.finish() + + let requestReader = HTTPRequestConcludingAsyncReader( + iterator: stream.makeAsyncIterator(), + readerState: .init() + ) + + _ = try await requestReader.consumeAndConclude { bodyReader in + var bodyReader = bodyReader + + try await bodyReader.read(maximumCount: nil) { _ in } + // The stream has finished without an end part. Calling `read` now should result in a fatal error. + try await bodyReader.read(maximumCount: nil) { _ in } } } } @@ -172,26 +197,125 @@ struct HTTPRequestConcludingAsyncReaderTests { @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) @Test("More bytes available than consumption limit") func testCollectMoreBytesThanAvailable() async throws { - await #expect(processExitsWith: .failure) { - let (stream, source) = NIOAsyncChannelInboundStream.makeTestingStream() + let (stream, source) = NIOAsyncChannelInboundStream.makeTestingStream() - // Write 10 bytes - source.yield(.body(.init(repeating: 5, count: 10))) - source.finish() + // Write 10 bytes + source.yield(.body(.init(repeating: 5, count: 10))) + source.finish() - let requestReader = HTTPRequestConcludingAsyncReader( - iterator: stream.makeAsyncIterator(), - readerState: .init() - ) + let requestReader = HTTPRequestConcludingAsyncReader( + iterator: stream.makeAsyncIterator(), + readerState: .init() + ) + + _ = try await requestReader.consumeAndConclude { requestBodyReader in + var requestBodyReader = requestBodyReader + + // There are more bytes available than our limit. + let collected = try await requestBodyReader.collect(upTo: 9) { element in + var buffer = ByteBuffer() + buffer.writeBytes(element.bytes) + return buffer + } + + // We should only collect up to the limit (the first 9 bytes). + #expect(collected == .init(repeating: 5, count: 9)) + } + } + + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + @Test("Multiple body chunks; multiple reads with limits") + func testReadWithLimits() async throws { + let (stream, source) = NIOAsyncChannelInboundStream.makeTestingStream() + + // First write 10 bytes; + source.yield(.body(.init(repeating: 1, count: 10))) + // Then write another 5 bytes. + source.yield(.body(.init(repeating: 2, count: 5))) + source.yield(.end(nil)) + source.finish() + + let streamIterator = stream.makeAsyncIterator() + + let requestReader = HTTPRequestConcludingAsyncReader(iterator: streamIterator, readerState: .init()) + _ = try await requestReader.consumeAndConclude { requestBodyReader in + var requestBodyReader = requestBodyReader + + // Collect 8 bytes (partial of first write). + let collectedPartOne = try await requestBodyReader.collect(upTo: 8) { element in + var buffer = ByteBuffer() + buffer.writeBytes(element.bytes) + return buffer + } + + // Then collect 4 more bytes (overlap of first and second write). + let collectedPartTwo = try await requestBodyReader.collect(upTo: 4) { element in + var buffer = ByteBuffer() + buffer.writeBytes(element.bytes) + return buffer + } + + // Then collect 3 more bytes (partial of second write). + let collectedPartThree = try await requestBodyReader.collect(upTo: 3) { element in + var buffer = ByteBuffer() + buffer.writeBytes(element.bytes) + return buffer + } - _ = try await requestReader.consumeAndConclude { requestBodyReader in - var requestBodyReader = requestBodyReader + #expect(collectedPartOne == .init(repeating: 1, count: 8)) + #expect(collectedPartTwo == .init([1, 1, 2, 2])) + #expect(collectedPartThree == .init(repeating: 2, count: 3)) + } + } + + @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) + @Test("Multiple random-length chunks; multiple reads with random limits") + func testMultipleReadsWithRandomLimits() async throws { + let (stream, source) = NIOAsyncChannelInboundStream.makeTestingStream() + + // Generate random ByteBuffers of varying length and write them to the stream. + var randomBuffer = ByteBuffer() + for _ in 0..<100 { + let randomNumber = UInt8.random(in: 1...50) + let randomCount = Int.random(in: 1...50) + + let randomData = ByteBuffer(repeating: randomNumber, count: randomCount) + // Store the data so we can track what we have wrote + randomBuffer.writeImmutableBuffer(randomData) + + source.yield(.body(randomData)) + } + source.yield(.end(nil)) + source.finish() + + let streamIterator = stream.makeAsyncIterator() + + let requestReader = HTTPRequestConcludingAsyncReader(iterator: streamIterator, readerState: .init()) + _ = try await requestReader.consumeAndConclude { requestBodyReader in + var requestBodyReader = requestBodyReader - // Since there are more bytes than requested, this should fail. - try await requestBodyReader.collect(upTo: 9) { element in - () + var collectedBuffer = ByteBuffer() + while true { + let randomMaxCount = Int.random(in: 1...100) + + let collected = try await requestBodyReader.collect(upTo: randomMaxCount) { element in + var localBuffer = ByteBuffer() + localBuffer.writeBytes(element.bytes) + return localBuffer } + + if collected.readableBytes == 0 { + break + } + + // The collected buffer should never exceed the specified max count. + try #require(collected.readableBytes <= randomMaxCount) + + collectedBuffer.writeImmutableBuffer(collected) } + + // Check if the collected buffer exactly matches what was written to the stream. + try #require(randomBuffer == collectedBuffer) } } }