Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion crates/monty-python/tests/test_repl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import Callable, Literal
import dataclasses
from typing import Callable, Literal, TypeAlias

import pytest
from inline_snapshot import snapshot

import pydantic_monty

PrintCallback = Callable[[Literal['stdout'], str], None]
ReplProgress: TypeAlias = (
pydantic_monty.FunctionSnapshot
| pydantic_monty.NameLookupSnapshot
| pydantic_monty.FutureSnapshot
| pydantic_monty.MontyComplete
)


def make_print_collector() -> tuple[list[str], PrintCallback]:
Expand Down Expand Up @@ -524,6 +531,52 @@ def test_feed_start_multiple_external_calls():
assert progress.output == snapshot(30)


def test_feed_start_regression_281_comprehension_dataclass_repl_state():
@dataclasses.dataclass
class Item:
text: str

@dataclasses.dataclass
class Container:
style: str = 'Normal'
items: list[Item] = dataclasses.field(default_factory=list)

def drive(state: ReplProgress) -> pydantic_monty.MontyComplete:
while not isinstance(state, pydantic_monty.MontyComplete):
if isinstance(state, pydantic_monty.NameLookupSnapshot):
state = state.resume()
elif isinstance(state, pydantic_monty.FunctionSnapshot):
fn = state.function_name
if fn == 'get_data':
data = [
Container(style='A', items=[Item(text='hello'), Item(text='world')]),
Container(style='B', items=[Item(text='foo')]),
]
state = state.resume(return_value=data)
elif fn == 'Container':
state = state.resume(return_value=Container(**state.kwargs))
else:
state = state.resume(return_value=None)
else:
state = state.resume({})
return state

repl = pydantic_monty.MontyRepl()
repl.register_dataclass(Item)
repl.register_dataclass(Container)

turn1 = repl.feed_start(
'data = get_data()\nitems = [i.text for i in data[0].items]\nitems = [i.text for i in data[1].items]\n'
)
turn1 = drive(turn1)
assert turn1.output == snapshot(None)

turn2 = repl.feed_start('c = Container(style="X")\n')
turn2 = drive(turn2)
assert turn2.output == snapshot(None)
assert repl.feed_run('c.style') == snapshot('X')


def test_feed_start_error_preserves_repl_state():
"""REPL state is preserved when feed_start raises an error."""
repl = pydantic_monty.MontyRepl()
Expand Down
70 changes: 70 additions & 0 deletions crates/monty/src/bytecode/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2379,6 +2379,7 @@ impl<'a> Compiler<'a> {
///
/// Bytecode structure:
/// ```text
/// SAVE_*_AND_CLEAR ... ; one per synthetic comprehension slot
/// BUILD_LIST 0 ; empty result
/// <compile first iter>
/// GET_ITER
Expand All @@ -2391,9 +2392,11 @@ impl<'a> Compiler<'a> {
/// LIST_APPEND depth
/// JUMP loop_start
/// end_loop:
/// RESTORE_* ... ; reverse order
/// ; result list on stack
/// ```
fn compile_list_comp(&mut self, elt: &ExprLoc, generators: &[Comprehension]) -> Result<(), CompileError> {
self.compile_comprehension_slot_saves(generators);
// Build empty list
self.code.emit_u16(Opcode::BuildList, 0);

Expand All @@ -2405,11 +2408,14 @@ impl<'a> Compiler<'a> {
Ok(())
})?;

self.compile_comprehension_slot_restores(generators);

Ok(())
}

/// Compiles a set comprehension: `{elt for target in iter if cond...}`
fn compile_set_comp(&mut self, elt: &ExprLoc, generators: &[Comprehension]) -> Result<(), CompileError> {
self.compile_comprehension_slot_saves(generators);
// Build empty set
self.code.emit_u16(Opcode::BuildSet, 0);

Expand All @@ -2421,6 +2427,8 @@ impl<'a> Compiler<'a> {
Ok(())
})?;

self.compile_comprehension_slot_restores(generators);

Ok(())
}

Expand All @@ -2431,6 +2439,7 @@ impl<'a> Compiler<'a> {
value: &ExprLoc,
generators: &[Comprehension],
) -> Result<(), CompileError> {
self.compile_comprehension_slot_saves(generators);
// Build empty dict
self.code.emit_u16(Opcode::BuildDict, 0);

Expand All @@ -2443,9 +2452,70 @@ impl<'a> Compiler<'a> {
Ok(())
})?;

self.compile_comprehension_slot_restores(generators);

Ok(())
}

/// Emits save-and-clear opcodes for every synthetic comprehension target slot.
///
/// Monty already inlines comprehensions, but their loop variables live in compiler-
/// generated slots. Saving and clearing those slots before execution makes their lifetime
/// match PEP 709 more closely and prevents stale values from leaking past the end of the
/// comprehension when execution completes incrementally.
fn compile_comprehension_slot_saves(&mut self, generators: &[Comprehension]) {
for slot in Self::comprehension_target_slots(generators) {
let opcode = if self.is_module_scope {
Opcode::SaveGlobalAndClear
} else {
Opcode::SaveLocalAndClear
};
self.code.emit_u16(opcode, slot);
}
}

/// Emits restore opcodes for every synthetic comprehension target slot.
///
/// Restores are emitted in reverse order so the VM's save stack is unwound LIFO.
fn compile_comprehension_slot_restores(&mut self, generators: &[Comprehension]) {
for slot in Self::comprehension_target_slots(generators).into_iter().rev() {
let opcode = if self.is_module_scope {
Opcode::RestoreGlobal
} else {
Opcode::RestoreLocal
};
self.code.emit_u16(opcode, slot);
}
}

