Skip to content

Commit 7504473

Browse files
committed
Move OsStr::slice_encoded_bytes validation to platform modules
On Unix this opens the possibility of removing the checks later. On other platforms this improves performance and error messaging.
1 parent 24067a6 commit 7504473

File tree

8 files changed

+225
-47
lines changed

8 files changed

+225
-47
lines changed

library/std/src/ffi/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@
127127
//! trait, which provides a [`from_wide`] method to convert a native Windows
128128
//! string (without the terminating nul character) to an [`OsString`].
129129
//!
130+
//! ## Other platforms
131+
//!
132+
//! Many other platforms provide their own extension traits in a
133+
//! `std::os::*::ffi` module.
134+
//!
130135
//! ## On all platforms
131136
//!
132137
//! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of
@@ -135,6 +140,8 @@
135140
//! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and
136141
//! [`OsStr::from_encoded_bytes_unchecked`].
137142
//!
143+
//! For basic string processing, see [`OsStr::slice_encoded_bytes`].
144+
//!
138145
//! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value
139146
//! [Unicode code point]: https://www.unicode.org/glossary/#code_point
140147
//! [`env::set_var()`]: crate::env::set_var "env::set_var"

library/std/src/ffi/os_str.rs

+4-35
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher};
99
use crate::ops::{self, Range};
1010
use crate::rc::Rc;
1111
use crate::slice;
12-
use crate::str::{from_utf8 as str_from_utf8, FromStr};
12+
use crate::str::FromStr;
1313
use crate::sync::Arc;
1414

1515
use crate::sys::os_str::{Buf, Slice};
@@ -997,42 +997,11 @@ impl OsStr {
997997
/// ```
998998
#[unstable(feature = "os_str_slice", issue = "118485")]
999999
pub fn slice_encoded_bytes<R: ops::RangeBounds<usize>>(&self, range: R) -> &Self {
1000-
#[track_caller]
1001-
fn check_valid_boundary(bytes: &[u8], index: usize) {
1002-
if index == 0 || index == bytes.len() {
1003-
return;
1004-
}
1005-
1006-
// Fast path
1007-
if bytes[index - 1].is_ascii() || bytes[index].is_ascii() {
1008-
return;
1009-
}
1010-
1011-
let (before, after) = bytes.split_at(index);
1012-
1013-
// UTF-8 takes at most 4 bytes per codepoint, so we don't
1014-
// need to check more than that.
1015-
let after = after.get(..4).unwrap_or(after);
1016-
match str_from_utf8(after) {
1017-
Ok(_) => return,
1018-
Err(err) if err.valid_up_to() != 0 => return,
1019-
Err(_) => (),
1020-
}
1021-
1022-
for len in 2..=4.min(index) {
1023-
let before = &before[index - len..];
1024-
if str_from_utf8(before).is_ok() {
1025-
return;
1026-
}
1027-
}
1028-
1029-
panic!("byte index {index} is not an OsStr boundary");
1030-
}
1031-
10321000
let encoded_bytes = self.as_encoded_bytes();
10331001
let Range { start, end } = slice::range(range, ..encoded_bytes.len());
1034-
check_valid_boundary(encoded_bytes, start);
1035-
check_valid_boundary(encoded_bytes, end);
1002+
1003+
self.inner.check_public_boundary(start);
1004+
self.inner.check_public_boundary(end);
10361005

10371006
// SAFETY: `slice::range` ensures that `start` and `end` are valid
10381007
let slice = unsafe { encoded_bytes.get_unchecked(start..end) };

library/std/src/ffi/os_str/tests.rs

+61-7
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,65 @@ fn slice_encoded_bytes() {
194194
}
195195

196196
#[test]
197-
#[should_panic(expected = "byte index 2 is not an OsStr boundary")]
197+
#[should_panic]
198+
fn slice_out_of_bounds() {
199+
let crab = OsStr::new("🦀");
200+
let _ = crab.slice_encoded_bytes(..5);
201+
}
202+
203+
#[test]
204+
#[should_panic]
198205
fn slice_mid_char() {
199206
let crab = OsStr::new("🦀");
200207
let _ = crab.slice_encoded_bytes(..2);
201208
}
202209

210+
#[cfg(unix)]
211+
#[test]
212+
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
213+
fn slice_invalid_data() {
214+
use crate::os::unix::ffi::OsStrExt;
215+
216+
let os_string = OsStr::from_bytes(b"\xFF\xFF");
217+
let _ = os_string.slice_encoded_bytes(1..);
218+
}
219+
220+
#[cfg(unix)]
221+
#[test]
222+
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
223+
fn slice_partial_utf8() {
224+
use crate::os::unix::ffi::{OsStrExt, OsStringExt};
225+
226+
let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]);
227+
let mut os_string = OsString::from_vec(vec![0xFF]);
228+
os_string.push(part_crab);
229+
let _ = os_string.slice_encoded_bytes(1..);
230+
}
231+
232+
#[cfg(unix)]
233+
#[test]
234+
fn slice_invalid_edge() {
235+
use crate::os::unix::ffi::{OsStrExt, OsStringExt};
236+
237+
let os_string = OsStr::from_bytes(b"a\xFFa");
238+
assert_eq!(os_string.slice_encoded_bytes(..1), "a");
239+
assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa"));
240+
assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF"));
241+
assert_eq!(os_string.slice_encoded_bytes(2..), "a");
242+
243+
let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]);
244+
assert_eq!(os_string.slice_encoded_bytes(..3), "abc");
245+
assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6"));
246+
247+
let mut os_string = OsString::from_vec(vec![0xFF]);
248+
os_string.push("🦀");
249+
assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF"));
250+
assert_eq!(os_string.slice_encoded_bytes(1..), "🦀");
251+
}
252+
203253
#[cfg(windows)]
204254
#[test]
205-
#[should_panic(expected = "byte index 3 is not an OsStr boundary")]
255+
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
206256
fn slice_between_surrogates() {
207257
use crate::os::windows::ffi::OsStringExt;
208258

@@ -216,10 +266,14 @@ fn slice_between_surrogates() {
216266
fn slice_surrogate_edge() {
217267
use crate::os::windows::ffi::OsStringExt;
218268

219-
let os_string = OsString::from_wide(&[0xD800]);
220-
let mut with_crab = os_string.clone();
221-
with_crab.push("🦀");
269+
let surrogate = OsString::from_wide(&[0xD800]);
270+
let mut pre_crab = surrogate.clone();
271+
pre_crab.push("🦀");
272+
assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate);
273+
assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀");
222274

223-
assert_eq!(with_crab.slice_encoded_bytes(..3), os_string);
224-
assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀");
275+
let mut post_crab = OsString::from("🦀");
276+
post_crab.push(&surrogate);
277+
assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀");
278+
assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate);
225279
}

