Skip to content

Commit

Permalink
Improve read_bytes*
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexagon committed Oct 9, 2024
1 parent 8db4c8c commit 249a1b4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "tl-proto"
description = "A collection of traits for working with TL serialization/deserialization"
authors = ["Ivan Kalinin <[email protected]>"]
repository = "https://github.com/broxus/tl-proto"
version = "0.4.9"
version = "0.4.10"
edition = "2021"
include = ["src/**/*.rs", "README.md"]
license = "MIT"
Expand Down
56 changes: 44 additions & 12 deletions src/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,17 +399,33 @@ impl<'a, const N: usize> TlRead<'a> for &'a BoundedBytes<N> {
max_len: usize,
offset: &mut usize,
) -> TlResult<&'a [u8]> {
let current_offset = *offset;
let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet, current_offset));
if len > max_len {
let packet_len = packet.len();
if unlikely(packet_len + 4 <= *offset) {
return Err(TlError::UnexpectedEof);
}

let first_bytes =
unsafe { packet.as_ptr().add(*offset).cast::<u32>().read_unaligned() };
let (len, have_read) = if first_bytes & 0xff != SIZE_MAGIC as u32 {
((first_bytes & 0xff) as usize, 1)
} else {
((first_bytes >> 8) as usize, 4)
};

if unlikely(len > max_len) {
return Err(TlError::InvalidData);
}

let padding = (4 - (have_read + len) % 4) % 4;
if unlikely(packet_len < *offset + have_read + len + padding) {
return Err(TlError::UnexpectedEof);
}

let result = unsafe {
std::slice::from_raw_parts(packet.as_ptr().add(current_offset + prefix_len), len)
std::slice::from_raw_parts(packet.as_ptr().add(*offset + have_read), len)
};

*offset += prefix_len + len + padding;
*offset += have_read + len + padding;
Ok(result)
}

Expand Down Expand Up @@ -666,8 +682,9 @@ fn read_fixed_bytes<'a, const N: usize>(
}
}

/// Computes the number of bytes required to encode the `[u8]` of the specified length.
#[inline(always)]
fn bytes_max_size_hint(mut len: usize) -> usize {
pub const fn bytes_max_size_hint(mut len: usize) -> usize {
if len < 254 {
len += 1;
} else {
Expand Down Expand Up @@ -712,14 +729,27 @@ where

#[inline(always)]
fn read_bytes<'a>(packet: &'a [u8], offset: &mut usize) -> TlResult<&'a [u8]> {
let current_offset = *offset;
let (prefix_len, len, padding) = ok!(compute_bytes_meta(packet, current_offset));
let packet_len = packet.len();
if unlikely(packet_len + 4 <= *offset) {
return Err(TlError::UnexpectedEof);
}

let result = unsafe {
std::slice::from_raw_parts(packet.as_ptr().add(current_offset + prefix_len), len)
let first_bytes = unsafe { packet.as_ptr().add(*offset).cast::<u32>().read_unaligned() };
let (len, have_read) = if first_bytes & 0xff != SIZE_MAGIC as u32 {
((first_bytes & 0xff) as usize, 1)
} else {
((first_bytes >> 8) as usize, 4)
};

*offset += prefix_len + len + padding;
let padding = (4 - (have_read + len) % 4) % 4;
if unlikely(packet_len < *offset + have_read + len + padding) {
return Err(TlError::UnexpectedEof);
}

let result =
unsafe { std::slice::from_raw_parts(packet.as_ptr().add(*offset + have_read), len) };

*offset += have_read + len + padding;
Ok(result)
}

Expand All @@ -737,7 +767,7 @@ fn compute_bytes_meta(packet: &[u8], offset: usize) -> TlResult<(usize, usize, u
// SAFETY: `current_offset` is guaranteed to be less than `packet_len`
// but the compiler is not able to eliminate bounds check
let first_bytes = unsafe { *packet.get_unchecked(offset) };
let (len, have_read) = if first_bytes != 254 {
let (len, have_read) = if first_bytes != SIZE_MAGIC {
(first_bytes as usize, 1)
} else {
if unlikely(packet_len <= offset + 3) {
Expand All @@ -764,3 +794,5 @@ fn compute_bytes_meta(packet: &[u8], offset: usize) -> TlResult<(usize, usize, u

Ok((have_read, len, padding))
}

const SIZE_MAGIC: u8 = 254;
4 changes: 2 additions & 2 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ impl TlPacket for bytes::BytesMut {
/// A wrapper type for writing to [`std::io::Write`] types.
///
/// Ignores all errors afther the first one.
/// The status can be retrieved using [`Writer::into_parts`].
/// The status can be retrieved using [`IoWriter::into_parts`].
pub struct IoWriter<W> {
writer: W,
status: std::io::Result<()>,
Expand All @@ -234,7 +234,7 @@ impl<W> IoWriter<W> {
&self.writer
}

/// Disassembles the [`Writer<W>`], returning the underlying writer, and the status.
/// Disassembles the [`IoWriter<W>`], returning the underlying writer, and the status.
pub fn into_parts(self) -> (W, std::io::Result<()>) {
(self.writer, self.status)
}
Expand Down

0 comments on commit 249a1b4

Please sign in to comment.