Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for Dict manager #148

Merged
merged 6 commits into from
Nov 18, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
366 changes: 366 additions & 0 deletions src/dict_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,369 @@ impl PyDictTracker {
}
}
}

#[cfg(test)]
mod tests {
use crate::{ids::PyIds, memory::PyMemory, utils::to_vm_error, vm_core::PyVM};
use cairo_rs::{
hint_processor::hint_processor_definition::HintReference,
serde::deserialize_program::{ApTracking, Member},
types::relocatable::Relocatable,
types::{instruction::Register, relocatable::MaybeRelocatable},
};
use num_bigint::{BigInt, Sign};
use pyo3::{types::PyDict, PyCell};

use super::*;

#[test]
fn new_dict() {
Python::with_gil(|py| {
let vm = PyVM::new(
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
false,
);
for _ in 0..2 {
vm.vm.borrow_mut().add_memory_segment();
}

let dict_manager = PyDictManager::default();

let memory = PyMemory::new(&vm);
let ap = PyRelocatable::from(vm.vm.borrow().get_ap());
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));

let globals = PyDict::new(py);
globals
.set_item("memory", PyCell::new(py, memory).unwrap())
.unwrap();
globals
.set_item("ap", PyCell::new(py, ap).unwrap())
.unwrap();
globals
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
.unwrap();
globals
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
.unwrap();

let code = r#"
memory[ap] = dict_manager.new_dict(segments, {})
memory[ap + 1] = dict_manager.new_dict(segments, {})
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(py_result.map_err(to_vm_error), Ok(()));

let mb_relocatable = vm.vm.borrow().get_maybe(&Relocatable::from((1, 0)));
assert_eq!(
mb_relocatable,
Ok(Some(MaybeRelocatable::RelocatableValue(Relocatable::from(
(2, 0)
))))
);
let mb_relocatable = vm.vm.borrow().get_maybe(&Relocatable::from((1, 1)));
assert_eq!(
mb_relocatable,
Ok(Some(MaybeRelocatable::RelocatableValue(Relocatable::from(
(3, 0)
))))
);
});
}

#[test]
fn tracker_read() {
Python::with_gil(|py| {
let vm = PyVM::new(
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
false,
);
for _ in 0..2 {
vm.vm.borrow_mut().add_memory_segment();
}

let dict_manager = PyDictManager::default();

let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));

//Create references
let mut references = HashMap::new();
references.insert(
String::from("dict"),
HintReference {
register: Some(Register::FP),
offset1: 0,
offset2: 0,
inner_dereference: false,
ap_tracking_data: None,
immediate: None,
dereference: true,
cairo_type: Some(String::from("DictAccess*")),
},
);

let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
struct_types.insert(String::from("DictAccess"), HashMap::new());

let ids = PyIds::new(
&vm,
&references,
&ApTracking::default(),
&HashMap::new(),
Rc::new(struct_types),
);

let globals = PyDict::new(py);
globals
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
.unwrap();
globals
.set_item("ids", PyCell::new(py, ids).unwrap())
.unwrap();
globals
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
.unwrap();

let code = r#"
initial_dict = { 1: 2, 4: 8, 16: 32 }
ids.dict = dict_manager.new_dict(segments, initial_dict)
dict_tracker = dict_manager.get_tracker(ids.dict)
assert dict_tracker.data[1] == 2
assert dict_tracker.data[4] == 8
assert dict_tracker.data[16] == 32
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(py_result.map_err(to_vm_error), Ok(()));
});
}