library/std/src/sys/pal/unix/os_str.rs

+43
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,49 @@ impl Slice {
212212
unsafe { mem::transmute(s) }
213213
}
214214

215+
#[track_caller]
216+
#[inline]
217+
pub fn check_public_boundary(&self, index: usize) {
218+
if index == 0 || index == self.inner.len() {
219+
return;
220+
}
221+
if index < self.inner.len()
222+
&& (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii())
223+
{
224+
return;
225+
}
226+
227+
slow_path(&self.inner, index);
228+
229+
/// We're betting that typical splits will involve an ASCII character.
230+
///
231+
/// Putting the expensive checks in a separate function generates notably
232+
/// better assembly.
233+
#[track_caller]
234+
#[inline(never)]
235+
fn slow_path(bytes: &[u8], index: usize) {
236+
let (before, after) = bytes.split_at(index);
237+
238+
// UTF-8 takes at most 4 bytes per codepoint, so we don't
239+
// need to check more than that.
240+
let after = after.get(..4).unwrap_or(after);
241+
match str::from_utf8(after) {
242+
Ok(_) => return,
243+
Err(err) if err.valid_up_to() != 0 => return,
244+
Err(_) => (),
245+
}
246+
247+
for len in 2..=4.min(index) {
248+
let before = &before[index - len..];
249+
if str::from_utf8(before).is_ok() {
250+
return;
251+
}
252+
}
253+
254+
panic!("byte index {index} is not an OsStr boundary");
255+
}
256+
}
257+
215258
#[inline]
216259
pub fn from_str(s: &str) -> &Slice {
217260
unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) }

library/std/src/sys/pal/unsupported/os_str.rs

+10
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,16 @@ impl Slice {
197197
unsafe { mem::transmute(s) }
198198
}
199199

200+
#[inline]
201+
pub fn check_public_boundary(&self, index: usize) {
202+
// We need to check that self.inner.is_char_boundary(index).
203+
// If we delegate that to the Index impl then we'll get a nice panic
204+
// message courtesy of slice_error_fail.
205+
// (Ideally we'd use slice_error_fail directly with #[track_caller]
206+
// but it isn't exported.)
207+
let _ = &self.inner[..index];
208+
}
209+
200210
#[inline]
201211
pub fn from_str(s: &str) -> &Slice {
202212
unsafe { mem::transmute(s) }

library/std/src/sys/pal/windows/os_str.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::fmt;
66
use crate::mem;
77
use crate::rc::Rc;
88
use crate::sync::Arc;
9-
use crate::sys_common::wtf8::{Wtf8, Wtf8Buf};
9+
use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf};
1010
use crate::sys_common::{AsInner, FromInner, IntoInner};
1111

1212
#[derive(Clone, Hash)]
@@ -171,6 +171,11 @@ impl Slice {
171171
mem::transmute(Wtf8::from_bytes_unchecked(s))
172172
}
173173