/// Collects every namespace slot assigned by the comprehension generators.
///
/// Each target slot is compiler-generated and scoped to the comprehension, so saving all
/// of them is sufficient to restore the namespace even for nested generators and tuple
/// unpacking targets.
fn comprehension_target_slots(generators: &[Comprehension]) -> Vec<u16> {
let mut slots = Vec::new();
for generator in generators {
Self::collect_unpack_target_slots(&generator.target, &mut slots);
}
slots
}

/// Recursively collects namespace slots written by an unpack target.
fn collect_unpack_target_slots(target: &UnpackTarget, slots: &mut Vec<u16>) {
match target {
UnpackTarget::Name(ident) | UnpackTarget::Starred(ident) => {
let slot = u16::try_from(ident.namespace_id().index()).expect("local slot exceeds u16");
slots.push(slot);
}
UnpackTarget::Tuple { targets, .. } => {
for target in targets {
Self::collect_unpack_target_slots(target, slots);
}
}
}
}

/// Recursively compiles comprehension generators (the for/if clauses).
///
/// For each generator:
Expand Down
32 changes: 28 additions & 4 deletions crates/monty/src/bytecode/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,26 @@ pub enum Opcode {
/// Pops iterable (TOS), adds each item to set at stack position `len - 2 - depth`.
/// Raises `TypeError` if iterable is not iterable.
SetExtend,
/// Save a local slot to the VM's comprehension-save stack, then clear it.
/// Operand: u16 slot.
///
/// Used by inlined comprehensions to give their synthetic loop-variable slots the
/// same temporary lifetime as CPython's `LOAD_FAST_AND_CLEAR` behavior, without
/// exposing the saved value on the operand stack.
SaveLocalAndClear,
/// Restore the most recently saved local slot from the comprehension-save stack.
/// Operand: u16 slot.
RestoreLocal,
/// Save a global slot to the VM's comprehension-save stack, then clear it.
/// Operand: u16 slot.
///
/// Module-scope comprehensions use global slots in Monty's runtime model, so they
/// need an explicit save/restore pair to avoid leaking temporary loop values across
/// REPL turns and other incremental execution boundaries.
SaveGlobalAndClear,
/// Restore the most recently saved global slot from the comprehension-save stack.
/// Operand: u16 slot.
RestoreGlobal,
}

impl TryFrom<u8> for Opcode {
Expand Down Expand Up @@ -477,7 +497,7 @@ impl Opcode {
LoadLocal | LoadLocalW | LoadLocalCallable | LoadLocalCallableW | LoadGlobal | LoadGlobalCallable
| LoadCell => 1,
StoreLocal | StoreLocalW | StoreGlobal | StoreCell => -1,
DeleteLocal | DeleteGlobal => 0, // doesn't affect stack
DeleteLocal | DeleteGlobal | SaveLocalAndClear | RestoreLocal | SaveGlobalAndClear | RestoreGlobal => 0,

// Binary operations: pop 2, push 1 = -1
BinaryAdd | BinarySub | BinaryMul | BinaryDiv | BinaryFloorDiv | BinaryMod | BinaryPow | BinaryAnd
Expand Down Expand Up @@ -584,8 +604,8 @@ mod tests {

#[test]
fn test_opcode_roundtrip() {
// Verify that all opcodes from 0 to DeleteGlobal (last opcode) can be converted to u8 and back.
for byte in 0..=Opcode::DeleteGlobal as u8 {
// Verify that all opcodes from 0 to RestoreGlobal (last opcode) can be converted to u8 and back.
for byte in 0..=Opcode::RestoreGlobal as u8 {
let opcode = Opcode::try_from(byte).unwrap();
assert_eq!(opcode as u8, byte, "opcode {opcode:?} has wrong discriminant");
}
Expand All @@ -601,12 +621,16 @@ mod tests {
assert_eq!(Opcode::DeleteGlobal as u8, 112);
assert_eq!(Opcode::DictUpdate as u8, 113);
assert_eq!(Opcode::SetExtend as u8, 114);
assert_eq!(Opcode::SaveLocalAndClear as u8, 115);
assert_eq!(Opcode::RestoreLocal as u8, 116);
assert_eq!(Opcode::SaveGlobalAndClear as u8, 117);
assert_eq!(Opcode::RestoreGlobal as u8, 118);
}

#[test]
fn test_invalid_opcode() {
// Byte just after the last valid opcode should fail
let result = Opcode::try_from(Opcode::SetExtend as u8 + 1);
let result = Opcode::try_from(Opcode::RestoreGlobal as u8 + 1);
assert!(result.is_err());
// 255 should also fail
let result = Opcode::try_from(255u8);
Expand Down
2 changes: 2 additions & 0 deletions crates/monty/src/bytecode/vm/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ impl<T: ResourceTracker> VM<'_, '_, T> {

// No handler in this frame - pop frame and try outer
if this.frames.len() <= 1 {
this.restore_saved_comprehension_slots_for_frame(0);
// No more frames - exception is unhandled
let is_spawned = this.is_spawned_task();

Expand Down Expand Up @@ -247,6 +248,7 @@ impl<T: ResourceTracker> VM<'_, '_, T> {
}
}
}
self.restore_saved_comprehension_slots_for_frame(0);
error
}

Expand Down
Loading
Loading