#[test]
fn tracker_read_default_dict() {
Python::with_gil(|py| {
let vm = PyVM::new(
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
false,
);
for _ in 0..2 {
vm.vm.borrow_mut().add_memory_segment();
}

let dict_manager = PyDictManager::default();

let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));

let mut references = HashMap::new();

// Create reference with type DictAccess*
references.insert(
String::from("dict"),
HintReference {
register: Some(Register::FP),
offset1: 0,
offset2: 0,
inner_dereference: false,
ap_tracking_data: None,
immediate: None,
dereference: true,
cairo_type: Some(String::from("DictAccess*")),
},
);

let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();

// Create dummy type DictAccess
struct_types.insert(String::from("DictAccess"), HashMap::new());

let ids = PyIds::new(
&vm,
&references,
&ApTracking::default(),
&HashMap::new(),
Rc::new(struct_types),
);

let globals = PyDict::new(py);
globals
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
.unwrap();
globals
.set_item("ids", PyCell::new(py, ids).unwrap())
.unwrap();
globals
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
.unwrap();

let code = r#"
ids.dict = dict_manager.new_default_dict(segments, 42, {})
dict_tracker = dict_manager.get_tracker(ids.dict)
assert dict_tracker.data[33] == 42
assert dict_tracker.data[223] == 42
assert dict_tracker.data[412] == 42
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(py_result.map_err(to_vm_error), Ok(()));
});
}

#[test]
fn tracker_write() {
Python::with_gil(|py| {
let vm = PyVM::new(
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
false,
);
for _ in 0..2 {
vm.vm.borrow_mut().add_memory_segment();
}

let dict_manager = PyDictManager::default();

let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));

//Create references
let mut references = HashMap::new();
references.insert(
String::from("dict"),
HintReference {
register: Some(Register::FP),
offset1: 0,
offset2: 0,
inner_dereference: false,
ap_tracking_data: None,
immediate: None,
dereference: true,
cairo_type: Some(String::from("DictAccess*")),
},
);

let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
struct_types.insert(String::from("DictAccess"), HashMap::new());

let ids = PyIds::new(
&vm,
&references,
&ApTracking::default(),
&HashMap::new(),
Rc::new(struct_types),
);

let globals = PyDict::new(py);
globals
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
.unwrap();
globals
.set_item("ids", PyCell::new(py, ids).unwrap())
.unwrap();
globals
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
.unwrap();

let code = r#"
ids.dict = dict_manager.new_dict(segments, {})
dict_tracker = dict_manager.get_tracker(ids.dict)

dict_tracker.data[1] = 5
assert dict_tracker.data[1] == 5

dict_tracker.data[1] = 22
assert dict_tracker.data[1] == 22
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(py_result.map_err(to_vm_error), Ok(()));
});
}

#[test]
fn tracker_get_and_set_current_ptr() {
Python::with_gil(|py| {
let vm = PyVM::new(
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
false,
);
for _ in 0..2 {
vm.vm.borrow_mut().add_memory_segment();
}

let dict_manager = PyDictManager::default();

let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));

let mut references = HashMap::new();

// Inserts `start_ptr` on references and memory
references.insert(String::from("start_ptr"), HintReference::new_simple(0));
vm.vm
.borrow_mut()
.insert_value(&Relocatable::from((1, 0)), &MaybeRelocatable::from((2, 0)))
.unwrap();

// Inserts `end_ptr` on references and memory
references.insert(String::from("end_ptr"), HintReference::new_simple(1));
vm.vm
.borrow_mut()
.insert_value(&Relocatable::from((1, 1)), &MaybeRelocatable::from((2, 1)))
.unwrap();

// Create reference with type DictAccess*
references.insert(
String::from("dict"),
HintReference {
register: Some(Register::FP),
offset1: 2,
offset2: 0,
inner_dereference: false,
ap_tracking_data: None,
immediate: None,
dereference: true,
cairo_type: Some(String::from("DictAccess*")),
},
);

let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();

// Create dummy type DictAccess
struct_types.insert(String::from("DictAccess"), HashMap::new());

let ids = PyIds::new(
&vm,
&references,
&ApTracking::default(),
&HashMap::new(),
Rc::new(struct_types),
);

let globals = PyDict::new(py);
globals
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
.unwrap();
globals
.set_item("ids", PyCell::new(py, ids).unwrap())
.unwrap();
globals
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
.unwrap();

let code = r#"
ids.dict = dict_manager.new_dict(segments, {})
dict_tracker = dict_manager.get_tracker(ids.dict)

assert dict_tracker.current_ptr == ids.start_ptr

dict_tracker.current_ptr = ids.end_ptr
assert dict_tracker.current_ptr == ids.end_ptr
"#;

let py_result = py.run(code, Some(globals), None);

assert_eq!(py_result.map_err(to_vm_error), Ok(()));
});
}
}