From b0f2fbd57ebe3e7fed29e046d72fd0a807e5244b Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Wed, 24 Jun 2026 00:45:51 +0100 Subject: [PATCH 1/4] Handle codes not in the dictionary Signed-off-by: Robert Kruszewski --- src/lib.rs | 8 +++++--- tests/correctness.rs | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8a3fe94..8fe5c1a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -238,8 +238,10 @@ impl<'a> Decompressor<'a> { /// /// # Panics /// - /// If the provided symbol table has length greater than or equal to [`FSST_CODE_BASE`] + /// If the provided symbol table has length greater than or equal to [`FSST_CODE_BASE`], + /// or if the symbols and lengths tables do not have the same length. pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self { + assert_eq!(symbols.len(), lengths.len(), "symbols and lengths differ"); assert!( symbols.len() < FSST_CODE_BASE as usize, "symbol table cannot have size exceeding 255" @@ -303,7 +305,7 @@ impl<'a> Decompressor<'a> { ($code:expr) => {{ out_ptr .cast::() - .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64()); + .write_unaligned(self.symbols[$code as usize].to_u64()); out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize); }}; } @@ -512,12 +514,12 @@ impl<'a> Decompressor<'a> { in_ptr = in_ptr.add(1); out_ptr = out_ptr.add(1); } else { + let sym = self.symbols[code as usize].to_u64(); let len = *self.lengths.get_unchecked(code as usize) as usize; assert!( out_end.offset_from(out_ptr) >= len as isize, "output buffer sized too small" ); - let sym = self.symbols.get_unchecked(code as usize).to_u64(); let sym_bytes = sym.to_le_bytes(); std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len); out_ptr = out_ptr.add(len); diff --git a/tests/correctness.rs b/tests/correctness.rs index daf80b3..24bc222 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -97,6 +97,28 @@ fn test_all_escape_roundtrip() { assert_eq!(decompressor.decompress(&compressed), input); } +#[test] +#[should_panic] +fn test_invalid_code_not_in_symbol_table_panics() { + let compressor = CompressorBuilder::new().build(); + let decompressor = compressor.decompressor(); + + // Empty symbol table: code 0 is malformed input, not a valid symbol code. + // Use more than 8 bytes so the unrolled decode loop is exercised. + let _ = decompressor.decompress(&[0; 9]); +} + +#[test] +#[should_panic] +fn test_invalid_tail_code_not_in_symbol_table_panics() { + let compressor = CompressorBuilder::new().build(); + let decompressor = compressor.decompressor(); + let mut decoded = []; + + // A one-byte malformed input reaches the final byte-copy fallback path. + let _ = decompressor.decompress_into(&[0], &mut decoded); +} + #[test] fn test_large_with_rebuild() { let corpus: Vec = DECLARATION.bytes().cycle().take(10_240).collect(); From 6705698f2bd251a75fbd619d68f2e1102f27bd4d Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Wed, 24 Jun 2026 12:07:13 +0100 Subject: [PATCH 2/4] better Signed-off-by: Robert Kruszewski --- src/builder.rs | 19 ++++++++------ src/lib.rs | 62 ++++++++++++++++++++++---------------------- tests/correctness.rs | 7 +++-- 3 files changed, 45 insertions(+), 43 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 16e59ef..6dca204 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -382,7 +382,7 @@ impl CompressorBuilder { /// /// Also returns the lengths vector, which is of length `n_symbols` and contains the /// length for each of the values. - fn finalize(&mut self) -> (u8, Vec) { + fn finalize(&mut self) -> (u8, [u8; 255]) { // Create a cumulative sum of each of the elements of the input line numbers. // Do a map that includes the previously seen value as well. // Regroup symbols based on their lengths. @@ -449,7 +449,7 @@ impl CompressorBuilder { } // Truncate the symbol table to only include the "true" symbols. - self.symbols.truncate(self.n_symbols as usize); + self.symbols.truncate(FSST_CODE_BASE as usize - 1); // Rewrite the codes_one_byte table to point at the new code values. // Replace pseudocodes with escapes. @@ -481,9 +481,9 @@ impl CompressorBuilder { self.lossy_pht.renumber(&new_codes); // Pre-compute the lengths - let mut lengths = Vec::with_capacity(self.n_symbols as usize); - for symbol in &self.symbols { - lengths.push(symbol.len() as u8); + let mut lengths = [0u8; 255]; + for (len, symbol) in lengths.iter_mut().zip(&self.symbols) { + *len = symbol.len() as u8; } (has_suffix_code, lengths) @@ -497,7 +497,10 @@ impl CompressorBuilder { let (has_suffix_code, lengths) = self.finalize(); Compressor { - symbols: self.symbols, + symbols: self + .symbols + .try_into() + .expect("Symbol table should be exactly 255 elements in length"), lengths, n_symbols: self.n_symbols, has_suffix_code, @@ -948,8 +951,8 @@ mod test { let corpus: Vec<&[u8]> = std::iter::repeat_n(text.as_slice(), 100).collect(); let compressor = Compressor::train(&corpus); - let symbols = compressor.symbol_table(); - let lengths = compressor.symbol_lengths(); + let symbols = &compressor.symbol_table()[0..compressor.n_symbols()]; + let lengths = &compressor.symbol_lengths()[0..compressor.n_symbols()]; // Collect all 1-byte symbols and check for duplicates. let one_byte: Vec = symbols diff --git a/src/lib.rs b/src/lib.rs index 8fe5c1a..ac68e76 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -227,26 +227,15 @@ impl Debug for Code { #[derive(Clone)] pub struct Decompressor<'a> { /// Slice mapping codes to symbols. - pub(crate) symbols: &'a [Symbol], + pub(crate) symbols: &'a [Symbol; 255], /// Slice containing the length of each symbol in the `symbols` slice. - pub(crate) lengths: &'a [u8], + pub(crate) lengths: &'a [u8; 255], } impl<'a> Decompressor<'a> { /// Returns a new decompressor that uses the provided symbol table. - /// - /// # Panics - /// - /// If the provided symbol table has length greater than or equal to [`FSST_CODE_BASE`], - /// or if the symbols and lengths tables do not have the same length. - pub fn new(symbols: &'a [Symbol], lengths: &'a [u8]) -> Self { - assert_eq!(symbols.len(), lengths.len(), "symbols and lengths differ"); - assert!( - symbols.len() < FSST_CODE_BASE as usize, - "symbol table cannot have size exceeding 255" - ); - + pub fn new(symbols: &'a [Symbol; 255], lengths: &'a [u8; 255]) -> Self { Self { symbols, lengths } } @@ -305,7 +294,7 @@ impl<'a> Decompressor<'a> { ($code:expr) => {{ out_ptr .cast::() - .write_unaligned(self.symbols[$code as usize].to_u64()); + .write_unaligned(self.symbols.get_unchecked($code as usize).to_u64()); out_ptr = out_ptr.add(*self.lengths.get_unchecked($code as usize) as usize); }}; } @@ -514,12 +503,12 @@ impl<'a> Decompressor<'a> { in_ptr = in_ptr.add(1); out_ptr = out_ptr.add(1); } else { - let sym = self.symbols[code as usize].to_u64(); let len = *self.lengths.get_unchecked(code as usize) as usize; assert!( out_end.offset_from(out_ptr) >= len as isize, "output buffer sized too small" ); + let sym = self.symbols.get_unchecked(code as usize).to_u64(); let sym_bytes = sym.to_le_bytes(); std::ptr::copy_nonoverlapping(sym_bytes.as_ptr(), out_ptr, len); out_ptr = out_ptr.add(len); @@ -568,10 +557,10 @@ impl<'a> Decompressor<'a> { #[derive(Clone)] pub struct Compressor { /// Table mapping codes to symbols. - pub(crate) symbols: Vec, + pub(crate) symbols: [Symbol; 255], /// Length of each symbol, values range from 1-8. - pub(crate) lengths: Vec, + pub(crate) lengths: [u8; 255], /// The number of entries in the symbol table that have been populated, not counting /// the escape values. @@ -795,40 +784,50 @@ impl Compressor { /// Returns a readonly slice of the current symbol table. /// /// The returned slice will have length of `n_symbols`. - pub fn symbol_table(&self) -> &[Symbol] { - &self.symbols[0..self.n_symbols as usize] + pub fn symbol_table(&self) -> &[Symbol; 255] { + &self.symbols } /// Returns a readonly slice where index `i` contains the /// length of the symbol represented by code `i`. /// /// Values range from 1-8. - pub fn symbol_lengths(&self) -> &[u8] { - &self.lengths[0..self.n_symbols as usize] + pub fn symbol_lengths(&self) -> &[u8; 255] { + &self.lengths + } + + /// Number of symbols present in the compressor's symbol table. + /// + /// Since the symbol table and length are padded to 255 elements, this value indicates the number of valid entries. + pub fn n_symbols(&self) -> usize { + self.n_symbols as usize } /// Rebuild a compressor from an existing symbol table. /// /// This will not attempt to optimize or re-order the codes. pub fn rebuild_from(symbols: impl AsRef<[Symbol]>, symbol_lens: impl AsRef<[u8]>) -> Self { - let symbols = symbols.as_ref(); + let symbols_slice = symbols.as_ref(); let symbol_lens = symbol_lens.as_ref(); assert_eq!( - symbols.len(), + symbols_slice.len(), symbol_lens.len(), "symbols and lengths differ" ); assert!( - symbols.len() <= 255, + symbols_slice.len() <= 255, "symbol table len must be <= 255, was {}", - symbols.len() + symbols_slice.len() ); validate_symbol_order(symbol_lens); + let n_symbols = symbols_slice.len(); // Insert the symbols in their given order into the FSST lookup structures. - let symbols = symbols.to_vec(); - let lengths = symbol_lens.to_vec(); + let mut symbols = [Symbol::ZERO; 255]; + symbols[0..n_symbols].copy_from_slice(symbols_slice); + let mut lengths = [0u8; 255]; + lengths[0..n_symbols].copy_from_slice(symbol_lens); let mut lossy_pht = LossyPHT::new(); let mut codes_one_byte = [Code::UNUSED; 256]; @@ -878,7 +877,7 @@ impl Compressor { } Compressor { - n_symbols: symbols.len() as u8, + n_symbols: n_symbols as u8, symbols, lengths, codes_two_byte, @@ -970,7 +969,8 @@ mod test { .collect(); let compressor = Compressor::rebuild_from(symbols, lens); - let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) }; + let built_symbols: &[u64] = + unsafe { mem::transmute(&compressor.symbol_table()[0..compressor.n_symbols()]) }; assert_eq!(built_symbols, symbols_u64); } @@ -999,7 +999,7 @@ mod test { builder.insert(symbol, *len as usize); } let compressor = builder.build(); - let built_symbols: &[u64] = unsafe { mem::transmute(compressor.symbol_table()) }; + let built_symbols: &[u64] = unsafe { mem::transmute(&compressor.symbol_table()[..]) }; assert_eq!(built_symbols, symbols); } } diff --git a/tests/correctness.rs b/tests/correctness.rs index 24bc222..dd8bb77 100644 --- a/tests/correctness.rs +++ b/tests/correctness.rs @@ -98,8 +98,7 @@ fn test_all_escape_roundtrip() { } #[test] -#[should_panic] -fn test_invalid_code_not_in_symbol_table_panics() { +fn test_invalid_code_not_in_symbol_works() { let compressor = CompressorBuilder::new().build(); let decompressor = compressor.decompressor(); @@ -164,7 +163,7 @@ fn test_pruning_small_input() { // merged symbol reaches 4 bytes instead of 8. #[cfg(not(miri))] assert_eq!( - compressor.symbol_table(), + &compressor.symbol_table()[0..compressor.n_symbols()], &[ Symbol::from_slice(b"aa\0\0\0\0\0\0"), Symbol::from_slice(b"aaaaaaaa"), @@ -174,7 +173,7 @@ fn test_pruning_small_input() { ); #[cfg(miri)] assert_eq!( - compressor.symbol_table(), + &compressor.symbol_table()[0..compressor.n_symbols()], &[ Symbol::from_slice(b"aa\0\0\0\0\0\0"), Symbol::from_slice(b"aaaa\0\0\0\0"), From 340ef6ee056fa6565debfdd19ae889af76364b07 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Wed, 24 Jun 2026 12:54:15 +0100 Subject: [PATCH 3/4] nit Signed-off-by: Robert Kruszewski --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index ac68e76..10e14e9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -999,7 +999,7 @@ mod test { builder.insert(symbol, *len as usize); } let compressor = builder.build(); - let built_symbols: &[u64] = unsafe { mem::transmute(&compressor.symbol_table()[..]) }; + let built_symbols: &[u64] = unsafe { mem::transmute(&compressor.symbol_table()[0..compressor.n_symbols()]) }; assert_eq!(built_symbols, symbols); } } From 302e7cba8f0bc22e2924d98a4a6388667cb731c1 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Wed, 24 Jun 2026 12:54:31 +0100 Subject: [PATCH 4/4] format Signed-off-by: Robert Kruszewski --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 10e14e9..c40339a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -999,7 +999,8 @@ mod test { builder.insert(symbol, *len as usize); } let compressor = builder.build(); - let built_symbols: &[u64] = unsafe { mem::transmute(&compressor.symbol_table()[0..compressor.n_symbols()]) }; + let built_symbols: &[u64] = + unsafe { mem::transmute(&compressor.symbol_table()[0..compressor.n_symbols()]) }; assert_eq!(built_symbols, symbols); } }