@@ -37,7 +37,10 @@ use super::extra::ExtraData;
37
37
use super :: { Lua , LuaOptions , WeakLua } ;
38
38
39
39
#[ cfg( not( feature = "luau" ) ) ]
40
- use crate :: hook:: { Debug , HookTriggers } ;
40
+ use crate :: {
41
+ hook:: Debug ,
42
+ types:: { HookCallback , HookKind , VmState } ,
43
+ } ;
41
44
42
45
#[ cfg( feature = "async" ) ]
43
46
use {
@@ -186,6 +189,8 @@ impl RawLua {
186
189
init_internal_metatable:: <XRc <UnsafeCell <ExtraData >>>( state, None ) ?;
187
190
init_internal_metatable:: <Callback >( state, None ) ?;
188
191
init_internal_metatable:: <CallbackUpvalue >( state, None ) ?;
192
+ #[ cfg( not( feature = "luau" ) ) ]
193
+ init_internal_metatable:: <HookCallback >( state, None ) ?;
189
194
#[ cfg( feature = "async" ) ]
190
195
{
191
196
init_internal_metatable:: <AsyncCallback >( state, None ) ?;
@@ -373,42 +378,22 @@ impl RawLua {
373
378
status
374
379
}
375
380
376
- /// Sets a ' hook' function for a thread (coroutine).
381
+ /// Sets a hook for a thread (coroutine).
377
382
#[ cfg( not( feature = "luau" ) ) ]
378
- pub ( crate ) unsafe fn set_thread_hook < F > (
383
+ pub ( crate ) unsafe fn set_thread_hook (
379
384
& self ,
380
- state : * mut ffi:: lua_State ,
381
- triggers : HookTriggers ,
382
- callback : F ,
383
- ) where
384
- F : Fn ( & Lua , Debug ) -> Result < crate :: VmState > + MaybeSend + ' static ,
385
- {
386
- use crate :: types:: VmState ;
387
- use std:: rc:: Rc ;
385
+ thread_state : * mut ffi:: lua_State ,
386
+ hook : HookKind ,
387
+ ) -> Result < ( ) > {
388
+ // Key to store hooks in the registry
389
+ const HOOKS_KEY : * const c_char = cstr ! ( "__mlua_hooks" ) ;
388
390
389
- unsafe extern "C-unwind" fn hook_proc ( state : * mut ffi:: lua_State , ar : * mut ffi:: lua_Debug ) {
390
- let extra = ExtraData :: get ( state) ;
391
- if ( * extra) . hook_thread != state {
392
- // Hook was destined for a different thread, ignore
393
- ffi:: lua_sethook ( state, None , 0 , 0 ) ;
394
- return ;
395
- }
396
- let result = callback_error_ext ( state, extra, move |extra, _| {
397
- let hook_cb = ( * extra) . hook_callback . clone ( ) ;
398
- let hook_cb = mlua_expect ! ( hook_cb, "no hook callback set in hook_proc" ) ;
399
- if Rc :: strong_count ( & hook_cb) > 2 {
400
- return Ok ( VmState :: Continue ) ; // Don't allow recursion
401
- }
402
- let rawlua = ( * extra) . raw_lua ( ) ;
403
- let _guard = StateGuard :: new ( rawlua, state) ;
404
- let debug = Debug :: new ( rawlua, ar) ;
405
- hook_cb ( ( * extra) . lua ( ) , debug)
406
- } ) ;
407
- match result {
391
+ unsafe fn process_status ( state : * mut ffi:: lua_State , event : c_int , status : VmState ) {
392
+ match status {
408
393
VmState :: Continue => { }
409
394
VmState :: Yield => {
410
395
// Only count and line events can yield
411
- if ( * ar ) . event == ffi:: LUA_HOOKCOUNT || ( * ar ) . event == ffi:: LUA_HOOKLINE {
396
+ if event == ffi:: LUA_HOOKCOUNT || event == ffi:: LUA_HOOKLINE {
412
397
#[ cfg( any( feature = "lua54" , feature = "lua53" ) ) ]
413
398
if ffi:: lua_isyieldable ( state) != 0 {
414
399
ffi:: lua_yield ( state, 0 ) ;
@@ -423,9 +408,86 @@ impl RawLua {
423
408
}
424
409
}
425
410
426
- ( * self . extra . get ( ) ) . hook_callback = Some ( Rc :: new ( callback) ) ;
427
- ( * self . extra . get ( ) ) . hook_thread = state; // Mark for what thread the hook is set
428
- ffi:: lua_sethook ( state, Some ( hook_proc) , triggers. mask ( ) , triggers. count ( ) ) ;
411
+ unsafe extern "C-unwind" fn global_hook_proc ( state : * mut ffi:: lua_State , ar : * mut ffi:: lua_Debug ) {
412
+ let status = callback_error_ext ( state, ptr:: null_mut ( ) , move |extra, _| {
413
+ let rawlua = ( * extra) . raw_lua ( ) ;
414
+ let debug = Debug :: new ( rawlua, ar) ;
415
+ match ( * extra) . hook_callback . take ( ) {
416
+ Some ( hook_cb) => {
417
+ // Temporary obtain ownership of the hook callback
418
+ let result = hook_cb ( ( * extra) . lua ( ) , debug) ;
419
+ ( * extra) . hook_callback = Some ( hook_cb) ;
420
+ result
421
+ }
422
+ None => {
423
+ ffi:: lua_sethook ( state, None , 0 , 0 ) ;
424
+ Ok ( VmState :: Continue )
425
+ }
426
+ }
427
+ } ) ;
428
+ process_status ( state, ( * ar) . event , status) ;
429
+ }
430
+
431
+ unsafe extern "C-unwind" fn hook_proc ( state : * mut ffi:: lua_State , ar : * mut ffi:: lua_Debug ) {
432
+ ffi:: luaL_checkstack ( state, 3 , ptr:: null ( ) ) ;
433
+ ffi:: lua_getfield ( state, ffi:: LUA_REGISTRYINDEX , HOOKS_KEY ) ;
434
+ ffi:: lua_pushthread ( state) ;
435
+ if ffi:: lua_rawget ( state, -2 ) != ffi:: LUA_TUSERDATA {
436
+ ffi:: lua_pop ( state, 2 ) ;
437
+ ffi:: lua_sethook ( state, None , 0 , 0 ) ;
438
+ return ;
439
+ }
440
+
441
+ let status = callback_error_ext ( state, ptr:: null_mut ( ) , |extra, _| {
442
+ let rawlua = ( * extra) . raw_lua ( ) ;
443
+ let debug = Debug :: new ( rawlua, ar) ;
444
+ match get_internal_userdata :: < HookCallback > ( state, -1 , ptr:: null ( ) ) . as_ref ( ) {
445
+ Some ( hook_cb) => hook_cb ( ( * extra) . lua ( ) , debug) ,
446
+ None => {
447
+ ffi:: lua_sethook ( state, None , 0 , 0 ) ;
448
+ Ok ( VmState :: Continue )
449
+ }
450
+ }
451
+ } ) ;
452
+ process_status ( state, ( * ar) . event , status)
453
+ }
454
+
455
+ let ( triggers, callback) = match hook {
456
+ HookKind :: Global if ( * self . extra . get ( ) ) . hook_callback . is_none ( ) => {
457
+ return Ok ( ( ) ) ;
458
+ }
459
+ HookKind :: Global => {
460
+ let triggers = ( * self . extra . get ( ) ) . hook_triggers ;
461
+ let ( mask, count) = ( triggers. mask ( ) , triggers. count ( ) ) ;
462
+ ffi:: lua_sethook ( thread_state, Some ( global_hook_proc) , mask, count) ;
463
+ return Ok ( ( ) ) ;
464
+ }
465
+ HookKind :: Thread ( triggers, callback) => ( triggers, callback) ,
466
+ } ;
467
+
468
+ // Hooks for threads stored in the registry (in a weak table)
469
+ let state = self . state ( ) ;
470
+ let _sg = StackGuard :: new ( state) ;
471
+ check_stack ( state, 3 ) ?;
472
+ protect_lua ! ( state, 0 , 0 , |state| {
473
+ if ffi:: luaL_getsubtable( state, ffi:: LUA_REGISTRYINDEX , HOOKS_KEY ) == 0 {
474
+ // Table just created, initialize it
475
+ ffi:: lua_pushliteral( state, "k" ) ;
476
+ ffi:: lua_setfield( state, -2 , cstr!( "__mode" ) ) ; // hooktable.__mode = "k"
477
+ ffi:: lua_pushvalue( state, -1 ) ;
478
+ ffi:: lua_setmetatable( state, -2 ) ; // metatable(hooktable) = hooktable
479
+ }
480
+
481
+ ffi:: lua_pushthread( thread_state) ;
482
+ ffi:: lua_xmove( thread_state, state, 1 ) ; // key (thread)
483
+ let callback: HookCallback = Box :: new( callback) ;
484
+ let _ = push_internal_userdata( state, callback, false ) ; // value (hook callback)
485
+ ffi:: lua_rawset( state, -3 ) ; // hooktable[thread] = hook callback
486
+ } ) ?;
487
+
488
+ ffi:: lua_sethook ( thread_state, Some ( hook_proc) , triggers. mask ( ) , triggers. count ( ) ) ;
489
+
490
+ Ok ( ( ) )
429
491
}
430
492
431
493
/// See [`Lua::create_string`]
@@ -497,6 +559,11 @@ impl RawLua {
497
559
} else {
498
560
protect_lua ! ( state, 0 , 1 , |state| ffi:: lua_newthread( state) ) ?
499
561
} ;
562
+
563
+ // Inherit global hook if set
564
+ #[ cfg( not( feature = "luau" ) ) ]
565
+ self . set_thread_hook ( thread_state, HookKind :: Global ) ?;
566
+
500
567
let thread = Thread ( self . pop_ref ( ) , thread_state) ;
501
568
ffi:: lua_xpush ( self . ref_thread ( ) , thread_state, func. 0 . index ) ;
502
569
Ok ( thread)
0 commit comments