diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index c0b1067625..7b4414604d 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -1,16 +1,18 @@ -use std::collections::HashSet; +use std::collections::{BTreeSet, HashSet}; use bit_vec::BitVec; use itertools::Itertools; -use powdr_ast::analyzed::{AlgebraicReference, Identity, SelectedExpressions}; +use powdr_ast::analyzed::{ + AlgebraicReference, Identity, PolyID, PolynomialType, SelectedExpressions, +}; use powdr_number::FieldElement; use crate::witgen::{jit::effect::format_code, machines::MachineParts, FixedData}; use super::{ effect::Effect, - variable::Variable, - witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference}, + variable::{Cell, Variable}, + witgen_inference::{CanProcessCall, FixedEvaluator, Value, WitgenInference}, }; /// A processor for generating JIT code for a block machine. @@ -85,6 +87,11 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { } } + fn row_range(&self) -> std::ops::Range { + // We iterate over all rows of the block +/- one row, so that we can also solve for non-rectangular blocks. + -1..(self.block_size + 1) as i32 + } + /// Repeatedly processes all identities on all rows, until no progress is made. /// Fails iff there are incomplete machine calls in the latch row. fn solve_block + Clone>( @@ -97,11 +104,10 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { for iteration in 0.. { let mut progress = false; - // TODO: This algorithm is assuming a rectangular block shape. - for row in 0..self.block_size { + for row in self.row_range() { for id in &self.machine_parts.identities { if !complete.contains(&(id.id(), row)) { - let result = witgen.process_identity(can_process.clone(), id, row as i32); + let result = witgen.process_identity(can_process.clone(), id, row); if result.complete { complete.insert((id.id(), row)); } @@ -125,22 +131,121 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { } } - // If any machine call could not be completed, that's bad because machine calls typically have side effects. - // So, the underlying lookup / permutation / bus argument likely does not hold. - // TODO: This assumes a rectangular block shape. - let has_incomplete_machine_calls = (0..self.block_size) - .flat_map(|row| { - self.machine_parts - .identities - .iter() - .filter(|id| is_machine_call(id)) - .map(move |id| (id, row)) + // TODO: Fail hard (or return a different error), as this should never + // happen for valid block machines. Currently fails in: + // powdr-pipeline::powdr_std arith256_memory_large_test + self.check_block_shape(witgen)?; + self.check_incomplete_machine_calls(&complete)?; + + Ok(()) + } + + /// After solving, the known values should be such that we can stack different blocks. + fn check_block_shape(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> { + let known_columns = witgen + .known_variables() + .iter() + .filter_map(|var| match var { + Variable::Cell(cell) => Some(cell.id), + _ => None, }) - .any(|(identity, row)| !complete.contains(&(identity.id(), row))); + .collect::>(); + + let can_stack = known_columns.iter().all(|column_id| { + // Increase the range by 1, because in row , + // we might have processed an identity with next references. + let row_range = self.row_range(); + let values = (row_range.start..(row_range.end + 1)) + .map(|row| { + witgen.value(&Variable::Cell(Cell { + id: *column_id, + row_offset: row, + // Dummy value, the column name is ignored in the implementation + // of Cell::eq, etc. + column_name: "".to_string(), + })) + }) + .collect::>(); + + // Two values that refer to the same row (modulo block size) are compatible if: + // - One of them is unknown, or + // - Both are concrete and equal + let is_compatible = |v1: Value, v2: Value| match (v1, v2) { + (Value::Unknown, _) | (_, Value::Unknown) => true, + (Value::Concrete(a), Value::Concrete(b)) => a == b, + _ => false, + }; + // A column is stackable if all rows equal to each other modulo + // the block size are compatible. + let stackable = (0..(values.len() - self.block_size)) + .all(|i| is_compatible(values[i], values[i + self.block_size])); - match has_incomplete_machine_calls { - true => Err("Incomplete machine calls".to_string()), - false => Ok(()), + if !stackable { + let column_name = self.fixed_data.column_name(&PolyID { + id: *column_id, + ptype: PolynomialType::Committed, + }); + let block_list = values.iter().skip(1).take(self.block_size).join(", "); + let column_str = format!( + "... {} | {} | {} ...", + values[0], + block_list, + values[self.block_size + 1] + ); + log::debug!("Column {column_name} is not stackable:\n{column_str}"); + } + + stackable + }); + + match can_stack { + true => Ok(()), + false => Err("Block machine shape does not allow stacking".to_string()), + } + } + + /// If any machine call could not be completed, that's bad because machine calls typically have side effects. + /// So, the underlying lookup / permutation / bus argument likely does not hold. + /// This function checks that all machine calls are complete, at least for a window of rows. + fn check_incomplete_machine_calls(&self, complete: &HashSet<(u64, i32)>) -> Result<(), String> { + let machine_calls = self + .machine_parts + .identities + .iter() + .filter(|id| is_machine_call(id)); + + let incomplete_machine_calls = machine_calls + .flat_map(|call| { + let complete_rows = self + .row_range() + .filter(|row| complete.contains(&(call.id(), *row))) + .collect::>(); + // Because we process rows -1..block_size+1, it is fine to have two incomplete machine calls, + // as long as consecutive rows are complete. + if complete_rows.len() >= self.block_size { + let (min, max) = complete_rows.iter().minmax().into_option().unwrap(); + let is_consecutive = max - min == complete_rows.len() as i32 - 1; + if is_consecutive { + return vec![]; + } + } + self.row_range() + .filter(|row| !complete.contains(&(call.id(), *row))) + .map(|row| (call, row)) + .collect::>() + }) + .collect::>(); + + if !incomplete_machine_calls.is_empty() { + Err(format!( + "Incomplete machine calls:\n {}", + incomplete_machine_calls + .iter() + .map(|(identity, row)| format!("{identity} (row {row})")) + .join("\n ") + )) + } else { + Ok(()) } } } @@ -160,7 +265,22 @@ impl FixedEvaluator for &BlockMachineProcessor<'_, T> { fn evaluate(&self, var: &AlgebraicReference, row_offset: i32) -> Option { assert!(var.is_fixed()); let values = self.fixed_data.fixed_cols[&var.poly_id].values_max_size(); - let row = (row_offset + var.next as i32 + values.len() as i32) as usize % values.len(); + + // By assumption of the block machine, all fixed columns are cyclic with a period of . + // An exception might be the first and last row. + assert!(row_offset >= -1); + assert!(self.block_size >= 1); + // The current row is guaranteed to be at least 1. + let current_row = (2 * self.block_size as i32 + row_offset) as usize; + let row = current_row + var.next as usize; + + assert!(values.len() >= self.block_size * 4); + + // Fixed columns are assumed to be cyclic, except in the first and last row. + // The code above should ensure that we never access the first or last row. + assert!(row > 0); + assert!(row < values.len() - 1); + Some(values[row]) } } @@ -265,11 +385,70 @@ params[2] = Add::c[0];" } #[test] - // TODO: Currently fails, because the machine has a non-rectangular block shape. - #[should_panic = "Unable to derive algorithm to compute output value \\\"main_binary::C\\\""] + #[should_panic = "Block machine shape does not allow stacking"] + fn not_stackable() { + let input = " + namespace Main(256); + col witness a, b, c; + [a] is NotStackable.sel $ [NotStackable.a]; + namespace NotStackable(256); + col witness sel, a; + a = a'; + "; + generate_for_block_machine(input, "NotStackable", 1, 0).unwrap(); + } + + #[test] fn binary() { let input = read_to_string("../test_data/pil/binary.pil").unwrap(); - generate_for_block_machine(&input, "main_binary", 3, 1).unwrap(); + let code = generate_for_block_machine(&input, "main_binary", 3, 1).unwrap(); + assert_eq!( + format_code(&code), + "main_binary::sel[0][3] = 1; +main_binary::operation_id[3] = params[0]; +main_binary::A[3] = params[1]; +main_binary::B[3] = params[2]; +main_binary::operation_id[2] = main_binary::operation_id[3]; +main_binary::A_byte[2] = ((main_binary::A[3] & 4278190080) // 16777216); +main_binary::A[2] = (main_binary::A[3] & 16777215); +assert (main_binary::A[3] & 18446744069414584320) == 0; +main_binary::B_byte[2] = ((main_binary::B[3] & 4278190080) // 16777216); +main_binary::B[2] = (main_binary::B[3] & 16777215); +assert (main_binary::B[3] & 18446744069414584320) == 0; +main_binary::operation_id_next[2] = main_binary::operation_id[3]; +machine_call(9, [Known(main_binary::operation_id_next[2]), Known(main_binary::A_byte[2]), Known(main_binary::B_byte[2]), Unknown(ret(9, 2, 3))]); +main_binary::C_byte[2] = ret(9, 2, 3); +main_binary::operation_id[1] = main_binary::operation_id[2]; +main_binary::A_byte[1] = ((main_binary::A[2] & 16711680) // 65536); +main_binary::A[1] = (main_binary::A[2] & 65535); +assert (main_binary::A[2] & 18446744073692774400) == 0; +main_binary::B_byte[1] = ((main_binary::B[2] & 16711680) // 65536); +main_binary::B[1] = (main_binary::B[2] & 65535); +assert (main_binary::B[2] & 18446744073692774400) == 0; +main_binary::operation_id_next[1] = main_binary::operation_id[2]; +machine_call(9, [Known(main_binary::operation_id_next[1]), Known(main_binary::A_byte[1]), Known(main_binary::B_byte[1]), Unknown(ret(9, 1, 3))]); +main_binary::C_byte[1] = ret(9, 1, 3); +main_binary::operation_id[0] = main_binary::operation_id[1]; +main_binary::A_byte[0] = ((main_binary::A[1] & 65280) // 256); +main_binary::A[0] = (main_binary::A[1] & 255); +assert (main_binary::A[1] & 18446744073709486080) == 0; +main_binary::B_byte[0] = ((main_binary::B[1] & 65280) // 256); +main_binary::B[0] = (main_binary::B[1] & 255); +assert (main_binary::B[1] & 18446744073709486080) == 0; +main_binary::operation_id_next[0] = main_binary::operation_id[1]; +machine_call(9, [Known(main_binary::operation_id_next[0]), Known(main_binary::A_byte[0]), Known(main_binary::B_byte[0]), Unknown(ret(9, 0, 3))]); +main_binary::C_byte[0] = ret(9, 0, 3); +main_binary::A_byte[-1] = main_binary::A[0]; +main_binary::B_byte[-1] = main_binary::B[0]; +main_binary::operation_id_next[-1] = main_binary::operation_id[0]; +machine_call(9, [Known(main_binary::operation_id_next[-1]), Known(main_binary::A_byte[-1]), Known(main_binary::B_byte[-1]), Unknown(ret(9, -1, 3))]); +main_binary::C_byte[-1] = ret(9, -1, 3); +main_binary::C[0] = main_binary::C_byte[-1]; +main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256)); +main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536)); +main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216)); +params[3] = main_binary::C[3];" + ) } #[test] diff --git a/executor/src/witgen/jit/function_cache.rs b/executor/src/witgen/jit/function_cache.rs index 2a79ba2d1f..4bc7f49283 100644 --- a/executor/src/witgen/jit/function_cache.rs +++ b/executor/src/witgen/jit/function_cache.rs @@ -5,6 +5,7 @@ use powdr_number::{FieldElement, KnownField}; use crate::witgen::{ data_structures::finalizable_data::{ColumnLayout, CompactDataRef}, + jit::effect::Effect, machines::{LookupCell, MachineParts}, EvalError, FixedData, MutableState, QueryCallback, }; @@ -28,6 +29,7 @@ pub struct FunctionCache<'a, T: FieldElement> { /// but failed. witgen_functions: HashMap>>, column_layout: ColumnLayout, + block_size: usize, } impl<'a, T: FieldElement> FunctionCache<'a, T> { @@ -45,6 +47,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { processor, column_layout: metadata, witgen_functions: HashMap::new(), + block_size, } } @@ -89,9 +92,29 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { cache_key: &CacheKey, ) -> Option> { log::trace!("Compiling JIT function for {:?}", cache_key); + self.processor .generate_code(mutable_state, cache_key.identity_id, &cache_key.known_args) .ok() + .and_then(|code| { + // TODO: Remove this once BlockMachine passes the right amount of context for machines with + // non-rectangular block shapes. + let is_rectangular = code + .iter() + .filter_map(|effect| match effect { + Effect::Assignment(v, _) => Some(v), + _ => None, + }) + .filter_map(|assigned_variable| match assigned_variable { + Variable::Cell(cell) => Some(cell.row_offset), + _ => None, + }) + .all(|row_offset| row_offset >= 0 && row_offset < self.block_size as i32); + if !is_rectangular { + log::debug!("Filtering out code for non-rectangular block shape"); + } + is_rectangular.then_some(code) + }) .map(|code| { log::trace!("Generated code ({} steps)", code.len()); let known_inputs = cache_key diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index f591c7f2f3..88fb1dffd4 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -1,4 +1,7 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + fmt::{Display, Formatter}, +}; use bit_vec::BitVec; use itertools::Itertools; @@ -41,6 +44,23 @@ pub struct WitgenInference<'a, T: FieldElement, FixedEval: FixedEvaluator> { code: Vec>, } +#[derive(Debug, Clone, Copy)] +pub enum Value { + Concrete(T), + Known, + Unknown, +} + +impl Display for Value { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Value::Concrete(v) => write!(f, "{v}"), + Value::Known => write!(f, ""), + Value::Unknown => write!(f, "???"), + } + } +} + impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, FixedEval> { pub fn new( fixed_data: &'a FixedData<'a, T>, @@ -61,10 +81,25 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F self.code } + pub fn known_variables(&self) -> &HashSet { + &self.known_variables + } + pub fn is_known(&self, variable: &Variable) -> bool { self.known_variables.contains(variable) } + pub fn value(&self, variable: &Variable) -> Value { + let rc = self.range_constraint(variable); + if let Some(val) = rc.as_ref().and_then(|rc| rc.try_to_single_value()) { + Value::Concrete(val) + } else if self.is_known(variable) { + Value::Known + } else { + Value::Unknown + } + } + /// Process an identity on a certain row. pub fn process_identity>( &mut self, @@ -272,7 +307,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F /// Adds a range constraint to the set of derived range constraints. Returns true if progress was made. fn add_range_constraint(&mut self, variable: Variable, rc: RangeConstraint) -> bool { let rc = self - .range_constraint(variable.clone()) + .range_constraint(&variable) .map_or(rc.clone(), |existing_rc| existing_rc.conjunction(&rc)); if !self.known_variables.contains(&variable) { if let Some(v) = rc.try_to_single_value() { @@ -292,7 +327,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F /// Returns the current best-known range constraint on the given variable /// combining global range constraints and newly derived local range constraints. - fn range_constraint(&self, variable: Variable) -> Option> { + fn range_constraint(&self, variable: &Variable) -> Option> { variable .try_to_witness_poly_id() .and_then(|poly_id| { @@ -305,7 +340,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F }) }) .iter() - .chain(self.derived_range_constraints.get(&variable)) + .chain(self.derived_range_constraints.get(variable)) .cloned() .reduce(|gc, rc| gc.conjunction(&rc)) } @@ -382,13 +417,14 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Evaluator<'a, T, FixedEv pub fn evaluate_variable(&self, variable: Variable) -> AffineSymbolicExpression { // If a variable is known and has a compile-time constant value, // that value is stored in the range constraints. - let rc = self.witgen_inference.range_constraint(variable.clone()); - if let Some(val) = rc.as_ref().and_then(|rc| rc.try_to_single_value()) { - val.into() - } else if !self.only_concrete_known && self.witgen_inference.is_known(&variable) { - AffineSymbolicExpression::from_known_symbol(variable, rc) - } else { - AffineSymbolicExpression::from_unknown_variable(variable, rc) + let rc = self.witgen_inference.range_constraint(&variable); + match self.witgen_inference.value(&variable) { + Value::Concrete(val) => val.into(), + Value::Unknown => AffineSymbolicExpression::from_unknown_variable(variable, rc), + Value::Known if self.only_concrete_known => { + AffineSymbolicExpression::from_unknown_variable(variable, rc) + } + Value::Known => AffineSymbolicExpression::from_known_symbol(variable, rc), } } diff --git a/executor/src/witgen/machines/block_machine.rs b/executor/src/witgen/machines/block_machine.rs index 4dc857e0dc..d8eb2a54b9 100644 --- a/executor/src/witgen/machines/block_machine.rs +++ b/executor/src/witgen/machines/block_machine.rs @@ -401,6 +401,10 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { } } + if self.rows() + self.block_size as DegreeType > self.degree { + return Err(EvalError::RowsExhausted(self.name.clone())); + } + let known_inputs = outer_query.left.iter().map(|e| e.is_constant()).collect(); if self .function_cache @@ -429,10 +433,6 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { )); } - if self.rows() + self.block_size as DegreeType > self.degree { - return Err(EvalError::RowsExhausted(self.name.clone())); - } - let process_result = self.process(mutable_state, &mut sequence_iterator, outer_query.clone())?; @@ -481,7 +481,7 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> { let values = outer_query.prepare_for_direct_lookup(&mut input_output_data); assert!( - (self.rows() + self.block_size as DegreeType) < self.degree, + (self.rows() + self.block_size as DegreeType) <= self.degree, "Block machine is full (this should have been checked before)" ); self.data diff --git a/test_data/asm/block_machine_exact_number_of_rows.asm b/test_data/asm/block_machine_exact_number_of_rows.asm index ffff1e1e05..26a1369e23 100644 --- a/test_data/asm/block_machine_exact_number_of_rows.asm +++ b/test_data/asm/block_machine_exact_number_of_rows.asm @@ -17,15 +17,17 @@ machine Main with min_degree: 32, max_degree: 64 { reg A; ByteBinary byte_binary; - // We'll call the binary machine twice and the block size - // is 4, so we need exactly 8 rows. - Binary binary(byte_binary, 8, 8); + // We'll call the binary machine 4 times and the block size + // is 4, so we need exactly 16 rows. + Binary binary(byte_binary, 16, 16); instr and X0, X1 -> X2 link ~> X2 = binary.and(X0, X1); function main { A <== and(0xaaaaaaaa, 0xaaaaaaaa); + A <== and(0x55555555, 0x55555555); + A <== and(0x00000000, 0xffffffff); A <== and(0xffffffff, 0xffffffff); return; diff --git a/test_data/pil/binary.pil b/test_data/pil/binary.pil index ddb29c91f2..0219bdbedc 100644 --- a/test_data/pil/binary.pil +++ b/test_data/pil/binary.pil @@ -1,9 +1,9 @@ // A compiled version of std/machines/large_field/binary.asm namespace main(128); - col witness a, b, c; + col witness binary_op, a, b, c; // Dummy connection constraint - [0, a, b, c] is main_binary::latch * main_binary::sel[0] $ [main_binary::operation_id, main_binary::A, main_binary::B, main_binary::C]; + [binary_op, a, b, c] is main_binary::latch * main_binary::sel[0] $ [main_binary::operation_id, main_binary::A, main_binary::B, main_binary::C]; namespace main_binary(128); col witness operation_id;