Skip to content

Commit f10e775

Browse files
authored
Add unit tests for Dict manager (#148)
* Add test for PyDictManager::new_dict * Add test reading from PyDictTracker * Add test writing to PyDictTracker * Add test reading from a default dict * Add test for current_ptr of PyDictTracker * Add more checks on previous tests
1 parent 255bc69 commit f10e775

File tree

1 file changed

+366
-0
lines changed

1 file changed

+366
-0
lines changed

src/dict_manager.rs

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,369 @@ impl PyDictTracker {
159159
}
160160
}
161161
}
162+
163+
#[cfg(test)]
164+
mod tests {
165+
use crate::{ids::PyIds, memory::PyMemory, utils::to_vm_error, vm_core::PyVM};
166+
use cairo_rs::{
167+
hint_processor::hint_processor_definition::HintReference,
168+
serde::deserialize_program::{ApTracking, Member},
169+
types::relocatable::Relocatable,
170+
types::{instruction::Register, relocatable::MaybeRelocatable},
171+
};
172+
use num_bigint::{BigInt, Sign};
173+
use pyo3::{types::PyDict, PyCell};
174+
175+
use super::*;
176+
177+
#[test]
178+
fn new_dict() {
179+
Python::with_gil(|py| {
180+
let vm = PyVM::new(
181+
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
182+
false,
183+
);
184+
for _ in 0..2 {
185+
vm.vm.borrow_mut().add_memory_segment();
186+
}
187+
188+
let dict_manager = PyDictManager::default();
189+
190+
let memory = PyMemory::new(&vm);
191+
let ap = PyRelocatable::from(vm.vm.borrow().get_ap());
192+
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));
193+
194+
let globals = PyDict::new(py);
195+
globals
196+
.set_item("memory", PyCell::new(py, memory).unwrap())
197+
.unwrap();
198+
globals
199+
.set_item("ap", PyCell::new(py, ap).unwrap())
200+
.unwrap();
201+
globals
202+
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
203+
.unwrap();
204+
globals
205+
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
206+
.unwrap();
207+
208+
let code = r#"
209+
memory[ap] = dict_manager.new_dict(segments, {})
210+
memory[ap + 1] = dict_manager.new_dict(segments, {})
211+
"#;
212+
213+
let py_result = py.run(code, Some(globals), None);
214+
215+
assert_eq!(py_result.map_err(to_vm_error), Ok(()));
216+
217+
let mb_relocatable = vm.vm.borrow().get_maybe(&Relocatable::from((1, 0)));
218+
assert_eq!(
219+
mb_relocatable,
220+
Ok(Some(MaybeRelocatable::RelocatableValue(Relocatable::from(
221+
(2, 0)
222+
))))
223+
);
224+
let mb_relocatable = vm.vm.borrow().get_maybe(&Relocatable::from((1, 1)));
225+
assert_eq!(
226+
mb_relocatable,
227+
Ok(Some(MaybeRelocatable::RelocatableValue(Relocatable::from(
228+
(3, 0)
229+
))))
230+
);
231+
});
232+
}
233+
234+
#[test]
235+
fn tracker_read() {
236+
Python::with_gil(|py| {
237+
let vm = PyVM::new(
238+
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
239+
false,
240+
);
241+
for _ in 0..2 {
242+
vm.vm.borrow_mut().add_memory_segment();
243+
}
244+
245+
let dict_manager = PyDictManager::default();
246+
247+
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));
248+
249+
//Create references
250+
let mut references = HashMap::new();
251+
references.insert(
252+
String::from("dict"),
253+
HintReference {
254+
register: Some(Register::FP),
255+
offset1: 0,
256+
offset2: 0,
257+
inner_dereference: false,
258+
ap_tracking_data: None,
259+
immediate: None,
260+
dereference: true,
261+
cairo_type: Some(String::from("DictAccess*")),
262+
},
263+
);
264+
265+
let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
266+
struct_types.insert(String::from("DictAccess"), HashMap::new());
267+
268+
let ids = PyIds::new(
269+
&vm,
270+
&references,
271+
&ApTracking::default(),
272+
&HashMap::new(),
273+
Rc::new(struct_types),
274+
);
275+
276+
let globals = PyDict::new(py);
277+
globals
278+
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
279+
.unwrap();
280+
globals
281+
.set_item("ids", PyCell::new(py, ids).unwrap())
282+
.unwrap();
283+
globals
284+
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
285+
.unwrap();
286+
287+
let code = r#"
288+
initial_dict = { 1: 2, 4: 8, 16: 32 }
289+
ids.dict = dict_manager.new_dict(segments, initial_dict)
290+
dict_tracker = dict_manager.get_tracker(ids.dict)
291+
assert dict_tracker.data[1] == 2
292+
assert dict_tracker.data[4] == 8
293+
assert dict_tracker.data[16] == 32
294+
"#;
295+
296+
let py_result = py.run(code, Some(globals), None);
297+
298+
assert_eq!(py_result.map_err(to_vm_error), Ok(()));
299+
});
300+
}
301+
302+
#[test]
303+
fn tracker_read_default_dict() {
304+
Python::with_gil(|py| {
305+
let vm = PyVM::new(
306+
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
307+
false,
308+
);
309+
for _ in 0..2 {
310+
vm.vm.borrow_mut().add_memory_segment();
311+
}
312+
313+
let dict_manager = PyDictManager::default();
314+
315+
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));
316+
317+
let mut references = HashMap::new();
318+
319+
// Create reference with type DictAccess*
320+
references.insert(
321+
String::from("dict"),
322+
HintReference {
323+
register: Some(Register::FP),
324+
offset1: 0,
325+
offset2: 0,
326+
inner_dereference: false,
327+
ap_tracking_data: None,
328+
immediate: None,
329+
dereference: true,
330+
cairo_type: Some(String::from("DictAccess*")),
331+
},
332+
);
333+
334+
let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
335+
336+
// Create dummy type DictAccess
337+
struct_types.insert(String::from("DictAccess"), HashMap::new());
338+
339+
let ids = PyIds::new(
340+
&vm,
341+
&references,
342+
&ApTracking::default(),
343+
&HashMap::new(),
344+
Rc::new(struct_types),
345+
);
346+
347+
let globals = PyDict::new(py);
348+
globals
349+
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
350+
.unwrap();
351+
globals
352+
.set_item("ids", PyCell::new(py, ids).unwrap())
353+
.unwrap();
354+
globals
355+
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
356+
.unwrap();
357+
358+
let code = r#"
359+
ids.dict = dict_manager.new_default_dict(segments, 42, {})
360+
dict_tracker = dict_manager.get_tracker(ids.dict)
361+
assert dict_tracker.data[33] == 42
362+
assert dict_tracker.data[223] == 42
363+
assert dict_tracker.data[412] == 42
364+
"#;
365+
366+
let py_result = py.run(code, Some(globals), None);
367+
368+
assert_eq!(py_result.map_err(to_vm_error), Ok(()));
369+
});
370+
}
371+
372+
#[test]
373+
fn tracker_write() {
374+
Python::with_gil(|py| {
375+
let vm = PyVM::new(
376+
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
377+
false,
378+
);
379+
for _ in 0..2 {
380+
vm.vm.borrow_mut().add_memory_segment();
381+
}
382+
383+
let dict_manager = PyDictManager::default();
384+
385+
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));
386+
387+
//Create references
388+
let mut references = HashMap::new();
389+
references.insert(
390+
String::from("dict"),
391+
HintReference {
392+
register: Some(Register::FP),
393+
offset1: 0,
394+
offset2: 0,
395+
inner_dereference: false,
396+
ap_tracking_data: None,
397+
immediate: None,
398+
dereference: true,
399+
cairo_type: Some(String::from("DictAccess*")),
400+
},
401+
);
402+
403+
let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
404+
struct_types.insert(String::from("DictAccess"), HashMap::new());
405+
406+
let ids = PyIds::new(
407+
&vm,
408+
&references,
409+
&ApTracking::default(),
410+
&HashMap::new(),
411+
Rc::new(struct_types),
412+
);
413+
414+
let globals = PyDict::new(py);
415+
globals
416+
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
417+
.unwrap();
418+
globals
419+
.set_item("ids", PyCell::new(py, ids).unwrap())
420+
.unwrap();
421+
globals
422+
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
423+
.unwrap();
424+
425+
let code = r#"
426+
ids.dict = dict_manager.new_dict(segments, {})
427+
dict_tracker = dict_manager.get_tracker(ids.dict)
428+
429+
dict_tracker.data[1] = 5
430+
assert dict_tracker.data[1] == 5
431+
432+
dict_tracker.data[1] = 22
433+
assert dict_tracker.data[1] == 22
434+
"#;
435+
436+
let py_result = py.run(code, Some(globals), None);
437+
438+
assert_eq!(py_result.map_err(to_vm_error), Ok(()));
439+
});
440+
}
441+
442+
#[test]
443+
fn tracker_get_and_set_current_ptr() {
444+
Python::with_gil(|py| {
445+
let vm = PyVM::new(
446+
BigInt::new(Sign::Plus, vec![1, 0, 0, 0, 0, 0, 17, 134217728]),
447+
false,
448+
);
449+
for _ in 0..2 {
450+
vm.vm.borrow_mut().add_memory_segment();
451+
}
452+
453+
let dict_manager = PyDictManager::default();
454+
455+
let segment_manager = PySegmentManager::new(&vm, PyMemory::new(&vm));
456+
457+
let mut references = HashMap::new();
458+
459+
// Inserts `start_ptr` on references and memory
460+
references.insert(String::from("start_ptr"), HintReference::new_simple(0));
461+
vm.vm
462+
.borrow_mut()
463+
.insert_value(&Relocatable::from((1, 0)), &MaybeRelocatable::from((2, 0)))
464+
.unwrap();
465+
466+
// Inserts `end_ptr` on references and memory
467+
references.insert(String::from("end_ptr"), HintReference::new_simple(1));
468+
vm.vm
469+
.borrow_mut()
470+
.insert_value(&Relocatable::from((1, 1)), &MaybeRelocatable::from((2, 1)))
471+
.unwrap();
472+
473+
// Create reference with type DictAccess*
474+
references.insert(
475+
String::from("dict"),
476+
HintReference {
477+
register: Some(Register::FP),
478+
offset1: 2,
479+
offset2: 0,
480+
inner_dereference: false,
481+
ap_tracking_data: None,
482+
immediate: None,
483+
dereference: true,
484+
cairo_type: Some(String::from("DictAccess*")),
485+
},
486+
);
487+
488+
let mut struct_types: HashMap<String, HashMap<String, Member>> = HashMap::new();
489+
490+
// Create dummy type DictAccess
491+
struct_types.insert(String::from("DictAccess"), HashMap::new());
492+
493+
let ids = PyIds::new(
494+
&vm,
495+
&references,
496+
&ApTracking::default(),
497+
&HashMap::new(),
498+
Rc::new(struct_types),
499+
);
500+
501+
let globals = PyDict::new(py);
502+
globals
503+
.set_item("dict_manager", PyCell::new(py, dict_manager).unwrap())
504+
.unwrap();
505+
globals
506+
.set_item("ids", PyCell::new(py, ids).unwrap())
507+
.unwrap();
508+
globals
509+
.set_item("segments", PyCell::new(py, segment_manager).unwrap())
510+
.unwrap();
511+
512+
let code = r#"
513+
ids.dict = dict_manager.new_dict(segments, {})
514+
dict_tracker = dict_manager.get_tracker(ids.dict)
515+
516+
assert dict_tracker.current_ptr == ids.start_ptr
517+
518+
dict_tracker.current_ptr = ids.end_ptr
519+
assert dict_tracker.current_ptr == ids.end_ptr
520+
"#;
521+
522+
let py_result = py.run(code, Some(globals), None);
523+
524+
assert_eq!(py_result.map_err(to_vm_error), Ok(()));
525+
});
526+
}
527+
}

0 commit comments

Comments
 (0)