diff --git a/src/lib.rs b/src/lib.rs index 3698875..09d4d07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -267,10 +267,7 @@ mod aux { } }; - for function in module - .get_functions() - .filter(|f| f.count_basic_blocks() > 0) - { + for function in module.get_functions() { for bb in function.get_basic_blocks() { for instr in bb.get_instructions() { let Ok(call) = CallSiteValue::try_from(instr) else { @@ -793,20 +790,7 @@ mod aux { } let call_args: Vec = extract_operands(instr)?; - let qubit_ptr = match call_args.as_slice() { - [BasicValueEnum::PointerValue(ptr), _] => *ptr, - [_, _] => { - return Err( - "Malformed mz_leaked call: expected first argument to be a pointer".into(), - ); - } - _ => { - return Err(format!( - "Malformed mz_leaked call: expected 1 argument plus callee, got {} operands", - call_args.len() - )); - } - }; + let qubit_ptr = mz_leaked_qubit_operand(&call_args)?; let q_handle = { let idx_fn = module @@ -878,6 +862,21 @@ mod aux { Ok(()) } + pub fn mz_leaked_qubit_operand<'ctx>( + call_args: &[BasicValueEnum<'ctx>], + ) -> Result, String> { + match call_args { + [BasicValueEnum::PointerValue(ptr), _] => Ok(*ptr), + [_, _] => { + Err("Malformed mz_leaked call: expected first argument to be a pointer".into()) + } + _ => Err(format!( + "Malformed mz_leaked call: expected 1 argument plus callee, got {} operands", + call_args.len() + )), + } + } + fn handle_reset_call(args: &ProcessCallArgs) -> Result<(), String> { let ProcessCallArgs { ctx, module, instr, .. @@ -2997,6 +2996,56 @@ declare void @__quantum__rt__result_record_output(%Result*, i8*) assert!(err.contains("Malformed mz_leaked call: expected signature i64 (ptr)")); } + #[test] + fn test_qir_to_qis_rejects_mz_leaked_with_wrong_return_width() { + let ll_text = minimal_qir_with_body( + "1", + "0", + "1", + "declare i1 @__quantum__qis__mz_leaked__body(%Qubit*)", + r" %0 = call i1 @__quantum__qis__mz_leaked__body(%Qubit* null)", + ); + + let bc_bytes = qir_ll_to_bc(&ll_text).expect("Failed to convert inline QIR to bitcode"); + let err = qir_to_qis(&bc_bytes, 0, "native", None) + .expect_err("mz_leaked with the wrong return width should fail cleanly"); + assert_eq!( + err, + "Malformed mz_leaked call: expected signature i64 (ptr)" + ); + } + + #[test] + fn test_qir_to_qis_rejects_mz_leaked_with_non_pointer_parameter() { + let ll_text = minimal_qir_with_body( + "1", + "0", + "1", + "declare i64 @__quantum__qis__mz_leaked__body(i64)", + r" %0 = call i64 @__quantum__qis__mz_leaked__body(i64 0)", + ); + + let bc_bytes = qir_ll_to_bc(&ll_text).expect("Failed to convert inline QIR to bitcode"); + let err = qir_to_qis(&bc_bytes, 0, "native", None) + .expect_err("mz_leaked with a non-pointer parameter should fail cleanly"); + assert_eq!( + err, + "Malformed mz_leaked call: expected signature i64 (ptr)" + ); + } + + #[test] + fn test_mz_leaked_operand_check_rejects_non_pointer_first_operand() { + let ctx = Context::create(); + let value = ctx.i64_type().const_zero().into(); + let err = crate::aux::mz_leaked_qubit_operand(&[value, value]) + .expect_err("non-pointer mz_leaked operands should fail cleanly"); + assert_eq!( + err, + "Malformed mz_leaked call: expected first argument to be a pointer" + ); + } + #[cfg(not(windows))] #[test] fn test_qir_to_qis_mz_leaked_lowers_via_uint_future_runtime() {