174+
#[track_caller]
175+
pub fn check_public_boundary(&self, index: usize) {
176+
check_utf8_boundary(&self.inner, index);
177+
}
178+
174179
#[inline]
175180
pub fn from_str(s: &str) -> &Slice {
176181
unsafe { mem::transmute(Wtf8::from_str(s)) }

library/std/src/sys_common/wtf8.rs

+32-4
Original file line numberDiff line numberDiff line change
@@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
885885
unsafe { char::from_u32_unchecked(code_point) }
886886
}
887887

888-
/// Copied from core::str::StrPrelude::is_char_boundary
888+
/// Copied from str::is_char_boundary
889889
#[inline]
890890
pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {
891-
if index == slice.len() {
891+
if index == 0 {
892892
return true;
893893
}
894894
match slice.bytes.get(index) {
895-
None => false,
896-
Some(&b) => b < 128 || b >= 192,
895+
None => index == slice.len(),
896+
Some(&b) => (b as i8) >= -0x40,
897+
}
898+
}
899+
900+
/// Verify that `index` is at the edge of either a valid UTF-8 codepoint
901+
/// or of the whole string.
902+
///
903+
/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`.
904+
/// Splitting between surrogates is valid as far as WTF-8 is concerned, but
905+
/// we do not permit it in the public API because WTF-8 is considered an
906+
/// implementation detail.
907+
#[track_caller]
908+
#[inline]
909+
pub fn check_utf8_boundary(slice: &Wtf8, index: usize) {
910+
if index == 0 {
911+
return;
912+
}
913+
match slice.bytes.get(index) {
914+
Some(0xED) => (), // Might be a surrogate
915+
Some(&b) if (b as i8) >= -0x40 => return,
916+
Some(_) => panic!("byte index {index} is not a codepoint boundary"),
917+
None if index == slice.len() => return,
918+
None => panic!("byte index {index} is out of bounds"),
919+
}
920+
if slice.bytes[index + 1] >= 0xA0 {
921+
// There's a surrogate after index. Now check before index.
922+
if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 {
923+
panic!("byte index {index} lies between surrogate codepoints");
924+
}
897925
}
898926
}
899927

library/std/src/sys_common/wtf8/tests.rs

+62
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,65 @@ fn wtf8_to_owned() {
663663
assert_eq!(string.bytes, b"\xED\xA0\x80");
664664
assert!(!string.is_known_utf8);
665665
}
666+
667+
#[test]
668+
fn wtf8_valid_utf8_boundaries() {
669+
let mut string = Wtf8Buf::from_str("aé 💩");
670+
string.push(CodePoint::from_u32(0xD800).unwrap());
671+
string.push(CodePoint::from_u32(0xD800).unwrap());
672+
check_utf8_boundary(&string, 0);
673+
check_utf8_boundary(&string, 1);
674+
check_utf8_boundary(&string, 3);
675+
check_utf8_boundary(&string, 4);
676+
check_utf8_boundary(&string, 8);
677+
check_utf8_boundary(&string, 14);
678+
assert_eq!(string.len(), 14);
679+
680+
string.push_char('a');
681+
check_utf8_boundary(&string, 14);
682+
check_utf8_boundary(&string, 15);
683+
684+
let mut string = Wtf8Buf::from_str("a");
685+
string.push(CodePoint::from_u32(0xD800).unwrap());
686+
check_utf8_boundary(&string, 1);
687+
688+
let mut string = Wtf8Buf::from_str("\u{D7FF}");
689+
string.push(CodePoint::from_u32(0xD800).unwrap());
690+
check_utf8_boundary(&string, 3);
691+
692+
let mut string = Wtf8Buf::new();
693+
string.push(CodePoint::from_u32(0xD800).unwrap());
694+
string.push_char('\u{D7FF}');
695+
check_utf8_boundary(&string, 3);
696+
}
697+
698+
#[test]
699+
#[should_panic(expected = "byte index 4 is out of bounds")]
700+
fn wtf8_utf8_boundary_out_of_bounds() {
701+
let string = Wtf8::from_str("aé");
702+
check_utf8_boundary(&string, 4);
703+
}
704+
705+
#[test]
706+
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
707+
fn wtf8_utf8_boundary_inside_codepoint() {
708+
let string = Wtf8::from_str("é");
709+
check_utf8_boundary(&string, 1);
710+
}
711+
712+
#[test]
713+
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
714+
fn wtf8_utf8_boundary_inside_surrogate() {
715+
let mut string = Wtf8Buf::new();
716+
string.push(CodePoint::from_u32(0xD800).unwrap());
717+
check_utf8_boundary(&string, 1);
718+
}
719+
720+
#[test]
721+
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
722+
fn wtf8_utf8_boundary_between_surrogates() {
723+
let mut string = Wtf8Buf::new();
724+
string.push(CodePoint::from_u32(0xD800).unwrap());
725+
string.push(CodePoint::from_u32(0xD800).unwrap());
726+
check_utf8_boundary(&string, 3);
727+
}

0 commit comments

Comments
 (0)