From ae2bb778144db55a3d246c3e4f2936b854c14706 Mon Sep 17 00:00:00 2001 From: Sanandan Sashikumar Date: Mon, 20 Jan 2025 21:27:16 +0000 Subject: [PATCH 1/5] fix --- Cargo.toml | 1 + src/thread.rs | 118 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index ad49f56f..32a06cff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ members = [ ] [features] +default = ["luau"] lua54 = ["ffi/lua54"] lua53 = ["ffi/lua53"] lua52 = ["ffi/lua52"] diff --git a/src/thread.rs b/src/thread.rs index 14e93bb4..fc2bde01 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -8,6 +8,7 @@ use crate::traits::{FromLuaMulti, IntoLuaMulti}; use crate::types::{LuaType, ValueRef}; use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard}; +use crate::IntoLua; #[cfg(not(feature = "luau"))] use crate::{ hook::{Debug, HookTriggers}, @@ -173,6 +174,29 @@ impl Thread { } } + pub fn resume_error(&self, args: impl IntoLua) -> Result<()> { + let lua = self.0.lua.lock(); + + match self.status_inner(&lua) { + ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_) => {} + _ => return Err(Error::CoroutineUnresumable), + }; + + let state = lua.state(); + let thread_state = self.state(); + unsafe { + let _sg = StackGuard::new(state); + let _thread_sg = StackGuard::new(thread_state); + + check_stack(state, 1)?; + args.push_into_stack(&lua)?; + ffi::lua_xmove(state, thread_state, 1); + self.resumeerror_inner(&lua)?; + + Ok(()) + } + } + /// Resumes execution of this thread. /// /// It's similar to `resume()` but leaves `nresults` values on the thread stack. @@ -196,6 +220,28 @@ impl Thread { } } + /// Resumes execution of this thread. + /// + /// It's similar to `resume()` but leaves `nresults` values on the thread stack. + unsafe fn resumeerror_inner(&self, lua: &RawLua) -> Result { + let state = lua.state(); + let thread_state = self.state(); + let ret = ffi::luau::lua_resumeerror(thread_state, state); + match ret { + ffi::LUA_OK => Ok(ThreadStatusInner::Finished), + ffi::LUA_YIELD => Ok(ThreadStatusInner::Yielded(0)), + ffi::LUA_ERRMEM => { + // Don't call error handler for memory errors + Err(pop_error(thread_state, ret)) + } + _ => { + check_stack(state, 3)?; + protect_lua!(state, 0, 1, |state| error_traceback_thread(state, thread_state))?; + Err(pop_error(state, ret)) + } + } + } + /// Gets the status of the thread. pub fn status(&self) -> ThreadStatus { match self.status_inner(&self.0.lua.lock()) { @@ -585,3 +631,75 @@ mod assertions { #[cfg(all(feature = "async", feature = "send"))] static_assertions::assert_impl_all!(AsyncThread<()>: Send, Sync); } + +#[cfg(test)] +mod resumeerror_test { + #[test] + fn test_resumeerror() { + // Create tokio runtime and use spawn_local + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .worker_threads(10) + .build() + .unwrap(); + + let local = tokio::task::LocalSet::new(); + + local.block_on(&rt, async { + use crate::{Function, Lua, Thread, Value}; + + let lua = Lua::new(); + + let thread: Function = lua + .load( + r#" + local luacall = ... + local function callback(...) + print("AM HERE") + luacall(coroutine.running(), ...) + return coroutine.yield() + end + + return callback + "#, + ) + .call( + lua.create_function(|lua, th: Thread| { + tokio::task::spawn_local(async move { + println!("Thread: {:?}, {:?}", th, th.status()); + th.resume_error("An error here".to_string()).unwrap(); + tokio::task::yield_now().await; + }); + Ok(()) + }) + .unwrap(), + ) + .unwrap(); + + let thread_b: Thread = lua + .load( + r#" + local a = ... + return coroutine.create(function (...) + local b = ... + assert(b == 1) + local ok, result = pcall(a) + assert(not ok) + print("Done with: ", ok, result) + return result + end) + "#, + ) + .call(thread.clone()) + .unwrap(); + + println!("{:?}", thread_b.resume::(1)); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + println!("{:?}", thread_b.status()); + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + }); + } +} From 74928989ea33838e64060f5f756b137880d21055 Mon Sep 17 00:00:00 2001 From: Sanandan Sashikumar Date: Mon, 20 Jan 2025 21:43:14 +0000 Subject: [PATCH 2/5] fix --- src/thread.rs | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/thread.rs b/src/thread.rs index fc2bde01..ea0a4eda 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -174,9 +174,11 @@ impl Thread { } } - pub fn resume_error(&self, args: impl IntoLua) -> Result<()> { + pub fn resume_error(&self, args: impl IntoLua) -> Result + where + R: FromLuaMulti, + { let lua = self.0.lua.lock(); - match self.status_inner(&lua) { ThreadStatusInner::New(_) | ThreadStatusInner::Yielded(_) => {} _ => return Err(Error::CoroutineUnresumable), @@ -191,9 +193,11 @@ impl Thread { check_stack(state, 1)?; args.push_into_stack(&lua)?; ffi::lua_xmove(state, thread_state, 1); - self.resumeerror_inner(&lua)?; - Ok(()) + let (_, nresults) = self.resumeerror_inner(&lua)?; + check_stack(state, nresults + 1)?; + ffi::lua_xmove(thread_state, state, nresults); + R::from_stack_multi(nresults, &lua) } } @@ -223,13 +227,19 @@ impl Thread { /// Resumes execution of this thread. /// /// It's similar to `resume()` but leaves `nresults` values on the thread stack. - unsafe fn resumeerror_inner(&self, lua: &RawLua) -> Result { + unsafe fn resumeerror_inner(&self, lua: &RawLua) -> Result<(ThreadStatusInner, c_int)> { let state = lua.state(); let thread_state = self.state(); + let mut nresults = 0; let ret = ffi::luau::lua_resumeerror(thread_state, state); + + if ret == ffi::LUA_OK || ret == ffi::LUA_YIELD { + nresults = ffi::lua_gettop(thread_state); + } + match ret { - ffi::LUA_OK => Ok(ThreadStatusInner::Finished), - ffi::LUA_YIELD => Ok(ThreadStatusInner::Yielded(0)), + ffi::LUA_OK => Ok((ThreadStatusInner::Finished, nresults)), + ffi::LUA_YIELD => Ok((ThreadStatusInner::Yielded(0), nresults)), ffi::LUA_ERRMEM => { // Don't call error handler for memory errors Err(pop_error(thread_state, ret)) @@ -667,7 +677,7 @@ mod resumeerror_test { lua.create_function(|lua, th: Thread| { tokio::task::spawn_local(async move { println!("Thread: {:?}, {:?}", th, th.status()); - th.resume_error("An error here".to_string()).unwrap(); + th.resume_error::<()>("An error here".to_string()).unwrap(); tokio::task::yield_now().await; }); Ok(()) From 04c6e7342276730af2445cc6bf7d5343e05e2ab3 Mon Sep 17 00:00:00 2001 From: Sanandan Sashikumar Date: Mon, 20 Jan 2025 21:48:50 +0000 Subject: [PATCH 3/5] . --- src/thread.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/thread.rs b/src/thread.rs index ea0a4eda..66691f98 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -188,7 +188,7 @@ impl Thread { let thread_state = self.state(); unsafe { let _sg = StackGuard::new(state); - let _thread_sg = StackGuard::new(thread_state); + let _thread_sg = StackGuard::with_top(thread_state, 0); check_stack(state, 1)?; args.push_into_stack(&lua)?; From f861f605b00fc89016cc17899c9488e0ba998fb9 Mon Sep 17 00:00:00 2001 From: Sanandan Sashikumar Date: Mon, 20 Jan 2025 21:50:28 +0000 Subject: [PATCH 4/5] fix --- Cargo.toml | 1 - src/thread.rs | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 32a06cff..ad49f56f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ members = [ ] [features] -default = ["luau"] lua54 = ["ffi/lua54"] lua53 = ["ffi/lua53"] lua52 = ["ffi/lua52"] diff --git a/src/thread.rs b/src/thread.rs index 66691f98..17e3c256 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -174,6 +174,8 @@ impl Thread { } } + #[cfg(feature = "luau")] + /// Resumes a thread with an error. pub fn resume_error(&self, args: impl IntoLua) -> Result where R: FromLuaMulti, @@ -224,9 +226,10 @@ impl Thread { } } - /// Resumes execution of this thread. + #[cfg(feature = "luau")] + /// Resumes execution of this thread with an error. /// - /// It's similar to `resume()` but leaves `nresults` values on the thread stack. + /// It's similar to `resume_error()` but leaves `nresults` values on the thread stack. unsafe fn resumeerror_inner(&self, lua: &RawLua) -> Result<(ThreadStatusInner, c_int)> { let state = lua.state(); let thread_state = self.state(); @@ -674,7 +677,7 @@ mod resumeerror_test { "#, ) .call( - lua.create_function(|lua, th: Thread| { + lua.create_function(|_lua, th: Thread| { tokio::task::spawn_local(async move { println!("Thread: {:?}, {:?}", th, th.status()); th.resume_error::<()>("An error here".to_string()).unwrap(); From 0e84d814918dbf07dc0bcdf775dacea6709a96bc Mon Sep 17 00:00:00 2001 From: Sanandan Sashikumar Date: Tue, 28 Jan 2025 22:56:20 +0000 Subject: [PATCH 5/5] move test to test folder --- src/thread.rs | 77 +++---------------------------------------------- tests/thread.rs | 72 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 73 deletions(-) diff --git a/src/thread.rs b/src/thread.rs index 17e3c256..5c174415 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -175,7 +175,10 @@ impl Thread { } #[cfg(feature = "luau")] - /// Resumes a thread with an error. + /// Resumes a thread with an error. This is useful when developing + /// custom async schedulers' etc. with mlua as it allows bubbling up + /// errors from the yielded thread upwards with working pcall out of + /// the box. pub fn resume_error(&self, args: impl IntoLua) -> Result where R: FromLuaMulti, @@ -644,75 +647,3 @@ mod assertions { #[cfg(all(feature = "async", feature = "send"))] static_assertions::assert_impl_all!(AsyncThread<()>: Send, Sync); } - -#[cfg(test)] -mod resumeerror_test { - #[test] - fn test_resumeerror() { - // Create tokio runtime and use spawn_local - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .worker_threads(10) - .build() - .unwrap(); - - let local = tokio::task::LocalSet::new(); - - local.block_on(&rt, async { - use crate::{Function, Lua, Thread, Value}; - - let lua = Lua::new(); - - let thread: Function = lua - .load( - r#" - local luacall = ... - local function callback(...) - print("AM HERE") - luacall(coroutine.running(), ...) - return coroutine.yield() - end - - return callback - "#, - ) - .call( - lua.create_function(|_lua, th: Thread| { - tokio::task::spawn_local(async move { - println!("Thread: {:?}, {:?}", th, th.status()); - th.resume_error::<()>("An error here".to_string()).unwrap(); - tokio::task::yield_now().await; - }); - Ok(()) - }) - .unwrap(), - ) - .unwrap(); - - let thread_b: Thread = lua - .load( - r#" - local a = ... - return coroutine.create(function (...) - local b = ... - assert(b == 1) - local ok, result = pcall(a) - assert(not ok) - print("Done with: ", ok, result) - return result - end) - "#, - ) - .call(thread.clone()) - .unwrap(); - - println!("{:?}", thread_b.resume::(1)); - - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - println!("{:?}", thread_b.status()); - - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - }); - } -} diff --git a/tests/thread.rs b/tests/thread.rs index 74f75614..a939c233 100644 --- a/tests/thread.rs +++ b/tests/thread.rs @@ -227,3 +227,75 @@ fn test_thread_pointer() -> Result<()> { Ok(()) } + +#[cfg(feature = "luau")] +#[cfg(test)] +mod resumeerror_test { + #[test] + fn test_resumeerror() { + // Create tokio runtime and use spawn_local [this is required for the test to work] + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .worker_threads(10) + .build() + .unwrap(); + + let local = tokio::task::LocalSet::new(); + + local.block_on(&rt, async { + use mlua::{Function, Lua, Thread}; + + let lua = Lua::new(); + + let thread: Function = lua + .load( + r#" +local luacall = ... +local function callback(...) + luacall(coroutine.running(), ...) + return coroutine.yield() +end + +return callback + "#, + ) + .call( + lua.create_function(|_lua, th: Thread| { + tokio::task::spawn_local(async move { + th.resume_error::<()>("ErrorABC".to_string()).unwrap(); + tokio::task::yield_now().await; + }); + Ok(()) + }) + .unwrap(), + ) + .unwrap(); + + let thread_b: Thread = lua + .load( + r#" + local a = ... + return coroutine.create(function (...) + local b = ... + assert(b == 1) + -- Test 1: Working pcall + local ok, result = pcall(a) + assert(not ok, "Should not be ok") + assert(result == "ErrorABC", "Should be ErrorABC") + + -- Repeat + local ok, result = pcall(a) + assert(not ok) + assert(result == "ErrorABC") + end) + "#, + ) + .call(thread.clone()) + .unwrap(); + + // Actually using this system in practice needs a few extra pieces (such as a way to track return + // values etc.) + thread_b.resume::<()>(1).expect("Error in thread_b"); + }); + } +}