Skip to content

Commit

Permalink
Speed-up html-escaping using jetscii
Browse files Browse the repository at this point in the history
```text
$ cargo bench --bench escape

Before the PR:           [3.6464 µs 3.6512 µs 3.6564 µs]
Impl. without `jetscii`: [3.4837 µs 3.4899 µs 3.4968 µs]
Impl with `jetscii`:     [2.0264 µs 2.0335 µs 2.0418 µs]
```

Until portable SIMD gets stabilized, I don't think we can do much for
non-X86 platforms. And even after it is stabilized, I guess any
optimizations should be implemented upstream in memchr and/or jetscii.
  • Loading branch information
Kijewski committed Jul 28, 2024
1 parent bf03e44 commit fe8750f
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 42 deletions.
3 changes: 3 additions & 0 deletions rinja/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ percent-encoding = { version = "2.1.0", optional = true }
serde = { version = "1.0", optional = true }
serde_json = { version = "1.0", optional = true }

[target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies]
jetscii = "0.5.3"

[dev-dependencies]
criterion = "0.5"

Expand Down
216 changes: 174 additions & 42 deletions rinja/src/html.rs
Original file line number Diff line number Diff line change
@@ -1,71 +1,203 @@
use std::fmt;
use std::num::NonZeroU8;
use std::{fmt, str};

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
#[allow(unused)]
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result {
let mut escaped_buf = *b"&#__;";
// Even though [`jetscii`] ships a generic implementation for unsupported platforms,
// it is not well optimized for this case. This implementation should work well enough in
// the meantime, until portable SIMD gets stabilized.

// Instead of testing the platform, we could test the CPU features. But given that the needed
// instruction set SSE 4.2 was introduced in 2008, that it has an 99.61 % availability rate
// in Steam's June 2024 hardware survey, and is a prerequisite to run Windows 11, I don't
// think we need to care.

let mut escaped_buf = ESCAPED_BUF_INIT;
let mut last = 0;

for (index, byte) in string.bytes().enumerate() {
let escaped = match byte {
MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
_ => None,
_ => 0,
};
if let Some(escaped) = escaped {
escaped_buf[2] = escaped[0].get();
escaped_buf[3] = escaped[1].get();
fmt.write_str(&string[last..index])?;
fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
if escaped != 0 {
[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();
write_str_if_nonempty(&mut fmt, &string[last..index])?;
// SAFETY: the content of `escaped_buf` is pure ASCII
fmt.write_str(unsafe {
std::str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN])
})?;
last = index + 1;
}
}
fmt.write_str(&string[last..])
write_str_if_nonempty(&mut fmt, &string[last..])
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[allow(unused)]
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, mut string: &str) -> fmt::Result {
let jetscii = jetscii::bytes!(b'"', b'&', b'\'', b'<', b'>');

let mut escaped_buf = ESCAPED_BUF_INIT;
loop {
if string.is_empty() {
return Ok(());
}

let found = if string.len() >= 16 {
// Only strings of at least 16 bytes can be escaped using SSE instructions.
match jetscii.find(string.as_bytes()) {
Some(index) => {
let escaped = TABLE.lookup[(string.as_bytes()[index] - MIN_CHAR) as usize];
Some((index, escaped))
}
None => None,
}
} else {
// The small-string fallback of [`jetscii`] is quite slow, so we roll our own
// implementation.
string.as_bytes().iter().find_map(|byte: &u8| {
let escaped = get_escaped(*byte)?;
let index = (byte as *const u8 as usize) - (string.as_ptr() as usize);
Some((index, escaped))
})
};
let Some((index, escaped)) = found else {
return fmt.write_str(string);
};

[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();

// SAFETY: index points at an ASCII char in `string`
let front;
(front, string) = unsafe {
(
string.get_unchecked(..index),
string.get_unchecked(index + 1..),
)
};

write_str_if_nonempty(&mut fmt, front)?;
// SAFETY: the content of `escaped_buf` is pure ASCII
fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })?;
}
}

#[allow(unused)]
pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result {
fmt.write_str(match (c.is_ascii(), c as u8) {
(true, b'"') => "&#34;",
(true, b'&') => "&#38;",
(true, b'\'') => "&#39;",
(true, b'<') => "&#60;",
(true, b'>') => "&#62;",
_ => return fmt.write_char(c),
})
if !c.is_ascii() {
fmt.write_char(c)
} else if let Some(escaped) = get_escaped(c as u8) {
let mut escaped_buf = ESCAPED_BUF_INIT;
[escaped_buf[2], escaped_buf[3]] = escaped.to_ne_bytes();
// SAFETY: the content of `escaped_buf` is pure ASCII
fmt.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })
} else {
// RATIONALE: `write_char(c)` gets optimized if it is known that `c.is_ascii()`
fmt.write_char(c)
}
}

