diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index fa02647ccd..4002b1cdc3 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -321,50 +321,50 @@ main_binary::operation_id[2] = main_binary::operation_id[3]; main_binary::operation_id[1] = main_binary::operation_id[2]; main_binary::operation_id[0] = main_binary::operation_id[1]; main_binary::operation_id_next[-1] = main_binary::operation_id[0]; +call_var(9, -1, 0) = main_binary::operation_id_next[-1]; main_binary::operation_id_next[0] = main_binary::operation_id[1]; +call_var(9, 0, 0) = main_binary::operation_id_next[0]; main_binary::operation_id_next[1] = main_binary::operation_id[2]; +call_var(9, 1, 0) = main_binary::operation_id_next[1]; main_binary::A_byte[2] = ((main_binary::A[3] & 0xff000000) // 16777216); main_binary::A[2] = (main_binary::A[3] & 0xffffff); assert (main_binary::A[3] & 0xffffffff00000000) == 0; +call_var(9, 2, 1) = main_binary::A_byte[2]; main_binary::A_byte[1] = ((main_binary::A[2] & 0xff0000) // 65536); main_binary::A[1] = (main_binary::A[2] & 0xffff); assert (main_binary::A[2] & 0xffffffffff000000) == 0; +call_var(9, 1, 1) = main_binary::A_byte[1]; main_binary::A_byte[0] = ((main_binary::A[1] & 0xff00) // 256); main_binary::A[0] = (main_binary::A[1] & 0xff); assert (main_binary::A[1] & 0xffffffffffff0000) == 0; +call_var(9, 0, 1) = main_binary::A_byte[0]; main_binary::A_byte[-1] = main_binary::A[0]; +call_var(9, -1, 1) = main_binary::A_byte[-1]; main_binary::B_byte[2] = ((main_binary::B[3] & 0xff000000) // 16777216); main_binary::B[2] = (main_binary::B[3] & 0xffffff); assert (main_binary::B[3] & 0xffffffff00000000) == 0; +call_var(9, 2, 2) = main_binary::B_byte[2]; main_binary::B_byte[1] = ((main_binary::B[2] & 0xff0000) // 65536); main_binary::B[1] = (main_binary::B[2] & 0xffff); assert (main_binary::B[2] & 0xffffffffff000000) == 0; +call_var(9, 1, 2) = main_binary::B_byte[1]; main_binary::B_byte[0] = ((main_binary::B[1] & 0xff00) // 256); main_binary::B[0] = (main_binary::B[1] & 0xff); assert (main_binary::B[1] & 0xffffffffffff0000) == 0; +call_var(9, 0, 2) = main_binary::B_byte[0]; main_binary::B_byte[-1] = main_binary::B[0]; -call_var(9, -1, 0) = main_binary::operation_id_next[-1]; -call_var(9, -1, 1) = main_binary::A_byte[-1]; call_var(9, -1, 2) = main_binary::B_byte[-1]; machine_call(9, [Known(call_var(9, -1, 0)), Known(call_var(9, -1, 1)), Known(call_var(9, -1, 2)), Unknown(call_var(9, -1, 3))]); main_binary::C_byte[-1] = call_var(9, -1, 3); main_binary::C[0] = main_binary::C_byte[-1]; -call_var(9, 0, 0) = main_binary::operation_id_next[0]; -call_var(9, 0, 1) = main_binary::A_byte[0]; -call_var(9, 0, 2) = main_binary::B_byte[0]; machine_call(9, [Known(call_var(9, 0, 0)), Known(call_var(9, 0, 1)), Known(call_var(9, 0, 2)), Unknown(call_var(9, 0, 3))]); main_binary::C_byte[0] = call_var(9, 0, 3); main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256)); -call_var(9, 1, 0) = main_binary::operation_id_next[1]; -call_var(9, 1, 1) = main_binary::A_byte[1]; -call_var(9, 1, 2) = main_binary::B_byte[1]; machine_call(9, [Known(call_var(9, 1, 0)), Known(call_var(9, 1, 1)), Known(call_var(9, 1, 2)), Unknown(call_var(9, 1, 3))]); main_binary::C_byte[1] = call_var(9, 1, 3); main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536)); main_binary::operation_id_next[2] = main_binary::operation_id[3]; call_var(9, 2, 0) = main_binary::operation_id_next[2]; -call_var(9, 2, 1) = main_binary::A_byte[2]; -call_var(9, 2, 2) = main_binary::B_byte[2]; machine_call(9, [Known(call_var(9, 2, 0)), Known(call_var(9, 2, 1)), Known(call_var(9, 2, 2)), Unknown(call_var(9, 2, 3))]); main_binary::C_byte[2] = call_var(9, 2, 3); main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216)); diff --git a/executor/src/witgen/jit/processor.rs b/executor/src/witgen/jit/processor.rs index 195d0c2de7..74728127b6 100644 --- a/executor/src/witgen/jit/processor.rs +++ b/executor/src/witgen/jit/processor.rs @@ -5,12 +5,14 @@ use std::{ }; use itertools::Itertools; -use powdr_ast::analyzed::{PolyID, PolynomialType}; +use powdr_ast::analyzed::{PolyID, PolynomialIdentity, PolynomialType}; use powdr_number::FieldElement; use crate::witgen::{ - data_structures::identity::Identity, jit::debug_formatter::format_identities, - range_constraints::RangeConstraint, FixedData, + data_structures::identity::{BusSend, Identity}, + jit::debug_formatter::format_identities, + range_constraints::RangeConstraint, + FixedData, }; use super::{ @@ -18,7 +20,7 @@ use super::{ effect::{format_code, Effect}, identity_queue::IdentityQueue, prover_function_heuristics::ProverFunction, - variable::{Cell, Variable}, + variable::{Cell, MachineCallVariable, Variable}, witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference}, }; @@ -109,8 +111,21 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv pub fn generate_code( self, can_process: impl CanProcessCall, - witgen: WitgenInference<'a, T, FixedEval>, + mut witgen: WitgenInference<'a, T, FixedEval>, ) -> Result, Error<'a, T, FixedEval>> { + // Create variables for bus send arguments. + for (id, row_offset) in &self.identities { + if let Identity::BusSend(bus_send) = id { + for (index, arg) in bus_send.selected_payload.expressions.iter().enumerate() { + let var = Variable::MachineCallParam(MachineCallVariable { + identity_id: bus_send.identity_id, + row_offset: *row_offset, + index, + }); + witgen.assign_variable(arg, *row_offset, var.clone()); + } + } + } let branch_depth = 0; let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities); self.generate_code_for_branch(can_process, witgen, identity_queue, branch_depth) @@ -283,9 +298,25 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv loop { let identity = identity_queue.next(); let updated_vars = match identity { - Some((identity, row_offset)) => { - witgen.process_identity(can_process.clone(), identity, row_offset) - } + Some((identity, row_offset)) => match identity { + Identity::Polynomial(PolynomialIdentity { id, expression, .. }) => { + witgen.process_polynomial_identity(*id, expression, row_offset) + } + Identity::BusSend(BusSend { + bus_id: _, + identity_id, + selected_payload, + }) => witgen.process_call( + can_process.clone(), + *identity_id, + &selected_payload.selector, + selected_payload.expressions.len(), + row_offset, + ), + Identity::Connect(..) => Ok(vec![]), + }, + // TODO Also add prover functions to the queue (activated by their variables) + // and sort them so that they are always last. None => self.process_prover_functions(witgen), }?; if updated_vars.is_empty() && identity.is_none() { diff --git a/executor/src/witgen/jit/single_step_processor.rs b/executor/src/witgen/jit/single_step_processor.rs index 289566474f..aae13fb543 100644 --- a/executor/src/witgen/jit/single_step_processor.rs +++ b/executor/src/witgen/jit/single_step_processor.rs @@ -237,12 +237,12 @@ namespace M(256); assert_eq!( format_code(&code), "\ -VM::pc[1] = (VM::pc[0] + 1); call_var(1, 0, 0) = VM::pc[0]; call_var(1, 0, 1) = VM::instr_add[0]; call_var(1, 0, 2) = VM::instr_mul[0]; -VM::B[1] = VM::B[0]; +VM::pc[1] = (VM::pc[0] + 1); call_var(1, 1, 0) = VM::pc[1]; +VM::B[1] = VM::B[0]; machine_call(1, [Known(call_var(1, 1, 0)), Unknown(call_var(1, 1, 1)), Unknown(call_var(1, 1, 2))]); VM::instr_add[1] = call_var(1, 1, 1); VM::instr_mul[1] = call_var(1, 1, 2); @@ -280,12 +280,12 @@ if (VM::instr_add[0] == 1) { assert_eq!( format_code(&code), "\ -VM::pc[1] = VM::pc[0]; call_var(2, 0, 0) = VM::pc[0]; -call_var(2, 0, 1) = 0; +call_var(2, 0, 1) = VM::instr_add[0]; call_var(2, 0, 2) = VM::instr_mul[0]; -VM::instr_add[1] = 0; +VM::pc[1] = VM::pc[0]; call_var(2, 1, 0) = VM::pc[1]; +VM::instr_add[1] = 0; call_var(2, 1, 1) = 0; call_var(2, 1, 2) = 1; machine_call(2, [Known(call_var(2, 1, 0)), Known(call_var(2, 1, 1)), Unknown(call_var(2, 1, 2))]); diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index cf936ef476..66dcfb2265 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -7,8 +7,7 @@ use bit_vec::BitVec; use itertools::Itertools; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, - AlgebraicReference, AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolynomialIdentity, - PolynomialType, + AlgebraicReference, AlgebraicUnaryOperation, AlgebraicUnaryOperator, PolynomialType, }; use powdr_number::FieldElement; @@ -169,32 +168,36 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F self.add_range_constraint(var.clone(), range_constraint); } - /// Process an identity on a certain row. - /// Returns Ok(true) if there was progress and Ok(false) if there was no progress. - /// If this returns an error, it means we have conflicting constraints. - pub fn process_identity( + pub fn process_polynomial_identity( &mut self, - can_process: impl CanProcessCall, - id: &'a Identity, + identity_id: u64, + expression: &'a Expression, row_offset: i32, ) -> Result, Error> { - let result = match id { - Identity::Polynomial(PolynomialIdentity { expression, .. }) => self - .process_equality_on_row( - expression, - row_offset, - &VariableOrValue::Value(T::from(0)), - )?, - Identity::BusSend(bus_interaction) => self.process_call( - can_process, - bus_interaction.identity_id, - &bus_interaction.selected_payload.selector, - &bus_interaction.selected_payload.expressions, - row_offset, - ), - Identity::Connect(_) => ProcessResult::empty(), - }; - self.ingest_effects(result, Some((id.id(), row_offset))) + let result = self.process_equality_on_row( + expression, + row_offset, + &VariableOrValue::Value(T::from(0)), + )?; + self.ingest_effects(result, Some((identity_id, row_offset))) + } + + pub fn process_call( + &mut self, + can_process_call: impl CanProcessCall, + lookup_id: u64, + selector: &Expression, + argument_count: usize, + row_offset: i32, + ) -> Result, Error> { + let result = self.process_call_inner( + can_process_call, + lookup_id, + selector, + argument_count, + row_offset, + ); + self.ingest_effects(result, Some((lookup_id, row_offset))) } /// Process a prover function on a row, i.e. determine if we can execute it and if it will @@ -321,14 +324,15 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F (lhs_evaluated - rhs_evaluated).solve() } - fn process_call( + fn process_call_inner( &mut self, can_process_call: impl CanProcessCall, lookup_id: u64, selector: &Expression, - arguments: &'a [Expression], + argument_count: usize, row_offset: i32, ) -> ProcessResult { + self.process_assignments().unwrap(); // We need to know the selector. let Some(selector) = self .evaluate(selector, row_offset) @@ -346,44 +350,36 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F assert_eq!(selector, 1.into(), "Selector is non-binary"); } - let evaluated = arguments - .iter() - .map(|a| self.evaluate(a, row_offset)) - .collect::>(); - let range_constraints = evaluated - .iter() - .map(|e| e.as_ref().map(|e| e.range_constraint()).unwrap_or_default()) + let arguments = (0..argument_count) + .map(|index| { + Variable::MachineCallParam(MachineCallVariable { + identity_id: lookup_id, + row_offset, + index, + }) + }) .collect_vec(); - let known: BitVec = evaluated + let range_constraints = arguments .iter() - .map(|e| e.as_ref().and_then(|e| e.try_to_known()).is_some()) - .collect(); + .map(|v| self.range_constraint(v)) + .collect_vec(); + let known: BitVec = arguments.iter().map(|v| self.is_known(v)).collect(); let Some(new_range_constraints) = can_process_call.can_process_call_fully(lookup_id, &known, &range_constraints) else { return ProcessResult::empty(); }; - let mut effects = vec![]; - let vars = arguments + let effects = arguments .iter() .zip_eq(new_range_constraints) - .enumerate() - .map(|(index, (arg, new_rc))| { - let var = Variable::MachineCallParam(MachineCallVariable { - identity_id: lookup_id, - row_offset, - index, - }); - self.assign_variable(arg, row_offset, var.clone()); - effects.push(Effect::RangeConstraint(var.clone(), new_rc.clone())); - if known[index] { - assert!(self.is_known(&var)); - } - var - }) - .collect_vec(); - effects.push(Effect::MachineCall(lookup_id, known, vars.clone())); + .map(|(var, new_rc)| Effect::RangeConstraint(var.clone(), new_rc.clone())) + .chain(std::iter::once(Effect::MachineCall( + lookup_id, + known, + arguments.to_vec(), + ))) + .collect(); ProcessResult { effects, complete: true, @@ -492,6 +488,8 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F if let Some(identity_id) = identity_id { // We actually only need to store completeness for submachine calls, // but we do it for all identities. + // TODO this is not correct, somehow we get problems if we do not provide + // IDs for poly identities. Need to revisit this. self.complete_identities.insert(identity_id); } } @@ -732,12 +730,13 @@ impl> CanProcessCall for &MutableState<' #[cfg(test)] mod test { - use powdr_ast::analyzed::PolyID; + use powdr_ast::analyzed::{PolyID, PolynomialIdentity}; use powdr_number::GoldilocksField; use pretty_assertions::assert_eq; use test_log::test; use crate::witgen::{ + data_structures::identity::BusSend, global_constraints, jit::{effect::format_code, test_util::read_pil, variable::Cell}, machines::{Connection, FixedLookup, KnownMachine}, @@ -792,15 +791,47 @@ mod test { let ref_eval = FixedEvaluatorForFixedData(&fixed_data); let mut witgen = WitgenInference::new(&fixed_data, ref_eval, known_cells, []); let mut counter = 0; + // Create variables for bus send arguments. + for row in rows { + for id in &fixed_data.identities { + if let Identity::BusSend(bus_send) = id { + for (index, arg) in bus_send.selected_payload.expressions.iter().enumerate() { + let var = Variable::MachineCallParam(MachineCallVariable { + identity_id: bus_send.identity_id, + row_offset: *row, + index, + }); + witgen.assign_variable(arg, *row, var.clone()); + } + } + } + } + loop { let mut progress = false; counter += 1; for row in rows { for id in fixed_data.identities.iter() { - progress |= !witgen - .process_identity(&mutable_state, id, *row) - .unwrap() - .is_empty(); + let updated_vars = match id { + Identity::Polynomial(PolynomialIdentity { id, expression, .. }) => witgen + .process_polynomial_identity(*id, expression, *row) + .unwrap(), + Identity::BusSend(BusSend { + bus_id: _, + identity_id, + selected_payload, + }) => witgen + .process_call( + &mutable_state, + *identity_id, + &selected_payload.selector, + selected_payload.expressions.len(), + *row, + ) + .unwrap(), + Identity::Connect(..) => vec![], + }; + progress |= !updated_vars.is_empty(); } } if !progress { @@ -899,37 +930,37 @@ namespace Xor(256 * 256); Xor::A_byte[6] = ((Xor::A[7] & 0xff000000) // 16777216); Xor::A[6] = (Xor::A[7] & 0xffffff); assert (Xor::A[7] & 0xffffffff00000000) == 0; +call_var(0, 6, 0) = Xor::A_byte[6]; Xor::C_byte[6] = ((Xor::C[7] & 0xff000000) // 16777216); Xor::C[6] = (Xor::C[7] & 0xffffff); assert (Xor::C[7] & 0xffffffff00000000) == 0; +call_var(0, 6, 2) = Xor::C_byte[6]; Xor::A_byte[5] = ((Xor::A[6] & 0xff0000) // 65536); Xor::A[5] = (Xor::A[6] & 0xffff); assert (Xor::A[6] & 0xffffffffff000000) == 0; +call_var(0, 5, 0) = Xor::A_byte[5]; Xor::C_byte[5] = ((Xor::C[6] & 0xff0000) // 65536); Xor::C[5] = (Xor::C[6] & 0xffff); assert (Xor::C[6] & 0xffffffffff000000) == 0; -call_var(0, 6, 0) = Xor::A_byte[6]; -call_var(0, 6, 2) = Xor::C_byte[6]; +call_var(0, 5, 2) = Xor::C_byte[5]; machine_call(0, [Known(call_var(0, 6, 0)), Unknown(call_var(0, 6, 1)), Known(call_var(0, 6, 2))]); Xor::B_byte[6] = call_var(0, 6, 1); Xor::A_byte[4] = ((Xor::A[5] & 0xff00) // 256); Xor::A[4] = (Xor::A[5] & 0xff); assert (Xor::A[5] & 0xffffffffffff0000) == 0; +call_var(0, 4, 0) = Xor::A_byte[4]; Xor::C_byte[4] = ((Xor::C[5] & 0xff00) // 256); Xor::C[4] = (Xor::C[5] & 0xff); assert (Xor::C[5] & 0xffffffffffff0000) == 0; -call_var(0, 5, 0) = Xor::A_byte[5]; -call_var(0, 5, 2) = Xor::C_byte[5]; +call_var(0, 4, 2) = Xor::C_byte[4]; machine_call(0, [Known(call_var(0, 5, 0)), Unknown(call_var(0, 5, 1)), Known(call_var(0, 5, 2))]); Xor::B_byte[5] = call_var(0, 5, 1); Xor::A_byte[3] = Xor::A[4]; +call_var(0, 3, 0) = Xor::A_byte[3]; Xor::C_byte[3] = Xor::C[4]; -call_var(0, 4, 0) = Xor::A_byte[4]; -call_var(0, 4, 2) = Xor::C_byte[4]; +call_var(0, 3, 2) = Xor::C_byte[3]; machine_call(0, [Known(call_var(0, 4, 0)), Unknown(call_var(0, 4, 1)), Known(call_var(0, 4, 2))]); Xor::B_byte[4] = call_var(0, 4, 1); -call_var(0, 3, 0) = Xor::A_byte[3]; -call_var(0, 3, 2) = Xor::C_byte[3]; machine_call(0, [Known(call_var(0, 3, 0)), Unknown(call_var(0, 3, 1)), Known(call_var(0, 3, 2))]); Xor::B_byte[3] = call_var(0, 3, 1); Xor::B[4] = Xor::B_byte[3];