diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index a743577f76..fa31e998d9 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -293,6 +293,7 @@ impl FixedEvaluator for &BlockMachineProcessor<'_, T> { mod test { use std::fs::read_to_string; + use pretty_assertions::assert_eq; use test_log::test; use powdr_number::GoldilocksField; @@ -420,8 +421,11 @@ main_binary::B_byte[2] = ((main_binary::B[3] & 4278190080) // 16777216); main_binary::B[2] = (main_binary::B[3] & 16777215); assert (main_binary::B[3] & 18446744069414584320) == 0; main_binary::operation_id_next[2] = main_binary::operation_id[3]; -machine_call(9, [Known(main_binary::operation_id_next[2]), Known(main_binary::A_byte[2]), Known(main_binary::B_byte[2]), Unknown(ret(9, 2, 3))]); -main_binary::C_byte[2] = ret(9, 2, 3); +call_var(9, 2, 0) = main_binary::operation_id_next[2]; +call_var(9, 2, 1) = main_binary::A_byte[2]; +call_var(9, 2, 2) = main_binary::B_byte[2]; +machine_call(9, [Known(call_var(9, 2, 0)), Known(call_var(9, 2, 1)), Known(call_var(9, 2, 2)), Unknown(call_var(9, 2, 3))]); +main_binary::C_byte[2] = call_var(9, 2, 3); main_binary::operation_id[1] = main_binary::operation_id[2]; main_binary::A_byte[1] = ((main_binary::A[2] & 16711680) // 65536); main_binary::A[1] = (main_binary::A[2] & 65535); @@ -430,8 +434,11 @@ main_binary::B_byte[1] = ((main_binary::B[2] & 16711680) // 65536); main_binary::B[1] = (main_binary::B[2] & 65535); assert (main_binary::B[2] & 18446744073692774400) == 0; main_binary::operation_id_next[1] = main_binary::operation_id[2]; -machine_call(9, [Known(main_binary::operation_id_next[1]), Known(main_binary::A_byte[1]), Known(main_binary::B_byte[1]), Unknown(ret(9, 1, 3))]); -main_binary::C_byte[1] = ret(9, 1, 3); +call_var(9, 1, 0) = main_binary::operation_id_next[1]; +call_var(9, 1, 1) = main_binary::A_byte[1]; +call_var(9, 1, 2) = main_binary::B_byte[1]; +machine_call(9, [Known(call_var(9, 1, 0)), Known(call_var(9, 1, 1)), Known(call_var(9, 1, 2)), Unknown(call_var(9, 1, 3))]); +main_binary::C_byte[1] = call_var(9, 1, 3); main_binary::operation_id[0] = main_binary::operation_id[1]; main_binary::A_byte[0] = ((main_binary::A[1] & 65280) // 256); main_binary::A[0] = (main_binary::A[1] & 255); @@ -440,13 +447,19 @@ main_binary::B_byte[0] = ((main_binary::B[1] & 65280) // 256); main_binary::B[0] = (main_binary::B[1] & 255); assert (main_binary::B[1] & 18446744073709486080) == 0; main_binary::operation_id_next[0] = main_binary::operation_id[1]; -machine_call(9, [Known(main_binary::operation_id_next[0]), Known(main_binary::A_byte[0]), Known(main_binary::B_byte[0]), Unknown(ret(9, 0, 3))]); -main_binary::C_byte[0] = ret(9, 0, 3); +call_var(9, 0, 0) = main_binary::operation_id_next[0]; +call_var(9, 0, 1) = main_binary::A_byte[0]; +call_var(9, 0, 2) = main_binary::B_byte[0]; +machine_call(9, [Known(call_var(9, 0, 0)), Known(call_var(9, 0, 1)), Known(call_var(9, 0, 2)), Unknown(call_var(9, 0, 3))]); +main_binary::C_byte[0] = call_var(9, 0, 3); main_binary::A_byte[-1] = main_binary::A[0]; main_binary::B_byte[-1] = main_binary::B[0]; main_binary::operation_id_next[-1] = main_binary::operation_id[0]; -machine_call(9, [Known(main_binary::operation_id_next[-1]), Known(main_binary::A_byte[-1]), Known(main_binary::B_byte[-1]), Unknown(ret(9, -1, 3))]); -main_binary::C_byte[-1] = ret(9, -1, 3); +call_var(9, -1, 0) = main_binary::operation_id_next[-1]; +call_var(9, -1, 1) = main_binary::A_byte[-1]; +call_var(9, -1, 2) = main_binary::B_byte[-1]; +machine_call(9, [Known(call_var(9, -1, 0)), Known(call_var(9, -1, 1)), Known(call_var(9, -1, 2)), Unknown(call_var(9, -1, 3))]); +main_binary::C_byte[-1] = call_var(9, -1, 3); main_binary::C[0] = main_binary::C_byte[-1]; main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256)); main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536)); diff --git a/executor/src/witgen/jit/compiler.rs b/executor/src/witgen/jit/compiler.rs index 61019c9da4..582d2234b7 100644 --- a/executor/src/witgen/jit/compiler.rs +++ b/executor/src/witgen/jit/compiler.rs @@ -7,7 +7,6 @@ use powdr_number::{FieldElement, KnownField}; use crate::witgen::{ data_structures::{finalizable_data::CompactDataRef, mutable_state::MutableState}, - jit::effect::MachineCallArgument, machines::{ profiling::{record_end, record_start}, LookupCell, @@ -150,8 +149,8 @@ fn witgen_code( format!("get(data, row_offset, {}, {})", c.row_offset, c.id) } Variable::Param(i) => format!("get_param(params, {i})"), - Variable::MachineCallReturnValue(_) => { - unreachable!("Machine call return values should not be pre-known.") + Variable::MachineCallParam(_) => { + unreachable!("Machine call variables should not be pre-known.") } }; format!(" let {var_name} = {value};") @@ -173,7 +172,7 @@ fn witgen_code( cell.row_offset, cell.id, )), Variable::Param(i) => Some(format!(" set_param(params, {i}, {value});")), - Variable::MachineCallReturnValue(_) => { + Variable::MachineCallParam(_) => { // This is just an internal variable. None } @@ -186,7 +185,7 @@ fn witgen_code( .iter() .filter_map(|var| match var { Variable::Cell(cell) => Some(cell), - Variable::Param(_) | Variable::MachineCallReturnValue(_) => None, + Variable::Param(_) | Variable::MachineCallParam(_) => None, }) .map(|cell| { format!( @@ -234,10 +233,11 @@ fn written_vars_in_effect( Effect::Assignment(var, _) => Box::new(iter::once((var, false))), Effect::RangeConstraint(..) => unreachable!(), Effect::Assertion(..) => Box::new(iter::empty()), - Effect::MachineCall(_, arguments) => Box::new(arguments.iter().flat_map(|e| match e { - MachineCallArgument::Unknown(v) => Some((v, true)), - MachineCallArgument::Known(_) => None, - })), + Effect::MachineCall(_, known, vars) => Box::new( + vars.iter() + .zip_eq(known) + .flat_map(|(v, known)| (!known).then_some((v, true))), + ), Effect::Branch(_, first, second) => Box::new( first .iter() @@ -287,21 +287,21 @@ fn format_effect(effect: &Effect, is_top_level: bo if *expected_equal { "==" } else { "!=" }, format_expression(rhs) ), - Effect::MachineCall(id, arguments) => { + Effect::MachineCall(id, known, vars) => { let mut result_vars = vec![]; - let args = arguments + let args = vars .iter() - .map(|a| match a { - MachineCallArgument::Unknown(v) => { - let var_name = variable_to_string(v); + .zip_eq(known) + .map(|(v, known)| { + let var_name = variable_to_string(v); + if known { + format!("LookupCell::Input(&{var_name})") + } else { if is_top_level { result_vars.push(var_name.clone()); } format!("LookupCell::Output(&mut {var_name})") } - MachineCallArgument::Known(v) => { - format!("LookupCell::Input(&{})", format_expression(v)) - } }) .format(", ") .to_string(); @@ -396,12 +396,12 @@ fn variable_to_string(v: &Variable) -> String { format_row_offset(cell.row_offset) ), Variable::Param(i) => format!("p_{i}"), - Variable::MachineCallReturnValue(ret) => { + Variable::MachineCallParam(call_var) => { format!( - "ret_{}_{}_{}", - ret.identity_id, - format_row_offset(ret.row_offset), - ret.index + "call_var_{}_{}_{}", + call_var.identity_id, + format_row_offset(call_var.row_offset), + call_var.index ) } } @@ -483,13 +483,14 @@ fn util_code(first_column_id: u64, column_count: usize) -> Resu #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; use test_log::test; use powdr_number::GoldilocksField; use crate::witgen::jit::variable::Cell; - use crate::witgen::jit::variable::MachineCallReturnVariable; + use crate::witgen::jit::variable::MachineCallVariable; use crate::witgen::range_constraints::RangeConstraint; use super::*; @@ -518,8 +519,8 @@ mod tests { Variable::Param(i) } - fn ret_val(identity_id: u64, row_offset: i32, index: usize) -> Variable { - Variable::MachineCallReturnValue(MachineCallReturnVariable { + fn call_var(identity_id: u64, row_offset: i32, index: usize) -> Variable { + Variable::MachineCallParam(MachineCallVariable { identity_id, row_offset, index, @@ -547,15 +548,15 @@ mod tests { let x0 = cell("x", 0, 0); let ym1 = cell("y", 1, -1); let yp1 = cell("y", 1, 1); - let r1 = ret_val(7, 1, 1); + let cv1 = call_var(7, 1, 0); + let r1 = call_var(7, 1, 1); let effects = vec![ assignment(&x0, number(7) * symbol(&a0)), + assignment(&cv1, symbol(&x0)), Effect::MachineCall( 7, - vec![ - MachineCallArgument::Unknown(r1.clone()), - MachineCallArgument::Known(symbol(&x0)), - ], + [false, true].into_iter().collect(), + vec![r1.clone(), cv1.clone()], ), assignment(&ym1, symbol(&r1)), assignment(&yp1, symbol(&ym1) + symbol(&x0)), @@ -568,8 +569,8 @@ mod tests { let known_inputs = vec![a0.clone()]; let code = witgen_code(&known_inputs, &effects); assert_eq!( - code, - " + code, + " #[no_mangle] extern \"C\" fn witgen( WitgenFunctionParams{ @@ -588,9 +589,10 @@ extern \"C\" fn witgen( let c_a_2_0 = get(data, row_offset, 0, 2); let c_x_0_0 = (FieldElement::from(7) * c_a_2_0); - let mut ret_7_1_1 = FieldElement::default(); - assert!(call_machine(mutable_state, 7, MutSlice::from((&mut [LookupCell::Output(&mut ret_7_1_1), LookupCell::Input(&c_x_0_0)]).as_mut_slice()))); - let c_y_1_m1 = ret_7_1_1; + let call_var_7_1_0 = c_x_0_0; + let mut call_var_7_1_1 = FieldElement::default(); + assert!(call_machine(mutable_state, 7, MutSlice::from((&mut [LookupCell::Output(&mut call_var_7_1_1), LookupCell::Input(&call_var_7_1_0)]).as_mut_slice()))); + let c_y_1_m1 = call_var_7_1_1; let c_y_1_1 = (c_y_1_m1 + c_x_0_0); assert!(c_y_1_m1 == c_x_0_0); @@ -603,7 +605,7 @@ extern \"C\" fn witgen( set_known(known, row_offset, 1, 1); } " - ); + ); } extern "C" fn no_call_machine( @@ -817,16 +819,15 @@ extern \"C\" fn witgen( fn submachine_calls() { let x = cell("x", 0, 0); let y = cell("y", 1, 0); - let r1 = ret_val(7, 0, 1); - let r2 = ret_val(7, 0, 2); + let v1 = call_var(7, 0, 0); + let r1 = call_var(7, 0, 1); + let r2 = call_var(7, 0, 2); let effects = vec![ + Effect::Assignment(v1.clone(), number(7)), Effect::MachineCall( 7, - vec![ - MachineCallArgument::Known(number(7)), - MachineCallArgument::Unknown(r1.clone()), - MachineCallArgument::Unknown(r2.clone()), - ], + [true, false, false].into_iter().collect(), + vec![v1, r1.clone(), r2.clone()], ), Effect::Assignment(x.clone(), symbol(&r1)), Effect::Assignment(y.clone(), symbol(&r2)), diff --git a/executor/src/witgen/jit/effect.rs b/executor/src/witgen/jit/effect.rs index 68ec23e7e9..c34bc36424 100644 --- a/executor/src/witgen/jit/effect.rs +++ b/executor/src/witgen/jit/effect.rs @@ -1,5 +1,6 @@ use std::cmp::Ordering; +use bit_vec::BitVec; use itertools::Itertools; use powdr_ast::indent; use powdr_number::FieldElement; @@ -17,8 +18,8 @@ pub enum Effect { RangeConstraint(V, RangeConstraint), /// A run-time assertion. If this fails, we have conflicting constraints. Assertion(Assertion), - /// A call to a different machine. - MachineCall(u64, Vec>), + /// A call to a different machine, with identity ID, known inputs and argument variables. + MachineCall(u64, BitVec, Vec), /// A branch on a variable. Branch(BranchCondition, Vec>, Vec>), } @@ -59,12 +60,6 @@ impl Assertion { } } -#[derive(Clone, PartialEq, Eq)] -pub enum MachineCallArgument { - Known(SymbolicExpression), - Unknown(V), -} - #[derive(Clone, PartialEq, Eq)] pub struct BranchCondition { pub variable: V, @@ -88,13 +83,15 @@ pub fn format_code(effects: &[Effect]) -> String { if *expected_equal { "==" } else { "!=" } ) } - Effect::MachineCall(id, args) => { + Effect::MachineCall(id, known, vars) => { format!( "machine_call({id}, [{}]);", - args.iter() - .map(|arg| match arg { - MachineCallArgument::Known(k) => format!("Known({k})"), - MachineCallArgument::Unknown(u) => format!("Unknown({u})"), + vars.iter() + .zip(known) + .map(|(v, known)| if known { + format!("Known({v})") + } else { + format!("Unknown({v})") }) .join(", ") ) diff --git a/executor/src/witgen/jit/single_step_processor.rs b/executor/src/witgen/jit/single_step_processor.rs index 542d802fc5..e64669a670 100644 --- a/executor/src/witgen/jit/single_step_processor.rs +++ b/executor/src/witgen/jit/single_step_processor.rs @@ -285,9 +285,10 @@ mod test { format_code(&code), "\ VM::pc[1] = (VM::pc[0] + 1); -machine_call(1, [Known(VM::pc[1]), Unknown(ret(1, 1, 1)), Unknown(ret(1, 1, 2))]); -VM::instr_add[1] = ret(1, 1, 1); -VM::instr_mul[1] = ret(1, 1, 2); +call_var(1, 1, 0) = VM::pc[1]; +machine_call(1, [Known(call_var(1, 1, 0)), Unknown(call_var(1, 1, 1)), Unknown(call_var(1, 1, 2))]); +VM::instr_add[1] = call_var(1, 1, 1); +VM::instr_mul[1] = call_var(1, 1, 2); VM::B[1] = VM::B[0]; if (VM::instr_add[0] == 1) { if (VM::instr_mul[0] == 1) { diff --git a/executor/src/witgen/jit/variable.rs b/executor/src/witgen/jit/variable.rs index f0af2bebff..34ac92c36d 100644 --- a/executor/src/witgen/jit/variable.rs +++ b/executor/src/witgen/jit/variable.rs @@ -4,9 +4,6 @@ use std::{ }; use powdr_ast::analyzed::{AlgebraicReference, PolyID, PolynomialType}; -use powdr_number::FieldElement; - -use super::effect::MachineCallArgument; #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Debug)] /// A variable that can be used in the inference engine. @@ -16,9 +13,9 @@ pub enum Variable { /// A parameter (input or output) of the machine. #[allow(dead_code)] Param(usize), - /// The return value of a machine call on a certain + /// An input or output value of a machine call on a certain /// identity on a certain row offset. - MachineCallReturnValue(MachineCallReturnVariable), + MachineCallParam(MachineCallVariable), } impl Display for Variable { @@ -26,10 +23,10 @@ impl Display for Variable { match self { Variable::Cell(cell) => write!(f, "{cell}"), Variable::Param(i) => write!(f, "params[{i}]"), - Variable::MachineCallReturnValue(ret) => { + Variable::MachineCallParam(ret) => { write!( f, - "ret({}, {}, {})", + "call_var({}, {}, {})", ret.identity_id, ret.row_offset, ret.index ) } @@ -55,24 +52,18 @@ impl Variable { id: cell.id, ptype: PolynomialType::Committed, }), - Variable::Param(_) | Variable::MachineCallReturnValue(_) => None, + Variable::Param(_) | Variable::MachineCallParam(_) => None, } } } #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Debug)] -pub struct MachineCallReturnVariable { +pub struct MachineCallVariable { pub identity_id: u64, pub row_offset: i32, pub index: usize, } -impl MachineCallReturnVariable { - pub fn into_argument(self) -> MachineCallArgument { - MachineCallArgument::Unknown(Variable::MachineCallReturnValue(self)) - } -} - /// The identifier of a witness cell in the trace table. /// The `row_offset` is relative to a certain "zero row" defined /// by the component that uses this data structure. diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index 33312ed104..d497f26bfc 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -20,8 +20,8 @@ use crate::witgen::{ use super::{ affine_symbolic_expression::{AffineSymbolicExpression, ProcessResult}, - effect::{BranchCondition, Effect, MachineCallArgument}, - variable::{MachineCallReturnVariable, Variable}, + effect::{BranchCondition, Effect}, + variable::{MachineCallVariable, Variable}, }; /// Summary of the effect of processing an action. @@ -261,30 +261,24 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F }).format(", ")); return ProcessResult::empty(); } - let args = evaluated - .into_iter() - .zip(arguments) + let vars = arguments + .iter() .enumerate() - .map(|(index, (eval_expr, arg))| { - if let Some(e) = eval_expr { - MachineCallArgument::Known(e) - } else { - let ret_var = MachineCallReturnVariable { - identity_id: lookup_id, - row_offset, - index, - }; - self.assign_variable( - arg, - row_offset, - Variable::MachineCallReturnValue(ret_var.clone()), - ); - ret_var.into_argument() + .map(|(index, arg)| { + let var = Variable::MachineCallParam(MachineCallVariable { + identity_id: lookup_id, + row_offset, + index, + }); + self.assign_variable(arg, row_offset, var.clone()); + if known[index] { + assert!(self.is_known(&var)); } + var }) .collect_vec(); ProcessResult { - effects: vec![Effect::MachineCall(lookup_id, args)], + effects: vec![Effect::MachineCall(lookup_id, known, vars)], complete: true, } } @@ -331,11 +325,10 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F Effect::RangeConstraint(variable, rc) => { progress |= self.add_range_constraint(variable.clone(), rc.clone()); } - Effect::MachineCall(_, arguments) => { - for arg in arguments { - if let MachineCallArgument::Unknown(v) = arg { - self.known_variables.insert(v.clone()); - } + Effect::MachineCall(_, _, vars) => { + for v in vars { + // Inputs are already known, but it does not hurt to add all of them. + self.known_variables.insert(v.clone()); } progress = true; self.code.push(e); @@ -732,22 +725,30 @@ assert (Xor::A[6] & 18446744073692774400) == 0; Xor::C_byte[5] = ((Xor::C[6] & 16711680) // 65536); Xor::C[5] = (Xor::C[6] & 65535); assert (Xor::C[6] & 18446744073692774400) == 0; -machine_call(0, [Known(Xor::A_byte[6]), Unknown(ret(0, 6, 1)), Known(Xor::C_byte[6])]); -Xor::B_byte[6] = ret(0, 6, 1); +call_var(0, 6, 0) = Xor::A_byte[6]; +call_var(0, 6, 2) = Xor::C_byte[6]; +machine_call(0, [Known(call_var(0, 6, 0)), Unknown(call_var(0, 6, 1)), Known(call_var(0, 6, 2))]); +Xor::B_byte[6] = call_var(0, 6, 1); Xor::A_byte[4] = ((Xor::A[5] & 65280) // 256); Xor::A[4] = (Xor::A[5] & 255); assert (Xor::A[5] & 18446744073709486080) == 0; Xor::C_byte[4] = ((Xor::C[5] & 65280) // 256); Xor::C[4] = (Xor::C[5] & 255); assert (Xor::C[5] & 18446744073709486080) == 0; -machine_call(0, [Known(Xor::A_byte[5]), Unknown(ret(0, 5, 1)), Known(Xor::C_byte[5])]); -Xor::B_byte[5] = ret(0, 5, 1); +call_var(0, 5, 0) = Xor::A_byte[5]; +call_var(0, 5, 2) = Xor::C_byte[5]; +machine_call(0, [Known(call_var(0, 5, 0)), Unknown(call_var(0, 5, 1)), Known(call_var(0, 5, 2))]); +Xor::B_byte[5] = call_var(0, 5, 1); Xor::A_byte[3] = Xor::A[4]; Xor::C_byte[3] = Xor::C[4]; -machine_call(0, [Known(Xor::A_byte[4]), Unknown(ret(0, 4, 1)), Known(Xor::C_byte[4])]); -Xor::B_byte[4] = ret(0, 4, 1); -machine_call(0, [Known(Xor::A_byte[3]), Unknown(ret(0, 3, 1)), Known(Xor::C_byte[3])]); -Xor::B_byte[3] = ret(0, 3, 1); +call_var(0, 4, 0) = Xor::A_byte[4]; +call_var(0, 4, 2) = Xor::C_byte[4]; +machine_call(0, [Known(call_var(0, 4, 0)), Unknown(call_var(0, 4, 1)), Known(call_var(0, 4, 2))]); +Xor::B_byte[4] = call_var(0, 4, 1); +call_var(0, 3, 0) = Xor::A_byte[3]; +call_var(0, 3, 2) = Xor::C_byte[3]; +machine_call(0, [Known(call_var(0, 3, 0)), Unknown(call_var(0, 3, 1)), Known(call_var(0, 3, 2))]); +Xor::B_byte[3] = call_var(0, 3, 1); Xor::B[4] = Xor::B_byte[3]; Xor::B[5] = (Xor::B[4] + (Xor::B_byte[4] * 256)); Xor::B[6] = (Xor::B[5] + (Xor::B_byte[5] * 65536));