const MIN_CHAR: u8 = b'"';
const MAX_CHAR: u8 = b'>';
#[inline(always)]
fn get_escaped(byte: u8) -> Option<u16> {
let c = byte.wrapping_sub(MIN_CHAR);
if (c < u32::BITS as u8) && (BITS & (1 << c as u32) != 0) {
Some(TABLE.lookup[c as usize])
} else {
None
}
}

struct Table {
_align: [usize; 0],
lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
#[inline(always)]
fn write_str_if_nonempty(output: &mut impl fmt::Write, input: &str) -> fmt::Result {
if !input.is_empty() {
output.write_str(input)
} else {
Ok(())
}
}

const TABLE: Table = {
const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
assert!(MIN_CHAR <= c && c <= MAX_CHAR);
/// List of characters that need HTML escaping, not necessarily in ordinal order.
/// Filling the [`TABLE`] and [`BITS`] constants will ensure that the range of lowest to hightest
/// codepoint wont exceed [`u32::BITS`] (=32) items.
const CHARS: &[u8] = br#""&'<>"#;

let n0 = match NonZeroU8::new(c / 10 + b'0') {
Some(n) => n,
None => panic!(),
};
let n1 = match NonZeroU8::new(c % 10 + b'0') {
Some(n) => n,
None => panic!(),
};
Some([n0, n1])
/// The character with the smallest codepoint that needs HTML escaping.
/// Both [`TABLE`] and [`BITS`] start at this value instead of `0`.
const MIN_CHAR: u8 = {
let mut v = u8::MAX;
let mut i = 0;
while i < CHARS.len() {
if v > CHARS[i] {
v = CHARS[i];
}
i += 1;
}
v
};

#[allow(unused)]
const MAX_CHAR: u8 = {
let mut v = u8::MIN;
let mut i = 0;
while i < CHARS.len() {
if v < CHARS[i] {
v = CHARS[i];
}
i += 1;
}
v
};

struct Table {
_align: [usize; 0],
lookup: [u16; u32::BITS as usize],
}

/// For characters that need HTML escaping, the codepoint formatted as decimal digits,
/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`].
const TABLE: Table = {
let mut table = Table {
_align: [],
lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
lookup: [0; u32::BITS as usize],
};

table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
let mut i = 0;
while i < CHARS.len() {
let c = CHARS[i];
let h = c / 10 + b'0';
let l = c % 10 + b'0';
table.lookup[(c - MIN_CHAR) as usize] = u16::from_ne_bytes([h, l]);
i += 1;
}
table
};

/// A bitset of the characters that need escaping, starting at [`MIN_CHAR`]
const BITS: u32 = {
let mut i = 0;
let mut bits = 0;
while i < CHARS.len() {
bits |= 1 << (CHARS[i] - MIN_CHAR) as u32;
i += 1;
}
bits
};

// RATIONALE: llvm generates better code if the buffer is register sized
const ESCAPED_BUF_INIT: [u8; 8] = *b"&#__;\0\0\0";
const ESCAPED_BUF_LEN: usize = b"&#__;".len();

#[test]
fn simple() {
let mut buf = String::new();
write_escaped_str(&mut buf, "<script>").unwrap();
assert_eq!(buf, "&#60;script&#62;");

buf.clear();
write_escaped_str(&mut buf, "s<crip>t").unwrap();
assert_eq!(buf, "s&#60;crip&#62;t");

buf.clear();
write_escaped_str(&mut buf, "s<cripcripcripcripcripcripcripcripcripcrip>t").unwrap();
assert_eq!(buf, "s&#60;cripcripcripcripcripcripcripcripcripcrip&#62;t");
}
3 changes: 3 additions & 0 deletions rinja_derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ quote = "1"
serde = { version = "1.0", optional = true, features = ["derive"] }
syn = "2.0.3"

[target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies]
jetscii = "0.5.3"

[dev-dependencies]
console = "0.15.8"
similar = "2.6.0"
Expand Down
3 changes: 3 additions & 0 deletions rinja_derive_standalone/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ quote = "1"
serde = { version = "1.0", optional = true, features = ["derive"] }
syn = "2"

[target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies]
jetscii = "0.5.3"

[dev-dependencies]
criterion = "0.5"

Expand Down

0 comments on commit fe8750f

Please sign in to comment.