diff --git a/src/dict_manager.rs b/src/dict_manager.rs index dfd9e5a0..05ad715f 100644 --- a/src/dict_manager.rs +++ b/src/dict_manager.rs @@ -159,3 +159,369 @@ impl PyDictTracker { } } } + +#[cfg(test)] +mod tests { + use crate::{ids::PyIds, memory::PyMemory, utils::to_vm_error, vm_core::PyVM}; + use cairo_rs::{ + hint_processor::hint_processor_definition::HintReference, + serde::deserialize_program::{ApTracking, Member}, + types::relocatable::Relocatable, + types::{instruction::Register, relocatable::MaybeRelocatable}, + }; + use num_bigint::{BigInt, Sign}; + use pyo3::{types::PyDict, PyCell}; + + use super::*; + + #[test] + fn new_dict() { + Python::with_gil(|py| { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + for _ in 0..2 { + vm.vm.borrow_mut().add_memory_segment(); + } + + let dict_manager = PyDictManager::default(); + + let memory = PyMemory::new(&vm); + let ap = PyRelocatable::from(vm.vm.borrow().get_ap()); + let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm)); + + let globals = PyDict::new(py); + globals + .set_item("memory", PyCell::new(py, memory).unwrap()) + .unwrap(); + globals + .set_item("ap", PyCell::new(py, ap).unwrap()) + .unwrap(); + globals + .set_item("dict_manager", PyCell::new(py, dict_manager).unwrap()) + .unwrap(); + globals + .set_item("segments", PyCell::new(py, segment_manager).unwrap()) + .unwrap(); + + let code = r#" +memory[ap] = dict_manager.new_dict(segments, {}) +memory[ap + 1] = dict_manager.new_dict(segments, {}) +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!(py_result.map_err(to_vm_error), Ok(())); + + let mb_relocatable = vm.vm.borrow().get_maybe(&Relocatable::from((1, 0))); + assert_eq!( + mb_relocatable, + Ok(Some(MaybeRelocatable::RelocatableValue(Relocatable::from( + (2, 0) + )))) + ); + let mb_relocatable = vm.vm.borrow().get_maybe(&Relocatable::from((1, 1))); + assert_eq!( + mb_relocatable, + Ok(Some(MaybeRelocatable::RelocatableValue(Relocatable::from( + (3, 0) + )))) + ); + }); + } + + #[test] + fn tracker_read() { + Python::with_gil(|py| { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + for _ in 0..2 { + vm.vm.borrow_mut().add_memory_segment(); + } + + let dict_manager = PyDictManager::default(); + + let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm)); + + //Create references + let mut references = HashMap::new(); + references.insert( + String::from("dict"), + HintReference { + register: Some(Register::FP), + offset1: 0, + offset2: 0, + inner_dereference: false, + ap_tracking_data: None, + immediate: None, + dereference: true, + cairo_type: Some(String::from("DictAccess*")), + }, + ); + + let mut struct_types: HashMap> = HashMap::new(); + struct_types.insert(String::from("DictAccess"), HashMap::new()); + + let ids = PyIds::new( + &vm, + &references, + &ApTracking::default(), + &HashMap::new(), + Rc::new(struct_types), + ); + + let globals = PyDict::new(py); + globals + .set_item("dict_manager", PyCell::new(py, dict_manager).unwrap()) + .unwrap(); + globals + .set_item("ids", PyCell::new(py, ids).unwrap()) + .unwrap(); + globals + .set_item("segments", PyCell::new(py, segment_manager).unwrap()) + .unwrap(); + + let code = r#" +initial_dict = { 1: 2, 4: 8, 16: 32 } +ids.dict = dict_manager.new_dict(segments, initial_dict) +dict_tracker = dict_manager.get_tracker(ids.dict) +assert dict_tracker.data[1] == 2 +assert dict_tracker.data[4] == 8 +assert dict_tracker.data[16] == 32 +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!(py_result.map_err(to_vm_error), Ok(())); + }); + } + + #[test] + fn tracker_read_default_dict() { + Python::with_gil(|py| { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + for _ in 0..2 { + vm.vm.borrow_mut().add_memory_segment(); + } + + let dict_manager = PyDictManager::default(); + + let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm)); + + let mut references = HashMap::new(); + + // Create reference with type DictAccess* + references.insert( + String::from("dict"), + HintReference { + register: Some(Register::FP), + offset1: 0, + offset2: 0, + inner_dereference: false, + ap_tracking_data: None, + immediate: None, + dereference: true, + cairo_type: Some(String::from("DictAccess*")), + }, + ); + + let mut struct_types: HashMap> = HashMap::new(); + + // Create dummy type DictAccess + struct_types.insert(String::from("DictAccess"), HashMap::new()); + + let ids = PyIds::new( + &vm, + &references, + &ApTracking::default(), + &HashMap::new(), + Rc::new(struct_types), + ); + + let globals = PyDict::new(py); + globals + .set_item("dict_manager", PyCell::new(py, dict_manager).unwrap()) + .unwrap(); + globals + .set_item("ids", PyCell::new(py, ids).unwrap()) + .unwrap(); + globals + .set_item("segments", PyCell::new(py, segment_manager).unwrap()) + .unwrap(); + + let code = r#" +ids.dict = dict_manager.new_default_dict(segments, 42, {}) +dict_tracker = dict_manager.get_tracker(ids.dict) +assert dict_tracker.data[33] == 42 +assert dict_tracker.data[223] == 42 +assert dict_tracker.data[412] == 42 +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!(py_result.map_err(to_vm_error), Ok(())); + }); + } + + #[test] + fn tracker_write() { + Python::with_gil(|py| { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + for _ in 0..2 { + vm.vm.borrow_mut().add_memory_segment(); + } + + let dict_manager = PyDictManager::default(); + + let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm)); + + //Create references + let mut references = HashMap::new(); + references.insert( + String::from("dict"), + HintReference { + register: Some(Register::FP), + offset1: 0, + offset2: 0, + inner_dereference: false, + ap_tracking_data: None, + immediate: None, + dereference: true, + cairo_type: Some(String::from("DictAccess*")), + }, + ); + + let mut struct_types: HashMap> = HashMap::new(); + struct_types.insert(String::from("DictAccess"), HashMap::new()); + + let ids = PyIds::new( + &vm, + &references, + &ApTracking::default(), + &HashMap::new(), + Rc::new(struct_types), + ); + + let globals = PyDict::new(py); + globals + .set_item("dict_manager", PyCell::new(py, dict_manager).unwrap()) + .unwrap(); + globals + .set_item("ids", PyCell::new(py, ids).unwrap()) + .unwrap(); + globals + .set_item("segments", PyCell::new(py, segment_manager).unwrap()) + .unwrap(); + + let code = r#" +ids.dict = dict_manager.new_dict(segments, {}) +dict_tracker = dict_manager.get_tracker(ids.dict) + +dict_tracker.data[1] = 5 +assert dict_tracker.data[1] == 5 + +dict_tracker.data[1] = 22 +assert dict_tracker.data[1] == 22 +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!(py_result.map_err(to_vm_error), Ok(())); + }); + } + + #[test] + fn tracker_get_and_set_current_ptr() { + Python::with_gil(|py| { + let vm = PyVM::new( + BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]), + false, + ); + for _ in 0..2 { + vm.vm.borrow_mut().add_memory_segment(); + } + + let dict_manager = PyDictManager::default(); + + let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm)); + + let mut references = HashMap::new(); + + // Inserts `start_ptr` on references and memory + references.insert(String::from("start_ptr"), HintReference::new_simple(0)); + vm.vm + .borrow_mut() + .insert_value(&Relocatable::from((1, 0)), &MaybeRelocatable::from((2, 0))) + .unwrap(); + + // Inserts `end_ptr` on references and memory + references.insert(String::from("end_ptr"), HintReference::new_simple(1)); + vm.vm + .borrow_mut() + .insert_value(&Relocatable::from((1, 1)), &MaybeRelocatable::from((2, 1))) + .unwrap(); + + // Create reference with type DictAccess* + references.insert( + String::from("dict"), + HintReference { + register: Some(Register::FP), + offset1: 2, + offset2: 0, + inner_dereference: false, + ap_tracking_data: None, + immediate: None, + dereference: true, + cairo_type: Some(String::from("DictAccess*")), + }, + ); + + let mut struct_types: HashMap> = HashMap::new(); + + // Create dummy type DictAccess + struct_types.insert(String::from("DictAccess"), HashMap::new()); + + let ids = PyIds::new( + &vm, + &references, + &ApTracking::default(), + &HashMap::new(), + Rc::new(struct_types), + ); + + let globals = PyDict::new(py); + globals + .set_item("dict_manager", PyCell::new(py, dict_manager).unwrap()) + .unwrap(); + globals + .set_item("ids", PyCell::new(py, ids).unwrap()) + .unwrap(); + globals + .set_item("segments", PyCell::new(py, segment_manager).unwrap()) + .unwrap(); + + let code = r#" +ids.dict = dict_manager.new_dict(segments, {}) +dict_tracker = dict_manager.get_tracker(ids.dict) + +assert dict_tracker.current_ptr == ids.start_ptr + +dict_tracker.current_ptr = ids.end_ptr +assert dict_tracker.current_ptr == ids.end_ptr +"#; + + let py_result = py.run(code, Some(globals), None); + + assert_eq!(py_result.map_err(to_vm_error), Ok(())); + }); + } +}