Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 133 additions & 12 deletions src/quic/packet.zig
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,24 @@ pub fn parseIncoming(bytes: []const u8) void {
_ = bytes;
}

fn buildHeaderBytes(header_bytes_buf: []u8, first_byte: u8, packet_buf: []const u8, packet_start: usize, header_end: usize) error{InvalidPacket}![]const u8 {
if (header_end <= packet_start or header_end > packet_buf.len) {
return error.InvalidPacket;
}

const header_len = header_end - packet_start;
if (header_len > header_bytes_buf.len) {
return error.InvalidPacket;
}

header_bytes_buf[0] = first_byte;
if (header_len > 1) {
@memcpy(header_bytes_buf[1..header_len], packet_buf[(packet_start + 1)..header_end]);
}

return header_bytes_buf[0..header_len];
}

pub fn decrypt(header: *Header, fbs: anytype, space: PacketNumSpace) ![]u8 {
// We need at least 4 bytes for packet number + 16 for sample
if (fbs.pos + 4 + crypto.SAMPLE_LEN > fbs.buffer.len) {
Expand Down Expand Up @@ -375,19 +393,20 @@ pub fn decrypt(header: *Header, fbs: anytype, space: PacketNumSpace) ![]u8 {
// Skip the packet number bytes in the buffer
try fbs.seekBy(@intCast(header.packet_number_len));

if (header.remainder_len < header.packet_number_len) {
return error.InvalidPacket;
}
const payload_len = header.remainder_len - header.packet_number_len;
if (payload_len > (fbs.buffer.len - fbs.pos)) {
return error.InvalidPacket;
}

// RFC 9001 Section 5.2: AD includes the unprotected first byte and everything up to and including packet number
// For coalesced packets, use packet_start to get the correct offset within the buffer
const pkt_start = header.packet_start;
const header_len = fbs.pos - pkt_start; // total header length including packet number
var header_bytes_buf: [512]u8 = undefined;
// Copy first byte as unprotected
header_bytes_buf[0] = first_byte;
// Copy the rest of the header and packet number (from byte after first to current position)
@memcpy(header_bytes_buf[1..][0..(header_len - 1)], fbs.buffer[(pkt_start + 1)..fbs.pos]);
const header_bytes = header_bytes_buf[0..header_len];
const encrypted_payload = fbs.buffer[(fbs.pos)..(fbs.pos + payload_len)];
const header_bytes = try buildHeaderBytes(&header_bytes_buf, first_byte, fbs.buffer, pkt_start, fbs.pos);
const encrypted_payload = fbs.buffer[fbs.pos..(fbs.pos + payload_len)];

// Decode packet number
header.packet_number = decodePacketNumber(space.next_packet_number, truncated_packet_number, header.packet_number_len * 8);
Expand Down Expand Up @@ -457,16 +476,19 @@ pub fn decryptWithKeyUpdate(header: *Header, fbs: anytype, space: *PacketNumSpac

try fbs.seekBy(@intCast(header.packet_number_len));

if (header.remainder_len < header.packet_number_len) {
return error.InvalidPacket;
}
const payload_len = header.remainder_len - header.packet_number_len;
if (payload_len > (fbs.buffer.len - fbs.pos)) {
return error.InvalidPacket;
}

// Build associated data
const pkt_start = header.packet_start;
const header_len = fbs.pos - pkt_start;
var header_bytes_buf: [512]u8 = undefined;
header_bytes_buf[0] = first_byte;
@memcpy(header_bytes_buf[1..][0..(header_len - 1)], fbs.buffer[(pkt_start + 1)..fbs.pos]);
const header_bytes = header_bytes_buf[0..header_len];
const encrypted_payload = fbs.buffer[(fbs.pos)..(fbs.pos + payload_len)];
const header_bytes = try buildHeaderBytes(&header_bytes_buf, first_byte, fbs.buffer, pkt_start, fbs.pos);
const encrypted_payload = fbs.buffer[fbs.pos..(fbs.pos + payload_len)];

// Decode packet number
header.packet_number = decodePacketNumber(space.next_packet_number, truncated_packet_number, header.packet_number_len * 8);
Expand Down Expand Up @@ -1061,6 +1083,105 @@ fn decodePacketNumber(expected_pkt_num: u64, truncated_pkt_num: u64, num_bits: u
return candidate;
}

fn testOpen() crypto.Open {
return .{
.key = .{0} ** crypto.max_key_len,
.hp_key = .{0} ** crypto.max_key_len,
.nonce = .{0} ** crypto.nonce_len,
};
}

fn testKeyUpdateManager() crypto.KeyUpdateManager {
return crypto.KeyUpdateManager.init(
.{0} ** 32,
.{0} ** 32,
.{0} ** crypto.max_key_len,
.{0} ** crypto.max_key_len,
);
}

fn initProtectedShortHeaderPacket(buffer: []u8, hp_key: [crypto.max_key_len]u8, unprotected_first_byte: u8, truncated_packet_number: u8) void {
@memset(buffer, 0);

const pn_offset = 1;
const sample: *const [crypto.SAMPLE_LEN]u8 = buffer[(pn_offset + 4)..][0..crypto.SAMPLE_LEN];
const mask = crypto.computeHpMask(sample, hp_key, .aes_128_gcm_sha256);

buffer[0] = unprotected_first_byte ^ (mask[0] & 0x1f);
buffer[pn_offset] = truncated_packet_number ^ mask[1];
}

test "QUIC: buildHeaderBytes rejects oversized associated data headers" {
var header_bytes_buf: [512]u8 = undefined;
var packet_buf: [513]u8 = .{0} ** 513;

try std.testing.expectError(
error.InvalidPacket,
buildHeaderBytes(&header_bytes_buf, FIXED_BIT, &packet_buf, 0, packet_buf.len),
);
}

test "QUIC: buildHeaderBytes keeps the unprotected first byte" {
var header_bytes_buf: [8]u8 = undefined;
const packet_buf = [_]u8{ 0xff, 0xaa, 0xbb, 0xcc };

const header_bytes = try buildHeaderBytes(&header_bytes_buf, FIXED_BIT, &packet_buf, 0, packet_buf.len);

try std.testing.expectEqualSlices(u8, &[_]u8{ FIXED_BIT, 0xaa, 0xbb, 0xcc }, header_bytes);
}

test "QUIC: decrypt rejects malformed payload lengths" {
const open = testOpen();
const space = PacketNumSpace{ .crypto_open = open };
const unprotected_first_byte = FIXED_BIT;

{
var buf: [32]u8 = undefined;
initProtectedShortHeaderPacket(&buf, open.hp_key, unprotected_first_byte, 0);
var fbs = io.fixedBufferStream(&buf);
fbs.pos = 1;
var header = Header{ .packet_start = 0, .remainder_len = 0 };

try std.testing.expectError(error.InvalidPacket, decrypt(&header, &fbs, space));
}

{
var buf: [32]u8 = undefined;
initProtectedShortHeaderPacket(&buf, open.hp_key, unprotected_first_byte, 0);
var fbs = io.fixedBufferStream(&buf);
fbs.pos = 1;
var header = Header{ .packet_start = 0, .remainder_len = buf.len };

try std.testing.expectError(error.InvalidPacket, decrypt(&header, &fbs, space));
}
}

test "QUIC: decryptWithKeyUpdate rejects malformed payload lengths" {
var ku = testKeyUpdateManager();
var space = PacketNumSpace{};
const unprotected_first_byte = FIXED_BIT;

{
var buf: [32]u8 = undefined;
initProtectedShortHeaderPacket(&buf, ku.hp_open, unprotected_first_byte, 0);
var fbs = io.fixedBufferStream(&buf);
fbs.pos = 1;
var header = Header{ .packet_start = 0, .remainder_len = 0 };

try std.testing.expectError(error.InvalidPacket, decryptWithKeyUpdate(&header, &fbs, &space, &ku));
}

{
var buf: [32]u8 = undefined;
initProtectedShortHeaderPacket(&buf, ku.hp_open, unprotected_first_byte, 0);
var fbs = io.fixedBufferStream(&buf);
fbs.pos = 1;
var header = Header{ .packet_start = 0, .remainder_len = buf.len };

try std.testing.expectError(error.InvalidPacket, decryptWithKeyUpdate(&header, &fbs, &space, &ku));
}
}

// Retry token roundtrip test
test "Retry token: generate and validate roundtrip" {
var token_key: [crypto.key_len]u8 = undefined;
Expand Down