diff --git a/crates/monty-python/tests/test_repl.py b/crates/monty-python/tests/test_repl.py index 99726ddd..e4fde147 100644 --- a/crates/monty-python/tests/test_repl.py +++ b/crates/monty-python/tests/test_repl.py @@ -1,4 +1,5 @@ -from typing import Callable, Literal +import dataclasses +from typing import Callable, Literal, TypeAlias import pytest from inline_snapshot import snapshot @@ -6,6 +7,12 @@ import pydantic_monty PrintCallback = Callable[[Literal['stdout'], str], None] +ReplProgress: TypeAlias = ( + pydantic_monty.FunctionSnapshot + | pydantic_monty.NameLookupSnapshot + | pydantic_monty.FutureSnapshot + | pydantic_monty.MontyComplete +) def make_print_collector() -> tuple[list[str], PrintCallback]: @@ -524,6 +531,52 @@ def test_feed_start_multiple_external_calls(): assert progress.output == snapshot(30) +def test_feed_start_regression_281_comprehension_dataclass_repl_state(): + @dataclasses.dataclass + class Item: + text: str + + @dataclasses.dataclass + class Container: + style: str = 'Normal' + items: list[Item] = dataclasses.field(default_factory=list) + + def drive(state: ReplProgress) -> pydantic_monty.MontyComplete: + while not isinstance(state, pydantic_monty.MontyComplete): + if isinstance(state, pydantic_monty.NameLookupSnapshot): + state = state.resume() + elif isinstance(state, pydantic_monty.FunctionSnapshot): + fn = state.function_name + if fn == 'get_data': + data = [ + Container(style='A', items=[Item(text='hello'), Item(text='world')]), + Container(style='B', items=[Item(text='foo')]), + ] + state = state.resume(return_value=data) + elif fn == 'Container': + state = state.resume(return_value=Container(**state.kwargs)) + else: + state = state.resume(return_value=None) + else: + state = state.resume({}) + return state + + repl = pydantic_monty.MontyRepl() + repl.register_dataclass(Item) + repl.register_dataclass(Container) + + turn1 = repl.feed_start( + 'data = get_data()\nitems = [i.text for i in data[0].items]\nitems = [i.text for i in data[1].items]\n' + ) + turn1 = drive(turn1) + assert turn1.output == snapshot(None) + + turn2 = repl.feed_start('c = Container(style="X")\n') + turn2 = drive(turn2) + assert turn2.output == snapshot(None) + assert repl.feed_run('c.style') == snapshot('X') + + def test_feed_start_error_preserves_repl_state(): """REPL state is preserved when feed_start raises an error.""" repl = pydantic_monty.MontyRepl() diff --git a/crates/monty/src/bytecode/compiler.rs b/crates/monty/src/bytecode/compiler.rs index 7bea26ec..400a53ab 100644 --- a/crates/monty/src/bytecode/compiler.rs +++ b/crates/monty/src/bytecode/compiler.rs @@ -2379,6 +2379,7 @@ impl<'a> Compiler<'a> { /// /// Bytecode structure: /// ```text + /// SAVE_*_AND_CLEAR ... ; one per synthetic comprehension slot /// BUILD_LIST 0 ; empty result /// /// GET_ITER @@ -2391,9 +2392,11 @@ impl<'a> Compiler<'a> { /// LIST_APPEND depth /// JUMP loop_start /// end_loop: + /// RESTORE_* ... ; reverse order /// ; result list on stack /// ``` fn compile_list_comp(&mut self, elt: &ExprLoc, generators: &[Comprehension]) -> Result<(), CompileError> { + self.compile_comprehension_slot_saves(generators); // Build empty list self.code.emit_u16(Opcode::BuildList, 0); @@ -2405,11 +2408,14 @@ impl<'a> Compiler<'a> { Ok(()) })?; + self.compile_comprehension_slot_restores(generators); + Ok(()) } /// Compiles a set comprehension: `{elt for target in iter if cond...}` fn compile_set_comp(&mut self, elt: &ExprLoc, generators: &[Comprehension]) -> Result<(), CompileError> { + self.compile_comprehension_slot_saves(generators); // Build empty set self.code.emit_u16(Opcode::BuildSet, 0); @@ -2421,6 +2427,8 @@ impl<'a> Compiler<'a> { Ok(()) })?; + self.compile_comprehension_slot_restores(generators); + Ok(()) } @@ -2431,6 +2439,7 @@ impl<'a> Compiler<'a> { value: &ExprLoc, generators: &[Comprehension], ) -> Result<(), CompileError> { + self.compile_comprehension_slot_saves(generators); // Build empty dict self.code.emit_u16(Opcode::BuildDict, 0); @@ -2443,9 +2452,70 @@ impl<'a> Compiler<'a> { Ok(()) })?; + self.compile_comprehension_slot_restores(generators); + Ok(()) } + /// Emits save-and-clear opcodes for every synthetic comprehension target slot. + /// + /// Monty already inlines comprehensions, but their loop variables live in compiler- + /// generated slots. Saving and clearing those slots before execution makes their lifetime + /// match PEP 709 more closely and prevents stale values from leaking past the end of the + /// comprehension when execution completes incrementally. + fn compile_comprehension_slot_saves(&mut self, generators: &[Comprehension]) { + for slot in Self::comprehension_target_slots(generators) { + let opcode = if self.is_module_scope { + Opcode::SaveGlobalAndClear + } else { + Opcode::SaveLocalAndClear + }; + self.code.emit_u16(opcode, slot); + } + } + + /// Emits restore opcodes for every synthetic comprehension target slot. + /// + /// Restores are emitted in reverse order so the VM's save stack is unwound LIFO. + fn compile_comprehension_slot_restores(&mut self, generators: &[Comprehension]) { + for slot in Self::comprehension_target_slots(generators).into_iter().rev() { + let opcode = if self.is_module_scope { + Opcode::RestoreGlobal + } else { + Opcode::RestoreLocal + }; + self.code.emit_u16(opcode, slot); + } + } + + /// Collects every namespace slot assigned by the comprehension generators. + /// + /// Each target slot is compiler-generated and scoped to the comprehension, so saving all + /// of them is sufficient to restore the namespace even for nested generators and tuple + /// unpacking targets. + fn comprehension_target_slots(generators: &[Comprehension]) -> Vec { + let mut slots = Vec::new(); + for generator in generators { + Self::collect_unpack_target_slots(&generator.target, &mut slots); + } + slots + } + + /// Recursively collects namespace slots written by an unpack target. + fn collect_unpack_target_slots(target: &UnpackTarget, slots: &mut Vec) { + match target { + UnpackTarget::Name(ident) | UnpackTarget::Starred(ident) => { + let slot = u16::try_from(ident.namespace_id().index()).expect("local slot exceeds u16"); + slots.push(slot); + } + UnpackTarget::Tuple { targets, .. } => { + for target in targets { + Self::collect_unpack_target_slots(target, slots); + } + } + } + } + /// Recursively compiles comprehension generators (the for/if clauses). /// /// For each generator: diff --git a/crates/monty/src/bytecode/op.rs b/crates/monty/src/bytecode/op.rs index c3ff85fc..3c7625c9 100644 --- a/crates/monty/src/bytecode/op.rs +++ b/crates/monty/src/bytecode/op.rs @@ -440,6 +440,26 @@ pub enum Opcode { /// Pops iterable (TOS), adds each item to set at stack position `len - 2 - depth`. /// Raises `TypeError` if iterable is not iterable. SetExtend, + /// Save a local slot to the VM's comprehension-save stack, then clear it. + /// Operand: u16 slot. + /// + /// Used by inlined comprehensions to give their synthetic loop-variable slots the + /// same temporary lifetime as CPython's `LOAD_FAST_AND_CLEAR` behavior, without + /// exposing the saved value on the operand stack. + SaveLocalAndClear, + /// Restore the most recently saved local slot from the comprehension-save stack. + /// Operand: u16 slot. + RestoreLocal, + /// Save a global slot to the VM's comprehension-save stack, then clear it. + /// Operand: u16 slot. + /// + /// Module-scope comprehensions use global slots in Monty's runtime model, so they + /// need an explicit save/restore pair to avoid leaking temporary loop values across + /// REPL turns and other incremental execution boundaries. + SaveGlobalAndClear, + /// Restore the most recently saved global slot from the comprehension-save stack. + /// Operand: u16 slot. + RestoreGlobal, } impl TryFrom for Opcode { @@ -477,7 +497,7 @@ impl Opcode { LoadLocal | LoadLocalW | LoadLocalCallable | LoadLocalCallableW | LoadGlobal | LoadGlobalCallable | LoadCell => 1, StoreLocal | StoreLocalW | StoreGlobal | StoreCell => -1, - DeleteLocal | DeleteGlobal => 0, // doesn't affect stack + DeleteLocal | DeleteGlobal | SaveLocalAndClear | RestoreLocal | SaveGlobalAndClear | RestoreGlobal => 0, // Binary operations: pop 2, push 1 = -1 BinaryAdd | BinarySub | BinaryMul | BinaryDiv | BinaryFloorDiv | BinaryMod | BinaryPow | BinaryAnd @@ -584,8 +604,8 @@ mod tests { #[test] fn test_opcode_roundtrip() { - // Verify that all opcodes from 0 to DeleteGlobal (last opcode) can be converted to u8 and back. - for byte in 0..=Opcode::DeleteGlobal as u8 { + // Verify that all opcodes from 0 to RestoreGlobal (last opcode) can be converted to u8 and back. + for byte in 0..=Opcode::RestoreGlobal as u8 { let opcode = Opcode::try_from(byte).unwrap(); assert_eq!(opcode as u8, byte, "opcode {opcode:?} has wrong discriminant"); } @@ -601,12 +621,16 @@ mod tests { assert_eq!(Opcode::DeleteGlobal as u8, 112); assert_eq!(Opcode::DictUpdate as u8, 113); assert_eq!(Opcode::SetExtend as u8, 114); + assert_eq!(Opcode::SaveLocalAndClear as u8, 115); + assert_eq!(Opcode::RestoreLocal as u8, 116); + assert_eq!(Opcode::SaveGlobalAndClear as u8, 117); + assert_eq!(Opcode::RestoreGlobal as u8, 118); } #[test] fn test_invalid_opcode() { // Byte just after the last valid opcode should fail - let result = Opcode::try_from(Opcode::SetExtend as u8 + 1); + let result = Opcode::try_from(Opcode::RestoreGlobal as u8 + 1); assert!(result.is_err()); // 255 should also fail let result = Opcode::try_from(255u8); diff --git a/crates/monty/src/bytecode/vm/exceptions.rs b/crates/monty/src/bytecode/vm/exceptions.rs index f8737b16..963ab134 100644 --- a/crates/monty/src/bytecode/vm/exceptions.rs +++ b/crates/monty/src/bytecode/vm/exceptions.rs @@ -178,6 +178,7 @@ impl VM<'_, '_, T> { // No handler in this frame - pop frame and try outer if this.frames.len() <= 1 { + this.restore_saved_comprehension_slots_for_frame(0); // No more frames - exception is unhandled let is_spawned = this.is_spawned_task(); @@ -247,6 +248,7 @@ impl VM<'_, '_, T> { } } } + self.restore_saved_comprehension_slots_for_frame(0); error } diff --git a/crates/monty/src/bytecode/vm/mod.rs b/crates/monty/src/bytecode/vm/mod.rs index 333c49f5..d97edac7 100644 --- a/crates/monty/src/bytecode/vm/mod.rs +++ b/crates/monty/src/bytecode/vm/mod.rs @@ -507,6 +507,43 @@ pub struct VMSnapshot { /// /// Contains call ID counter, task state, pending calls, and resolved futures. scheduler: Scheduler, + + /// Saved comprehension slots that must survive pause/resume. + /// + /// Inlined comprehensions temporarily clear their synthetic local/global slots so + /// loop-variable state cannot leak after the comprehension exits. If execution pauses + /// mid-comprehension, these saved values must be serialized so resume can restore them + /// in the correct order. + saved_comprehension_slots: Vec, +} + +/// Storage class for a saved comprehension slot. +/// +/// Monty models inlined comprehensions with synthetic namespace slots. Saving whether a +/// slot lives in function locals or module globals lets the VM restore it correctly on the +/// normal path, during unwinding, and after snapshot restore. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +enum SavedComprehensionStorage { + /// Slot lives in the current frame's locals region on the operand stack. + Local, + /// Slot lives in the VM's persistent globals array. + Global, +} + +/// Saved value for a comprehension-local slot that has been temporarily cleared. +/// +/// The `frame_depth` tracks which frame owns the saved slot so the VM can restore pending +/// comprehension state automatically when a frame returns or unwinds with an exception. +#[derive(Debug, serde::Serialize, serde::Deserialize)] +struct SavedComprehensionSlot { + /// Zero-based depth of the frame that owns this saved slot. + frame_depth: usize, + /// Which storage area the slot belongs to. + storage: SavedComprehensionStorage, + /// Slot index within the storage area. + slot: u16, + /// Value that was present before the comprehension cleared the slot. + saved_value: Value, } // ============================================================================ @@ -583,6 +620,13 @@ pub struct VM<'h, 'a, T: ResourceTracker> { /// back to a `NameError`, so the traceback points to the name reference rather than /// the call expression. ext_function_load_ip: Option, + + /// Saved comprehension-local slots waiting to be restored. + /// + /// This mirrors CPython's save/restore behavior from PEP 709 for Monty's inlined + /// comprehensions. Values are restored explicitly by bytecode on the happy path, and + /// automatically when a frame exits early so temporary loop variables never leak. + saved_comprehension_slots: Vec, } impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { @@ -605,6 +649,7 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { scheduler: Scheduler::new(), ext_function_load_ip: None, // Set by LoadGlobalCallable/LoadLocalCallable module_code: None, + saved_comprehension_slots: Vec::new(), } } @@ -666,6 +711,7 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { scheduler: snapshot.scheduler, module_code: Some(module_code), ext_function_load_ip: None, + saved_comprehension_slots: snapshot.saved_comprehension_slots, } } /// Consumes the VM and creates a snapshot for pause/resume. @@ -686,6 +732,7 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { exception_stack: self.exception_stack, instruction_ip: self.instruction_ip, scheduler: self.scheduler, + saved_comprehension_slots: self.saved_comprehension_slots, } } @@ -855,6 +902,14 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { let slot = fetch_u16!(cached_frame); self.store_local(&cached_frame, slot); } + Opcode::SaveLocalAndClear => { + let slot = fetch_u16!(cached_frame); + self.save_local_and_clear(&cached_frame, slot); + } + Opcode::RestoreLocal => { + let slot = fetch_u16!(cached_frame); + self.restore_local(&cached_frame, slot); + } Opcode::DeleteLocal => { let slot = u16::from(fetch_u8!(cached_frame)); self.delete_local(&cached_frame, slot); @@ -888,6 +943,14 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { let slot = fetch_u16!(cached_frame); self.store_global(slot); } + Opcode::SaveGlobalAndClear => { + let slot = fetch_u16!(cached_frame); + self.save_global_and_clear(slot); + } + Opcode::RestoreGlobal => { + let slot = fetch_u16!(cached_frame); + self.restore_global(slot); + } // Variables - Cell Operations (closures) Opcode::LoadCell => { let slot = fetch_u16!(cached_frame); @@ -1452,6 +1515,7 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { if self.frames.len() == 1 { // Last frame - check if this is main task or spawned task let is_main_task = self.is_main_task(); + self.restore_saved_comprehension_slots_for_frame(self.frames.len() - 1); if is_main_task { // Module-level return - we're done @@ -1663,6 +1727,7 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { /// /// Returns `true` if this frame indicated evaluation should stop when popped. pub(super) fn pop_frame(&mut self) -> bool { + self.restore_saved_comprehension_slots_for_frame(self.frames.len() - 1); let frame = self.frames.pop().expect("no frame to pop"); self.cleanup_frame_state(&frame); // Sync instruction_ip to the parent frame so exception table lookups @@ -1698,6 +1763,7 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { /// Drains the stack with proper `drop_with_heap` for each value (since locals /// are inlined on the stack), then cleans up each frame's cell references. pub(super) fn cleanup_current_task(&mut self) { + self.restore_all_saved_comprehension_slots(); self.stack.drain(..).drop_with_heap(self.heap); self.frames.clear(); } @@ -1882,6 +1948,122 @@ impl<'h, 'a, T: ResourceTracker> VM<'h, 'a, T> { old_value.drop_with_heap(self); } + /// Saves a local slot for the current comprehension scope and clears it. + /// + /// The saved value is tracked outside the operand stack so Monty can keep its current + /// comprehension bytecode shape while still restoring temporary slots if execution exits + /// the comprehension early. + fn save_local_and_clear(&mut self, cached_frame: &CachedFrame<'a>, slot: u16) { + let frame_depth = self.frames.len() - 1; + let target = &mut self.stack[cached_frame.stack_base + slot as usize]; + let saved_value = mem::replace(target, Value::Undefined); + self.saved_comprehension_slots.push(SavedComprehensionSlot { + frame_depth, + storage: SavedComprehensionStorage::Local, + slot, + saved_value, + }); + } + + /// Restores the most recently saved local comprehension slot. + fn restore_local(&mut self, cached_frame: &CachedFrame<'a>, slot: u16) { + let saved = self + .saved_comprehension_slots + .pop() + .expect("RestoreLocal without matching SaveLocalAndClear"); + assert_eq!( + saved.frame_depth, + self.frames.len() - 1, + "RestoreLocal frame depth mismatch" + ); + assert_eq!( + saved.storage, + SavedComprehensionStorage::Local, + "RestoreLocal storage mismatch" + ); + assert_eq!(saved.slot, slot, "RestoreLocal slot mismatch"); + + let target = &mut self.stack[cached_frame.stack_base + slot as usize]; + let old_value = mem::replace(target, saved.saved_value); + old_value.drop_with_heap(self); + } + + /// Saves a global slot for the current comprehension scope and clears it. + fn save_global_and_clear(&mut self, slot: u16) { + let frame_depth = self.frames.len() - 1; + let target = &mut self.globals[slot as usize]; + let saved_value = mem::replace(target, Value::Undefined); + self.saved_comprehension_slots.push(SavedComprehensionSlot { + frame_depth, + storage: SavedComprehensionStorage::Global, + slot, + saved_value, + }); + } + + /// Restores the most recently saved global comprehension slot. + fn restore_global(&mut self, slot: u16) { + let saved = self + .saved_comprehension_slots + .pop() + .expect("RestoreGlobal without matching SaveGlobalAndClear"); + assert_eq!( + saved.frame_depth, + self.frames.len() - 1, + "RestoreGlobal frame depth mismatch" + ); + assert_eq!( + saved.storage, + SavedComprehensionStorage::Global, + "RestoreGlobal storage mismatch" + ); + assert_eq!(saved.slot, slot, "RestoreGlobal slot mismatch"); + + let target = &mut self.globals[slot as usize]; + let old_value = mem::replace(target, saved.saved_value); + old_value.drop_with_heap(self); + } + + /// Restores any comprehension slots still active for the given frame depth. + /// + /// This is the fallback that makes early returns, uncaught exceptions, and task cleanup + /// leave no synthetic comprehension values behind. + fn restore_saved_comprehension_slots_for_frame(&mut self, frame_depth: usize) { + while self + .saved_comprehension_slots + .last() + .is_some_and(|saved| saved.frame_depth == frame_depth) + { + let saved = self.saved_comprehension_slots.pop().expect("checked above"); + self.restore_saved_comprehension_slot(saved, frame_depth); + } + } + + /// Restores all saved comprehension slots, regardless of frame depth. + fn restore_all_saved_comprehension_slots(&mut self) { + while let Some(saved) = self.saved_comprehension_slots.pop() { + let frame_depth = saved.frame_depth; + self.restore_saved_comprehension_slot(saved, frame_depth); + } + } + + /// Restores one saved comprehension slot into its original storage location. + fn restore_saved_comprehension_slot(&mut self, saved: SavedComprehensionSlot, frame_depth: usize) { + match saved.storage { + SavedComprehensionStorage::Local => { + let frame = &self.frames[frame_depth]; + let target = &mut self.stack[frame.stack_base + saved.slot as usize]; + let old_value = mem::replace(target, saved.saved_value); + old_value.drop_with_heap(self); + } + SavedComprehensionStorage::Global => { + let target = &mut self.globals[saved.slot as usize]; + let old_value = mem::replace(target, saved.saved_value); + old_value.drop_with_heap(self); + } + } + } + /// Loads from a closure cell and pushes onto the stack. /// /// The cell `HeapId` is read from the frame's local variable slot on the stack diff --git a/crates/monty/tests/repl.rs b/crates/monty/tests/repl.rs index d51e2de9..d08cc64f 100644 --- a/crates/monty/tests/repl.rs +++ b/crates/monty/tests/repl.rs @@ -179,6 +179,42 @@ fn repl_start_external_call_resumes_to_updated_repl() { assert_eq!(feed_run_print(&mut repl, "x").unwrap(), MontyObject::Int(5)); } +#[test] +fn repl_feed_start_restores_comprehension_slots_before_next_turn() { + let (repl, _) = init_repl(""); + + let progress = repl + .feed_start( + "items = [i for i in [1]]\nitems = [i for i in [2]]\n", + vec![], + PrintWriter::Stdout, + ) + .unwrap(); + let (repl, value) = progress.into_complete().expect("expected completion"); + assert_eq!(value, MontyObject::None); + + let progress = repl.feed_start("foo()", vec![], PrintWriter::Stdout).unwrap(); + let call = progress.into_function_call().expect("expected function call"); + assert_eq!(call.function_name, "foo"); + assert!(call.args.is_empty()); + let _repl = call.into_repl(); +} + +#[test] +fn repl_feed_start_restores_comprehension_slots_after_runtime_error() { + let (repl, _) = init_repl(""); + + let err = repl + .feed_start("items = [1 / i for i in [0]]", vec![], PrintWriter::Stdout) + .expect_err("expected runtime error"); + + let progress = err.repl.feed_start("foo()", vec![], PrintWriter::Stdout).unwrap(); + let call = progress.into_function_call().expect("expected function call"); + assert_eq!(call.function_name, "foo"); + assert!(call.args.is_empty()); + let _repl = call.into_repl(); +} + #[test] fn repl_progress_dump_load_roundtrip() { let (repl, _) = init_repl("");