Skip to content

Commit 759587c

Browse files
Add test to PostgresConnectionTests to test prepared statements
1 parent 261a531 commit 759587c

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ extension PostgresFrontendMessage {
142142
}
143143

144144
let parameters = (0..<parameterCount).map { _ -> ByteBuffer? in
145-
let length = buffer.readInteger(as: UInt16.self)
145+
let length = buffer.readInteger(as: UInt32.self)
146146
switch length {
147147
case .some(..<0):
148148
return nil

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,97 @@ class PostgresConnectionTests: XCTestCase {
275275
}
276276
}
277277

278+
struct TestPrepareStatement: PostgresPreparedStatement {
279+
static var sql = "SELECT datname FROM pg_stat_activity WHERE state = $1"
280+
typealias Row = String
281+
282+
var state: String
283+
284+
func makeBindings() -> PostgresBindings {
285+
var bindings = PostgresBindings()
286+
bindings.append(.init(string: self.state))
287+
return bindings
288+
}
289+
290+
func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
291+
try row.decode(Row.self)
292+
}
293+
}
294+
295+
func testPreparedStatement() async throws {
296+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
297+
298+
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
299+
taskGroup.addTask {
300+
let preparedStatement = TestPrepareStatement(state: "active")
301+
let result = try await connection.execute(preparedStatement, logger: .psqlTest)
302+
var rows = 0
303+
for try await database in result {
304+
rows += 1
305+
XCTAssertEqual("test_database", database)
306+
}
307+
XCTAssertEqual(rows, 1)
308+
}
309+
// Wait for the PREPARE request from the client
310+
guard case .parse(let parse) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else {
311+
fatalError("Unexpected message")
312+
}
313+
XCTAssertEqual(parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1")
314+
XCTAssertEqual(parse.parameters.count, 0)
315+
guard case .describe(.preparedStatement(let name)) = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else {
316+
fatalError("Unexpected message")
317+
}
318+
XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self))
319+
guard case .sync = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else {
320+
fatalError("Unexpected message")
321+
}
322+
323+
// Respond to the PREPARE request
324+
try await channel.writeInbound(PostgresBackendMessage.parseComplete)
325+
try await channel.testingEventLoop.executeInContext { channel.read() }
326+
try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: [
327+
PostgresDataType.text
328+
])))
329+
try await channel.testingEventLoop.executeInContext { channel.read() }
330+
let rowDescription = RowDescription(columns: [
331+
.init(
332+
name: "datname",
333+
tableOID: 12222,
334+
columnAttributeNumber: 2,
335+
dataType: .name,
336+
dataTypeSize: 64,
337+
dataTypeModifier: -1,
338+
format: .text
339+
)
340+
])
341+
try await channel.writeInbound(PostgresBackendMessage.rowDescription(rowDescription))
342+
try await channel.testingEventLoop.executeInContext { channel.read() }
343+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
344+
try await channel.testingEventLoop.executeInContext { channel.read() }
345+
346+
// Wait for the EXECUTE request
347+
guard case .bind = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else {
348+
fatalError("Unexpected message")
349+
}
350+
guard case .execute = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else {
351+
fatalError("Unexpected message")
352+
}
353+
guard case .sync = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) else {
354+
fatalError("Unexpected message")
355+
}
356+
// Respond to the EXECUTE request
357+
try await channel.writeInbound(PostgresBackendMessage.bindComplete)
358+
try await channel.testingEventLoop.executeInContext { channel.read() }
359+
let dataRow = DataRow(arrayLiteral: "test_database")
360+
try await channel.writeInbound(PostgresBackendMessage.dataRow(dataRow))
361+
try await channel.testingEventLoop.executeInContext { channel.read() }
362+
try await channel.writeInbound(PostgresBackendMessage.commandComplete(TestPrepareStatement.sql))
363+
try await channel.testingEventLoop.executeInContext { channel.read() }
364+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
365+
try await channel.testingEventLoop.executeInContext { channel.read() }
366+
}
367+
}
368+
278369
func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
279370
let eventLoop = NIOAsyncTestingEventLoop()
280371
let channel = await NIOAsyncTestingChannel(handlers: [

0 commit comments

Comments
 (0)