From e77d3801c1decff039fd0ec6bbeb55ed734357fb Mon Sep 17 00:00:00 2001 From: Thibaut Schaeffer Date: Fri, 31 Jan 2025 18:43:26 +0100 Subject: [PATCH] Fix `contains_next_ref` (#2409) --- ast/src/analyzed/contains_next_ref.rs | 101 ++++++++++++++++++ ast/src/analyzed/mod.rs | 23 +--- ast/src/parsed/mod.rs | 1 + backend/src/composite/mod.rs | 7 +- .../src/witgen/data_structures/identity.rs | 5 - executor/src/witgen/global_constraints.rs | 4 +- .../src/witgen/jit/block_machine_processor.rs | 30 ++++-- .../src/witgen/jit/single_step_processor.rs | 5 +- executor/src/witgen/machines/mod.rs | 9 +- executor/src/witgen/processor.rs | 5 +- executor/src/witgen/vm_processor.rs | 10 +- 11 files changed, 148 insertions(+), 52 deletions(-) create mode 100644 ast/src/analyzed/contains_next_ref.rs diff --git a/ast/src/analyzed/contains_next_ref.rs b/ast/src/analyzed/contains_next_ref.rs new file mode 100644 index 0000000000..ed24bc5320 --- /dev/null +++ b/ast/src/analyzed/contains_next_ref.rs @@ -0,0 +1,101 @@ +use std::collections::BTreeMap; + +use crate::parsed::visitor::AllChildren; + +use super::{AlgebraicExpression, AlgebraicReferenceThin, PolynomialType}; + +/// Auxiliary function to check if an AST node contains a next reference +/// References to intermediate values are resolved recursively +fn contains_next_ref_with_intermediates>>( + e: &E, + intermediate_definitions: &BTreeMap>, + intermediates_cache: &mut BTreeMap, +) -> bool { + e.all_children() + .filter_map(|e| { + if let AlgebraicExpression::Reference(reference) = e { + Some(reference) + } else { + None + } + }) + .any(|reference| { + if reference.next { + true + } else if reference.poly_id.ptype == PolynomialType::Intermediate { + let reference = reference.to_thin(); + intermediates_cache + .get(&reference) + .cloned() + .unwrap_or_else(|| { + let result = contains_next_ref_with_intermediates( + &intermediate_definitions[&reference], + intermediate_definitions, + intermediates_cache, + ); + intermediates_cache.insert(reference, result); + result + }) + } else { + false + } + }) +} + +pub trait ContainsNextRef { + fn contains_next_ref( + &self, + intermediate_definitions: &BTreeMap>, + ) -> bool; +} + +impl>> ContainsNextRef for E { + fn contains_next_ref( + &self, + intermediate_definitions: &BTreeMap>, + ) -> bool { + contains_next_ref_with_intermediates(self, intermediate_definitions, &mut BTreeMap::new()) + } +} + +#[cfg(test)] +mod tests { + use std::iter::once; + + use crate::analyzed::{ + contains_next_ref::ContainsNextRef, AlgebraicExpression, AlgebraicReference, PolyID, + PolynomialType, + }; + + #[test] + fn contains_next_ref() { + let column = AlgebraicExpression::::Reference(AlgebraicReference { + name: "column".to_string(), + poly_id: PolyID { + id: 0, + ptype: PolynomialType::Committed, + }, + next: false, + }); + + let one = AlgebraicExpression::Number(1); + + let expr = column.clone() + one.clone() * column.clone(); + assert!(!expr.contains_next_ref(&Default::default())); + + let expr = column.clone() + one.clone() * column.clone().next().unwrap(); + assert!(expr.contains_next_ref(&Default::default())); + + let inter = AlgebraicReference { + name: "inter".to_string(), + poly_id: PolyID { + id: 1, + ptype: PolynomialType::Intermediate, + }, + next: false, + }; + let intermediates = once((inter.to_thin(), column.clone().next().unwrap())).collect(); + let expr = column.clone() + one.clone() * AlgebraicExpression::Reference(inter.clone()); + assert!(expr.contains_next_ref(&intermediates)); + } +} diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 6fc9e60964..d53ae67fe4 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -1,3 +1,4 @@ +mod contains_next_ref; mod display; pub mod visitor; @@ -24,6 +25,7 @@ use crate::parsed::{ self, ArrayExpression, EnumDeclaration, EnumVariant, NamedType, SourceReference, TraitDeclaration, TraitImplementation, TypeDeclaration, }; +pub use contains_next_ref::ContainsNextRef; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)] pub enum StatementIdentifier { @@ -1058,10 +1060,6 @@ pub enum Identity { } impl Identity { - pub fn contains_next_ref(&self) -> bool { - self.children().any(|e| e.contains_next_ref()) - } - pub fn degree( &self, intermediate_polynomials: &BTreeMap>, @@ -1163,14 +1161,6 @@ pub enum IdentityKind { PhantomBusInteraction, } -impl SelectedExpressions { - /// @returns true if the expression contains a reference to a next value of a - /// (witness or fixed) column - pub fn contains_next_ref(&self) -> bool { - self.children().any(|e| e.contains_next_ref()) - } -} - pub type Expression = parsed::Expression; pub type TypedExpression = crate::parsed::TypedExpression; @@ -1580,15 +1570,6 @@ impl AlgebraicExpression { pub fn new_unary(op: AlgebraicUnaryOperator, expr: Self) -> Self { AlgebraicUnaryOperation::new(op, expr).into() } - - /// @returns true if the expression contains a reference to a next value of a - /// (witness or fixed) column - pub fn contains_next_ref(&self) -> bool { - self.expr_any(|e| match e { - AlgebraicExpression::Reference(poly) => poly.next, - _ => false, - }) - } } impl ops::Add for AlgebraicExpression { diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index b3988024a7..65742c6727 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -929,6 +929,7 @@ impl Expression { } /// Returns true if the expression contains a reference to a next value + // TODO: Is it fine that this does not check references to intermediate polynomials? pub fn contains_next_ref(&self) -> bool { self.expr_any(|e| { matches!( diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index e202c78a7f..a1e6c7124b 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -8,7 +8,7 @@ use std::{ }; use itertools::Itertools; -use powdr_ast::analyzed::Analyzed; +use powdr_ast::analyzed::{Analyzed, ContainsNextRef}; use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns}; use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback}; use powdr_number::{DegreeType, FieldElement}; @@ -178,7 +178,10 @@ fn log_machine_stats(machine_name: &str, pil: &Analyzed) { .map(|i| i.degree(&intermediate_definitions)) .max() .unwrap_or(0); - let uses_next_operator = pil.identities.iter().any(|i| i.contains_next_ref()); + let uses_next_operator = pil + .identities + .iter() + .any(|i| i.contains_next_ref(&intermediate_definitions)); // This assumes that we'll always at least once reference the current row let number_of_rotations = 1 + if uses_next_operator { 1 } else { 0 }; let num_identities_by_kind = pil diff --git a/executor/src/witgen/data_structures/identity.rs b/executor/src/witgen/data_structures/identity.rs index 19e0b06fa2..081eb90909 100644 --- a/executor/src/witgen/data_structures/identity.rs +++ b/executor/src/witgen/data_structures/identity.rs @@ -166,11 +166,6 @@ impl Children> for Identity { } impl Identity { - pub fn contains_next_ref(&self) -> bool { - // TODO: This does not check the definitions of intermediate polynomials! - self.children().any(|e| e.contains_next_ref()) - } - pub fn id(&self) -> u64 { match self { Identity::Polynomial(i) => i.id, diff --git a/executor/src/witgen/global_constraints.rs b/executor/src/witgen/global_constraints.rs index 9147a1e456..f4dd4fef10 100644 --- a/executor/src/witgen/global_constraints.rs +++ b/executor/src/witgen/global_constraints.rs @@ -6,7 +6,7 @@ use num_traits::Zero; use num_traits::One; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, - AlgebraicReference, AlgebraicReferenceThin, PolyID, PolynomialType, + AlgebraicReference, AlgebraicReferenceThin, ContainsNextRef, PolyID, PolynomialType, }; use powdr_number::FieldElement; @@ -361,7 +361,7 @@ fn try_transfer_constraints( expr: &Expression, known_constraints: &BTreeMap>, ) -> Vec<(PolyID, RangeConstraint)> { - if expr.contains_next_ref() { + if expr.contains_next_ref(intermediate_definitions) { return vec![]; } diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index 1087ed6658..b82cb833f2 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use bit_vec::BitVec; use itertools::Itertools; -use powdr_ast::analyzed::{PolyID, PolynomialType}; +use powdr_ast::analyzed::{ContainsNextRef, PolyID, PolynomialType}; use powdr_number::FieldElement; use crate::witgen::{ @@ -81,12 +81,14 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { witgen.assign_variable(expr, self.latch_row as i32, Variable::Param(index)); } + let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions(); + // Compute the identity-row-pairs we consider. let have_next_ref = self .machine_parts .identities .iter() - .any(|id| id.contains_next_ref()); + .any(|id| id.contains_next_ref(&intermediate_definitions)); let start_row = if !have_next_ref { // No identity contains a next reference - we do not need to consider row -1, // and the block has to be rectangular-shaped. @@ -96,15 +98,21 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { // We iterate over all rows of the block +/- one row. -1 }; - let identities = (start_row..self.block_size as i32).flat_map(move |row| { - self.machine_parts.identities.iter().filter_map(move |&id| { - // Filter out identities with next references on the last row. - if row as usize == self.block_size - 1 && id.contains_next_ref() { - None - } else { - Some((id, row)) - } - }) + let identities = (start_row..self.block_size as i32).flat_map(|row| { + self.machine_parts + .identities + .iter() + .filter_map(|id| { + // Filter out identities with next references on the last row. + if row as usize == self.block_size - 1 + && id.contains_next_ref(&intermediate_definitions) + { + None + } else { + Some((*id, row)) + } + }) + .collect_vec() }); let requested_known = known_args diff --git a/executor/src/witgen/jit/single_step_processor.rs b/executor/src/witgen/jit/single_step_processor.rs index 2034fcb65b..a342781a24 100644 --- a/executor/src/witgen/jit/single_step_processor.rs +++ b/executor/src/witgen/jit/single_step_processor.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use powdr_ast::analyzed::{ - AlgebraicExpression as Expression, AlgebraicReference, PolyID, PolynomialType, + AlgebraicExpression as Expression, AlgebraicReference, ContainsNextRef, PolyID, PolynomialType, }; use powdr_number::FieldElement; @@ -37,6 +37,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> { &self, can_process: impl CanProcessCall, ) -> Result>, String> { + let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions(); let all_witnesses = self .machine_parts .witnesses @@ -56,7 +57,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> { .identities .iter() .flat_map(|&id| { - if id.contains_next_ref() { + if id.contains_next_ref(&intermediate_definitions) { vec![(id, 0)] } else { // Process it on both rows, but mark it as complete on row 0, diff --git a/executor/src/witgen/machines/mod.rs b/executor/src/witgen/machines/mod.rs index 09a36aaf64..9617a6f5e6 100644 --- a/executor/src/witgen/machines/mod.rs +++ b/executor/src/witgen/machines/mod.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use bit_vec::BitVec; use dynamic_machine::DynamicMachine; -use powdr_ast::analyzed::{self, AlgebraicExpression, DegreeRange, PolyID}; +use powdr_ast::analyzed::{self, AlgebraicExpression, ContainsNextRef, DegreeRange, PolyID}; use powdr_number::DegreeType; use powdr_number::FieldElement; @@ -361,10 +361,15 @@ impl<'a, T: FieldElement> MachineParts<'a, T> { /// Returns a copy of the machine parts but only containing identities that /// have a "next" reference. pub fn restricted_to_identities_with_next_references(&self) -> MachineParts<'a, T> { + let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions(); let identities_with_next_reference = self .identities .iter() - .filter_map(|identity| identity.contains_next_ref().then_some(*identity)) + .filter_map(|identity| { + identity + .contains_next_ref(&intermediate_definitions) + .then_some(*identity) + }) .collect::>(); Self { identities: identities_with_next_reference, diff --git a/executor/src/witgen/processor.rs b/executor/src/witgen/processor.rs index 3483834357..58383e6eaa 100644 --- a/executor/src/witgen/processor.rs +++ b/executor/src/witgen/processor.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use num_traits::One; use powdr_ast::analyzed::{ - AlgebraicExpression as Expression, AlgebraicReference, PolyID, PolynomialType, + AlgebraicExpression as Expression, AlgebraicReference, ContainsNextRef, PolyID, PolynomialType, }; use powdr_number::{DegreeType, FieldElement}; @@ -293,7 +293,8 @@ Known values in current row (local: {row_index}, global {global_row_index}): ", self.data[row_index].render_values(false, self.parts) ); - if identity.contains_next_ref() { + let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions(); + if identity.contains_next_ref(&intermediate_definitions) { error += &format!( "Known values in next row (local: {}, global {}):\n{}\n", row_index + 1, diff --git a/executor/src/witgen/vm_processor.rs b/executor/src/witgen/vm_processor.rs index 0258dcc9aa..11860fdd54 100644 --- a/executor/src/witgen/vm_processor.rs +++ b/executor/src/witgen/vm_processor.rs @@ -1,6 +1,6 @@ use indicatif::{ProgressBar, ProgressStyle}; use itertools::Itertools; -use powdr_ast::analyzed::DegreeRange; +use powdr_ast::analyzed::{ContainsNextRef, DegreeRange}; use powdr_ast::indent; use powdr_number::{DegreeType, FieldElement}; use std::cmp::max; @@ -87,10 +87,10 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback> VmProcessor<'a, 'c, T, Q> { ) -> Self { let degree_range = parts.common_degree_range(); - let (identities_with_next, identities_without_next): (Vec<_>, Vec<_>) = parts - .identities - .iter() - .partition(|identity| identity.contains_next_ref()); + let (identities_with_next, identities_without_next): (Vec<_>, Vec<_>) = + parts.identities.iter().partition(|identity| { + identity.contains_next_ref(&fixed_data.analyzed.intermediate_definitions()) + }); let processor = Processor::new( row_offset, mutable_data,