Skip to content

Commit 7559549

Browse files
authoredFeb 26, 2023
Rollup merge of rust-lang#107110 - strega-nil:mbtwc-wctmb, r=ChrisDenton
[stdio][windows] Use MBTWC and WCTMB `MultiByteToWideChar` and `WideCharToMultiByte` are extremely well optimized, and therefore should probably be used when we know we can (specifically in the Windows stdio stuff). Fixes rust-lang#107092
2 parents e6f7f29 + 7f25580 commit 7559549

File tree

3 files changed

+78
-29
lines changed

3 files changed

+78
-29
lines changed
 

‎library/std/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@
232232
all(target_vendor = "fortanix", target_env = "sgx"),
233233
feature(slice_index_methods, coerce_unsized, sgx_platform)
234234
)]
235+
#![cfg_attr(windows, feature(round_char_boundary))]
235236
//
236237
// Language features:
237238
#![feature(alloc_error_handler)]

‎library/std/src/sys/windows/c.rs

+30-2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
use crate::ffi::CStr;
88
use crate::mem;
9-
use crate::os::raw::{c_char, c_int, c_long, c_longlong, c_uint, c_ulong, c_ushort};
9+
use crate::os::raw::{c_char, c_long, c_longlong, c_uint, c_ulong, c_ushort};
1010
use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull};
1111
use crate::ptr;
1212
use core::ffi::NonZero_c_ulong;
1313

1414
use libc::{c_void, size_t, wchar_t};
1515

16+
pub use crate::os::raw::c_int;
17+
1618
#[path = "c/errors.rs"] // c.rs is included from two places so we need to specify this
1719
mod errors;
1820
pub use errors::*;
@@ -47,16 +49,19 @@ pub type ACCESS_MASK = DWORD;
4749

4850
pub type LPBOOL = *mut BOOL;
4951
pub type LPBYTE = *mut BYTE;
52+
pub type LPCCH = *const CHAR;
5053
pub type LPCSTR = *const CHAR;
54+
pub type LPCWCH = *const WCHAR;
5155
pub type LPCWSTR = *const WCHAR;
56+
pub type LPCVOID = *const c_void;
5257
pub type LPDWORD = *mut DWORD;
5358
pub type LPHANDLE = *mut HANDLE;
5459
pub type LPOVERLAPPED = *mut OVERLAPPED;
5560
pub type LPPROCESS_INFORMATION = *mut PROCESS_INFORMATION;
5661
pub type LPSECURITY_ATTRIBUTES = *mut SECURITY_ATTRIBUTES;
5762
pub type LPSTARTUPINFO = *mut STARTUPINFO;
63+
pub type LPSTR = *mut CHAR;
5864
pub type LPVOID = *mut c_void;
59-
pub type LPCVOID = *const c_void;
6065
pub type LPWCH = *mut WCHAR;
6166
pub type LPWIN32_FIND_DATAW = *mut WIN32_FIND_DATAW;
6267
pub type LPWSADATA = *mut WSADATA;
@@ -132,6 +137,10 @@ pub const MAX_PATH: usize = 260;
132137

133138
pub const FILE_TYPE_PIPE: u32 = 3;
134139

140+
pub const CP_UTF8: DWORD = 65001;
141+
pub const MB_ERR_INVALID_CHARS: DWORD = 0x08;
142+
pub const WC_ERR_INVALID_CHARS: DWORD = 0x80;
143+
135144
#[repr(C)]
136145
#[derive(Copy)]
137146
pub struct WIN32_FIND_DATAW {
@@ -1147,6 +1156,25 @@ extern "system" {
11471156
lpFilePart: *mut LPWSTR,
11481157
) -> DWORD;
11491158
pub fn GetFileAttributesW(lpFileName: LPCWSTR) -> DWORD;
1159+
1160+
pub fn MultiByteToWideChar(
1161+
CodePage: UINT,
1162+
dwFlags: DWORD,
1163+
lpMultiByteStr: LPCCH,
1164+
cbMultiByte: c_int,
1165+
lpWideCharStr: LPWSTR,
1166+
cchWideChar: c_int,
1167+
) -> c_int;
1168+
pub fn WideCharToMultiByte(
1169+
CodePage: UINT,
1170+
dwFlags: DWORD,
1171+
lpWideCharStr: LPCWCH,
1172+
cchWideChar: c_int,
1173+
lpMultiByteStr: LPSTR,
1174+
cbMultiByte: c_int,
1175+
lpDefaultChar: LPCCH,
1176+
lpUsedDefaultChar: LPBOOL,
1177+
) -> c_int;
11501178
}
11511179

11521180
#[link(name = "ws2_32")]

‎library/std/src/sys/windows/stdio.rs

+47-27
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,27 @@ fn write(
169169
}
170170

