diff --git a/src/dict_manager.rs b/src/dict_manager.rs index 05ad715f..ec21f08e 100644 --- a/src/dict_manager.rs +++ b/src/dict_manager.rs @@ -164,10 +164,12 @@ impl PyDictTracker { mod tests { use crate::{ids::PyIds, memory::PyMemory, utils::to_vm_error, vm_core::PyVM}; use cairo_rs::{ + bigint, hint_processor::hint_processor_definition::HintReference, serde::deserialize_program::{ApTracking, Member}, types::relocatable::Relocatable, types::{instruction::Register, relocatable::MaybeRelocatable}, + vm::errors::vm_errors::VirtualMachineError, }; use num_bigint::{BigInt, Sign}; use pyo3::{types::PyDict, PyCell}; @@ -439,6 +441,114 @@ assert dict_tracker.data[1] == 22 }); } + #[test] + fn tracker_read_and_write_invalid_key() { + Python::with_gil(|py| { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + for _ in 0..2 { + vm.vm.borrow_mut().add_memory_segment(); + } + + let dict_manager = PyDictManager::default(); + + let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm)); + + //Create references + let mut references = HashMap::new(); + references.insert( + String::from("dict"), + HintReference { + register: Some(Register::FP), + offset1: 0, + offset2: 0, + inner_dereference: false, + ap_tracking_data: None, + immediate: None, + dereference: true, + cairo_type: Some(String::from("DictAccess*")), + }, + ); + // Create ids.a + references.insert(String::from("a"), HintReference::new_simple(1)); + + //Insert ids.a into memory + vm.vm + .borrow_mut() + .insert_value( + &Relocatable::from((1, 1)), + &MaybeRelocatable::from((128, 64)), + ) + .unwrap(); + + let mut struct_types: HashMap> = HashMap::new(); + struct_types.insert(String::from("DictAccess"), HashMap::new()); + + let ids = PyIds::new( + &vm, + &references, + &ApTracking::default(), + &HashMap::new(), + Rc::new(struct_types), + ); + + let globals = PyDict::new(py); + globals + .set_item("dict_manager", PyCell::new(py, dict_manager).unwrap()) + .unwrap(); + globals + .set_item("ids", PyCell::new(py, ids).unwrap()) + .unwrap(); + globals + .set_item("segments", PyCell::new(py, segment_manager).unwrap()) + .unwrap(); + + let code = r#" +initial_dict = { 1: 2, 4: 8, 16: 32 } +ids.dict = dict_manager.new_dict(segments, initial_dict) +dict_tracker = dict_manager.get_tracker(ids.dict) +dict_tracker.data[3] +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!( + py_result.map_err(to_vm_error), + Err(to_vm_error(to_py_error( + VirtualMachineError::NoValueForKey(bigint!(3)) + ))), + ); + + let code = r#" +dict_tracker = dict_manager.get_tracker(ids.dict) +dict_tracker.data[ids.a] +"#; + + let py_result = py.run(code, Some(globals), None); + let key = PyMaybeRelocatable::from(PyRelocatable::from((128, 64))); + + assert_eq!( + py_result.map_err(to_vm_error), + Err(PyKeyError::new_err(key.to_object(py))).map_err(to_vm_error), + ); + + let code = r#" +dict_tracker = dict_manager.get_tracker(ids.dict) +dict_tracker.data[ids.a] = 5 +"#; + + let py_result = py.run(code, Some(globals), None); + let key = PyMaybeRelocatable::from(PyRelocatable::from((128, 64))); + + assert_eq!( + py_result.map_err(to_vm_error), + Err(PyKeyError::new_err(key.to_object(py))).map_err(to_vm_error), + ); + }); + } + #[test] fn tracker_get_and_set_current_ptr() { Python::with_gil(|py| { @@ -524,4 +634,105 @@ assert dict_tracker.current_ptr == ids.end_ptr assert_eq!(py_result.map_err(to_vm_error), Ok(())); }); } + + #[test] + fn manager_get_tracker_invalid_dict_ptr() { + Python::with_gil(|py| { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + for _ in 0..2 { + vm.vm.borrow_mut().add_memory_segment(); + } + + let dict_manager = PyDictManager::default(); + + let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm)); + + //Create references + let mut references = HashMap::new(); + references.insert( + String::from("dict"), + HintReference { + register: Some(Register::FP), + offset1: 0, + offset2: 0, + inner_dereference: false, + ap_tracking_data: None, + immediate: None, + dereference: true, + cairo_type: Some(String::from("DictAccess*")), + }, + ); + references.insert( + String::from("no_dict"), + HintReference { + register: Some(Register::FP), + offset1: 0, + offset2: 0, + inner_dereference: false, + ap_tracking_data: None, + immediate: None, + dereference: true, + cairo_type: Some(String::from("DictAccess")), + }, + ); + + let mut struct_types: HashMap> = HashMap::new(); + struct_types.insert(String::from("DictAccess"), HashMap::new()); + + let ids = PyIds::new( + &vm, + &references, + &ApTracking::default(), + &HashMap::new(), + Rc::new(struct_types), + ); + + let globals = PyDict::new(py); + globals + .set_item("dict_manager", PyCell::new(py, dict_manager).unwrap()) + .unwrap(); + globals + .set_item("ids", PyCell::new(py, ids).unwrap()) + .unwrap(); + globals + .set_item("segments", PyCell::new(py, segment_manager).unwrap()) + .unwrap(); + + let code = r#" +ids.dict = dict_manager.new_dict(segments, {}) +dict_tracker = dict_manager.get_tracker(ids.no_dict) +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!( + py_result.map_err(to_vm_error), + Err(to_vm_error(to_py_error( + VirtualMachineError::NoDictTracker(vm.vm.borrow().get_fp().segment_index), + ))), + ); + + let code = r#" +dict_tracker = dict_manager.get_tracker(ids.dict) +dict_tracker.current_ptr = dict_tracker.current_ptr + 3 + +dict_tracker = dict_manager.get_tracker(ids.dict) +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!( + py_result.map_err(to_vm_error), + Err(to_vm_error(to_py_error( + VirtualMachineError::MismatchedDictPtr( + Relocatable::from((2, 3)), + Relocatable::from((2, 0)), + ), + ))), + ); + }); + } }