diff --git a/Cargo.lock b/Cargo.lock index 8641a916fc3e..022c0278e099 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3149,6 +3149,7 @@ dependencies = [ "cap-rand", "cap-std", "io-extras", + "log", "rustix", "thiserror", "tracing", @@ -3506,6 +3507,7 @@ dependencies = [ "wasmtime-wasi", "wasmtime-wasi-crypto", "wasmtime-wasi-nn", + "wasmtime-wasi-threads", "wasmtime-wast", "wast 52.0.2", "wat", @@ -3746,6 +3748,7 @@ name = "wasmtime-wasi" version = "7.0.0" dependencies = [ "anyhow", + "libc", "wasi-cap-std-sync", "wasi-common", "wasi-tokio", @@ -3774,6 +3777,18 @@ dependencies = [ "wiggle", ] +[[package]] +name = "wasmtime-wasi-threads" +version = "7.0.0" +dependencies = [ + "anyhow", + "log", + "rand 0.8.5", + "wasi-common", + "wasmtime", + "wasmtime-wasi", +] + [[package]] name = "wasmtime-wast" version = "7.0.0" diff --git a/Cargo.toml b/Cargo.toml index 45cb7af61a12..33cf23dd4a0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,13 +28,13 @@ wasmtime-cli-flags = { workspace = true } wasmtime-cranelift = { workspace = true } wasmtime-environ = { workspace = true } wasmtime-wast = { workspace = true } -wasmtime-wasi = { workspace = true } +wasmtime-wasi = { workspace = true, features = ["exit"] } wasmtime-wasi-crypto = { workspace = true, optional = true } wasmtime-wasi-nn = { workspace = true, optional = true } +wasmtime-wasi-threads = { workspace = true, optional = true } clap = { workspace = true, features = ["color", "suggestions", "derive"] } anyhow = { workspace = true } target-lexicon = { workspace = true } -libc = "0.2.60" humantime = "2.0.0" once_cell = { workspace = true } listenfd = "1.0.0" @@ -68,6 +68,7 @@ wasmtime-component-util = { workspace = true } component-macro-test = { path = "crates/misc/component-macro-test" } component-test-util = { workspace = true } bstr = "0.2.17" +libc = "0.2.60" [target.'cfg(windows)'.dev-dependencies] windows-sys = { workspace = true, features = ["Win32_System_Memory"] } @@ -124,6 +125,7 @@ wasmtime-wast = { path = "crates/wast", version = "=7.0.0" } wasmtime-wasi = { path = "crates/wasi", version = "7.0.0" } wasmtime-wasi-crypto = { path = "crates/wasi-crypto", version = "7.0.0" } wasmtime-wasi-nn = { path = "crates/wasi-nn", version = "7.0.0" } +wasmtime-wasi-threads = { path = "crates/wasi-threads", version = "7.0.0" } wasmtime-component-util = { path = "crates/component-util", version = "=7.0.0" } wasmtime-component-macro = { path = "crates/component-macro", version = "=7.0.0" } wasmtime-asm-macros = { path = "crates/asm-macros", version = "=7.0.0" } @@ -205,6 +207,7 @@ jitdump = ["wasmtime/jitdump"] vtune = ["wasmtime/vtune"] wasi-crypto = ["dep:wasmtime-wasi-crypto"] wasi-nn = ["dep:wasmtime-wasi-nn"] +wasi-threads = ["dep:wasmtime-wasi-threads"] pooling-allocator = ["wasmtime/pooling-allocator", "wasmtime-cli-flags/pooling-allocator"] all-arch = ["wasmtime/all-arch"] posix-signals-on-macos = ["wasmtime/posix-signals-on-macos"] diff --git a/ci/run-tests.sh b/ci/run-tests.sh index f81b155f2d4a..db187b779e40 100755 --- a/ci/run-tests.sh +++ b/ci/run-tests.sh @@ -2,6 +2,7 @@ cargo test \ --features "test-programs/test_programs" \ + --features wasi-threads \ --workspace \ --exclude 'wasmtime-wasi-*' \ --exclude wasi-crypto \ diff --git a/crates/cli-flags/src/lib.rs b/crates/cli-flags/src/lib.rs index fb778af2dea5..7e40bff3aa8d 100644 --- a/crates/cli-flags/src/lib.rs +++ b/crates/cli-flags/src/lib.rs @@ -50,13 +50,17 @@ pub const SUPPORTED_WASI_MODULES: &[(&str, &str)] = &[ "wasi-common", "enables support for the WASI common APIs, see https://github.com/WebAssembly/WASI", ), + ( + "experimental-wasi-crypto", + "enables support for the WASI cryptography APIs (experimental), see https://github.com/WebAssembly/wasi-crypto", + ), ( "experimental-wasi-nn", "enables support for the WASI neural network API (experimental), see https://github.com/WebAssembly/wasi-nn", ), ( - "experimental-wasi-crypto", - "enables support for the WASI cryptography APIs (experimental), see https://github.com/WebAssembly/wasi-crypto", + "experimental-wasi-threads", + "enables support for the WASI threading API (experimental), see https://github.com/WebAssembly/wasi-threads", ), ]; @@ -466,8 +470,9 @@ fn parse_wasi_modules(modules: &str) -> Result { let mut set = |module: &str, enable: bool| match module { "" => Ok(()), "wasi-common" => Ok(wasi_modules.wasi_common = enable), - "experimental-wasi-nn" => Ok(wasi_modules.wasi_nn = enable), "experimental-wasi-crypto" => Ok(wasi_modules.wasi_crypto = enable), + "experimental-wasi-nn" => Ok(wasi_modules.wasi_nn = enable), + "experimental-wasi-threads" => Ok(wasi_modules.wasi_threads = enable), "default" => bail!("'default' cannot be specified with other WASI modules"), _ => bail!("unsupported WASI module '{}'", module), }; @@ -494,19 +499,23 @@ pub struct WasiModules { /// parts once the implementation allows for it (e.g. wasi-fs, wasi-clocks, etc.). pub wasi_common: bool, - /// Enable the experimental wasi-nn implementation + /// Enable the experimental wasi-crypto implementation. + pub wasi_crypto: bool, + + /// Enable the experimental wasi-nn implementation. pub wasi_nn: bool, - /// Enable the experimental wasi-crypto implementation - pub wasi_crypto: bool, + /// Enable the experimental wasi-threads implementation. + pub wasi_threads: bool, } impl Default for WasiModules { fn default() -> Self { Self { wasi_common: true, - wasi_nn: false, wasi_crypto: false, + wasi_nn: false, + wasi_threads: false, } } } @@ -518,6 +527,7 @@ impl WasiModules { wasi_common: false, wasi_nn: false, wasi_crypto: false, + wasi_threads: false, } } } @@ -663,8 +673,9 @@ mod test { options.wasi_modules.unwrap(), WasiModules { wasi_common: true, + wasi_crypto: false, wasi_nn: false, - wasi_crypto: false + wasi_threads: false } ); } @@ -676,8 +687,9 @@ mod test { options.wasi_modules.unwrap(), WasiModules { wasi_common: true, + wasi_crypto: false, wasi_nn: false, - wasi_crypto: false + wasi_threads: false } ); } @@ -693,8 +705,9 @@ mod test { options.wasi_modules.unwrap(), WasiModules { wasi_common: false, + wasi_crypto: false, wasi_nn: true, - wasi_crypto: false + wasi_threads: false } ); } @@ -707,8 +720,9 @@ mod test { options.wasi_modules.unwrap(), WasiModules { wasi_common: false, + wasi_crypto: false, wasi_nn: false, - wasi_crypto: false + wasi_threads: false } ); } diff --git a/crates/wasi-common/Cargo.toml b/crates/wasi-common/Cargo.toml index 065a03b52638..3e77c5c8488f 100644 --- a/crates/wasi-common/Cargo.toml +++ b/crates/wasi-common/Cargo.toml @@ -26,6 +26,7 @@ tracing = { workspace = true } cap-std = { workspace = true } cap-rand = { workspace = true } bitflags = { workspace = true } +log = { workspace = true } [target.'cfg(unix)'.dependencies] rustix = { workspace = true, features = ["fs"] } diff --git a/crates/wasi-common/cap-std-sync/src/file.rs b/crates/wasi-common/cap-std-sync/src/file.rs index 8fe395fdaf8b..49a86b8298d2 100644 --- a/crates/wasi-common/cap-std-sync/src/file.rs +++ b/crates/wasi-common/cap-std-sync/src/file.rs @@ -31,24 +31,23 @@ impl WasiFile for File { fn pollable(&self) -> Option { Some(self.0.as_fd()) } - #[cfg(windows)] fn pollable(&self) -> Option { Some(self.0.as_raw_handle_or_socket()) } - async fn datasync(&mut self) -> Result<(), Error> { + async fn datasync(&self) -> Result<(), Error> { self.0.sync_data()?; Ok(()) } - async fn sync(&mut self) -> Result<(), Error> { + async fn sync(&self) -> Result<(), Error> { self.0.sync_all()?; Ok(()) } - async fn get_filetype(&mut self) -> Result { + async fn get_filetype(&self) -> Result { let meta = self.0.metadata()?; Ok(filetype_from(&meta.file_type())) } - async fn get_fdflags(&mut self) -> Result { + async fn get_fdflags(&self) -> Result { let fdflags = get_fd_flags(&self.0)?; Ok(fdflags) } @@ -64,7 +63,7 @@ impl WasiFile for File { self.0.set_fd_flags(set_fd_flags)?; Ok(()) } - async fn get_filestat(&mut self) -> Result { + async fn get_filestat(&self) -> Result { let meta = self.0.metadata()?; Ok(Filestat { device_id: meta.dev(), @@ -77,20 +76,20 @@ impl WasiFile for File { ctim: meta.created().map(|t| Some(t.into_std())).unwrap_or(None), }) } - async fn set_filestat_size(&mut self, size: u64) -> Result<(), Error> { + async fn set_filestat_size(&self, size: u64) -> Result<(), Error> { self.0.set_len(size)?; Ok(()) } - async fn advise(&mut self, offset: u64, len: u64, advice: Advice) -> Result<(), Error> { + async fn advise(&self, offset: u64, len: u64, advice: Advice) -> Result<(), Error> { self.0.advise(offset, len, convert_advice(advice))?; Ok(()) } - async fn allocate(&mut self, offset: u64, len: u64) -> Result<(), Error> { + async fn allocate(&self, offset: u64, len: u64) -> Result<(), Error> { self.0.allocate(offset, len)?; Ok(()) } async fn set_times( - &mut self, + &self, atime: Option, mtime: Option, ) -> Result<(), Error> { @@ -98,41 +97,41 @@ impl WasiFile for File { .set_times(convert_systimespec(atime), convert_systimespec(mtime))?; Ok(()) } - async fn read_vectored<'a>(&mut self, bufs: &mut [io::IoSliceMut<'a>]) -> Result { + async fn read_vectored<'a>(&self, bufs: &mut [io::IoSliceMut<'a>]) -> Result { let n = self.0.read_vectored(bufs)?; Ok(n.try_into()?) } async fn read_vectored_at<'a>( - &mut self, + &self, bufs: &mut [io::IoSliceMut<'a>], offset: u64, ) -> Result { let n = self.0.read_vectored_at(bufs, offset)?; Ok(n.try_into()?) } - async fn write_vectored<'a>(&mut self, bufs: &[io::IoSlice<'a>]) -> Result { + async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result { let n = self.0.write_vectored(bufs)?; Ok(n.try_into()?) } async fn write_vectored_at<'a>( - &mut self, + &self, bufs: &[io::IoSlice<'a>], offset: u64, ) -> Result { let n = self.0.write_vectored_at(bufs, offset)?; Ok(n.try_into()?) } - async fn seek(&mut self, pos: std::io::SeekFrom) -> Result { + async fn seek(&self, pos: std::io::SeekFrom) -> Result { Ok(self.0.seek(pos)?) } - async fn peek(&mut self, buf: &mut [u8]) -> Result { + async fn peek(&self, buf: &mut [u8]) -> Result { let n = self.0.peek(buf)?; Ok(n.try_into()?) } - async fn num_ready_bytes(&self) -> Result { + fn num_ready_bytes(&self) -> Result { Ok(self.0.num_ready_bytes()?) } - fn isatty(&mut self) -> bool { + fn isatty(&self) -> bool { self.0.is_terminal() } } diff --git a/crates/wasi-common/cap-std-sync/src/lib.rs b/crates/wasi-common/cap-std-sync/src/lib.rs index 12f27bcec885..fbaa7bbbd9d9 100644 --- a/crates/wasi-common/cap-std-sync/src/lib.rs +++ b/crates/wasi-common/cap-std-sync/src/lib.rs @@ -94,15 +94,15 @@ impl WasiCtxBuilder { } Ok(self) } - pub fn stdin(mut self, f: Box) -> Self { + pub fn stdin(self, f: Box) -> Self { self.0.set_stdin(f); self } - pub fn stdout(mut self, f: Box) -> Self { + pub fn stdout(self, f: Box) -> Self { self.0.set_stdout(f); self } - pub fn stderr(mut self, f: Box) -> Self { + pub fn stderr(self, f: Box) -> Self { self.0.set_stderr(f); self } @@ -118,12 +118,12 @@ impl WasiCtxBuilder { pub fn inherit_stdio(self) -> Self { self.inherit_stdin().inherit_stdout().inherit_stderr() } - pub fn preopened_dir(mut self, dir: Dir, guest_path: impl AsRef) -> Result { + pub fn preopened_dir(self, dir: Dir, guest_path: impl AsRef) -> Result { let dir = Box::new(crate::dir::Dir::from_cap_std(dir)); self.0.push_preopened_dir(dir, guest_path)?; Ok(self) } - pub fn preopened_socket(mut self, fd: u32, socket: impl Into) -> Result { + pub fn preopened_socket(self, fd: u32, socket: impl Into) -> Result { let socket: Socket = socket.into(); let file: Box = socket.into(); diff --git a/crates/wasi-common/cap-std-sync/src/net.rs b/crates/wasi-common/cap-std-sync/src/net.rs index bdbf507fe48a..c0750cd83e46 100644 --- a/crates/wasi-common/cap-std-sync/src/net.rs +++ b/crates/wasi-common/cap-std-sync/src/net.rs @@ -86,22 +86,21 @@ macro_rules! wasi_listen_write_impl { fn pollable(&self) -> Option { Some(self.0.as_fd()) } - #[cfg(windows)] fn pollable(&self) -> Option { Some(self.0.as_raw_handle_or_socket()) } - async fn sock_accept(&mut self, fdflags: FdFlags) -> Result, Error> { + async fn sock_accept(&self, fdflags: FdFlags) -> Result, Error> { let (stream, _) = self.0.accept()?; let mut stream = <$stream>::from_cap_std(stream); stream.set_fdflags(fdflags).await?; Ok(Box::new(stream)) } - async fn get_filetype(&mut self) -> Result { + async fn get_filetype(&self) -> Result { Ok(FileType::SocketStream) } #[cfg(unix)] - async fn get_fdflags(&mut self) -> Result { + async fn get_fdflags(&self) -> Result { let fdflags = get_fd_flags(&self.0)?; Ok(fdflags) } @@ -117,7 +116,7 @@ macro_rules! wasi_listen_write_impl { } Ok(()) } - async fn num_ready_bytes(&self) -> Result { + fn num_ready_bytes(&self) -> Result { Ok(1) } } @@ -180,16 +179,15 @@ macro_rules! wasi_stream_write_impl { fn pollable(&self) -> Option { Some(self.0.as_fd()) } - #[cfg(windows)] fn pollable(&self) -> Option { Some(self.0.as_raw_handle_or_socket()) } - async fn get_filetype(&mut self) -> Result { + async fn get_filetype(&self) -> Result { Ok(FileType::SocketStream) } #[cfg(unix)] - async fn get_fdflags(&mut self) -> Result { + async fn get_fdflags(&self) -> Result { let fdflags = get_fd_flags(&self.0)?; Ok(fdflags) } @@ -206,23 +204,23 @@ macro_rules! wasi_stream_write_impl { Ok(()) } async fn read_vectored<'a>( - &mut self, + &self, bufs: &mut [io::IoSliceMut<'a>], ) -> Result { use std::io::Read; let n = Read::read_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?; Ok(n.try_into()?) } - async fn write_vectored<'a>(&mut self, bufs: &[io::IoSlice<'a>]) -> Result { + async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result { use std::io::Write; let n = Write::write_vectored(&mut &*self.as_socketlike_view::<$std_ty>(), bufs)?; Ok(n.try_into()?) } - async fn peek(&mut self, buf: &mut [u8]) -> Result { + async fn peek(&self, buf: &mut [u8]) -> Result { let n = self.0.peek(buf)?; Ok(n.try_into()?) } - async fn num_ready_bytes(&self) -> Result { + fn num_ready_bytes(&self) -> Result { let val = self.as_socketlike_view::<$std_ty>().num_ready_bytes()?; Ok(val) } @@ -244,7 +242,7 @@ macro_rules! wasi_stream_write_impl { } async fn sock_recv<'a>( - &mut self, + &self, ri_data: &mut [std::io::IoSliceMut<'a>], ri_flags: RiFlags, ) -> Result<(u64, RoFlags), Error> { @@ -272,7 +270,7 @@ macro_rules! wasi_stream_write_impl { } async fn sock_send<'a>( - &mut self, + &self, si_data: &[std::io::IoSlice<'a>], si_flags: SiFlags, ) -> Result { @@ -284,7 +282,7 @@ macro_rules! wasi_stream_write_impl { Ok(n as u64) } - async fn sock_shutdown(&mut self, how: SdFlags) -> Result<(), Error> { + async fn sock_shutdown(&self, how: SdFlags) -> Result<(), Error> { let how = if how == SdFlags::RD | SdFlags::WR { cap_std::net::Shutdown::Both } else if how == SdFlags::RD { diff --git a/crates/wasi-common/cap-std-sync/src/sched/unix.rs b/crates/wasi-common/cap-std-sync/src/sched/unix.rs index c53acf1a29f8..13f20f018812 100644 --- a/crates/wasi-common/cap-std-sync/src/sched/unix.rs +++ b/crates/wasi-common/cap-std-sync/src/sched/unix.rs @@ -55,7 +55,7 @@ pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> { let revents = pollfd.revents(); let (nbytes, rwsub) = match rwsub { Subscription::Read(sub) => { - let ready = sub.file.num_ready_bytes().await?; + let ready = sub.file.num_ready_bytes()?; (std::cmp::max(ready, 1), sub) } Subscription::Write(sub) => (0, sub), diff --git a/crates/wasi-common/cap-std-sync/src/sched/windows.rs b/crates/wasi-common/cap-std-sync/src/sched/windows.rs index c4ad559cc344..e3eeb930523e 100644 --- a/crates/wasi-common/cap-std-sync/src/sched/windows.rs +++ b/crates/wasi-common/cap-std-sync/src/sched/windows.rs @@ -96,7 +96,7 @@ pub async fn poll_oneoff_<'a>( } } for r in immediate_reads { - match r.file.num_ready_bytes().await { + match r.file.num_ready_bytes() { Ok(ready_bytes) => { r.complete(ready_bytes, RwEventFlags::empty()); ready = true; diff --git a/crates/wasi-common/cap-std-sync/src/stdio.rs b/crates/wasi-common/cap-std-sync/src/stdio.rs index 9d82348af15d..60f55056bccd 100644 --- a/crates/wasi-common/cap-std-sync/src/stdio.rs +++ b/crates/wasi-common/cap-std-sync/src/stdio.rs @@ -31,6 +31,7 @@ impl WasiFile for Stdin { fn as_any(&self) -> &dyn Any { self } + #[cfg(unix)] fn pollable(&self) -> Option { Some(self.0.as_fd()) @@ -40,32 +41,33 @@ impl WasiFile for Stdin { fn pollable(&self) -> Option { Some(self.0.as_raw_handle_or_socket()) } - async fn get_filetype(&mut self) -> Result { + + async fn get_filetype(&self) -> Result { if self.isatty() { Ok(FileType::CharacterDevice) } else { Ok(FileType::Unknown) } } - async fn read_vectored<'a>(&mut self, bufs: &mut [io::IoSliceMut<'a>]) -> Result { + async fn read_vectored<'a>(&self, bufs: &mut [io::IoSliceMut<'a>]) -> Result { let n = (&*self.0.as_filelike_view::()).read_vectored(bufs)?; Ok(n.try_into().map_err(|_| Error::range())?) } async fn read_vectored_at<'a>( - &mut self, + &self, _bufs: &mut [io::IoSliceMut<'a>], _offset: u64, ) -> Result { Err(Error::seek_pipe()) } - async fn seek(&mut self, _pos: std::io::SeekFrom) -> Result { + async fn seek(&self, _pos: std::io::SeekFrom) -> Result { Err(Error::seek_pipe()) } - async fn peek(&mut self, _buf: &mut [u8]) -> Result { + async fn peek(&self, _buf: &mut [u8]) -> Result { Err(Error::seek_pipe()) } async fn set_times( - &mut self, + &self, atime: Option, mtime: Option, ) -> Result<(), Error> { @@ -73,10 +75,10 @@ impl WasiFile for Stdin { .set_times(convert_systimespec(atime), convert_systimespec(mtime))?; Ok(()) } - async fn num_ready_bytes(&self) -> Result { + fn num_ready_bytes(&self) -> Result { Ok(self.0.num_ready_bytes()?) } - fn isatty(&mut self) -> bool { + fn isatty(&self) -> bool { self.0.is_terminal() } } @@ -111,39 +113,38 @@ macro_rules! wasi_file_write_impl { fn pollable(&self) -> Option { Some(self.0.as_fd()) } - #[cfg(windows)] fn pollable(&self) -> Option { Some(self.0.as_raw_handle_or_socket()) } - async fn get_filetype(&mut self) -> Result { + async fn get_filetype(&self) -> Result { if self.isatty() { Ok(FileType::CharacterDevice) } else { Ok(FileType::Unknown) } } - async fn get_fdflags(&mut self) -> Result { + async fn get_fdflags(&self) -> Result { Ok(FdFlags::APPEND) } - async fn write_vectored<'a>(&mut self, bufs: &[io::IoSlice<'a>]) -> Result { + async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result { let n = (&*self.0.as_filelike_view::()).write_vectored(bufs)?; Ok(n.try_into().map_err(|_| { Error::range().context("converting write_vectored total length") })?) } async fn write_vectored_at<'a>( - &mut self, + &self, _bufs: &[io::IoSlice<'a>], _offset: u64, ) -> Result { Err(Error::seek_pipe()) } - async fn seek(&mut self, _pos: std::io::SeekFrom) -> Result { + async fn seek(&self, _pos: std::io::SeekFrom) -> Result { Err(Error::seek_pipe()) } async fn set_times( - &mut self, + &self, atime: Option, mtime: Option, ) -> Result<(), Error> { @@ -151,7 +152,7 @@ macro_rules! wasi_file_write_impl { .set_times(convert_systimespec(atime), convert_systimespec(mtime))?; Ok(()) } - fn isatty(&mut self) -> bool { + fn isatty(&self) -> bool { self.0.is_terminal() } } diff --git a/crates/wasi-common/src/ctx.rs b/crates/wasi-common/src/ctx.rs index b4b3a55324a7..6eabaa1fdee5 100644 --- a/crates/wasi-common/src/ctx.rs +++ b/crates/wasi-common/src/ctx.rs @@ -2,16 +2,29 @@ use crate::clocks::WasiClocks; use crate::dir::{DirCaps, DirEntry, WasiDir}; use crate::file::{FileCaps, FileEntry, WasiFile}; use crate::sched::WasiSched; -use crate::string_array::{StringArray, StringArrayError}; +use crate::string_array::StringArray; use crate::table::Table; -use crate::Error; +use crate::{Error, StringArrayError}; use cap_rand::RngCore; +use std::ops::Deref; use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; -pub struct WasiCtx { +/// An `Arc`-wrapper around the wasi-common context to allow mutable access to +/// the file descriptor table. This wrapper is only necessary due to the +/// signature of `fd_fdstat_set_flags`; if that changes, there are a variety of +/// improvements that can be made (TODO: +/// https://github.com/bytecodealliance/wasmtime/issues/5643). +#[derive(Clone)] +pub struct WasiCtx(Arc); + +pub struct WasiCtxInner { pub args: StringArray, pub env: StringArray, - pub random: Box, + // TODO: this mutex should not be necessary, it forces threads to serialize + // their access to randomness unnecessarily + // (https://github.com/bytecodealliance/wasmtime/issues/5660). + pub random: Mutex>, pub clocks: WasiClocks, pub sched: Box, pub table: Table, @@ -24,31 +37,31 @@ impl WasiCtx { sched: Box, table: Table, ) -> Self { - let mut s = WasiCtx { + let s = WasiCtx(Arc::new(WasiCtxInner { args: StringArray::new(), env: StringArray::new(), - random, + random: Mutex::new(random), clocks, sched, table, - }; + })); s.set_stdin(Box::new(crate::pipe::ReadPipe::new(std::io::empty()))); s.set_stdout(Box::new(crate::pipe::WritePipe::new(std::io::sink()))); s.set_stderr(Box::new(crate::pipe::WritePipe::new(std::io::sink()))); s } - pub fn insert_file(&mut self, fd: u32, file: Box, caps: FileCaps) { + pub fn insert_file(&self, fd: u32, file: Box, caps: FileCaps) { self.table() - .insert_at(fd, Box::new(FileEntry::new(caps, file))); + .insert_at(fd, Arc::new(FileEntry::new(caps, file))); } - pub fn push_file(&mut self, file: Box, caps: FileCaps) -> Result { - self.table().push(Box::new(FileEntry::new(caps, file))) + pub fn push_file(&self, file: Box, caps: FileCaps) -> Result { + self.table().push(Arc::new(FileEntry::new(caps, file))) } pub fn insert_dir( - &mut self, + &self, fd: u32, dir: Box, caps: DirCaps, @@ -57,45 +70,55 @@ impl WasiCtx { ) { self.table().insert_at( fd, - Box::new(DirEntry::new(caps, file_caps, Some(path), dir)), + Arc::new(DirEntry::new(caps, file_caps, Some(path), dir)), ); } pub fn push_dir( - &mut self, + &self, dir: Box, caps: DirCaps, file_caps: FileCaps, path: PathBuf, ) -> Result { self.table() - .push(Box::new(DirEntry::new(caps, file_caps, Some(path), dir))) + .push(Arc::new(DirEntry::new(caps, file_caps, Some(path), dir))) + } + + pub fn table(&self) -> &Table { + &self.table } - pub fn table(&mut self) -> &mut Table { - &mut self.table + pub fn table_mut(&mut self) -> Option<&mut Table> { + Arc::get_mut(&mut self.0).map(|c| &mut c.table) } pub fn push_arg(&mut self, arg: &str) -> Result<(), StringArrayError> { - self.args.push(arg.to_owned()) + let s = Arc::get_mut(&mut self.0).expect( + "`push_arg` should only be used during initialization before the context is cloned", + ); + s.args.push(arg.to_owned()) } pub fn push_env(&mut self, var: &str, value: &str) -> Result<(), StringArrayError> { - self.env.push(format!("{}={}", var, value))?; + let s = Arc::get_mut(&mut self.0).expect( + "`push_env` should only be used during initialization before the context is cloned", + ); + s.env.push(format!("{}={}", var, value))?; Ok(()) } - pub fn set_stdin(&mut self, mut f: Box) { + pub fn set_stdin(&self, mut f: Box) { let rights = Self::stdio_rights(&mut *f); self.insert_file(0, f, rights); } - pub fn set_stdout(&mut self, mut f: Box) { + pub fn set_stdout(&self, mut f: Box) { let rights = Self::stdio_rights(&mut *f); self.insert_file(1, f, rights); } - pub fn set_stderr(&mut self, mut f: Box) { + pub fn set_stderr(&self, mut f: Box) { let rights = Self::stdio_rights(&mut *f); self.insert_file(2, f, rights); } @@ -114,13 +137,13 @@ impl WasiCtx { } pub fn push_preopened_dir( - &mut self, + &self, dir: Box, path: impl AsRef, ) -> Result<(), Error> { let caps = DirCaps::all(); let file_caps = FileCaps::all(); - self.table().push(Box::new(DirEntry::new( + self.table().push(Arc::new(DirEntry::new( caps, file_caps, Some(path.as_ref().to_owned()), @@ -129,3 +152,10 @@ impl WasiCtx { Ok(()) } } + +impl Deref for WasiCtx { + type Target = WasiCtxInner; + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/crates/wasi-common/src/dir.rs b/crates/wasi-common/src/dir.rs index 56a8849db468..48cc10d2f604 100644 --- a/crates/wasi-common/src/dir.rs +++ b/crates/wasi-common/src/dir.rs @@ -3,6 +3,7 @@ use crate::{Error, ErrorExt, SystemTimeSpec}; use bitflags::bitflags; use std::any::Any; use std::path::PathBuf; +use std::sync::{Arc, RwLock}; #[wiggle::async_trait] pub trait WasiDir: Send + Sync { @@ -98,67 +99,50 @@ pub trait WasiDir: Send + Sync { } pub(crate) struct DirEntry { - caps: DirCaps, - file_caps: FileCaps, + caps: RwLock, preopen_path: Option, // precondition: PathBuf is valid unicode dir: Box, } impl DirEntry { pub fn new( - caps: DirCaps, + dir_caps: DirCaps, file_caps: FileCaps, preopen_path: Option, dir: Box, ) -> Self { DirEntry { - caps, - file_caps, + caps: RwLock::new(DirFdStat { + dir_caps, + file_caps, + }), preopen_path, dir, } } pub fn capable_of_dir(&self, caps: DirCaps) -> Result<(), Error> { - if self.caps.contains(caps) { - Ok(()) - } else { - let missing = caps & !self.caps; - let err = if missing.intersects(DirCaps::READDIR) { - Error::not_dir() - } else { - Error::perm() - }; - Err(err.context(format!("desired rights {:?}, has {:?}", caps, self.caps))) - } - } - pub fn capable_of_file(&self, caps: FileCaps) -> Result<(), Error> { - if self.file_caps.contains(caps) { - Ok(()) - } else { - Err(Error::perm().context(format!( - "desired rights {:?}, has {:?}", - caps, self.file_caps - ))) - } + let fdstat = self.caps.read().unwrap(); + fdstat.capable_of_dir(caps) } - pub fn drop_caps_to(&mut self, caps: DirCaps, file_caps: FileCaps) -> Result<(), Error> { - self.capable_of_dir(caps)?; - self.capable_of_file(file_caps)?; - self.caps = caps; - self.file_caps = file_caps; + + pub fn drop_caps_to(&self, dir_caps: DirCaps, file_caps: FileCaps) -> Result<(), Error> { + let mut fdstat = self.caps.write().unwrap(); + fdstat.capable_of_dir(dir_caps)?; + fdstat.capable_of_file(file_caps)?; + *fdstat = DirFdStat { + dir_caps, + file_caps, + }; Ok(()) } pub fn child_dir_caps(&self, desired_caps: DirCaps) -> DirCaps { - self.caps & desired_caps + self.caps.read().unwrap().dir_caps & desired_caps } pub fn child_file_caps(&self, desired_caps: FileCaps) -> FileCaps { - self.file_caps & desired_caps + self.caps.read().unwrap().file_caps & desired_caps } pub fn get_dir_fdstat(&self) -> DirFdStat { - DirFdStat { - dir_caps: self.caps, - file_caps: self.file_caps, - } + self.caps.read().unwrap().clone() } pub fn preopen_path(&self) -> &Option { &self.preopen_path @@ -203,18 +187,47 @@ pub struct DirFdStat { pub dir_caps: DirCaps, } +impl DirFdStat { + pub fn capable_of_dir(&self, caps: DirCaps) -> Result<(), Error> { + if self.dir_caps.contains(caps) { + Ok(()) + } else { + let missing = caps & !self.dir_caps; + let err = if missing.intersects(DirCaps::READDIR) { + Error::not_dir() + } else { + Error::perm() + }; + Err(err.context(format!( + "desired rights {:?}, has {:?}", + caps, self.dir_caps + ))) + } + } + pub fn capable_of_file(&self, caps: FileCaps) -> Result<(), Error> { + if self.file_caps.contains(caps) { + Ok(()) + } else { + Err(Error::perm().context(format!( + "desired rights {:?}, has {:?}", + caps, self.file_caps + ))) + } + } +} + pub(crate) trait TableDirExt { - fn get_dir(&self, fd: u32) -> Result<&DirEntry, Error>; + fn get_dir(&self, fd: u32) -> Result, Error>; fn is_preopen(&self, fd: u32) -> bool; } impl TableDirExt for crate::table::Table { - fn get_dir(&self, fd: u32) -> Result<&DirEntry, Error> { + fn get_dir(&self, fd: u32) -> Result, Error> { self.get(fd) } fn is_preopen(&self, fd: u32) -> bool { if self.is::(fd) { - let dir_entry: &DirEntry = self.get(fd).unwrap(); + let dir_entry: Arc = self.get(fd).unwrap(); dir_entry.preopen_path.is_some() } else { false diff --git a/crates/wasi-common/src/file.rs b/crates/wasi-common/src/file.rs index 8d4bf6a45a61..c799b6dcbbc5 100644 --- a/crates/wasi-common/src/file.rs +++ b/crates/wasi-common/src/file.rs @@ -1,11 +1,12 @@ use crate::{Error, ErrorExt, SystemTimeSpec}; use bitflags::bitflags; use std::any::Any; +use std::sync::{Arc, RwLock}; #[wiggle::async_trait] pub trait WasiFile: Send + Sync { fn as_any(&self) -> &dyn Any; - async fn get_filetype(&mut self) -> Result; + async fn get_filetype(&self) -> Result; #[cfg(unix)] fn pollable(&self) -> Option { @@ -17,16 +18,16 @@ pub trait WasiFile: Send + Sync { None } - fn isatty(&mut self) -> bool { + fn isatty(&self) -> bool { false } - async fn sock_accept(&mut self, _fdflags: FdFlags) -> Result, Error> { + async fn sock_accept(&self, _fdflags: FdFlags) -> Result, Error> { Err(Error::badf()) } async fn sock_recv<'a>( - &mut self, + &self, _ri_data: &mut [std::io::IoSliceMut<'a>], _ri_flags: RiFlags, ) -> Result<(u64, RoFlags), Error> { @@ -34,26 +35,26 @@ pub trait WasiFile: Send + Sync { } async fn sock_send<'a>( - &mut self, + &self, _si_data: &[std::io::IoSlice<'a>], _si_flags: SiFlags, ) -> Result { Err(Error::badf()) } - async fn sock_shutdown(&mut self, _how: SdFlags) -> Result<(), Error> { + async fn sock_shutdown(&self, _how: SdFlags) -> Result<(), Error> { Err(Error::badf()) } - async fn datasync(&mut self) -> Result<(), Error> { + async fn datasync(&self) -> Result<(), Error> { Ok(()) } - async fn sync(&mut self) -> Result<(), Error> { + async fn sync(&self) -> Result<(), Error> { Ok(()) } - async fn get_fdflags(&mut self) -> Result { + async fn get_fdflags(&self) -> Result { Ok(FdFlags::empty()) } @@ -61,7 +62,7 @@ pub trait WasiFile: Send + Sync { Err(Error::badf()) } - async fn get_filestat(&mut self) -> Result { + async fn get_filestat(&self) -> Result { Ok(Filestat { device_id: 0, inode: 0, @@ -74,62 +75,59 @@ pub trait WasiFile: Send + Sync { }) } - async fn set_filestat_size(&mut self, _size: u64) -> Result<(), Error> { + async fn set_filestat_size(&self, _size: u64) -> Result<(), Error> { Err(Error::badf()) } - async fn advise(&mut self, _offset: u64, _len: u64, _advice: Advice) -> Result<(), Error> { + async fn advise(&self, _offset: u64, _len: u64, _advice: Advice) -> Result<(), Error> { Err(Error::badf()) } - async fn allocate(&mut self, _offset: u64, _len: u64) -> Result<(), Error> { + async fn allocate(&self, _offset: u64, _len: u64) -> Result<(), Error> { Err(Error::badf()) } async fn set_times( - &mut self, + &self, _atime: Option, _mtime: Option, ) -> Result<(), Error> { Err(Error::badf()) } - async fn read_vectored<'a>( - &mut self, - _bufs: &mut [std::io::IoSliceMut<'a>], - ) -> Result { + async fn read_vectored<'a>(&self, _bufs: &mut [std::io::IoSliceMut<'a>]) -> Result { Err(Error::badf()) } async fn read_vectored_at<'a>( - &mut self, + &self, _bufs: &mut [std::io::IoSliceMut<'a>], _offset: u64, ) -> Result { Err(Error::badf()) } - async fn write_vectored<'a>(&mut self, _bufs: &[std::io::IoSlice<'a>]) -> Result { + async fn write_vectored<'a>(&self, _bufs: &[std::io::IoSlice<'a>]) -> Result { Err(Error::badf()) } async fn write_vectored_at<'a>( - &mut self, + &self, _bufs: &[std::io::IoSlice<'a>], _offset: u64, ) -> Result { Err(Error::badf()) } - async fn seek(&mut self, _pos: std::io::SeekFrom) -> Result { + async fn seek(&self, _pos: std::io::SeekFrom) -> Result { Err(Error::badf()) } - async fn peek(&mut self, _buf: &mut [u8]) -> Result { + async fn peek(&self, _buf: &mut [u8]) -> Result { Err(Error::badf()) } - async fn num_ready_bytes(&self) -> Result { + fn num_ready_bytes(&self) -> Result { Ok(0) } @@ -212,11 +210,11 @@ pub struct Filestat { } pub(crate) trait TableFileExt { - fn get_file(&self, fd: u32) -> Result<&FileEntry, Error>; + fn get_file(&self, fd: u32) -> Result, Error>; fn get_file_mut(&mut self, fd: u32) -> Result<&mut FileEntry, Error>; } impl TableFileExt for crate::table::Table { - fn get_file(&self, fd: u32) -> Result<&FileEntry, Error> { + fn get_file(&self, fd: u32) -> Result, Error> { self.get(fd) } fn get_file_mut(&mut self, fd: u32) -> Result<&mut FileEntry, Error> { @@ -225,20 +223,23 @@ impl TableFileExt for crate::table::Table { } pub(crate) struct FileEntry { - caps: FileCaps, + caps: RwLock, file: Box, } impl FileEntry { pub fn new(caps: FileCaps, file: Box) -> Self { - FileEntry { caps, file } + FileEntry { + caps: RwLock::new(caps), + file, + } } pub fn capable_of(&self, caps: FileCaps) -> Result<(), Error> { - if self.caps.contains(caps) { + if self.caps.read().unwrap().contains(caps) { Ok(()) } else { - let missing = caps & !self.caps; + let missing = caps & !(*self.caps.read().unwrap()); let err = if missing.intersects(FileCaps::READ | FileCaps::WRITE) { // `EBADF` is a little surprising here because it's also used // for unknown-file-descriptor errors, but it's what POSIX uses @@ -251,16 +252,17 @@ impl FileEntry { } } - pub fn drop_caps_to(&mut self, caps: FileCaps) -> Result<(), Error> { + pub fn drop_caps_to(&self, caps: FileCaps) -> Result<(), Error> { self.capable_of(caps)?; - self.caps = caps; + *self.caps.write().unwrap() = caps; Ok(()) } - pub async fn get_fdstat(&mut self) -> Result { + pub async fn get_fdstat(&self) -> Result { + let caps = self.caps.read().unwrap().clone(); Ok(FdStat { filetype: self.file.get_filetype().await?, - caps: self.caps, + caps, flags: self.file.get_fdflags().await?, }) } @@ -276,7 +278,6 @@ impl FileEntryExt for FileEntry { self.capable_of(caps)?; Ok(&*self.file) } - fn get_cap_mut(&mut self, caps: FileCaps) -> Result<&mut dyn WasiFile, Error> { self.capable_of(caps)?; Ok(&mut *self.file) diff --git a/crates/wasi-common/src/pipe.rs b/crates/wasi-common/src/pipe.rs index a5fceb80a1b1..1700131bd6cc 100644 --- a/crates/wasi-common/src/pipe.rs +++ b/crates/wasi-common/src/pipe.rs @@ -105,10 +105,10 @@ impl WasiFile for ReadPipe { fn as_any(&self) -> &dyn Any { self } - async fn get_filetype(&mut self) -> Result { + async fn get_filetype(&self) -> Result { Ok(FileType::Pipe) } - async fn read_vectored<'a>(&mut self, bufs: &mut [io::IoSliceMut<'a>]) -> Result { + async fn read_vectored<'a>(&self, bufs: &mut [io::IoSliceMut<'a>]) -> Result { let n = self.borrow().read_vectored(bufs)?; Ok(n.try_into()?) } @@ -189,13 +189,13 @@ impl WasiFile for WritePipe { fn as_any(&self) -> &dyn Any { self } - async fn get_filetype(&mut self) -> Result { + async fn get_filetype(&self) -> Result { Ok(FileType::Pipe) } - async fn get_fdflags(&mut self) -> Result { + async fn get_fdflags(&self) -> Result { Ok(FdFlags::APPEND) } - async fn write_vectored<'a>(&mut self, bufs: &[io::IoSlice<'a>]) -> Result { + async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result { let n = self.borrow().write_vectored(bufs)?; Ok(n.try_into()?) } diff --git a/crates/wasi-common/src/snapshots/preview_0.rs b/crates/wasi-common/src/snapshots/preview_0.rs index 763f685fcccf..e2a47223af7b 100644 --- a/crates/wasi-common/src/snapshots/preview_0.rs +++ b/crates/wasi-common/src/snapshots/preview_0.rs @@ -528,10 +528,8 @@ impl wasi_unstable::WasiUnstable for WasiCtx { fd: types::Fd, iovs: &types::IovecArray<'a>, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::READ)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::READ)?; let iovs: Vec> = iovs .iter() @@ -601,10 +599,8 @@ impl wasi_unstable::WasiUnstable for WasiCtx { iovs: &types::IovecArray<'a>, offset: types::Filesize, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::READ | FileCaps::SEEK)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::READ | FileCaps::SEEK)?; let iovs: Vec> = iovs .iter() @@ -675,10 +671,8 @@ impl wasi_unstable::WasiUnstable for WasiCtx { fd: types::Fd, ciovs: &types::CiovecArray<'a>, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::WRITE)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::WRITE)?; let guest_slices: Vec> = ciovs .iter() @@ -704,10 +698,8 @@ impl wasi_unstable::WasiUnstable for WasiCtx { ciovs: &types::CiovecArray<'a>, offset: types::Filesize, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::WRITE | FileCaps::SEEK)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::WRITE | FileCaps::SEEK)?; let guest_slices: Vec> = ciovs .iter() @@ -953,7 +945,7 @@ impl wasi_unstable::WasiUnstable for WasiCtx { } } - let table = &mut self.table; + let table = &self.table; let mut sub_fds: HashSet = HashSet::new(); // We need these refmuts to outlive Poll, which will hold the &mut dyn WasiFile inside let mut reads: Vec<(u32, Userdata)> = Vec::new(); @@ -1003,8 +995,8 @@ impl wasi_unstable::WasiUnstable for WasiCtx { sub_fds.insert(fd); } table - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::POLL_READWRITE)?; + .get_file(u32::from(fd))? + .get_cap(FileCaps::POLL_READWRITE)?; reads.push((u32::from(fd), sub.userdata.into())); } types::SubscriptionU::FdWrite(writesub) => { @@ -1016,8 +1008,8 @@ impl wasi_unstable::WasiUnstable for WasiCtx { sub_fds.insert(fd); } table - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::POLL_READWRITE)?; + .get_file(u32::from(fd))? + .get_cap(FileCaps::POLL_READWRITE)?; writes.push((u32::from(fd), sub.userdata.into())); } } diff --git a/crates/wasi-common/src/snapshots/preview_1.rs b/crates/wasi-common/src/snapshots/preview_1.rs index eac6d8544e5d..2d368305485d 100644 --- a/crates/wasi-common/src/snapshots/preview_1.rs +++ b/crates/wasi-common/src/snapshots/preview_1.rs @@ -14,6 +14,7 @@ use cap_std::time::{Duration, SystemClock}; use std::convert::{TryFrom, TryInto}; use std::io::{IoSlice, IoSliceMut}; use std::ops::Deref; +use std::sync::Arc; use wiggle::GuestPtr; pub mod error; @@ -111,8 +112,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { advice: types::Advice, ) -> Result<(), Error> { self.table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::ADVISE)? + .get_file(u32::from(fd))? + .get_cap(FileCaps::ADVISE)? .advise(offset, len, advice.into()) .await?; Ok(()) @@ -125,8 +126,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { len: types::Filesize, ) -> Result<(), Error> { self.table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::ALLOCATE)? + .get_file(u32::from(fd))? + .get_cap(FileCaps::ALLOCATE)? .allocate(offset, len) .await?; Ok(()) @@ -142,15 +143,15 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { } // fd_close must close either a File or a Dir handle if table.is::(fd) { - let _ = table.delete(fd); + let _ = table.delete::(fd); } else if table.is::(fd) { // We cannot close preopened directories - let dir_entry: &DirEntry = table.get(fd).unwrap(); + let dir_entry: Arc = table.get(fd).unwrap(); if dir_entry.preopen_path().is_some() { return Err(Error::not_supported().context("cannot close propened directory")); } drop(dir_entry); - let _ = table.delete(fd); + let _ = table.delete::(fd); } else { return Err(Error::badf().context("key does not refer to file or directory")); } @@ -160,8 +161,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { async fn fd_datasync(&mut self, fd: types::Fd) -> Result<(), Error> { self.table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::DATASYNC)? + .get_file(u32::from(fd))? + .get_cap(FileCaps::DATASYNC)? .datasync() .await?; Ok(()) @@ -171,11 +172,11 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { let table = self.table(); let fd = u32::from(fd); if table.is::(fd) { - let file_entry: &mut FileEntry = table.get_mut(fd)?; + let file_entry: Arc = table.get(fd)?; let fdstat = file_entry.get_fdstat().await?; Ok(types::Fdstat::from(&fdstat)) } else if table.is::(fd) { - let dir_entry: &DirEntry = table.get(fd)?; + let dir_entry: Arc = table.get(fd)?; let dir_fdstat = dir_entry.get_dir_fdstat(); Ok(types::Fdstat::from(&dir_fdstat)) } else { @@ -188,11 +189,16 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { fd: types::Fd, flags: types::Fdflags, ) -> Result<(), Error> { - self.table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::FDSTAT_SET_FLAGS)? - .set_fdflags(FdFlags::from(flags)) - .await + if let Some(table) = self.table_mut() { + table + .get_file_mut(u32::from(fd))? + .get_cap_mut(FileCaps::FDSTAT_SET_FLAGS)? + .set_fdflags(FdFlags::from(flags)) + .await + } else { + log::warn!("`fd_fdstat_set_flags` does not work with wasi-threads enabled; see https://github.com/bytecodealliance/wasmtime/issues/5643"); + Err(Error::not_supported()) + } } async fn fd_fdstat_set_rights( @@ -204,11 +210,11 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { let table = self.table(); let fd = u32::from(fd); if table.is::(fd) { - let file_entry: &mut FileEntry = table.get_mut(fd)?; + let file_entry: Arc = table.get(fd)?; let file_caps = FileCaps::from(&fs_rights_base); file_entry.drop_caps_to(file_caps) } else if table.is::(fd) { - let dir_entry: &mut DirEntry = table.get_mut(fd)?; + let dir_entry: Arc = table.get(fd)?; let dir_caps = DirCaps::from(&fs_rights_base); let file_caps = FileCaps::from(&fs_rights_inheriting); dir_entry.drop_caps_to(dir_caps, file_caps) @@ -222,8 +228,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { let fd = u32::from(fd); if table.is::(fd) { let filestat = table - .get_file_mut(fd)? - .get_cap_mut(FileCaps::FILESTAT_GET)? + .get_file(fd)? + .get_cap(FileCaps::FILESTAT_GET)? .get_filestat() .await?; Ok(filestat.into()) @@ -245,8 +251,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { size: types::Filesize, ) -> Result<(), Error> { self.table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::FILESTAT_SET_SIZE)? + .get_file(u32::from(fd))? + .get_cap(FileCaps::FILESTAT_SET_SIZE)? .set_filestat_size(size) .await?; Ok(()) @@ -272,9 +278,9 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { if table.is::(fd) { table - .get_file_mut(fd) + .get_file(fd) .expect("checked that entry is file") - .get_cap_mut(FileCaps::FILESTAT_SET_TIMES)? + .get_cap(FileCaps::FILESTAT_SET_TIMES)? .set_times(atim, mtim) .await } else if table.is::(fd) { @@ -294,10 +300,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { fd: types::Fd, iovs: &types::IovecArray<'a>, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::READ)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::READ)?; let iovs: Vec> = iovs .iter() @@ -367,10 +371,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { iovs: &types::IovecArray<'a>, offset: types::Filesize, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::READ | FileCaps::SEEK)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::READ | FileCaps::SEEK)?; let iovs: Vec> = iovs .iter() @@ -441,10 +443,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { fd: types::Fd, ciovs: &types::CiovecArray<'a>, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::WRITE)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::WRITE)?; let guest_slices: Vec> = ciovs .iter() @@ -470,10 +470,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { ciovs: &types::CiovecArray<'a>, offset: types::Filesize, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::WRITE | FileCaps::SEEK)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::WRITE | FileCaps::SEEK)?; let guest_slices: Vec> = ciovs .iter() @@ -495,7 +493,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { async fn fd_prestat_get(&mut self, fd: types::Fd) -> Result { let table = self.table(); - let dir_entry: &DirEntry = table.get(u32::from(fd)).map_err(|_| Error::badf())?; + let dir_entry: Arc = table.get(u32::from(fd)).map_err(|_| Error::badf())?; if let Some(ref preopen) = dir_entry.preopen_path() { let path_str = preopen.to_str().ok_or_else(|| Error::not_supported())?; let pr_name_len = u32::try_from(path_str.as_bytes().len())?; @@ -512,7 +510,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { path_max_len: types::Size, ) -> Result<(), Error> { let table = self.table(); - let dir_entry: &DirEntry = table.get(u32::from(fd)).map_err(|_| Error::not_dir())?; + let dir_entry: Arc = table.get(u32::from(fd)).map_err(|_| Error::not_dir())?; if let Some(ref preopen) = dir_entry.preopen_path() { let path_bytes = preopen .to_str() @@ -538,11 +536,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { if table.is_preopen(from) || table.is_preopen(to) { return Err(Error::not_supported().context("cannot renumber a preopen")); } - let from_entry = table - .delete(from) - .expect("we checked that table contains from"); - table.insert_at(to, from_entry); - Ok(()) + table.renumber(from, to) } async fn fd_seek( @@ -566,8 +560,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { }; let newoffset = self .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(required_caps)? + .get_file(u32::from(fd))? + .get_cap(required_caps)? .seek(whence) .await?; Ok(newoffset) @@ -575,8 +569,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { async fn fd_sync(&mut self, fd: types::Fd) -> Result<(), Error> { self.table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::SYNC)? + .get_file(u32::from(fd))? + .get_cap(FileCaps::SYNC)? .sync() .await?; Ok(()) @@ -586,8 +580,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { // XXX should this be stream_position? let offset = self .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::TELL)? + .get_file(u32::from(fd))? + .get_cap(FileCaps::TELL)? .seek(std::io::SeekFrom::Current(0)) .await?; Ok(offset) @@ -714,12 +708,10 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { target_path: &GuestPtr<'a, str>, ) -> Result<(), Error> { let table = self.table(); - let src_dir = table - .get_dir(u32::from(src_fd))? - .get_cap(DirCaps::LINK_SOURCE)?; - let target_dir = table - .get_dir(u32::from(target_fd))? - .get_cap(DirCaps::LINK_TARGET)?; + let src_dir = table.get_dir(u32::from(src_fd))?; + let src_dir = src_dir.get_cap(DirCaps::LINK_SOURCE)?; + let target_dir = table.get_dir(u32::from(target_fd))?; + let target_dir = target_dir.get_cap(DirCaps::LINK_TARGET)?; let symlink_follow = src_flags.contains(types::Lookupflags::SYMLINK_FOLLOW); if symlink_follow { return Err(Error::invalid_argument() @@ -769,7 +761,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { let dir = dir_entry.get_cap(DirCaps::OPEN)?; let child_dir = dir.open_dir(symlink_follow, path.deref()).await?; drop(dir); - let fd = table.push(Box::new(DirEntry::new( + let fd = table.push(Arc::new(DirEntry::new( dir_caps, file_caps, None, child_dir, )))?; Ok(types::Fd::from(fd)) @@ -789,7 +781,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { .open_file(symlink_follow, path.deref(), oflags, read, write, fdflags) .await?; drop(dir); - let fd = table.push(Box::new(FileEntry::new(file_caps, file)))?; + let fd = table.push(Arc::new(FileEntry::new(file_caps, file)))?; Ok(types::Fd::from(fd)) } } @@ -839,12 +831,10 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { dest_path: &GuestPtr<'a, str>, ) -> Result<(), Error> { let table = self.table(); - let src_dir = table - .get_dir(u32::from(src_fd))? - .get_cap(DirCaps::RENAME_SOURCE)?; - let dest_dir = table - .get_dir(u32::from(dest_fd))? - .get_cap(DirCaps::RENAME_TARGET)?; + let src_dir = table.get_dir(u32::from(src_fd))?; + let src_dir = src_dir.get_cap(DirCaps::RENAME_SOURCE)?; + let dest_dir = table.get_dir(u32::from(dest_fd))?; + let dest_dir = dest_dir.get_cap(DirCaps::RENAME_TARGET)?; src_dir .rename( src_path.as_cow()?.deref(), @@ -914,10 +904,11 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { } } - let table = &mut self.table; + let table = &self.table; // We need these refmuts to outlive Poll, which will hold the &mut dyn WasiFile inside - let mut read_refs: Vec<(&dyn WasiFile, Userdata)> = Vec::new(); - let mut write_refs: Vec<(&dyn WasiFile, Userdata)> = Vec::new(); + let mut read_refs: Vec<(Arc, Option)> = Vec::new(); + let mut write_refs: Vec<(Arc, Option)> = Vec::new(); + let mut poll = Poll::new(); let subs = subs.as_array(nsubscriptions); @@ -983,25 +974,37 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { }, types::SubscriptionU::FdRead(readsub) => { let fd = readsub.file_descriptor; - let file_ref = table - .get_file(u32::from(fd))? - .get_cap(FileCaps::POLL_READWRITE)?; - read_refs.push((file_ref, sub.userdata.into())); + let file_ref = table.get_file(u32::from(fd))?; + let _file = file_ref.get_cap(FileCaps::POLL_READWRITE)?; + + read_refs.push((file_ref, Some(sub.userdata.into()))); } types::SubscriptionU::FdWrite(writesub) => { let fd = writesub.file_descriptor; - let file_ref = table - .get_file(u32::from(fd))? - .get_cap(FileCaps::POLL_READWRITE)?; - write_refs.push((file_ref, sub.userdata.into())); + let file_ref = table.get_file(u32::from(fd))?; + let _file = file_ref.get_cap(FileCaps::POLL_READWRITE)?; + write_refs.push((file_ref, Some(sub.userdata.into()))); } } } - for (f, ud) in read_refs.iter_mut() { + let mut read_mut_refs: Vec<(&dyn WasiFile, Userdata)> = Vec::new(); + for (file_lock, userdata) in read_refs.iter_mut() { + let file = file_lock.get_cap(FileCaps::POLL_READWRITE)?; + read_mut_refs.push((file, userdata.take().unwrap())); + } + + for (f, ud) in read_mut_refs.iter_mut() { poll.subscribe_read(*f, *ud); } - for (f, ud) in write_refs.iter_mut() { + + let mut write_mut_refs: Vec<(&dyn WasiFile, Userdata)> = Vec::new(); + for (file_lock, userdata) in write_refs.iter_mut() { + let file = file_lock.get_cap(FileCaps::POLL_READWRITE)?; + write_mut_refs.push((file, userdata.take().unwrap())); + } + + for (f, ud) in write_mut_refs.iter_mut() { poll.subscribe_write(*f, *ud); } @@ -1112,7 +1115,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { while copied < buf.len() { let len = (buf.len() - copied).min(MAX_SHARED_BUFFER_SIZE as u32); let mut tmp = vec![0; len as usize]; - self.random.try_fill_bytes(&mut tmp)?; + self.random.lock().unwrap().try_fill_bytes(&mut tmp)?; let dest = buf .get_range(copied..copied + len) .unwrap() @@ -1124,7 +1127,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { // If the Wasm memory is non-shared, copy directly into the linear // memory. let mem = &mut buf.as_slice_mut()?.unwrap(); - self.random.try_fill_bytes(mem)?; + self.random.lock().unwrap().try_fill_bytes(mem)?; } Ok(()) } @@ -1135,9 +1138,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { flags: types::Fdflags, ) -> Result { let table = self.table(); - let f = table - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::READ)?; + let f = table.get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::READ)?; let file = f.sock_accept(FdFlags::from(flags)).await?; let file_caps = FileCaps::READ @@ -1146,7 +1148,7 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { | FileCaps::POLL_READWRITE | FileCaps::FILESTAT_GET; - let fd = table.push(Box::new(FileEntry::new(file_caps, file)))?; + let fd = table.push(Arc::new(FileEntry::new(file_caps, file)))?; Ok(types::Fd::from(fd)) } @@ -1156,10 +1158,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { ri_data: &types::IovecArray<'a>, ri_flags: types::Riflags, ) -> Result<(types::Size, types::Roflags), Error> { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::READ)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::READ)?; let iovs: Vec> = ri_data .iter() @@ -1231,10 +1231,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { si_data: &types::CiovecArray<'a>, _si_flags: types::Siflags, ) -> Result { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::WRITE)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::WRITE)?; let guest_slices: Vec> = si_data .iter() @@ -1255,10 +1253,8 @@ impl wasi_snapshot_preview1::WasiSnapshotPreview1 for WasiCtx { } async fn sock_shutdown(&mut self, fd: types::Fd, how: types::Sdflags) -> Result<(), Error> { - let f = self - .table() - .get_file_mut(u32::from(fd))? - .get_cap_mut(FileCaps::FDSTAT_SET_FLAGS)?; + let f = self.table().get_file(u32::from(fd))?; + let f = f.get_cap(FileCaps::FDSTAT_SET_FLAGS)?; f.sock_shutdown(SdFlags::from(how)).await } diff --git a/crates/wasi-common/src/table.rs b/crates/wasi-common/src/table.rs index 195d8babb677..40069636786e 100644 --- a/crates/wasi-common/src/table.rs +++ b/crates/wasi-common/src/table.rs @@ -1,6 +1,7 @@ use crate::{Error, ErrorExt}; use std::any::Any; use std::collections::HashMap; +use std::sync::{Arc, RwLock}; /// The `Table` type is designed to map u32 handles to resources. The table is now part of the /// public interface to a `WasiCtx` - it is reference counted so that it can be shared beyond a @@ -9,84 +10,105 @@ use std::collections::HashMap; /// /// The `Table` type is intended to model how the Interface Types concept of Resources is shaping /// up. Right now it is just an approximation. -pub struct Table { - map: HashMap>, +pub struct Table(RwLock); + +struct Inner { + map: HashMap>, next_key: u32, } impl Table { /// Create an empty table. New insertions will begin at 3, above stdio. pub fn new() -> Self { - Table { + Table(RwLock::new(Inner { map: HashMap::new(), next_key: 3, // 0, 1 and 2 are reserved for stdio - } + })) } /// Insert a resource at a certain index. - pub fn insert_at(&mut self, key: u32, a: Box) { - self.map.insert(key, a); + pub fn insert_at(&self, key: u32, a: Arc) { + self.0.write().unwrap().map.insert(key, a); } /// Insert a resource at the next available index. - pub fn push(&mut self, a: Box) -> Result { + pub fn push(&self, a: Arc) -> Result { + let mut inner = self.0.write().unwrap(); // NOTE: The performance of this new key calculation could be very bad once keys wrap // around. - if self.map.len() == u32::MAX as usize { + if inner.map.len() == u32::MAX as usize { return Err(Error::trap(anyhow::Error::msg("table has no free keys"))); } loop { - let key = self.next_key; - self.next_key = self.next_key.wrapping_add(1); - if self.map.contains_key(&key) { + let key = inner.next_key; + inner.next_key += 1; + if inner.map.contains_key(&key) { continue; } - self.map.insert(key, a); + inner.map.insert(key, a); return Ok(key); } } /// Check if the table has a resource at the given index. pub fn contains_key(&self, key: u32) -> bool { - self.map.contains_key(&key) + self.0.read().unwrap().map.contains_key(&key) } /// Check if the resource at a given index can be downcast to a given type. /// Note: this will always fail if the resource is already borrowed. pub fn is(&self, key: u32) -> bool { - if let Some(r) = self.map.get(&key) { + if let Some(r) = self.0.read().unwrap().map.get(&key) { r.is::() } else { false } } - /// Get an immutable reference to a resource of a given type at a given index. Multiple - /// immutable references can be borrowed at any given time. Borrow failure - /// results in a trapping error. - pub fn get(&self, key: u32) -> Result<&T, Error> { - if let Some(r) = self.map.get(&key) { - r.downcast_ref::() - .ok_or_else(|| Error::badf().context("element is a different type")) + /// Get an Arc reference to a resource of a given type at a given index. Multiple + /// immutable references can be borrowed at any given time. + pub fn get(&self, key: u32) -> Result, Error> { + if let Some(r) = self.0.read().unwrap().map.get(&key).cloned() { + r.downcast::() + .map_err(|_| Error::badf().context("element is a different type")) } else { Err(Error::badf().context("key not in table")) } } - /// Get a mutable reference to a resource of a given type at a given index. Only one mutable - /// reference can be borrowed at any given time. Borrow failure results in a trapping error. - pub fn get_mut(&mut self, key: u32) -> Result<&mut T, Error> { - if let Some(r) = self.map.get_mut(&key) { - r.downcast_mut::() - .ok_or_else(|| Error::badf().context("element is a different type")) - } else { - Err(Error::badf().context("key not in table")) - } + /// Get a mutable reference to a resource of a given type at a given index. + /// Only one such reference can be borrowed at any given time. + pub fn get_mut(&mut self, key: u32) -> Result<&mut T, Error> { + let entry = match self.0.get_mut().unwrap().map.get_mut(&key) { + Some(entry) => entry, + None => return Err(Error::badf().context("key not in table")), + }; + let entry = match Arc::get_mut(entry) { + Some(entry) => entry, + None => return Err(Error::badf().context("cannot mutably borrow shared file")), + }; + entry + .downcast_mut::() + .ok_or_else(|| Error::badf().context("element is a different type")) + } + + /// Remove a resource at a given index from the table. Returns the resource + /// if it was present. + pub fn delete(&self, key: u32) -> Option> { + self.0 + .write() + .unwrap() + .map + .remove(&key) + .map(|r| r.downcast::().unwrap()) } /// Remove a resource at a given index from the table. Returns the resource /// if it was present. - pub fn delete(&mut self, key: u32) -> Option> { - self.map.remove(&key) + pub fn renumber(&self, from: u32, to: u32) -> Result<(), Error> { + let map = &mut self.0.write().unwrap().map; + let from_entry = map.remove(&from).ok_or(Error::badf())?; + map.insert(to, from_entry); + Ok(()) } } diff --git a/crates/wasi-common/tokio/src/file.rs b/crates/wasi-common/tokio/src/file.rs index 030e60e5119c..114b5a2eae34 100644 --- a/crates/wasi-common/tokio/src/file.rs +++ b/crates/wasi-common/tokio/src/file.rs @@ -4,6 +4,7 @@ use io_extras::os::windows::{AsRawHandleOrSocket, RawHandleOrSocket}; #[cfg(not(windows))] use io_lifetimes::AsFd; use std::any::Any; +use std::borrow::Borrow; use std::io; use wasi_common::{ file::{Advice, FdFlags, FileType, Filestat, WasiFile}, @@ -98,78 +99,77 @@ macro_rules! wasi_file_impl { fn pollable(&self) -> Option { Some(self.0.as_fd()) } - #[cfg(windows)] fn pollable(&self) -> Option { Some(self.0.as_raw_handle_or_socket()) } - async fn datasync(&mut self) -> Result<(), Error> { + async fn datasync(&self) -> Result<(), Error> { block_on_dummy_executor(|| self.0.datasync()) } - async fn sync(&mut self) -> Result<(), Error> { + async fn sync(&self) -> Result<(), Error> { block_on_dummy_executor(|| self.0.sync()) } - async fn get_filetype(&mut self) -> Result { + async fn get_filetype(&self) -> Result { block_on_dummy_executor(|| self.0.get_filetype()) } - async fn get_fdflags(&mut self) -> Result { + async fn get_fdflags(&self) -> Result { block_on_dummy_executor(|| self.0.get_fdflags()) } async fn set_fdflags(&mut self, fdflags: FdFlags) -> Result<(), Error> { block_on_dummy_executor(|| self.0.set_fdflags(fdflags)) } - async fn get_filestat(&mut self) -> Result { + async fn get_filestat(&self) -> Result { block_on_dummy_executor(|| self.0.get_filestat()) } - async fn set_filestat_size(&mut self, size: u64) -> Result<(), Error> { + async fn set_filestat_size(&self, size: u64) -> Result<(), Error> { block_on_dummy_executor(move || self.0.set_filestat_size(size)) } - async fn advise(&mut self, offset: u64, len: u64, advice: Advice) -> Result<(), Error> { + async fn advise(&self, offset: u64, len: u64, advice: Advice) -> Result<(), Error> { block_on_dummy_executor(move || self.0.advise(offset, len, advice)) } - async fn allocate(&mut self, offset: u64, len: u64) -> Result<(), Error> { + async fn allocate(&self, offset: u64, len: u64) -> Result<(), Error> { block_on_dummy_executor(move || self.0.allocate(offset, len)) } async fn read_vectored<'a>( - &mut self, + &self, bufs: &mut [io::IoSliceMut<'a>], ) -> Result { block_on_dummy_executor(move || self.0.read_vectored(bufs)) } async fn read_vectored_at<'a>( - &mut self, + &self, bufs: &mut [io::IoSliceMut<'a>], offset: u64, ) -> Result { block_on_dummy_executor(move || self.0.read_vectored_at(bufs, offset)) } - async fn write_vectored<'a>(&mut self, bufs: &[io::IoSlice<'a>]) -> Result { + async fn write_vectored<'a>(&self, bufs: &[io::IoSlice<'a>]) -> Result { block_on_dummy_executor(move || self.0.write_vectored(bufs)) } async fn write_vectored_at<'a>( - &mut self, + &self, bufs: &[io::IoSlice<'a>], offset: u64, ) -> Result { block_on_dummy_executor(move || self.0.write_vectored_at(bufs, offset)) } - async fn seek(&mut self, pos: std::io::SeekFrom) -> Result { + async fn seek(&self, pos: std::io::SeekFrom) -> Result { block_on_dummy_executor(move || self.0.seek(pos)) } - async fn peek(&mut self, buf: &mut [u8]) -> Result { + async fn peek(&self, buf: &mut [u8]) -> Result { block_on_dummy_executor(move || self.0.peek(buf)) } async fn set_times( - &mut self, + &self, atime: Option, mtime: Option, ) -> Result<(), Error> { block_on_dummy_executor(move || self.0.set_times(atime, mtime)) } - async fn num_ready_bytes(&self) -> Result { - block_on_dummy_executor(|| self.0.num_ready_bytes()) + fn num_ready_bytes(&self) -> Result { + self.0.num_ready_bytes() } - fn isatty(&mut self) -> bool { + fn isatty(&self) -> bool { self.0.isatty() } @@ -182,7 +182,7 @@ macro_rules! wasi_file_impl { // lifetime of the AsyncFd. use std::os::unix::io::AsRawFd; use tokio::io::{unix::AsyncFd, Interest}; - let rawfd = self.0.as_fd().as_raw_fd(); + let rawfd = self.0.borrow().as_fd().as_raw_fd(); match AsyncFd::with_interest(rawfd, Interest::READABLE) { Ok(asyncfd) => { let _ = asyncfd.readable().await?; @@ -206,7 +206,7 @@ macro_rules! wasi_file_impl { // lifetime of the AsyncFd. use std::os::unix::io::AsRawFd; use tokio::io::{unix::AsyncFd, Interest}; - let rawfd = self.0.as_fd().as_raw_fd(); + let rawfd = self.0.borrow().as_fd().as_raw_fd(); match AsyncFd::with_interest(rawfd, Interest::WRITABLE) { Ok(asyncfd) => { let _ = asyncfd.writable().await?; @@ -221,7 +221,7 @@ macro_rules! wasi_file_impl { } } - async fn sock_accept(&mut self, fdflags: FdFlags) -> Result, Error> { + async fn sock_accept(&self, fdflags: FdFlags) -> Result, Error> { block_on_dummy_executor(|| self.0.sock_accept(fdflags)) } } @@ -229,7 +229,7 @@ macro_rules! wasi_file_impl { impl AsRawHandleOrSocket for $ty { #[inline] fn as_raw_handle_or_socket(&self) -> RawHandleOrSocket { - self.0.as_raw_handle_or_socket() + self.0.borrow().as_raw_handle_or_socket() } } }; diff --git a/crates/wasi-common/tokio/src/lib.rs b/crates/wasi-common/tokio/src/lib.rs index 577c6e2e1e38..1c7a1decb300 100644 --- a/crates/wasi-common/tokio/src/lib.rs +++ b/crates/wasi-common/tokio/src/lib.rs @@ -62,15 +62,15 @@ impl WasiCtxBuilder { } Ok(self) } - pub fn stdin(mut self, f: Box) -> Self { + pub fn stdin(self, f: Box) -> Self { self.0.set_stdin(f); self } - pub fn stdout(mut self, f: Box) -> Self { + pub fn stdout(self, f: Box) -> Self { self.0.set_stdout(f); self } - pub fn stderr(mut self, f: Box) -> Self { + pub fn stderr(self, f: Box) -> Self { self.0.set_stderr(f); self } @@ -87,7 +87,7 @@ impl WasiCtxBuilder { self.inherit_stdin().inherit_stdout().inherit_stderr() } pub fn preopened_dir( - mut self, + self, dir: cap_std::fs::Dir, guest_path: impl AsRef, ) -> Result { @@ -95,7 +95,7 @@ impl WasiCtxBuilder { self.0.push_preopened_dir(dir, guest_path)?; Ok(self) } - pub fn preopened_socket(mut self, fd: u32, socket: impl Into) -> Result { + pub fn preopened_socket(self, fd: u32, socket: impl Into) -> Result { let socket: Socket = socket.into(); let file: Box = socket.into(); diff --git a/crates/wasi-common/tokio/src/sched/unix.rs b/crates/wasi-common/tokio/src/sched/unix.rs index 5ca3b1200d09..4fd47d1cb248 100644 --- a/crates/wasi-common/tokio/src/sched/unix.rs +++ b/crates/wasi-common/tokio/src/sched/unix.rs @@ -63,7 +63,6 @@ pub async fn poll_oneoff<'a>(poll: &mut Poll<'a>) -> Result<(), Error> { f.complete( f.file .num_ready_bytes() - .await .map_err(|e| e.context("read num_ready_bytes"))?, RwEventFlags::empty(), ); diff --git a/crates/wasi-common/tokio/tests/poll_oneoff.rs b/crates/wasi-common/tokio/tests/poll_oneoff.rs index abaacef891fc..9ba85f6deeb5 100644 --- a/crates/wasi-common/tokio/tests/poll_oneoff.rs +++ b/crates/wasi-common/tokio/tests/poll_oneoff.rs @@ -20,7 +20,7 @@ async fn empty_file_readable() -> Result<(), Error> { let d = workspace.open_dir("d").context("open dir")?; let d = Dir::from_cap_std(d); - let mut f = d + let f = d .open_file(false, "f", OFlags::CREATE, false, true, FdFlags::empty()) .await .context("create writable file f")?; diff --git a/crates/wasi-nn/src/api.rs b/crates/wasi-nn/src/api.rs index 89fd46fbdc2f..2ad6e0edf94e 100644 --- a/crates/wasi-nn/src/api.rs +++ b/crates/wasi-nn/src/api.rs @@ -7,7 +7,7 @@ use thiserror::Error; use wiggle::GuestError; /// A [Backend] contains the necessary state to load [BackendGraph]s. -pub(crate) trait Backend: Send { +pub(crate) trait Backend: Send + Sync { fn name(&self) -> &str; fn load( &mut self, @@ -18,7 +18,7 @@ pub(crate) trait Backend: Send { /// A [BackendGraph] can create [BackendExecutionContext]s; this is the backing /// implementation for a [crate::witx::types::Graph]. -pub(crate) trait BackendGraph: Send { +pub(crate) trait BackendGraph: Send + Sync { fn init_execution_context(&mut self) -> Result, BackendError>; } diff --git a/crates/wasi-nn/src/openvino.rs b/crates/wasi-nn/src/openvino.rs index fff9bf7e5cf1..769beb3dad70 100644 --- a/crates/wasi-nn/src/openvino.rs +++ b/crates/wasi-nn/src/openvino.rs @@ -1,4 +1,5 @@ //! Implements the wasi-nn API. + use crate::api::{Backend, BackendError, BackendExecutionContext, BackendGraph}; use crate::witx::types::{ExecutionTarget, GraphBuilderArray, Tensor, TensorType}; use openvino::{InferenceError, Layout, Precision, SetupError, TensorDesc}; @@ -7,6 +8,9 @@ use std::sync::Arc; #[derive(Default)] pub(crate) struct OpenvinoBackend(Option); +unsafe impl Send for OpenvinoBackend {} +unsafe impl Sync for OpenvinoBackend {} + impl Backend for OpenvinoBackend { fn name(&self) -> &str { "openvino" @@ -65,6 +69,9 @@ impl Backend for OpenvinoBackend { struct OpenvinoGraph(Arc, openvino::ExecutableNetwork); +unsafe impl Send for OpenvinoGraph {} +unsafe impl Sync for OpenvinoGraph {} + impl BackendGraph for OpenvinoGraph { fn init_execution_context(&mut self) -> Result, BackendError> { let infer_request = self.1.create_infer_request()?; diff --git a/crates/wasi-threads/Cargo.toml b/crates/wasi-threads/Cargo.toml new file mode 100644 index 000000000000..bba6e21bd178 --- /dev/null +++ b/crates/wasi-threads/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "wasmtime-wasi-threads" +version.workspace = true +authors.workspace = true +description = "Wasmtime implementation of the wasi-threads API" +documentation = "https://docs.rs/wasmtime-wasi-nn" +license = "Apache-2.0 WITH LLVM-exception" +categories = ["wasm", "parallelism", "threads"] +keywords = ["webassembly", "wasm", "neural-network"] +repository = "https://github.com/bytecodealliance/wasmtime" +readme = "README.md" +edition.workspace = true + +[dependencies] +anyhow = { workspace = true } +log = { workspace = true } +rand = "0.8" +wasi-common = { workspace = true } +wasmtime = { workspace = true } +wasmtime-wasi = { workspace = true, features = ["exit"] } + +[badges] +maintenance = { status = "experimental" } diff --git a/crates/wasi-threads/README.md b/crates/wasi-threads/README.md new file mode 100644 index 000000000000..31478778e07c --- /dev/null +++ b/crates/wasi-threads/README.md @@ -0,0 +1,12 @@ +# wasmtime-wasi-threads + +Implement the `wasi-threads` [specification] in Wasmtime. + +[specification]: https://github.com/WebAssembly/wasi-threads + +> Note: this crate is experimental and not yet suitable for use in multi-tenant +> embeddings. As specified, a trap or WASI exit in one thread must end execution +> for all threads. Due to the complexity of stopping threads, however, this +> implementation currently exits the process entirely. This will work for some +> use cases (e.g., CLI usage) but not for embedders. This warning can be removed +> once a suitable mechanism is implemented that avoids exiting the process. diff --git a/crates/wasi-threads/src/lib.rs b/crates/wasi-threads/src/lib.rs new file mode 100644 index 000000000000..9116e4508e3e --- /dev/null +++ b/crates/wasi-threads/src/lib.rs @@ -0,0 +1,159 @@ +//! Implement [`wasi-threads`]. +//! +//! [`wasi-threads`]: https://github.com/WebAssembly/wasi-threads + +use anyhow::{anyhow, bail, Result}; +use rand::Rng; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::sync::Arc; +use std::thread; +use wasmtime::{Caller, Linker, Module, SharedMemory, Store, ValType}; +use wasmtime_wasi::maybe_exit_on_error; + +// This name is a function export designated by the wasi-threads specification: +// https://github.com/WebAssembly/wasi-threads/#detailed-design-discussion +const WASI_ENTRY_POINT: &str = "wasi_thread_start"; + +pub struct WasiThreadsCtx { + module: Module, + linker: Arc>, +} + +impl WasiThreadsCtx { + pub fn new(module: Module, linker: Arc>) -> Result { + if !has_wasi_entry_point(&module) { + bail!( + "failed to find wasi-threads entry point function: {}", + WASI_ENTRY_POINT + ); + } + Ok(Self { module, linker }) + } + + pub fn spawn(&self, host: T, thread_start_arg: i32) -> Result { + let module = self.module.clone(); + let linker = self.linker.clone(); + + // Start a Rust thread running a new instance of the current module. + let wasi_thread_id = random_thread_id(); + let builder = thread::Builder::new().name(format!("wasi-thread-{}", wasi_thread_id)); + builder.spawn(move || { + // Catch any panic failures in host code; e.g., if a WASI module + // were to crash, we want all threads to exit, not just this one. + let result = catch_unwind(AssertUnwindSafe(|| { + // Each new instance is created in its own store. + let mut store = Store::new(&module.engine(), host); + + // Ideally, we would have already checked much earlier (e.g., + // `new`) whether the module can be instantiated. Because + // `Linker::instantiate_pre` requires a `Store` and that is only + // available now. TODO: + // https://github.com/bytecodealliance/wasmtime/issues/5675. + let instance = linker.instantiate(&mut store, &module).expect(&format!( + "wasi-thread-{} exited unsuccessfully: failed to instantiate", + wasi_thread_id + )); + let thread_entry_point = instance + .get_typed_func::<(i32, i32), ()>(&mut store, WASI_ENTRY_POINT) + .unwrap(); + + // Start the thread's entry point. Any traps or calls to + // `proc_exit`, by specification, should end execution for all + // threads. This code uses `process::exit` to do so, which is what + // the user expects from the CLI but probably not in a Wasmtime + // embedding. + log::trace!( + "spawned thread id = {}; calling start function `{}` with: {}", + wasi_thread_id, + WASI_ENTRY_POINT, + thread_start_arg + ); + match thread_entry_point.call(&mut store, (wasi_thread_id, thread_start_arg)) { + Ok(_) => log::trace!("exiting thread id = {} normally", wasi_thread_id), + Err(e) => { + log::trace!("exiting thread id = {} due to error", wasi_thread_id); + let e = maybe_exit_on_error(e); + eprintln!("Error: {:?}", e); + std::process::exit(1); + } + } + })); + + if let Err(e) = result { + eprintln!("wasi-thread-{} panicked: {:?}", wasi_thread_id, e); + std::process::exit(1); + } + })?; + + Ok(wasi_thread_id) + } +} + +/// Helper for generating valid WASI thread IDs (TID). +/// +/// Callers of `wasi_thread_spawn` expect a TID >=0 to indicate a successful +/// spawning of the thread whereas a negative return value indicates an +/// failure to spawn. +fn random_thread_id() -> i32 { + let tid: u32 = rand::thread_rng().gen(); + (tid >> 1) as i32 +} + +/// Manually add the WASI `thread_spawn` function to the linker. +/// +/// It is unclear what namespace the `wasi-threads` proposal should live under: +/// it is not clear if it should be included in any of the `preview*` releases +/// so for the time being its module namespace is simply `"wasi"` (TODO). +pub fn add_to_linker( + linker: &mut wasmtime::Linker, + store: &wasmtime::Store, + module: &Module, + get_cx: impl Fn(&mut T) -> &WasiThreadsCtx + Send + Sync + Copy + 'static, +) -> anyhow::Result { + linker.func_wrap( + "wasi", + "thread_spawn", + move |mut caller: Caller<'_, T>, start_arg: i32| -> i32 { + log::trace!("new thread requested via `wasi::thread_spawn` call"); + let host = caller.data().clone(); + let ctx = get_cx(caller.data_mut()); + match ctx.spawn(host, start_arg) { + Ok(thread_id) => { + assert!(thread_id >= 0, "thread_id = {}", thread_id); + thread_id + } + Err(e) => { + log::error!("failed to spawn thread: {}", e); + -1 + } + } + }, + )?; + + // Find the shared memory import and satisfy it with a newly-created shared + // memory import. This currently does not handle multiple memories (TODO). + for import in module.imports() { + if let Some(m) = import.ty().memory() { + if m.is_shared() { + let mem = SharedMemory::new(module.engine(), m.clone())?; + linker.define(store, import.module(), import.name(), mem.clone())?; + return Ok(mem); + } + } + } + Err(anyhow!( + "unable to link a shared memory import to the module; a `wasi-threads` \ + module should import a single shared memory as \"memory\"" + )) +} + +fn has_wasi_entry_point(module: &Module) -> bool { + module + .get_export(WASI_ENTRY_POINT) + .and_then(|t| t.func().cloned()) + .and_then(|t| { + let params: Vec = t.params().collect(); + Some(params == [ValType::I32, ValType::I32] && t.results().len() == 0) + }) + .unwrap_or(false) +} diff --git a/crates/wasi/Cargo.toml b/crates/wasi/Cargo.toml index cca14f9597ce..f26ec10c1b9b 100644 --- a/crates/wasi/Cargo.toml +++ b/crates/wasi/Cargo.toml @@ -13,6 +13,7 @@ include = ["src/**/*", "README.md", "LICENSE", "build.rs"] build = "build.rs" [dependencies] +libc = "0.2.60" wasi-common = { workspace = true } wasi-cap-std-sync = { workspace = true, optional = true } wasi-tokio = { workspace = true, optional = true } @@ -24,3 +25,4 @@ anyhow = { workspace = true } default = ["sync"] sync = ["wasi-cap-std-sync"] tokio = ["wasi-tokio", "wasmtime/async", "wiggle/wasmtime_async"] +exit = [] diff --git a/crates/wasi/src/lib.rs b/crates/wasi/src/lib.rs index a86557769515..5227f4e8993d 100644 --- a/crates/wasi/src/lib.rs +++ b/crates/wasi/src/lib.rs @@ -82,3 +82,47 @@ pub mod snapshots { } } } + +/// Exit the process with a conventional OS error code as long as Wasmtime +/// understands the error. If the error is not an `I32Exit` or `Trap`, return +/// the error back to the caller for it to decide what to do. +/// +/// Note: this function is designed for usage where it is acceptable for +/// Wasmtime failures to terminate the parent process, such as in the Wasmtime +/// CLI; this would not be suitable for use in multi-tenant embeddings. +#[cfg(feature = "exit")] +pub fn maybe_exit_on_error(e: anyhow::Error) -> anyhow::Error { + use std::process; + use wasmtime::Trap; + + // If a specific WASI error code was requested then that's + // forwarded through to the process here without printing any + // extra error information. + if let Some(exit) = e.downcast_ref::() { + // Print the error message in the usual way. + // On Windows, exit status 3 indicates an abort (see below), + // so return 1 indicating a non-zero status to avoid ambiguity. + if cfg!(windows) && exit.0 >= 3 { + process::exit(1); + } + process::exit(exit.0); + } + + // If the program exited because of a trap, return an error code + // to the outside environment indicating a more severe problem + // than a simple failure. + if e.is::() { + eprintln!("Error: {:?}", e); + + if cfg!(unix) { + // On Unix, return the error code of an abort. + process::exit(128 + libc::SIGABRT); + } else if cfg!(windows) { + // On Windows, return 3. + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/abort?view=vs-2019 + process::exit(3); + } + } + + e +} diff --git a/scripts/publish.rs b/scripts/publish.rs index fc2c6c8bf499..df377254c2e8 100644 --- a/scripts/publish.rs +++ b/scripts/publish.rs @@ -64,8 +64,9 @@ const CRATES_TO_PUBLISH: &[&str] = &[ "wasi-tokio", // other misc wasmtime crates "wasmtime-wasi", - "wasmtime-wasi-nn", "wasmtime-wasi-crypto", + "wasmtime-wasi-nn", + "wasmtime-wasi-threads", "wasmtime-wast", "wasmtime-cli-flags", "wasmtime-cli", @@ -84,8 +85,9 @@ const PUBLIC_CRATES: &[&str] = &[ // patch releases. "wasmtime", "wasmtime-wasi", - "wasmtime-wasi-nn", "wasmtime-wasi-crypto", + "wasmtime-wasi-nn", + "wasmtime-wasi-threads", "wasmtime-cli", // all cranelift crates are considered "public" in that they can't // have breaking API changes in patch releases diff --git a/src/commands/run.rs b/src/commands/run.rs index 9b2ed8268a4e..4f5751c3befe 100644 --- a/src/commands/run.rs +++ b/src/commands/run.rs @@ -3,17 +3,17 @@ use anyhow::{anyhow, bail, Context as _, Result}; use clap::Parser; use once_cell::sync::Lazy; +use std::ffi::OsStr; +use std::path::{Component, Path, PathBuf}; use std::thread; use std::time::Duration; -use std::{ - ffi::OsStr, - path::{Component, Path, PathBuf}, - process, -}; -use wasmtime::{Engine, Func, Linker, Module, Store, Trap, Val, ValType}; +use wasmtime::{Engine, Func, Linker, Module, Store, Val, ValType}; use wasmtime_cli_flags::{CommonOptions, WasiModules}; +use wasmtime_wasi::maybe_exit_on_error; use wasmtime_wasi::sync::{ambient_authority, Dir, TcpListener, WasiCtxBuilder}; -use wasmtime_wasi::I32Exit; + +#[cfg(any(feature = "wasi-crypto", feature = "wasi-nn", feature = "wasi-threads"))] +use std::sync::Arc; #[cfg(feature = "wasi-nn")] use wasmtime_wasi_nn::WasiNnCtx; @@ -21,6 +21,9 @@ use wasmtime_wasi_nn::WasiNnCtx; #[cfg(feature = "wasi-crypto")] use wasmtime_wasi_crypto::WasiCryptoCtx; +#[cfg(feature = "wasi-threads")] +use wasmtime_wasi_threads::WasiThreadsCtx; + fn parse_module(s: &OsStr) -> anyhow::Result { // Do not accept wasmtime subcommand names as the module name match s.to_str() { @@ -164,13 +167,6 @@ impl RunCommand { config.epoch_interruption(true); } let engine = Engine::new(&config)?; - let mut store = Store::new(&engine, Host::default()); - - // If fuel has been configured, we want to add the configured - // fuel amount to this store. - if let Some(fuel) = self.common.fuel { - store.add_fuel(fuel)?; - } let preopen_sockets = self.compute_preopen_sockets()?; @@ -181,9 +177,15 @@ impl RunCommand { let mut linker = Linker::new(&engine); linker.allow_unknown_exports(self.allow_unknown_exports); + // Read the wasm module binary either as `*.wat` or a raw binary. + let module = self.load_module(linker.engine(), &self.module)?; + + let host = Host::default(); + let mut store = Store::new(&engine, host); populate_with_wasi( - &mut store, &mut linker, + &mut store, + module.clone(), preopen_dirs, &argv, &self.vars, @@ -192,6 +194,12 @@ impl RunCommand { preopen_sockets, )?; + // If fuel has been configured, we want to add the configured + // fuel amount to this store. + if let Some(fuel) = self.common.fuel { + store.add_fuel(fuel)?; + } + // Load the preload wasm modules. for (name, path) in self.preloads.iter() { // Read the wasm module binary either as `*.wat` or a raw binary @@ -207,43 +215,15 @@ impl RunCommand { // Load the main wasm module. match self - .load_main_module(&mut store, &mut linker) + .load_main_module(&mut store, &mut linker, module) .with_context(|| format!("failed to run main module `{}`", self.module.display())) { Ok(()) => (), Err(e) => { - // If a specific WASI error code was requested then that's - // forwarded through to the process here without printing any - // extra error information. - if let Some(exit) = e.downcast_ref::() { - // Print the error message in the usual way. - // On Windows, exit status 3 indicates an abort (see below), - // so return 1 indicating a non-zero status to avoid ambiguity. - if cfg!(windows) && exit.0 >= 3 { - process::exit(1); - } - process::exit(exit.0); - } - - // If the program exited because of a trap, return an error code - // to the outside environment indicating a more severe problem - // than a simple failure. - if e.is::() { - eprintln!("Error: {:?}", e); - - if cfg!(unix) { - // On Unix, return the error code of an abort. - process::exit(128 + libc::SIGABRT); - } else if cfg!(windows) { - // On Windows, return 3. - // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/abort?view=vs-2019 - process::exit(3); - } - } - - // Otherwise fall back on Rust's default error printing/return + // Exit the process if Wasmtime understands the error; + // otherwise, fall back on Rust's default error printing/return // code. - return Err(e); + return Err(maybe_exit_on_error(e)); } } @@ -309,7 +289,12 @@ impl RunCommand { result } - fn load_main_module(&self, store: &mut Store, linker: &mut Linker) -> Result<()> { + fn load_main_module( + &self, + store: &mut Store, + linker: &mut Linker, + module: Module, + ) -> Result<()> { if let Some(timeout) = self.wasm_timeout { store.set_epoch_deadline(1); let engine = store.engine().clone(); @@ -319,8 +304,6 @@ impl RunCommand { }); } - // Read the wasm module binary either as `*.wat` or a raw binary. - let module = self.load_module(linker.engine(), &self.module)?; // The main module might be allowed to have unknown imports, which // should be defined as traps: if self.trap_unknown_imports { @@ -432,19 +415,22 @@ impl RunCommand { } } -#[derive(Default)] +#[derive(Default, Clone)] struct Host { wasi: Option, - #[cfg(feature = "wasi-nn")] - wasi_nn: Option, #[cfg(feature = "wasi-crypto")] - wasi_crypto: Option, + wasi_crypto: Option>, + #[cfg(feature = "wasi-nn")] + wasi_nn: Option>, + #[cfg(feature = "wasi-threads")] + wasi_threads: Option>>, } /// Populates the given `Linker` with WASI APIs. fn populate_with_wasi( - store: &mut Store, linker: &mut Linker, + store: &mut Store, + module: Module, preopen_dirs: Vec<(String, Dir)>, argv: &[String], vars: &[(String, String)], @@ -478,6 +464,28 @@ fn populate_with_wasi( store.data_mut().wasi = Some(builder.build()); } + if wasi_modules.wasi_crypto { + #[cfg(not(feature = "wasi-crypto"))] + { + bail!("Cannot enable wasi-crypto when the binary is not compiled with this feature."); + } + #[cfg(feature = "wasi-crypto")] + { + wasmtime_wasi_crypto::add_to_linker(linker, |host| { + // This WASI proposal is currently not protected against + // concurrent access--i.e., when wasi-threads is actively + // spawning new threads, we cannot (yet) safely allow access and + // fail if more than one thread has `Arc`-references to the + // context. Once this proposal is updated (as wasi-common has + // been) to allow concurrent access, this `Arc::get_mut` + // limitation can be removed. + Arc::get_mut(host.wasi_crypto.as_mut().unwrap()) + .expect("wasi-crypto is not implemented with multi-threading support") + })?; + store.data_mut().wasi_crypto = Some(Arc::new(WasiCryptoCtx::new())); + } + } + if wasi_modules.wasi_nn { #[cfg(not(feature = "wasi-nn"))] { @@ -485,20 +493,33 @@ fn populate_with_wasi( } #[cfg(feature = "wasi-nn")] { - wasmtime_wasi_nn::add_to_linker(linker, |host| host.wasi_nn.as_mut().unwrap())?; - store.data_mut().wasi_nn = Some(WasiNnCtx::new()?); + wasmtime_wasi_nn::add_to_linker(linker, |host| { + // See documentation for wasi-crypto for why this is needed. + Arc::get_mut(host.wasi_nn.as_mut().unwrap()) + .expect("wasi-nn is not implemented with multi-threading support") + })?; + store.data_mut().wasi_nn = Some(Arc::new(WasiNnCtx::new()?)); } } - if wasi_modules.wasi_crypto { - #[cfg(not(feature = "wasi-crypto"))] + if wasi_modules.wasi_threads { + #[cfg(not(feature = "wasi-threads"))] { - bail!("Cannot enable wasi-crypto when the binary is not compiled with this feature."); + // Silence the unused warning for `module` as it is only used in the + // conditionally-compiled wasi-threads. + drop(&module); + + bail!("Cannot enable wasi-threads when the binary is not compiled with this feature."); } - #[cfg(feature = "wasi-crypto")] + #[cfg(feature = "wasi-threads")] { - wasmtime_wasi_crypto::add_to_linker(linker, |host| host.wasi_crypto.as_mut().unwrap())?; - store.data_mut().wasi_crypto = Some(WasiCryptoCtx::new()); + wasmtime_wasi_threads::add_to_linker(linker, store, &module, |host| { + host.wasi_threads.as_ref().unwrap() + })?; + store.data_mut().wasi_threads = Some(Arc::new(WasiThreadsCtx::new( + module, + Arc::new(linker.clone()), + )?)); } } diff --git a/tests/all/cli_tests.rs b/tests/all/cli_tests.rs index b8dbfce83139..7086ebea68dc 100644 --- a/tests/all/cli_tests.rs +++ b/tests/all/cli_tests.rs @@ -473,3 +473,28 @@ fn run_cwasm_from_stdin() -> Result<()> { } Ok(()) } + +#[cfg(feature = "wasi-threads")] +#[test] +fn run_threads() -> Result<()> { + let wasm = build_wasm("tests/all/cli_tests/threads.wat")?; + let stdout = run_wasmtime(&[ + "run", + "--wasi-modules", + "experimental-wasi-threads", + "--wasm-features", + "threads", + "--disable-cache", + wasm.path().to_str().unwrap(), + ])?; + + assert!( + stdout + == "Called _start\n\ + Running wasi_thread_start\n\ + Running wasi_thread_start\n\ + Running wasi_thread_start\n\ + Done\n" + ); + Ok(()) +} diff --git a/tests/all/cli_tests/threads.wat b/tests/all/cli_tests/threads.wat new file mode 100644 index 000000000000..d935289738c1 --- /dev/null +++ b/tests/all/cli_tests/threads.wat @@ -0,0 +1,62 @@ +(module + ;; As we have discussed, it makes sense to make the shared memory an import + ;; so that all + (import "" "memory" (memory $shmem 1 1 shared)) + (import "wasi_snapshot_preview1" "fd_write" + (func $__wasi_fd_write (param i32 i32 i32 i32) (result i32))) + (import "wasi_snapshot_preview1" "proc_exit" + (func $__wasi_proc_exit (param i32))) + (import "wasi" "thread_spawn" + (func $__wasi_thread_spawn (param i32) (result i32))) + + (func (export "_start") + (local $i i32) + + ;; Print "Called _start". + (call $print (i32.const 32) (i32.const 14)) + + ;; Print "Running wasi_thread_start" in several threads. + (drop (call $__wasi_thread_spawn (i32.const 0))) + (drop (call $__wasi_thread_spawn (i32.const 0))) + (drop (call $__wasi_thread_spawn (i32.const 0))) + + ;; Wait for all the threads to notify us that they are done. + (loop $again + ;; Retrieve the i32 at address 128, compare it to -1 (it should always + ;; fail) and load it atomically to check if all three threads are + ;; complete. This wait is for 1ms or until notified, whichever is first. + (drop (memory.atomic.wait32 (i32.const 128) (i32.const -1) (i64.const 1000000))) + (br_if $again (i32.lt_s (i32.atomic.load (i32.const 128)) (i32.const 3))) + ) + + ;; Print "Done". + (call $print (i32.const 64) (i32.const 5)) + ) + + ;; A threads-enabled module must export this spec-designated entry point. + (func (export "wasi_thread_start") (param $tid i32) (param $start_arg i32) + (call $print (i32.const 96) (i32.const 26)) + ;; After printing, we atomically increment the value at address 128 and then + ;; wake up the main thread's join loop. + (drop (i32.atomic.rmw.add (i32.const 128) (i32.const 1))) + (drop (memory.atomic.notify (i32.const 128) (i32.const 1))) + ) + + ;; A helper function for printing ptr-len strings. + (func $print (param $ptr i32) (param $len i32) + (i32.store (i32.const 8) (local.get $len)) + (i32.store (i32.const 4) (local.get $ptr)) + (drop (call $__wasi_fd_write + (i32.const 1) + (i32.const 4) + (i32.const 1) + (i32.const 0))) + ) + + ;; We still need to export the shared memory for Wiggle's sake. + (export "memory" (memory $shmem)) + + (data (i32.const 32) "Called _start\0a") + (data (i32.const 64) "Done\0a") + (data (i32.const 96) "Running wasi_thread_start\0a") +)