diff --git a/src/quic/packet.zig b/src/quic/packet.zig index 2db6153..4da6b68 100644 --- a/src/quic/packet.zig +++ b/src/quic/packet.zig @@ -308,6 +308,24 @@ pub fn parseIncoming(bytes: []const u8) void { _ = bytes; } +fn buildHeaderBytes(header_bytes_buf: []u8, first_byte: u8, packet_buf: []const u8, packet_start: usize, header_end: usize) error{InvalidPacket}![]const u8 { + if (header_end <= packet_start or header_end > packet_buf.len) { + return error.InvalidPacket; + } + + const header_len = header_end - packet_start; + if (header_len > header_bytes_buf.len) { + return error.InvalidPacket; + } + + header_bytes_buf[0] = first_byte; + if (header_len > 1) { + @memcpy(header_bytes_buf[1..header_len], packet_buf[(packet_start + 1)..header_end]); + } + + return header_bytes_buf[0..header_len]; +} + pub fn decrypt(header: *Header, fbs: anytype, space: PacketNumSpace) ![]u8 { // We need at least 4 bytes for packet number + 16 for sample if (fbs.pos + 4 + crypto.SAMPLE_LEN > fbs.buffer.len) { @@ -375,19 +393,20 @@ pub fn decrypt(header: *Header, fbs: anytype, space: PacketNumSpace) ![]u8 { // Skip the packet number bytes in the buffer try fbs.seekBy(@intCast(header.packet_number_len)); + if (header.remainder_len < header.packet_number_len) { + return error.InvalidPacket; + } const payload_len = header.remainder_len - header.packet_number_len; + if (payload_len > (fbs.buffer.len - fbs.pos)) { + return error.InvalidPacket; + } // RFC 9001 Section 5.2: AD includes the unprotected first byte and everything up to and including packet number // For coalesced packets, use packet_start to get the correct offset within the buffer const pkt_start = header.packet_start; - const header_len = fbs.pos - pkt_start; // total header length including packet number var header_bytes_buf: [512]u8 = undefined; - // Copy first byte as unprotected - header_bytes_buf[0] = first_byte; - // Copy the rest of the header and packet number (from byte after first to current position) - @memcpy(header_bytes_buf[1..][0..(header_len - 1)], fbs.buffer[(pkt_start + 1)..fbs.pos]); - const header_bytes = header_bytes_buf[0..header_len]; - const encrypted_payload = fbs.buffer[(fbs.pos)..(fbs.pos + payload_len)]; + const header_bytes = try buildHeaderBytes(&header_bytes_buf, first_byte, fbs.buffer, pkt_start, fbs.pos); + const encrypted_payload = fbs.buffer[fbs.pos..(fbs.pos + payload_len)]; // Decode packet number header.packet_number = decodePacketNumber(space.next_packet_number, truncated_packet_number, header.packet_number_len * 8); @@ -457,16 +476,19 @@ pub fn decryptWithKeyUpdate(header: *Header, fbs: anytype, space: *PacketNumSpac try fbs.seekBy(@intCast(header.packet_number_len)); + if (header.remainder_len < header.packet_number_len) { + return error.InvalidPacket; + } const payload_len = header.remainder_len - header.packet_number_len; + if (payload_len > (fbs.buffer.len - fbs.pos)) { + return error.InvalidPacket; + } // Build associated data const pkt_start = header.packet_start; - const header_len = fbs.pos - pkt_start; var header_bytes_buf: [512]u8 = undefined; - header_bytes_buf[0] = first_byte; - @memcpy(header_bytes_buf[1..][0..(header_len - 1)], fbs.buffer[(pkt_start + 1)..fbs.pos]); - const header_bytes = header_bytes_buf[0..header_len]; - const encrypted_payload = fbs.buffer[(fbs.pos)..(fbs.pos + payload_len)]; + const header_bytes = try buildHeaderBytes(&header_bytes_buf, first_byte, fbs.buffer, pkt_start, fbs.pos); + const encrypted_payload = fbs.buffer[fbs.pos..(fbs.pos + payload_len)]; // Decode packet number header.packet_number = decodePacketNumber(space.next_packet_number, truncated_packet_number, header.packet_number_len * 8); @@ -1061,6 +1083,105 @@ fn decodePacketNumber(expected_pkt_num: u64, truncated_pkt_num: u64, num_bits: u return candidate; } +fn testOpen() crypto.Open { + return .{ + .key = .{0} ** crypto.max_key_len, + .hp_key = .{0} ** crypto.max_key_len, + .nonce = .{0} ** crypto.nonce_len, + }; +} + +fn testKeyUpdateManager() crypto.KeyUpdateManager { + return crypto.KeyUpdateManager.init( + .{0} ** 32, + .{0} ** 32, + .{0} ** crypto.max_key_len, + .{0} ** crypto.max_key_len, + ); +} + +fn initProtectedShortHeaderPacket(buffer: []u8, hp_key: [crypto.max_key_len]u8, unprotected_first_byte: u8, truncated_packet_number: u8) void { + @memset(buffer, 0); + + const pn_offset = 1; + const sample: *const [crypto.SAMPLE_LEN]u8 = buffer[(pn_offset + 4)..][0..crypto.SAMPLE_LEN]; + const mask = crypto.computeHpMask(sample, hp_key, .aes_128_gcm_sha256); + + buffer[0] = unprotected_first_byte ^ (mask[0] & 0x1f); + buffer[pn_offset] = truncated_packet_number ^ mask[1]; +} + +test "QUIC: buildHeaderBytes rejects oversized associated data headers" { + var header_bytes_buf: [512]u8 = undefined; + var packet_buf: [513]u8 = .{0} ** 513; + + try std.testing.expectError( + error.InvalidPacket, + buildHeaderBytes(&header_bytes_buf, FIXED_BIT, &packet_buf, 0, packet_buf.len), + ); +} + +test "QUIC: buildHeaderBytes keeps the unprotected first byte" { + var header_bytes_buf: [8]u8 = undefined; + const packet_buf = [_]u8{ 0xff, 0xaa, 0xbb, 0xcc }; + + const header_bytes = try buildHeaderBytes(&header_bytes_buf, FIXED_BIT, &packet_buf, 0, packet_buf.len); + + try std.testing.expectEqualSlices(u8, &[_]u8{ FIXED_BIT, 0xaa, 0xbb, 0xcc }, header_bytes); +} + +test "QUIC: decrypt rejects malformed payload lengths" { + const open = testOpen(); + const space = PacketNumSpace{ .crypto_open = open }; + const unprotected_first_byte = FIXED_BIT; + + { + var buf: [32]u8 = undefined; + initProtectedShortHeaderPacket(&buf, open.hp_key, unprotected_first_byte, 0); + var fbs = io.fixedBufferStream(&buf); + fbs.pos = 1; + var header = Header{ .packet_start = 0, .remainder_len = 0 }; + + try std.testing.expectError(error.InvalidPacket, decrypt(&header, &fbs, space)); + } + + { + var buf: [32]u8 = undefined; + initProtectedShortHeaderPacket(&buf, open.hp_key, unprotected_first_byte, 0); + var fbs = io.fixedBufferStream(&buf); + fbs.pos = 1; + var header = Header{ .packet_start = 0, .remainder_len = buf.len }; + + try std.testing.expectError(error.InvalidPacket, decrypt(&header, &fbs, space)); + } +} + +test "QUIC: decryptWithKeyUpdate rejects malformed payload lengths" { + var ku = testKeyUpdateManager(); + var space = PacketNumSpace{}; + const unprotected_first_byte = FIXED_BIT; + + { + var buf: [32]u8 = undefined; + initProtectedShortHeaderPacket(&buf, ku.hp_open, unprotected_first_byte, 0); + var fbs = io.fixedBufferStream(&buf); + fbs.pos = 1; + var header = Header{ .packet_start = 0, .remainder_len = 0 }; + + try std.testing.expectError(error.InvalidPacket, decryptWithKeyUpdate(&header, &fbs, &space, &ku)); + } + + { + var buf: [32]u8 = undefined; + initProtectedShortHeaderPacket(&buf, ku.hp_open, unprotected_first_byte, 0); + var fbs = io.fixedBufferStream(&buf); + fbs.pos = 1; + var header = Header{ .packet_start = 0, .remainder_len = buf.len }; + + try std.testing.expectError(error.InvalidPacket, decryptWithKeyUpdate(&header, &fbs, &space, &ku)); + } +} + // Retry token roundtrip test test "Retry token: generate and validate roundtrip" { var token_key: [crypto.key_len]u8 = undefined;