Skip to content

Commit

Permalink
Fix contains_next_ref (#2409)
Browse files Browse the repository at this point in the history
  • Loading branch information
Schaeff authored Jan 31, 2025
1 parent 9dff547 commit e77d380
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 52 deletions.
101 changes: 101 additions & 0 deletions ast/src/analyzed/contains_next_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use std::collections::BTreeMap;

use crate::parsed::visitor::AllChildren;

use super::{AlgebraicExpression, AlgebraicReferenceThin, PolynomialType};

/// Auxiliary function to check if an AST node contains a next reference
/// References to intermediate values are resolved recursively
fn contains_next_ref_with_intermediates<T, E: AllChildren<AlgebraicExpression<T>>>(
e: &E,
intermediate_definitions: &BTreeMap<AlgebraicReferenceThin, AlgebraicExpression<T>>,
intermediates_cache: &mut BTreeMap<AlgebraicReferenceThin, bool>,
) -> bool {
e.all_children()
.filter_map(|e| {
if let AlgebraicExpression::Reference(reference) = e {
Some(reference)
} else {
None
}
})
.any(|reference| {
if reference.next {
true
} else if reference.poly_id.ptype == PolynomialType::Intermediate {
let reference = reference.to_thin();
intermediates_cache
.get(&reference)
.cloned()
.unwrap_or_else(|| {
let result = contains_next_ref_with_intermediates(
&intermediate_definitions[&reference],
intermediate_definitions,
intermediates_cache,
);
intermediates_cache.insert(reference, result);
result
})
} else {
false
}
})
}

pub trait ContainsNextRef<T> {
fn contains_next_ref(
&self,
intermediate_definitions: &BTreeMap<AlgebraicReferenceThin, AlgebraicExpression<T>>,
) -> bool;
}

impl<T, E: AllChildren<AlgebraicExpression<T>>> ContainsNextRef<T> for E {
fn contains_next_ref(
&self,
intermediate_definitions: &BTreeMap<AlgebraicReferenceThin, AlgebraicExpression<T>>,
) -> bool {
contains_next_ref_with_intermediates(self, intermediate_definitions, &mut BTreeMap::new())
}
}

#[cfg(test)]
mod tests {
use std::iter::once;

use crate::analyzed::{
contains_next_ref::ContainsNextRef, AlgebraicExpression, AlgebraicReference, PolyID,
PolynomialType,
};

#[test]
fn contains_next_ref() {
let column = AlgebraicExpression::<i32>::Reference(AlgebraicReference {
name: "column".to_string(),
poly_id: PolyID {
id: 0,
ptype: PolynomialType::Committed,
},
next: false,
});

let one = AlgebraicExpression::Number(1);

let expr = column.clone() + one.clone() * column.clone();
assert!(!expr.contains_next_ref(&Default::default()));

let expr = column.clone() + one.clone() * column.clone().next().unwrap();
assert!(expr.contains_next_ref(&Default::default()));

let inter = AlgebraicReference {
name: "inter".to_string(),
poly_id: PolyID {
id: 1,
ptype: PolynomialType::Intermediate,
},
next: false,
};
let intermediates = once((inter.to_thin(), column.clone().next().unwrap())).collect();
let expr = column.clone() + one.clone() * AlgebraicExpression::Reference(inter.clone());
assert!(expr.contains_next_ref(&intermediates));
}
}
23 changes: 2 additions & 21 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod contains_next_ref;
mod display;
pub mod visitor;

Expand All @@ -24,6 +25,7 @@ use crate::parsed::{
self, ArrayExpression, EnumDeclaration, EnumVariant, NamedType, SourceReference,
TraitDeclaration, TraitImplementation, TypeDeclaration,
};
pub use contains_next_ref::ContainsNextRef;

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
pub enum StatementIdentifier {
Expand Down Expand Up @@ -1058,10 +1060,6 @@ pub enum Identity<T> {
}

impl<T> Identity<T> {
pub fn contains_next_ref(&self) -> bool {
self.children().any(|e| e.contains_next_ref())
}

pub fn degree(
&self,
intermediate_polynomials: &BTreeMap<AlgebraicReferenceThin, AlgebraicExpression<T>>,
Expand Down Expand Up @@ -1163,14 +1161,6 @@ pub enum IdentityKind {
PhantomBusInteraction,
}

impl<T> SelectedExpressions<T> {
/// @returns true if the expression contains a reference to a next value of a
/// (witness or fixed) column
pub fn contains_next_ref(&self) -> bool {
self.children().any(|e| e.contains_next_ref())
}
}

pub type Expression = parsed::Expression<Reference>;
pub type TypedExpression = crate::parsed::TypedExpression<Reference, u64>;

Expand Down Expand Up @@ -1580,15 +1570,6 @@ impl<T> AlgebraicExpression<T> {
pub fn new_unary(op: AlgebraicUnaryOperator, expr: Self) -> Self {
AlgebraicUnaryOperation::new(op, expr).into()
}

/// @returns true if the expression contains a reference to a next value of a
/// (witness or fixed) column
pub fn contains_next_ref(&self) -> bool {
self.expr_any(|e| match e {
AlgebraicExpression::Reference(poly) => poly.next,
_ => false,
})
}
}

impl<T> ops::Add for AlgebraicExpression<T> {
Expand Down
1 change: 1 addition & 0 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ impl<R> Expression<R> {
}

/// Returns true if the expression contains a reference to a next value
// TODO: Is it fine that this does not check references to intermediate polynomials?
pub fn contains_next_ref(&self) -> bool {
self.expr_any(|e| {
matches!(
Expand Down
7 changes: 5 additions & 2 deletions backend/src/composite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
};

use itertools::Itertools;
use powdr_ast::analyzed::Analyzed;
use powdr_ast::analyzed::{Analyzed, ContainsNextRef};
use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns};
use powdr_executor::{constant_evaluator::VariablySizedColumn, witgen::WitgenCallback};
use powdr_number::{DegreeType, FieldElement};
Expand Down Expand Up @@ -178,7 +178,10 @@ fn log_machine_stats<T: FieldElement>(machine_name: &str, pil: &Analyzed<T>) {
.map(|i| i.degree(&intermediate_definitions))
.max()
.unwrap_or(0);
let uses_next_operator = pil.identities.iter().any(|i| i.contains_next_ref());
let uses_next_operator = pil
.identities
.iter()
.any(|i| i.contains_next_ref(&intermediate_definitions));
// This assumes that we'll always at least once reference the current row
let number_of_rotations = 1 + if uses_next_operator { 1 } else { 0 };
let num_identities_by_kind = pil
Expand Down
5 changes: 0 additions & 5 deletions executor/src/witgen/data_structures/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,6 @@ impl<T> Children<AlgebraicExpression<T>> for Identity<T> {
}

impl<T> Identity<T> {
pub fn contains_next_ref(&self) -> bool {
// TODO: This does not check the definitions of intermediate polynomials!
self.children().any(|e| e.contains_next_ref())
}

pub fn id(&self) -> u64 {
match self {
Identity::Polynomial(i) => i.id,
Expand Down
4 changes: 2 additions & 2 deletions executor/src/witgen/global_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use num_traits::Zero;
use num_traits::One;
use powdr_ast::analyzed::{
AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression,
AlgebraicReference, AlgebraicReferenceThin, PolyID, PolynomialType,
AlgebraicReference, AlgebraicReferenceThin, ContainsNextRef, PolyID, PolynomialType,
};

use powdr_number::FieldElement;
Expand Down Expand Up @@ -361,7 +361,7 @@ fn try_transfer_constraints<T: FieldElement>(
expr: &Expression<T>,
known_constraints: &BTreeMap<PolyID, RangeConstraint<T>>,
) -> Vec<(PolyID, RangeConstraint<T>)> {
if expr.contains_next_ref() {
if expr.contains_next_ref(intermediate_definitions) {
return vec![];
}

Expand Down
30 changes: 19 additions & 11 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashSet;

use bit_vec::BitVec;
use itertools::Itertools;
use powdr_ast::analyzed::{PolyID, PolynomialType};
use powdr_ast::analyzed::{ContainsNextRef, PolyID, PolynomialType};
use powdr_number::FieldElement;

use crate::witgen::{
Expand Down Expand Up @@ -81,12 +81,14 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
witgen.assign_variable(expr, self.latch_row as i32, Variable::Param(index));
}

let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions();

// Compute the identity-row-pairs we consider.
let have_next_ref = self
.machine_parts
.identities
.iter()
.any(|id| id.contains_next_ref());
.any(|id| id.contains_next_ref(&intermediate_definitions));
let start_row = if !have_next_ref {
// No identity contains a next reference - we do not need to consider row -1,
// and the block has to be rectangular-shaped.
Expand All @@ -96,15 +98,21 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
// We iterate over all rows of the block +/- one row.
-1
};
let identities = (start_row..self.block_size as i32).flat_map(move |row| {
self.machine_parts.identities.iter().filter_map(move |&id| {
// Filter out identities with next references on the last row.
if row as usize == self.block_size - 1 && id.contains_next_ref() {
None
} else {
Some((id, row))
}
})
let identities = (start_row..self.block_size as i32).flat_map(|row| {
self.machine_parts
.identities
.iter()
.filter_map(|id| {
// Filter out identities with next references on the last row.
if row as usize == self.block_size - 1
&& id.contains_next_ref(&intermediate_definitions)
{
None
} else {
Some((*id, row))
}
})
.collect_vec()
});

let requested_known = known_args
Expand Down
5 changes: 3 additions & 2 deletions executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use itertools::Itertools;
use powdr_ast::analyzed::{
AlgebraicExpression as Expression, AlgebraicReference, PolyID, PolynomialType,
AlgebraicExpression as Expression, AlgebraicReference, ContainsNextRef, PolyID, PolynomialType,
};
use powdr_number::FieldElement;

Expand Down Expand Up @@ -37,6 +37,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
&self,
can_process: impl CanProcessCall<T>,
) -> Result<Vec<Effect<T, Variable>>, String> {
let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions();
let all_witnesses = self
.machine_parts
.witnesses
Expand All @@ -56,7 +57,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
.identities
.iter()
.flat_map(|&id| {
if id.contains_next_ref() {
if id.contains_next_ref(&intermediate_definitions) {
vec![(id, 0)]
} else {
// Process it on both rows, but mark it as complete on row 0,
Expand Down
9 changes: 7 additions & 2 deletions executor/src/witgen/machines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Display;

use bit_vec::BitVec;
use dynamic_machine::DynamicMachine;
use powdr_ast::analyzed::{self, AlgebraicExpression, DegreeRange, PolyID};
use powdr_ast::analyzed::{self, AlgebraicExpression, ContainsNextRef, DegreeRange, PolyID};

use powdr_number::DegreeType;
use powdr_number::FieldElement;
Expand Down Expand Up @@ -361,10 +361,15 @@ impl<'a, T: FieldElement> MachineParts<'a, T> {
/// Returns a copy of the machine parts but only containing identities that
/// have a "next" reference.
pub fn restricted_to_identities_with_next_references(&self) -> MachineParts<'a, T> {
let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions();
let identities_with_next_reference = self
.identities
.iter()
.filter_map(|identity| identity.contains_next_ref().then_some(*identity))
.filter_map(|identity| {
identity
.contains_next_ref(&intermediate_definitions)
.then_some(*identity)
})
.collect::<Vec<_>>();
Self {
identities: identities_with_next_reference,
Expand Down
5 changes: 3 additions & 2 deletions executor/src/witgen/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::BTreeMap;

use num_traits::One;
use powdr_ast::analyzed::{
AlgebraicExpression as Expression, AlgebraicReference, PolyID, PolynomialType,
AlgebraicExpression as Expression, AlgebraicReference, ContainsNextRef, PolyID, PolynomialType,
};

use powdr_number::{DegreeType, FieldElement};
Expand Down Expand Up @@ -293,7 +293,8 @@ Known values in current row (local: {row_index}, global {global_row_index}):
",
self.data[row_index].render_values(false, self.parts)
);
if identity.contains_next_ref() {
let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions();
if identity.contains_next_ref(&intermediate_definitions) {
error += &format!(
"Known values in next row (local: {}, global {}):\n{}\n",
row_index + 1,
Expand Down
10 changes: 5 additions & 5 deletions executor/src/witgen/vm_processor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use indicatif::{ProgressBar, ProgressStyle};
use itertools::Itertools;
use powdr_ast::analyzed::DegreeRange;
use powdr_ast::analyzed::{ContainsNextRef, DegreeRange};
use powdr_ast::indent;
use powdr_number::{DegreeType, FieldElement};
use std::cmp::max;
Expand Down Expand Up @@ -87,10 +87,10 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'c, T, Q> {
) -> Self {
let degree_range = parts.common_degree_range();

let (identities_with_next, identities_without_next): (Vec<_>, Vec<_>) = parts
.identities
.iter()
.partition(|identity| identity.contains_next_ref());
let (identities_with_next, identities_without_next): (Vec<_>, Vec<_>) =
parts.identities.iter().partition(|identity| {
identity.contains_next_ref(&fixed_data.analyzed.intermediate_definitions())
});
let processor = Processor::new(
row_offset,
mutable_data,
Expand Down

0 comments on commit e77d380

Please sign in to comment.