diff --git a/examples/hello_c/hello.c b/examples/hello_c/hello.c index 6a1e4b4..5364273 100644 --- a/examples/hello_c/hello.c +++ b/examples/hello_c/hello.c @@ -4,23 +4,37 @@ #include #include #include +#include #define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern "C" #else #include #include #include +#include #define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern #endif +// === +// Functions for the protocol + PROTOCOL_FUNCTION void wasm_minimal_protocol_send_result_to_host(const uint8_t *ptr, size_t len); PROTOCOL_FUNCTION void wasm_minimal_protocol_write_args_to_buffer(uint8_t *ptr); +EMSCRIPTEN_KEEPALIVE void wasm_minimal_protocol_free_byte_buffer(uint8_t *ptr, + size_t len) { + free(ptr); +} + +// === + EMSCRIPTEN_KEEPALIVE int32_t hello(void) { - const char message[] = "Hello world !"; - wasm_minimal_protocol_send_result_to_host((uint8_t *)message, - sizeof(message) - 1); + const char static_message[] = "Hello world !"; + const size_t length = sizeof(static_message); + char *message = malloc(length); + memcpy((void *)message, (void *)static_message, length); + wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1); return 0; } @@ -36,7 +50,6 @@ int32_t double_it(size_t arg_len) { alloc_result[arg_len + i] = alloc_result[i]; } wasm_minimal_protocol_send_result_to_host(alloc_result, result_len); - free(alloc_result); return 0; } @@ -66,7 +79,6 @@ int32_t concatenate(size_t arg1_len, size_t arg2_len) { wasm_minimal_protocol_send_result_to_host(result, total_len + 1); - free(result); free(args); return 0; } @@ -102,24 +114,27 @@ int32_t shuffle(size_t arg1_len, size_t arg2_len, size_t arg3_len) { wasm_minimal_protocol_send_result_to_host(result, result_len); - free(result); free(args); return 0; } EMSCRIPTEN_KEEPALIVE int32_t returns_ok() { - const char message[] = "This is an `Ok`"; - wasm_minimal_protocol_send_result_to_host((uint8_t *)message, - sizeof(message) - 1); + const char static_message[] = "This is an `Ok`"; + const size_t length = sizeof(static_message); + char *message = malloc(length); + memcpy((void *)message, (void *)static_message, length); + wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1); return 0; } EMSCRIPTEN_KEEPALIVE int32_t returns_err() { - const char message[] = "This is an `Err`"; - wasm_minimal_protocol_send_result_to_host((uint8_t *)message, - sizeof(message) - 1); + const char static_message[] = "This is an `Err`"; + const size_t length = sizeof(static_message); + char *message = malloc(length); + memcpy((void *)message, (void *)static_message, length); + wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1); return 1; } diff --git a/examples/hello_zig/hello.zig b/examples/hello_zig/hello.zig index cd92a95..b6bfdf9 100644 --- a/examples/hello_zig/hello.zig +++ b/examples/hello_zig/hello.zig @@ -1,23 +1,36 @@ const std = @import("std"); const allocator = std.heap.page_allocator; +// === +// Functions for the protocol + extern "typst_env" fn wasm_minimal_protocol_send_result_to_host(ptr: [*]const u8, len: usize) void; extern "typst_env" fn wasm_minimal_protocol_write_args_to_buffer(ptr: [*]u8) void; +export fn wasm_minimal_protocol_free_byte_buffer(ptr: [*]u8, len: usize) void { + var slice: []u8 = undefined; + slice.ptr = ptr; + slice.len = len; + allocator.free(slice); +} + +// === + export fn hello() i32 { const message = "Hello world !"; - wasm_minimal_protocol_send_result_to_host(message.ptr, message.len); + var result = allocator.alloc(u8, message.len) catch return 1; + @memcpy(result, message); + wasm_minimal_protocol_send_result_to_host(result.ptr, result.len); return 0; } export fn double_it(arg1_len: usize) i32 { - var alloc_result = allocator.alloc(u8, arg1_len * 2) catch return 1; - defer allocator.free(alloc_result); - wasm_minimal_protocol_write_args_to_buffer(alloc_result.ptr); + var result = allocator.alloc(u8, arg1_len * 2) catch return 1; + wasm_minimal_protocol_write_args_to_buffer(result.ptr); for (0..arg1_len) |i| { - alloc_result[i + arg1_len] = alloc_result[i]; + result[i + arg1_len] = result[i]; } - wasm_minimal_protocol_send_result_to_host(alloc_result.ptr, alloc_result.len); + wasm_minimal_protocol_send_result_to_host(result.ptr, result.len); return 0; } @@ -27,7 +40,6 @@ export fn concatenate(arg1_len: usize, arg2_len: usize) i32 { wasm_minimal_protocol_write_args_to_buffer(args.ptr); var result = allocator.alloc(u8, arg1_len + arg2_len + 1) catch return 1; - defer allocator.free(result); for (0..arg1_len) |i| { result[i] = args[i]; } @@ -49,27 +61,30 @@ export fn shuffle(arg1_len: usize, arg2_len: usize, arg3_len: usize) i32 { var arg2 = args[arg1_len .. arg1_len + arg2_len]; var arg3 = args[arg1_len + arg2_len .. args.len]; - var result: std.ArrayList(u8) = std.ArrayList(u8).initCapacity(allocator, args_len + 2) catch return 1; - defer result.deinit(); - result.appendSlice(arg3) catch return 1; - result.append('-') catch return 1; - result.appendSlice(arg1) catch return 1; - result.append('-') catch return 1; - result.appendSlice(arg2) catch return 1; + var result = allocator.alloc(u8, arg1_len + arg2_len + arg3_len + 2) catch return 1; + @memcpy(result[0..arg3.len], arg3); + result[arg3.len] = '-'; + @memcpy(result[arg3.len + 1 ..][0..arg1.len], arg1); + result[arg3.len + arg1.len + 1] = '-'; + @memcpy(result[arg3.len + arg1.len + 2 ..][0..arg2.len], arg2); - wasm_minimal_protocol_send_result_to_host(result.items.ptr, result.items.len); + wasm_minimal_protocol_send_result_to_host(result.ptr, result.len); return 0; } export fn returns_ok() i32 { const message = "This is an `Ok`"; - wasm_minimal_protocol_send_result_to_host(message.ptr, message.len); + var result = allocator.alloc(u8, message.len) catch return 1; + @memcpy(result, message); + wasm_minimal_protocol_send_result_to_host(result.ptr, result.len); return 0; } export fn returns_err() i32 { const message = "This is an `Err`"; - wasm_minimal_protocol_send_result_to_host(message.ptr, message.len); + var result = allocator.alloc(u8, message.len) catch return 1; + @memcpy(result, message); + wasm_minimal_protocol_send_result_to_host(result.ptr, result.len); return 1; } diff --git a/examples/host-wasmi/src/lib.rs b/examples/host-wasmi/src/lib.rs index e141519..c41c969 100644 --- a/examples/host-wasmi/src/lib.rs +++ b/examples/host-wasmi/src/lib.rs @@ -1,16 +1,57 @@ -use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Module, Value}; +use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Memory, Module, Value}; type Store = wasmi::Store; +/// Reference to a slice of memory returned after +/// [calling a wasm function](PluginInstance::call). +/// +/// # Drop +/// On [`Drop`], this will free the slice of memory inside the plugin. +/// +/// As such, this structure mutably borrows the [`PluginInstance`], which prevents +/// another function from being called. +pub struct ReturnedData<'a> { + memory: Memory, + ptr: u32, + len: u32, + free_function: &'a Function, + context_mut: &'a mut Store, +} + +impl<'a> ReturnedData<'a> { + /// Get a reference to the returned slice of data. + /// + /// # Panic + /// This may panic if the function returned an invalid `(ptr, len)` pair. + pub fn get(&self) -> &[u8] { + &self.memory.data(&*self.context_mut)[self.ptr as usize..(self.ptr + self.len) as usize] + } +} + +impl Drop for ReturnedData<'_> { + fn drop(&mut self) { + self.free_function + .call( + &mut *self.context_mut, + &[Value::I32(self.ptr as _), Value::I32(self.len as _)], + &mut [], + ) + .unwrap(); + } +} + #[derive(Debug, Clone)] struct PersistentData { - result_data: Vec, + result_ptr: u32, + result_len: u32, arg_buffer: Vec, } #[derive(Debug)] pub struct PluginInstance { store: Store, + memory: Memory, + free_function: Function, functions: Vec<(String, Function)>, } @@ -18,8 +59,9 @@ impl PluginInstance { pub fn new_from_bytes(bytes: impl AsRef<[u8]>) -> Result { let engine = Engine::default(); let data = PersistentData { - result_data: Vec::new(), arg_buffer: Vec::new(), + result_ptr: 0, + result_len: 0, }; let mut store = Store::new(&engine, data); @@ -32,11 +74,8 @@ impl PluginInstance { "typst_env", "wasm_minimal_protocol_send_result_to_host", move |mut caller: Caller, ptr: u32, len: u32| { - let memory = caller.get_export("memory").unwrap().into_memory().unwrap(); - let mut buffer = std::mem::take(&mut caller.data_mut().result_data); - buffer.resize(len as usize, 0); - memory.read(&caller, ptr as _, &mut buffer).unwrap(); - caller.data_mut().result_data = buffer; + caller.data_mut().result_ptr = ptr; + caller.data_mut().result_len = len; }, ) .unwrap() @@ -51,54 +90,44 @@ impl PluginInstance { }, ) .unwrap() - // hack to accept wasi file - // https://github.com/near/wasi-stub is preferred - /* - .func_wrap( - "wasi_snapshot_preview1", - "fd_write", - |_: i32, _: i32, _: i32, _: i32| 0i32, - ) - .unwrap() - .func_wrap( - "wasi_snapshot_preview1", - "environ_get", - |_: i32, _: i32| 0i32, - ) - .unwrap() - .func_wrap( - "wasi_snapshot_preview1", - "environ_sizes_get", - |_: i32, _: i32| 0i32, - ) - .unwrap() - .func_wrap( - "wasi_snapshot_preview1", - "proc_exit", - |_: i32| {}, - ) - .unwrap() - */ .instantiate(&mut store, &module) .map_err(|e| format!("{e}"))? .start(&mut store) .map_err(|e| format!("{e}"))?; + let mut free_function = None; let functions = instance .exports(&store) .filter_map(|e| { let name = e.name().to_owned(); - e.into_func().map(|func| (name, func)) + + e.into_func().map(|func| { + if name == "wasm_minimal_protocol_free_byte_buffer" { + free_function = Some(func); + } + (name, func) + }) }) .collect::>(); - Ok(Self { store, functions }) + let free_function = free_function.unwrap(); + let memory = instance + .get_export(&store, "memory") + .unwrap() + .into_memory() + .unwrap(); + Ok(Self { + store, + memory, + free_function, + functions, + }) } fn write(&mut self, args: &[&[u8]]) { self.store.data_mut().arg_buffer = args.concat(); } - pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result, String> { + pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result { self.write(args); let (_, function) = self @@ -122,11 +151,19 @@ impl PluginInstance { code.first().cloned().unwrap_or(Value::I32(3)) // if the function returns nothing }; - let s = std::mem::take(&mut self.store.data_mut().result_data); + let (ptr, len) = (self.store.data().result_ptr, self.store.data().result_len); + + let result = ReturnedData { + memory: self.memory, + ptr, + len, + free_function: &self.free_function, + context_mut: &mut self.store, + }; match code { - Value::I32(0) => Ok(s), - Value::I32(1) => Err(match String::from_utf8(s) { + Value::I32(0) => Ok(result), + Value::I32(1) => Err(match std::str::from_utf8(result.get()) { Ok(err) => format!("plugin errored with: '{}'", err,), Err(_) => String::from("plugin errored and did not return valid UTF-8"), }), diff --git a/examples/test-runner/src/main.rs b/examples/test-runner/src/main.rs index cd8467e..a14fe91 100644 --- a/examples/test-runner/src/main.rs +++ b/examples/test-runner/src/main.rs @@ -2,9 +2,8 @@ // you need to build the hello example first use anyhow::Result; -use std::process::Command; - use host_wasmi::PluginInstance; +use std::process::Command; #[cfg(not(feature = "wasi"))] mod consts { @@ -118,7 +117,7 @@ fn main() -> Result<()> { return Ok(()); } }; - match String::from_utf8(result) { + match std::str::from_utf8(result.get()) { Ok(s) => println!("{s}"), Err(_) => panic!("Error: function call '{function}' did not return UTF-8"), } @@ -141,7 +140,7 @@ fn main() -> Result<()> { continue; } }; - match String::from_utf8(result) { + match std::str::from_utf8(result.get()) { Ok(s) => println!("{s}"), Err(_) => panic!("Error: function call '{function}' did not return UTF-8"), } diff --git a/protocol.md b/protocol.md index 6c34868..e817fd9 100644 --- a/protocol.md +++ b/protocol.md @@ -18,23 +18,33 @@ Valid plugins need to import two functions (that will be provided by the runtime Write the arguments for the current function into the buffer pointed at by `ptr`. - Each function for the protocol receives lengths as its arguments (see [Exported functions](#exported-functions)). The capacity of the buffer pointed at by `ptr` should be at least the sum of all those lengths. + Each function for the protocol receives lengths as its arguments (see [User-defined functions](#user-defined-functions)). The capacity of the buffer pointed at by `ptr` should be at least the sum of all those lengths. - `(import "typst_env" "wasm_minimal_protocol_send_result_to_host" (func (param i32 i32)))` The first parameter is a pointer to a buffer (`ptr`), the second is the length of the buffer (`len`). - Reads `len` bytes pointed at by `ptr` into host memory. The memory pointed at by `ptr` can be freed immediately after this function returns. + Send `len` and `ptr` to host memory. The buffer must not be freed by the end of the function: it will be freed by the runtime by calling [`wasm_minimal_protocol_send_result_to_host`](#exports). - If the message should be interpreted as an error message (see [Exported functions](#exported-functions)), it should be encoded as UTF-8. + If the message should be interpreted as an error message (see [User-defined functions](#user-defined-functions)), it should be encoded as UTF-8. -# Exported functions + ### Note + + If [`wasm_minimal_protocol_send_result_to_host`](#exports) calls `free` (or a similar routine), be careful that the buffer does not point to static memory. + +# Exports + +Valid plugins need to export a function named `wasm_minimal_protocol_send_result_to_host`, that has signature `func (param i32 i32)`. + +This function will be used by the runtime to free the block of memory returned by a [user-defined](#user-defined-functions) function. + +# User-defined functions To conform to the protocol, an exported function should: -- Take `n` arguments `a₁`, `a₂`, ..., `aₙ` of type `u32` (interpreted as lengths, so `usize/size_t` may be preferable), and return one `i32`. +- Take `n` arguments `a₁`, `a₂`, ..., `aₙ` of type `u32` (interpreted as lengths, so `usize/size_t` may be preferable), and return one `i32`. We will call the return `return_code`. - The function should first allocate a buffer `buf` of length `a₁ + a₂ + ⋯ + aₙ`, and call `wasm_minimal_protocol_write_args_to_buffer(buf.ptr)`. - The `a₁` first bytes of the buffer constitute the first argument, the `a₂` next bytes the second argument, and so on. - Before returning, the function should call `wasm_minimal_protocol_send_result_to_host` to send its result back to the host. -- To signal success, the function should return `0`. -- To signal an error, the function should return `1`. The written buffer is then interpreted as an error message. +- To signal success, `return_code` must be `0`. +- To signal an error, `return_code` must be `1`. The sent buffer is then interpreted as an error message, and must be encoded as UTF-8. diff --git a/src/lib.rs b/src/lib.rs index 38c96c4..1f7561b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,9 +8,10 @@ //! ``` //! use wasm_minimal_protocol::wasm_func; //! +//! #[cfg(target_arch = "wasm32")] //! wasm_minimal_protocol::initiate_protocol!(); //! -//! #[wasm_func] +//! #[cfg_attr(target_arch = "wasm32", wasm_func)] //! fn concatenate(arg1: &[u8], arg2: &[u8]) -> Vec { //! [arg1, arg2].concat() //! } @@ -37,6 +38,16 @@ pub fn initiate_protocol(stream: TokenStream) -> TokenStream { #[cfg(not(target_arch = "wasm32"))] compile_error!("Error: this protocol may only be used when compiling to wasm architectures"); + /// Safety: `data` and `len` should form a `Box`-allocated slice together, + /// ready to be dropped. + #[export_name = "wasm_minimal_protocol_free_byte_buffer"] + pub unsafe extern "C" fn __free_byte_buffer(data: u32, len: u32) { + let data = data as usize as *mut u8; + let len = len as usize; + let ptr_slice = ::std::ptr::slice_from_raw_parts_mut(data, len); + drop(::std::boxed::Box::from_raw(ptr_slice)); + } + #[link(wasm_import_module = "typst_env")] extern "C" { #[link_name = "wasm_minimal_protocol_send_result_to_host"] @@ -88,19 +99,20 @@ pub fn initiate_protocol(stream: TokenStream) -> TokenStream { /// ``` /// use wasm_minimal_protocol::wasm_func; /// +/// #[cfg(target_arch = "wasm32")] /// wasm_minimal_protocol::initiate_protocol!(); /// -/// #[wasm_func] +/// #[cfg_attr(target_arch = "wasm32", wasm_func)] /// fn function_one() -> Vec { /// Vec::new() /// } /// -/// #[wasm_func] +/// #[cfg_attr(target_arch = "wasm32", wasm_func)] /// fn function_two(arg1: &[u8], arg2: &[u8]) -> Result, i32> { /// Ok(b"Normal message".to_vec()) /// } /// -/// #[wasm_func] +/// #[cfg_attr(target_arch = "wasm32", wasm_func)] /// fn function_three(arg1: &[u8]) -> Result, String> { /// Err(String::from("Error message")) /// } @@ -218,17 +230,18 @@ pub fn wasm_func(_: TokenStream, item: TokenStream) -> TokenStream { } else { result.extend(quote!( #[export_name = #export_name] - #vis_marker fn #inner_name(#(#p_idx: usize),*) -> i32 { + #vis_marker extern "C" fn #inner_name(#(#p_idx: usize),*) -> i32 { #get_unsplit_params #set_args let result = __BytesOrResultBytes::convert(#name(#(#p),*)); let (message, code) = match result { - Ok(s) => (s, 0), - Err(err) => (err.to_string().into_bytes(), 1), + Ok(s) => (s.into_boxed_slice(), 0), + Err(err) => (err.to_string().into_bytes().into_boxed_slice(), 1), }; unsafe { __send_result_to_host(message.as_ptr(), message.len()); } - code // indicates everything was successful + ::std::mem::forget(message); + code } )) }