Skip to content

Commit 05778fb

Browse files
committed
Don't store and use wrong main Lua state in module mode (Lua 5.1/JIT only).
When mlua module is loaded from a non-main coroutine we store a reference to it to use later. If the coroutine is destroyed by GC we can pass a wrong pointer to Lua that will trigger a segfault. Instead, set main_state as Option and use current (active) state if needed. Relates to #479
1 parent b34d67e commit 05778fb

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

src/state.rs

+28-25
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ impl Lua {
485485
let lua = self.lock();
486486
unsafe {
487487
if (*lua.extra.get()).sandboxed != enabled {
488-
let state = lua.main_state;
488+
let state = lua.main_state();
489489
check_stack(state, 3)?;
490490
protect_lua!(state, 0, 0, |state| {
491491
if enabled {
@@ -562,10 +562,10 @@ impl Lua {
562562
unsafe {
563563
let state = lua.state();
564564
ffi::lua_sethook(state, None, 0, 0);
565-
match crate::util::get_main_state(lua.main_state) {
566-
Some(main_state) if !ptr::eq(state, main_state) => {
565+
match lua.main_state {
566+
Some(main_state) if state != main_state.as_ptr() => {
567567
// If main_state is different from state, remove hook from it too
568-
ffi::lua_sethook(main_state, None, 0, 0);
568+
ffi::lua_sethook(main_state.as_ptr(), None, 0, 0);
569569
}
570570
_ => {}
571571
};
@@ -654,7 +654,7 @@ impl Lua {
654654
let lua = self.lock();
655655
unsafe {
656656
(*lua.extra.get()).interrupt_callback = Some(Rc::new(callback));
657-
(*ffi::lua_callbacks(lua.main_state)).interrupt = Some(interrupt_proc);
657+
(*ffi::lua_callbacks(lua.main_state())).interrupt = Some(interrupt_proc);
658658
}
659659
}
660660

@@ -667,7 +667,7 @@ impl Lua {
667667
let lua = self.lock();
668668
unsafe {
669669
(*lua.extra.get()).interrupt_callback = None;
670-
(*ffi::lua_callbacks(lua.main_state)).interrupt = None;
670+
(*ffi::lua_callbacks(lua.main_state())).interrupt = None;
671671
}
672672
}
673673

@@ -697,10 +697,9 @@ impl Lua {
697697
}
698698

699699
let lua = self.lock();
700-
let state = lua.main_state;
701700
unsafe {
702701
(*lua.extra.get()).warn_callback = Some(Box::new(callback));
703-
ffi::lua_setwarnf(state, Some(warn_proc), lua.extra.get() as *mut c_void);
702+
ffi::lua_setwarnf(lua.state(), Some(warn_proc), lua.extra.get() as *mut c_void);
704703
}
705704
}
706705

@@ -715,7 +714,7 @@ impl Lua {
715714
let lua = self.lock();
716715
unsafe {
717716
(*lua.extra.get()).warn_callback = None;
718-
ffi::lua_setwarnf(lua.main_state, None, ptr::null_mut());
717+
ffi::lua_setwarnf(lua.state(), None, ptr::null_mut());
719718
}
720719
}
721720

@@ -767,13 +766,14 @@ impl Lua {
767766
/// Returns the amount of memory (in bytes) currently used inside this Lua state.
768767
pub fn used_memory(&self) -> usize {
769768
let lua = self.lock();
769+
let state = lua.main_state();
770770
unsafe {
771-
match MemoryState::get(lua.main_state) {
771+
match MemoryState::get(state) {
772772
mem_state if !mem_state.is_null() => (*mem_state).used_memory(),
773773
_ => {
774774
// Get data from the Lua GC
775-
let used_kbytes = ffi::lua_gc(lua.main_state, ffi::LUA_GCCOUNT, 0);
776-
let used_kbytes_rem = ffi::lua_gc(lua.main_state, ffi::LUA_GCCOUNTB, 0);
775+
let used_kbytes = ffi::lua_gc(state, ffi::LUA_GCCOUNT, 0);
776+
let used_kbytes_rem = ffi::lua_gc(state, ffi::LUA_GCCOUNTB, 0);
777777
(used_kbytes as usize) * 1024 + (used_kbytes_rem as usize)
778778
}
779779
}
@@ -790,7 +790,7 @@ impl Lua {
790790
pub fn set_memory_limit(&self, limit: usize) -> Result<usize> {
791791
let lua = self.lock();
792792
unsafe {
793-
match MemoryState::get(lua.main_state) {
793+
match MemoryState::get(lua.state()) {
794794
mem_state if !mem_state.is_null() => Ok((*mem_state).set_memory_limit(limit)),
795795
_ => Err(Error::MemoryControlNotAvailable),
796796
}
@@ -803,19 +803,19 @@ impl Lua {
803803
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "luau"))]
804804
pub fn gc_is_running(&self) -> bool {
805805
let lua = self.lock();
806-
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCISRUNNING, 0) != 0 }
806+
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCISRUNNING, 0) != 0 }
807807
}
808808

809809
/// Stop the Lua GC from running
810810
pub fn gc_stop(&self) {
811811
let lua = self.lock();
812-
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCSTOP, 0) };
812+
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCSTOP, 0) };
813813
}
814814

815815
/// Restarts the Lua GC if it is not running
816816
pub fn gc_restart(&self) {
817817
let lua = self.lock();
818-
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCRESTART, 0) };
818+
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCRESTART, 0) };
819819
}
820820

821821
/// Perform a full garbage-collection cycle.
@@ -824,9 +824,10 @@ impl Lua {
824824
/// objects. Once to finish the current gc cycle, and once to start and finish the next cycle.
825825
pub fn gc_collect(&self) -> Result<()> {
826826
let lua = self.lock();
827+
let state = lua.main_state();
827828
unsafe {
828-
check_stack(lua.main_state, 2)?;
829-
protect_lua!(lua.main_state, 0, 0, fn(state) ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0))
829+
check_stack(state, 2)?;
830+
protect_lua!(state, 0, 0, fn(state) ffi::lua_gc(state, ffi::LUA_GCCOLLECT, 0))
830831
}
831832
}
832833

@@ -843,9 +844,10 @@ impl Lua {
843844
/// finished a collection cycle.
844845
pub fn gc_step_kbytes(&self, kbytes: c_int) -> Result<bool> {
845846
let lua = self.lock();
847+
let state = lua.main_state();
846848
unsafe {
847-
check_stack(lua.main_state, 3)?;
848-
protect_lua!(lua.main_state, 0, 0, |state| {
849+
check_stack(state, 3)?;
850+
protect_lua!(state, 0, 0, |state| {
849851
ffi::lua_gc(state, ffi::LUA_GCSTEP, kbytes) != 0
850852
})
851853
}
@@ -861,11 +863,12 @@ impl Lua {
861863
/// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5
862864
pub fn gc_set_pause(&self, pause: c_int) -> c_int {
863865
let lua = self.lock();
866+
let state = lua.main_state();
864867
unsafe {
865868
#[cfg(not(feature = "luau"))]
866-
return ffi::lua_gc(lua.main_state, ffi::LUA_GCSETPAUSE, pause);
869+
return ffi::lua_gc(state, ffi::LUA_GCSETPAUSE, pause);
867870
#[cfg(feature = "luau")]
868-
return ffi::lua_gc(lua.main_state, ffi::LUA_GCSETGOAL, pause);
871+
return ffi::lua_gc(state, ffi::LUA_GCSETGOAL, pause);
869872
}
870873
}
871874

@@ -877,7 +880,7 @@ impl Lua {
877880
/// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5
878881
pub fn gc_set_step_multiplier(&self, step_multiplier: c_int) -> c_int {
879882
let lua = self.lock();
880-
unsafe { ffi::lua_gc(lua.main_state, ffi::LUA_GCSETSTEPMUL, step_multiplier) }
883+
unsafe { ffi::lua_gc(lua.main_state(), ffi::LUA_GCSETSTEPMUL, step_multiplier) }
881884
}
882885

883886
/// Changes the collector to incremental mode with the given parameters.
@@ -888,7 +891,7 @@ impl Lua {
888891
/// [documentation]: https://www.lua.org/manual/5.4/manual.html#2.5.1
889892
pub fn gc_inc(&self, pause: c_int, step_multiplier: c_int, step_size: c_int) -> GCMode {
890893
let lua = self.lock();
891-
let state = lua.main_state;
894+
let state = lua.main_state();
892895

893896
#[cfg(any(
894897
feature = "lua53",
@@ -941,7 +944,7 @@ impl Lua {
941944
#[cfg_attr(docsrs, doc(cfg(feature = "lua54")))]
942945
pub fn gc_gen(&self, minor_multiplier: c_int, major_multiplier: c_int) -> GCMode {
943946
let lua = self.lock();
944-
let state = lua.main_state;
947+
let state = lua.main_state();
945948
let prev_mode = unsafe { ffi::lua_gc(state, ffi::LUA_GCGEN, minor_multiplier, major_multiplier) };
946949
match prev_mode {
947950
ffi::LUA_GCGEN => GCMode::Generational,

src/state/raw.rs

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use std::any::TypeId;
22
use std::cell::{Cell, UnsafeCell};
33
use std::ffi::{CStr, CString};
4+
use std::mem;
45
use std::os::raw::{c_char, c_int, c_void};
56
use std::panic::resume_unwind;
7+
use std::ptr::{self, NonNull};
68
use std::result::Result as StdResult;
79
use std::sync::Arc;
8-
use std::{mem, ptr};
910

1011
use crate::chunk::ChunkMode;
1112
use crate::error::{Error, Result};
@@ -41,7 +42,6 @@ use {
4142
crate::multi::MultiValue,
4243
crate::traits::FromLuaMulti,
4344
crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue},
44-
std::ptr::NonNull,
4545
std::task::{Context, Poll, Waker},
4646
};
4747

@@ -50,7 +50,7 @@ use {
5050
pub struct RawLua {
5151
// The state is dynamic and depends on context
5252
pub(super) state: Cell<*mut ffi::lua_State>,
53-
pub(super) main_state: *mut ffi::lua_State,
53+
pub(super) main_state: Option<NonNull<ffi::lua_State>>,
5454
pub(super) extra: XRc<UnsafeCell<ExtraData>>,
5555
}
5656

@@ -61,9 +61,9 @@ impl Drop for RawLua {
6161
return;
6262
}
6363

64-
let mem_state = MemoryState::get(self.main_state);
64+
let mem_state = MemoryState::get(self.main_state());
6565

66-
ffi::lua_close(self.main_state);
66+
ffi::lua_close(self.main_state());
6767

6868
// Deallocate `MemoryState`
6969
if !mem_state.is_null() {
@@ -95,10 +95,11 @@ impl RawLua {
9595
self.state.get()
9696
}
9797

98-
#[cfg(feature = "luau")]
9998
#[inline(always)]
10099
pub(crate) fn main_state(&self) -> *mut ffi::lua_State {
101100
self.main_state
101+
.map(|state| state.as_ptr())
102+
.unwrap_or_else(|| self.state())
102103
}
103104

104105
#[inline(always)]
@@ -221,7 +222,8 @@ impl RawLua {
221222
#[allow(clippy::arc_with_non_send_sync)]
222223
let rawlua = XRc::new(ReentrantMutex::new(RawLua {
223224
state: Cell::new(state),
224-
main_state,
225+
// Make sure that we don't store current state as main state (if it's not available)
226+
main_state: get_main_state(state).and_then(NonNull::new),
225227
extra: XRc::clone(&extra),
226228
}));
227229
(*extra.get()).set_lua(&rawlua);
@@ -263,7 +265,7 @@ impl RawLua {
263265
));
264266
}
265267

266-
let res = load_std_libs(self.main_state, libs);
268+
let res = load_std_libs(self.main_state(), libs);
267269

268270
// If `package` library loaded into a safe lua state then disable C modules
269271
let curr_libs = (*self.extra.get()).libs;
@@ -734,7 +736,7 @@ impl RawLua {
734736
}
735737

736738
// MemoryInfo is empty in module mode so we cannot predict memory limits
737-
match MemoryState::get(self.main_state) {
739+
match MemoryState::get(self.state()) {
738740
mem_state if !mem_state.is_null() => (*mem_state).memory_limit() == 0,
739741
_ => (*self.extra.get()).skip_memory_check, // Check the special flag (only for module mode)
740742
}
@@ -1095,7 +1097,7 @@ impl RawLua {
10951097
#[cfg(any(feature = "lua54", feature = "lua53", feature = "lua52", feature = "luau"))]
10961098
unsafe {
10971099
if !(*self.extra.get()).libs.contains(StdLib::COROUTINE) {
1098-
load_std_libs(self.main_state, StdLib::COROUTINE)?;
1100+
load_std_libs(self.main_state(), StdLib::COROUTINE)?;
10991101
(*self.extra.get()).libs |= StdLib::COROUTINE;
11001102
}
11011103
}

0 commit comments

Comments
 (0)