Skip to content

Commit 74b4601

Browse files
committed
Rework Lua hooks:
- Support global hooks inherited by new threads - Support thread hooks, where each thread can have its own hook This should also allow to enable hooks for async calls. Related to #489 #347
1 parent a89800b commit 74b4601

File tree

8 files changed

+253
-91
lines changed

8 files changed

+253
-91
lines changed

src/state.rs

+39-18
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use crate::util::{
3030
use crate::value::{Nil, Value};
3131

3232
#[cfg(not(feature = "luau"))]
33-
use crate::hook::HookTriggers;
33+
use crate::{hook::HookTriggers, types::HookKind};
3434

3535
#[cfg(any(feature = "luau", doc))]
3636
use crate::{buffer::Buffer, chunk::Compiler};
@@ -501,6 +501,26 @@ impl Lua {
501501
}
502502
}
503503

504+
/// Sets or replaces a global hook function that will periodically be called as Lua code
505+
/// executes.
506+
///
507+
/// All new threads created (by mlua) after this call will use the global hook function.
508+
///
509+
/// For more information see [`Lua::set_hook`].
510+
#[cfg(not(feature = "luau"))]
511+
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
512+
pub fn set_global_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
513+
where
514+
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
515+
{
516+
let lua = self.lock();
517+
unsafe {
518+
(*lua.extra.get()).hook_triggers = triggers;
519+
(*lua.extra.get()).hook_callback = Some(Box::new(callback));
520+
lua.set_thread_hook(lua.state(), HookKind::Global)
521+
}
522+
}
523+
504524
/// Sets a hook function that will periodically be called as Lua code executes.
505525
///
506526
/// When exactly the hook function is called depends on the contents of the `triggers`
@@ -511,12 +531,10 @@ impl Lua {
511531
/// limited form of execution limits by setting [`HookTriggers.every_nth_instruction`] and
512532
/// erroring once an instruction limit has been reached.
513533
///
514-
/// This method sets a hook function for the current thread of this Lua instance.
534+
/// This method sets a hook function for the *current* thread of this Lua instance.
515535
/// If you want to set a hook function for another thread (coroutine), use
516536
/// [`Thread::set_hook`] instead.
517537
///
518-
/// Please note you cannot have more than one hook function set at a time for this Lua instance.
519-
///
520538
/// # Example
521539
///
522540
/// Shows each line number of code being executed by the Lua interpreter.
@@ -541,33 +559,36 @@ impl Lua {
541559
/// [`HookTriggers.every_nth_instruction`]: crate::HookTriggers::every_nth_instruction
542560
#[cfg(not(feature = "luau"))]
543561
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
544-
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
562+
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
545563
where
546564
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
547565
{
548566
let lua = self.lock();
549-
unsafe { lua.set_thread_hook(lua.state(), triggers, callback) };
567+
unsafe { lua.set_thread_hook(lua.state(), HookKind::Thread(triggers, Box::new(callback))) }
550568
}
551569

552-
/// Removes any hook previously set by [`Lua::set_hook`] or [`Thread::set_hook`].
570+
/// Removes a global hook previously set by [`Lua::set_global_hook`].
553571
///
554572
/// This function has no effect if a hook was not previously set.
555573
#[cfg(not(feature = "luau"))]
556574
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
557-
pub fn remove_hook(&self) {
575+
pub fn remove_global_hook(&self) {
558576
let lua = self.lock();
559577
unsafe {
560-
let state = lua.state();
561-
ffi::lua_sethook(state, None, 0, 0);
562-
match lua.main_state {
563-
Some(main_state) if state != main_state.as_ptr() => {
564-
// If main_state is different from state, remove hook from it too
565-
ffi::lua_sethook(main_state.as_ptr(), None, 0, 0);
566-
}
567-
_ => {}
568-
};
569578
(*lua.extra.get()).hook_callback = None;
570-
(*lua.extra.get()).hook_thread = ptr::null_mut();
579+
(*lua.extra.get()).hook_triggers = HookTriggers::default();
580+
}
581+
}
582+
583+
/// Removes any hook from the current thread.
584+
///
585+
/// This function has no effect if a hook was not previously set.
586+
#[cfg(not(feature = "luau"))]
587+
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
588+
pub fn remove_hook(&self) {
589+
let lua = self.lock();
590+
unsafe {
591+
ffi::lua_sethook(lua.state(), None, 0, 0);
571592
}
572593
}
573594

src/state/extra.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ pub(crate) struct ExtraData {
7575
#[cfg(not(feature = "luau"))]
7676
pub(super) hook_callback: Option<crate::types::HookCallback>,
7777
#[cfg(not(feature = "luau"))]
78-
pub(super) hook_thread: *mut ffi::lua_State,
78+
pub(super) hook_triggers: crate::hook::HookTriggers,
7979
#[cfg(feature = "lua54")]
8080
pub(super) warn_callback: Option<crate::types::WarnCallback>,
8181
#[cfg(feature = "luau")]
@@ -171,7 +171,7 @@ impl ExtraData {
171171
#[cfg(not(feature = "luau"))]
172172
hook_callback: None,
173173
#[cfg(not(feature = "luau"))]
174-
hook_thread: ptr::null_mut(),
174+
hook_triggers: Default::default(),
175175
#[cfg(feature = "lua54")]
176176
warn_callback: None,
177177
#[cfg(feature = "luau")]

src/state/raw.rs

+101-34
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ use super::extra::ExtraData;
3737
use super::{Lua, LuaOptions, WeakLua};
3838

3939
#[cfg(not(feature = "luau"))]
40-
use crate::hook::{Debug, HookTriggers};
40+
use crate::{
41+
hook::Debug,
42+
types::{HookCallback, HookKind, VmState},
43+
};
4144

4245
#[cfg(feature = "async")]
4346
use {
@@ -186,6 +189,8 @@ impl RawLua {
186189
init_internal_metatable::<XRc<UnsafeCell<ExtraData>>>(state, None)?;
187190
init_internal_metatable::<Callback>(state, None)?;
188191
init_internal_metatable::<CallbackUpvalue>(state, None)?;
192+
#[cfg(not(feature = "luau"))]
193+
init_internal_metatable::<HookCallback>(state, None)?;
189194
#[cfg(feature = "async")]
190195
{
191196
init_internal_metatable::<AsyncCallback>(state, None)?;
@@ -373,42 +378,22 @@ impl RawLua {
373378
status
374379
}
375380

376-
/// Sets a 'hook' function for a thread (coroutine).
381+
/// Sets a hook for a thread (coroutine).
377382
#[cfg(not(feature = "luau"))]
378-
pub(crate) unsafe fn set_thread_hook<F>(
383+
pub(crate) unsafe fn set_thread_hook(
379384
&self,
380-
state: *mut ffi::lua_State,
381-
triggers: HookTriggers,
382-
callback: F,
383-
) where
384-
F: Fn(&Lua, Debug) -> Result<crate::VmState> + MaybeSend + 'static,
385-
{
386-
use crate::types::VmState;
387-
use std::rc::Rc;
385+
thread_state: *mut ffi::lua_State,
386+
hook: HookKind,
387+
) -> Result<()> {
388+
// Key to store hooks in the registry
389+
const HOOKS_KEY: *const c_char = cstr!("__mlua_hooks");
388390

389-
unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
390-
let extra = ExtraData::get(state);
391-
if (*extra).hook_thread != state {
392-
// Hook was destined for a different thread, ignore
393-
ffi::lua_sethook(state, None, 0, 0);
394-
return;
395-
}
396-
let result = callback_error_ext(state, extra, move |extra, _| {
397-
let hook_cb = (*extra).hook_callback.clone();
398-
let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc");
399-
if Rc::strong_count(&hook_cb) > 2 {
400-
return Ok(VmState::Continue); // Don't allow recursion
401-
}
402-
let rawlua = (*extra).raw_lua();
403-
let _guard = StateGuard::new(rawlua, state);
404-
let debug = Debug::new(rawlua, ar);
405-
hook_cb((*extra).lua(), debug)
406-
});
407-
match result {
391+
unsafe fn process_status(state: *mut ffi::lua_State, event: c_int, status: VmState) {
392+
match status {
408393
VmState::Continue => {}
409394
VmState::Yield => {
410395
// Only count and line events can yield
411-
if (*ar).event == ffi::LUA_HOOKCOUNT || (*ar).event == ffi::LUA_HOOKLINE {
396+
if event == ffi::LUA_HOOKCOUNT || event == ffi::LUA_HOOKLINE {
412397
#[cfg(any(feature = "lua54", feature = "lua53"))]
413398
if ffi::lua_isyieldable(state) != 0 {
414399
ffi::lua_yield(state, 0);
@@ -423,9 +408,86 @@ impl RawLua {
423408
}
424409
}
425410

426-
(*self.extra.get()).hook_callback = Some(Rc::new(callback));
427-
(*self.extra.get()).hook_thread = state; // Mark for what thread the hook is set
428-
ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count());
411+
unsafe extern "C-unwind" fn global_hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
412+
let status = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
413+
let rawlua = (*extra).raw_lua();
414+
let debug = Debug::new(rawlua, ar);
415+
match (*extra).hook_callback.take() {
416+
Some(hook_cb) => {
417+
// Temporary obtain ownership of the hook callback
418+
let result = hook_cb((*extra).lua(), debug);
419+
(*extra).hook_callback = Some(hook_cb);
420+
result
421+
}
422+
None => {
423+
ffi::lua_sethook(state, None, 0, 0);
424+
Ok(VmState::Continue)
425+
}
426+
}
427+
});
428+
process_status(state, (*ar).event, status);
429+
}
430+
431+
unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
432+
ffi::luaL_checkstack(state, 3, ptr::null());
433+
ffi::lua_getfield(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY);
434+
ffi::lua_pushthread(state);
435+
if ffi::lua_rawget(state, -2) != ffi::LUA_TUSERDATA {
436+
ffi::lua_pop(state, 2);
437+
ffi::lua_sethook(state, None, 0, 0);
438+
return;
439+
}
440+
441+
let status = callback_error_ext(state, ptr::null_mut(), |extra, _| {
442+
let rawlua = (*extra).raw_lua();
443+
let debug = Debug::new(rawlua, ar);
444+
match get_internal_userdata::<HookCallback>(state, -1, ptr::null()).as_ref() {
445+
Some(hook_cb) => hook_cb((*extra).lua(), debug),
446+
None => {
447+
ffi::lua_sethook(state, None, 0, 0);
448+
Ok(VmState::Continue)
449+
}
450+
}
451+
});
452+
process_status(state, (*ar).event, status)
453+
}
454+
455+
let (triggers, callback) = match hook {
456+
HookKind::Global if (*self.extra.get()).hook_callback.is_none() => {
457+
return Ok(());
458+
}
459+
HookKind::Global => {
460+
let triggers = (*self.extra.get()).hook_triggers;
461+
let (mask, count) = (triggers.mask(), triggers.count());
462+
ffi::lua_sethook(thread_state, Some(global_hook_proc), mask, count);
463+
return Ok(());
464+
}
465+
HookKind::Thread(triggers, callback) => (triggers, callback),
466+
};
467+
468+
// Hooks for threads stored in the registry (in a weak table)
469+
let state = self.state();
470+
let _sg = StackGuard::new(state);
471+
check_stack(state, 3)?;
472+
protect_lua!(state, 0, 0, |state| {
473+
if ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, HOOKS_KEY) == 0 {
474+
// Table just created, initialize it
475+
ffi::lua_pushliteral(state, "k");
476+
ffi::lua_setfield(state, -2, cstr!("__mode")); // hooktable.__mode = "k"
477+
ffi::lua_pushvalue(state, -1);
478+
ffi::lua_setmetatable(state, -2); // metatable(hooktable) = hooktable
479+
}
480+
481+
ffi::lua_pushthread(thread_state);
482+
ffi::lua_xmove(thread_state, state, 1); // key (thread)
483+
let callback: HookCallback = Box::new(callback);
484+
let _ = push_internal_userdata(state, callback, false); // value (hook callback)
485+
ffi::lua_rawset(state, -3); // hooktable[thread] = hook callback
486+
})?;
487+
488+
ffi::lua_sethook(thread_state, Some(hook_proc), triggers.mask(), triggers.count());
489+
490+
Ok(())
429491
}
430492

431493
/// See [`Lua::create_string`]
@@ -497,6 +559,11 @@ impl RawLua {
497559
} else {
498560
protect_lua!(state, 0, 1, |state| ffi::lua_newthread(state))?
499561
};
562+
563+
// Inherit global hook if set
564+
#[cfg(not(feature = "luau"))]
565+
self.set_thread_hook(thread_state, HookKind::Global)?;
566+
500567
let thread = Thread(self.pop_ref(), thread_state);
501568
ffi::lua_xpush(self.ref_thread(), thread_state, func.0.index);
502569
Ok(thread)

src/thread.rs

+15-5
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
1111
#[cfg(not(feature = "luau"))]
1212
use crate::{
1313
hook::{Debug, HookTriggers},
14-
types::MaybeSend,
14+
types::HookKind,
1515
};
1616

1717
#[cfg(feature = "async")]
@@ -262,16 +262,26 @@ impl Thread {
262262
/// Sets a hook function that will periodically be called as Lua code executes.
263263
///
264264
/// This function is similar or [`Lua::set_hook`] except that it sets for the thread.
265-
/// To remove a hook call [`Lua::remove_hook`].
265+
/// You can have multiple hooks for different threads.
266+
///
267+
/// To remove a hook call [`Thread::remove_hook`].
266268
#[cfg(not(feature = "luau"))]
267269
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
268-
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
270+
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F) -> Result<()>
269271
where
270-
F: Fn(&crate::Lua, Debug) -> Result<crate::VmState> + MaybeSend + 'static,
272+
F: Fn(&crate::Lua, Debug) -> Result<crate::VmState> + crate::MaybeSend + 'static,
271273
{
272274
let lua = self.0.lua.lock();
275+
unsafe { lua.set_thread_hook(self.state(), HookKind::Thread(triggers, Box::new(callback))) }
276+
}
277+
278+
/// Removes any hook function from this thread.
279+
#[cfg(not(feature = "luau"))]
280+
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
281+
pub fn remove_hook(&self) {
282+
let _lua = self.0.lua.lock();
273283
unsafe {
274-
lua.set_thread_hook(self.state(), triggers, callback);
284+
ffi::lua_sethook(self.state(), None, 0, 0);
275285
}
276286
}
277287

src/types.rs

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use std::cell::UnsafeCell;
22
use std::os::raw::{c_int, c_void};
3-
use std::rc::Rc;
43

54
use crate::error::Result;
65
#[cfg(not(feature = "luau"))]
7-
use crate::hook::Debug;
6+
use crate::hook::{Debug, HookTriggers};
87
use crate::state::{ExtraData, Lua, RawLua};
98

109
// Re-export mutex wrappers
@@ -73,17 +72,23 @@ pub enum VmState {
7372
Yield,
7473
}
7574

75+
#[cfg(not(feature = "luau"))]
76+
pub(crate) enum HookKind {
77+
Global,
78+
Thread(HookTriggers, HookCallback),
79+
}
80+
7681
#[cfg(all(feature = "send", not(feature = "luau")))]
77-
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
82+
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;
7883

7984
#[cfg(all(not(feature = "send"), not(feature = "luau")))]
80-
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState>>;
85+
pub(crate) type HookCallback = Box<dyn Fn(&Lua, Debug) -> Result<VmState>>;
8186

8287
#[cfg(all(feature = "send", feature = "luau"))]
83-
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
88+
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
8489

8590
#[cfg(all(not(feature = "send"), feature = "luau"))]
86-
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState>>;
91+
pub(crate) type InterruptCallback = std::rc::Rc<dyn Fn(&Lua) -> Result<VmState>>;
8792

8893
#[cfg(all(feature = "send", feature = "lua54"))]
8994
pub(crate) type WarnCallback = Box<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;

0 commit comments

Comments
 (0)