171171
fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usize> {
172+
debug_assert!(!utf8.is_empty());
173+
172174
let mut utf16 = [MaybeUninit::<u16>::uninit(); MAX_BUFFER_SIZE / 2];
173-
let mut len_utf16 = 0;
174-
for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) {
175-
*dest = MaybeUninit::new(chr);
176-
len_utf16 += 1;
177-
}
178-
// Safety: We've initialized `len_utf16` values.
179-
let utf16: &[u16] = unsafe { MaybeUninit::slice_assume_init_ref(&utf16[..len_utf16]) };
175+
let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len())];
176+
177+
let utf16: &[u16] = unsafe {
178+
// Note that this theoretically checks validity twice in the (most common) case
179+
// where the underlying byte sequence is valid utf-8 (given the check in `write()`).
180+
let result = c::MultiByteToWideChar(
181+
c::CP_UTF8, // CodePage
182+
c::MB_ERR_INVALID_CHARS, // dwFlags
183+
utf8.as_ptr() as c::LPCCH, // lpMultiByteStr
184+
utf8.len() as c::c_int, // cbMultiByte
185+
utf16.as_mut_ptr() as c::LPWSTR, // lpWideCharStr
186+
utf16.len() as c::c_int, // cchWideChar
187+
);
188+
assert!(result != 0, "Unexpected error in MultiByteToWideChar");
189+
190+
// Safety: MultiByteToWideChar initializes `result` values.
191+
MaybeUninit::slice_assume_init_ref(&utf16[..result as usize])
192+
};
180193

181194
let mut written = write_u16s(handle, &utf16)?;
182195

@@ -189,8 +202,8 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
189202
// a missing surrogate can be produced (and also because of the UTF-8 validation above),
190203
// write the missing surrogate out now.
191204
// Buffering it would mean we have to lie about the number of bytes written.
192-
let first_char_remaining = utf16[written];
193-
if first_char_remaining >= 0xDCEE && first_char_remaining <= 0xDFFF {
205+
let first_code_unit_remaining = utf16[written];
206+
if first_code_unit_remaining >= 0xDCEE && first_code_unit_remaining <= 0xDFFF {
194207
// low surrogate
195208
// We just hope this works, and give up otherwise
196209
let _ = write_u16s(handle, &utf16[written..written + 1]);
@@ -212,6 +225,7 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result<usiz
212225
}
213226

214227
fn write_u16s(handle: c::HANDLE, data: &[u16]) -> io::Result<usize> {
228+
debug_assert!(data.len() < u32::MAX as usize);
215229
let mut written = 0;
216230
cvt(unsafe {
217231
c::WriteConsoleW(
@@ -365,26 +379,32 @@ fn read_u16s(handle: c::HANDLE, buf: &mut [MaybeUninit<u16>]) -> io::Result<usiz
365379
Ok(amount as usize)
366380
}
367381

368-
#[allow(unused)]
369382
fn utf16_to_utf8(utf16: &[u16], utf8: &mut [u8]) -> io::Result<usize> {
370-
let mut written = 0;
371-
for chr in char::decode_utf16(utf16.iter().cloned()) {
372-
match chr {
373-
Ok(chr) => {
374-
chr.encode_utf8(&mut utf8[written..]);
375-
written += chr.len_utf8();
376-
}
377-
Err(_) => {
378-
// We can't really do any better than forget all data and return an error.
379-
return Err(io::const_io_error!(
380-
io::ErrorKind::InvalidData,
381-
"Windows stdin in console mode does not support non-UTF-16 input; \
382-
encountered unpaired surrogate",
383-
));
384-
}
385-
}
383+
debug_assert!(utf16.len() <= c::c_int::MAX as usize);
384+
debug_assert!(utf8.len() <= c::c_int::MAX as usize);
385+
386+
let result = unsafe {
387+
c::WideCharToMultiByte(
388+
c::CP_UTF8, // CodePage
389+
c::WC_ERR_INVALID_CHARS, // dwFlags
390+
utf16.as_ptr(), // lpWideCharStr
391+
utf16.len() as c::c_int, // cchWideChar
392+
utf8.as_mut_ptr() as c::LPSTR, // lpMultiByteStr
393+
utf8.len() as c::c_int, // cbMultiByte
394+
ptr::null(), // lpDefaultChar
395+
ptr::null_mut(), // lpUsedDefaultChar
396+
)
397+
};
398+
if result == 0 {
399+
// We can't really do any better than forget all data and return an error.
400+
Err(io::const_io_error!(
401+
io::ErrorKind::InvalidData,
402+
"Windows stdin in console mode does not support non-UTF-16 input; \
403+
encountered unpaired surrogate",
404+
))
405+
} else {
406+
Ok(result as usize)
386407
}
387-
Ok(written)
388408
}
389409

390410
impl IncompleteUtf8 {

0 commit comments

Comments
 (0)
Please sign in to comment.