Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ impl<'a> Decompressor<'a> {
///
/// # Panics
///
/// If the provided symbol table has length greater than or equal to [`FSST_CODE_BASE`]
/// If the provided symbol table has length greater than or equal to [`FSST_CODE_BASE`],
/// or if the symbols and lengths tables do not have the same length.
pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self {
assert_eq!(symbols.len(), lengths.len(), "symbols and lengths differ");
assert!(
symbols.len() < FSST_CODE_BASE as usize,
"symbol table cannot have size exceeding 255"
Expand Down Expand Up @@ -303,7 +305,7 @@ impl<'a> Decompressor<'a> {
($code:expr) => {{
out_ptr
.cast::<u64>()
.write_unaligned(self.symbols.get_unchecked($code as usize).to_u64());
.write_unaligned(self.symbols[$code as usize].to_u64());
out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize);
}};
}
Expand Down Expand Up @@ -512,12 +514,12 @@ impl<'a> Decompressor<'a> {
in_ptr = in_ptr.add(1);
out_ptr = out_ptr.add(1);
} else {
let sym = self.symbols[code as usize].to_u64();
let len = *self.lengths.get_unchecked(code as usize) as usize;
assert!(
out_end.offset_from(out_ptr) >= len as isize,
"output buffer sized too small"
);
let sym = self.symbols.get_unchecked(code as usize).to_u64();
let sym_bytes = sym.to_le_bytes();
std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len);
out_ptr = out_ptr.add(len);
Expand Down
22 changes: 22 additions & 0 deletions tests/correctness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ fn test_all_escape_roundtrip() {
assert_eq!(decompressor.decompress(&compressed), input);
}

#[test]
#[should_panic]
fn test_invalid_code_not_in_symbol_table_panics() {
let compressor = CompressorBuilder::new().build();
let decompressor = compressor.decompressor();

// Empty symbol table: code 0 is malformed input, not a valid symbol code.
// Use more than 8 bytes so the unrolled decode loop is exercised.
let _ = decompressor.decompress(&[0; 9]);
}

#[test]
#[should_panic]
fn test_invalid_tail_code_not_in_symbol_table_panics() {
let compressor = CompressorBuilder::new().build();
let decompressor = compressor.decompressor();
let mut decoded = [];

// A one-byte malformed input reaches the final byte-copy fallback path.
let _ = decompressor.decompress_into(&[0], &mut decoded);
}

#[test]
fn test_large_with_rebuild() {
let corpus: Vec<u8> = DECLARATION.bytes().cycle().take(10_240).collect();
Expand Down
Loading