diff --git a/ast/src/analyzed/expression_evaluator.rs b/ast/src/analyzed/expression_evaluator.rs new file mode 100644 index 0000000000..928a03ce16 --- /dev/null +++ b/ast/src/analyzed/expression_evaluator.rs @@ -0,0 +1,264 @@ +use core::ops::{Add, Mul, Sub}; +use std::collections::BTreeMap; +use std::ops::Neg; + +use crate::analyzed::{ + AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, + AlgebraicReference, AlgebraicReferenceThin, AlgebraicUnaryOperation, AlgebraicUnaryOperator, + Analyzed, Challenge, PolyID, PolynomialType, +}; +use powdr_number::FieldElement; +use powdr_number::LargeInt; + +/// Accessor for terminal symbols. +pub trait TerminalAccess { + fn get(&self, _poly_ref: &AlgebraicReference) -> T { + unimplemented!(); + } + fn get_public(&self, _public: &str) -> T { + unimplemented!(); + } + fn get_challenge(&self, _challenge: &Challenge) -> T { + unimplemented!(); + } +} + +/// A simple container for trace values. +pub struct OwnedTerminalValues { + pub trace: BTreeMap>, + pub public_values: BTreeMap, + pub challenge_values: BTreeMap, +} + +/// A view into the trace values for a single row. +pub struct RowValues<'a, F> { + values: &'a OwnedTerminalValues, + row: usize, +} + +impl OwnedTerminalValues { + pub fn new( + pil: &Analyzed, + witness_columns: Vec<(String, Vec)>, + fixed_columns: Vec<(String, Vec)>, + ) -> Self { + let mut columns_by_name = witness_columns + .into_iter() + .chain(fixed_columns) + .collect::>(); + let trace = pil + .committed_polys_in_source_order() + .chain(pil.constant_polys_in_source_order()) + .flat_map(|(symbol, _)| symbol.array_elements()) + .filter_map(|(name, poly_id)| { + columns_by_name + .remove(&name) + .map(|column| (poly_id, column)) + }) + .collect(); + Self { + trace, + public_values: Default::default(), + challenge_values: Default::default(), + } + } + + pub fn with_publics(mut self, publics: Vec<(String, F)>) -> Self { + self.public_values = publics.into_iter().collect(); + self + } + + pub fn with_challenges(mut self, challenges: BTreeMap) -> Self { + self.challenge_values = challenges; + self + } + + pub fn height(&self) -> usize { + self.trace.values().next().map(|v| v.len()).unwrap() + } + + pub fn row(&self, row: usize) -> RowValues { + RowValues { values: self, row } + } +} + +impl> TerminalAccess for RowValues<'_, F> { + fn get(&self, column: &AlgebraicReference) -> T { + match column.poly_id.ptype { + PolynomialType::Committed | PolynomialType::Constant => { + let column_values = self.values.trace.get(&column.poly_id).unwrap(); + let row = (self.row + column.next as usize) % column_values.len(); + column_values[row].into() + } + PolynomialType::Intermediate => unreachable!( + "Intermediate polynomials should have been handled by ExpressionEvaluator" + ), + } + } + + fn get_public(&self, public: &str) -> T { + self.values.public_values[public].into() + } + + fn get_challenge(&self, challenge: &Challenge) -> T { + self.values.challenge_values[&challenge.id].into() + } +} + +pub trait ExpressionWalkerCallback { + fn handle_binary_operation( + &self, + left: T, + op: &AlgebraicBinaryOperator, + right: T, + right_expr: &Expression, + ) -> T; + fn handle_unary_operation(&self, op: &AlgebraicUnaryOperator, arg: T) -> T; + fn handle_number(&self, fe: &F) -> T; +} + +/// Evaluates an algebraic expression to a value. +pub struct ExpressionWalker<'a, T, Expr, TA, C> { + terminal_access: TA, + intermediate_definitions: &'a BTreeMap>, + /// Maps intermediate reference to their evaluation. Updated throughout the lifetime of the + /// ExpressionEvaluator. + intermediates_cache: BTreeMap, + callback: C, +} + +impl<'a, T, Expr: Clone, TA, C> ExpressionWalker<'a, T, Expr, TA, C> +where + TA: TerminalAccess, + C: ExpressionWalkerCallback, +{ + /// Create a new expression evaluator with custom expression converters. + pub fn new( + terminal_access: TA, + intermediate_definitions: &'a BTreeMap>, + callback: C, + ) -> Self { + Self { + terminal_access, + intermediate_definitions, + intermediates_cache: Default::default(), + callback, + } + } + + pub fn evaluate(&mut self, expr: &'a Expression) -> Expr { + match expr { + Expression::Reference(reference) => match reference.poly_id.ptype { + PolynomialType::Committed => self.terminal_access.get(reference), + PolynomialType::Constant => self.terminal_access.get(reference), + PolynomialType::Intermediate => { + let reference = reference.to_thin(); + let value = self.intermediates_cache.get(&reference).cloned(); + match value { + Some(v) => v, + None => { + let definition = self.intermediate_definitions.get(&reference).unwrap(); + let result = self.evaluate(definition); + self.intermediates_cache.insert(reference, result.clone()); + result + } + } + } + }, + Expression::PublicReference(public) => self.terminal_access.get_public(public), + Expression::Challenge(challenge) => self.terminal_access.get_challenge(challenge), + Expression::Number(n) => self.callback.handle_number(n), + Expression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { + let left_value = self.evaluate(left); + let right_value = self.evaluate(right); + self.callback + .handle_binary_operation(left_value, op, right_value, right) + } + Expression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => { + let arg = self.evaluate(expr); + self.callback.handle_unary_operation(op, arg) + } + } + } +} + +struct EvaluatorCallback { + to_expr: fn(&F) -> T, +} + +impl ExpressionWalkerCallback for EvaluatorCallback +where + Expr: Clone + Add + Sub + Mul + Neg, + F: FieldElement, +{ + fn handle_binary_operation( + &self, + left: Expr, + op: &AlgebraicBinaryOperator, + right: Expr, + right_expr: &Expression, + ) -> Expr { + match op { + AlgebraicBinaryOperator::Add => left + right, + AlgebraicBinaryOperator::Sub => left - right, + AlgebraicBinaryOperator::Mul => left * right, + AlgebraicBinaryOperator::Pow => match right_expr { + Expression::Number(n) => (0u32..n.to_integer().try_into_u32().unwrap()) + .fold((self.to_expr)(&F::one()), |acc, _| acc * left.clone()), + _ => unimplemented!("pow with non-constant exponent"), + }, + } + } + + fn handle_unary_operation(&self, op: &AlgebraicUnaryOperator, arg: Expr) -> Expr { + match op { + AlgebraicUnaryOperator::Minus => -arg, + } + } + + fn handle_number(&self, fe: &F) -> Expr { + (self.to_expr)(fe) + } +} + +/// Evaluates an algebraic expression to a value. +pub struct ExpressionEvaluator<'a, T, Expr, TA> { + expression_walker: ExpressionWalker<'a, T, Expr, TA, EvaluatorCallback>, +} + +impl<'a, T, TA> ExpressionEvaluator<'a, T, T, TA> +where + TA: TerminalAccess, + T: FieldElement, +{ + /// Create a new expression evaluator (for the case where Expr = T). + pub fn new( + terminal_access: TA, + intermediate_definitions: &'a BTreeMap>, + ) -> Self { + Self::new_with_custom_expr(terminal_access, intermediate_definitions, |x| *x) + } +} + +impl<'a, T, Expr, TA> ExpressionEvaluator<'a, T, Expr, TA> +where + TA: TerminalAccess, + Expr: Clone + Add + Sub + Mul + Neg, + T: FieldElement, +{ + /// Create a new expression evaluator with custom expression converters. + pub fn new_with_custom_expr( + terminal_access: TA, + intermediate_definitions: &'a BTreeMap>, + to_expr: fn(&T) -> Expr, + ) -> Self { + let callback = EvaluatorCallback { to_expr }; + let expression_walker = + ExpressionWalker::new(terminal_access, intermediate_definitions, callback); + Self { expression_walker } + } + + pub fn evaluate(&mut self, expr: &'a Expression) -> Expr { + self.expression_walker.evaluate(expr) + } +} diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index adf0070545..1a63df30c6 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -9,6 +9,7 @@ use std::iter::{self, empty, once}; use std::ops::{self, ControlFlow}; use std::sync::Arc; +use expression_evaluator::{ExpressionWalker, ExpressionWalkerCallback, TerminalAccess}; use itertools::Itertools; use num_traits::One; use powdr_number::{DegreeType, FieldElement}; @@ -25,6 +26,8 @@ use crate::parsed::{ TraitDeclaration, TraitImplementation, TypeDeclaration, }; +pub mod expression_evaluator; + #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub enum StatementIdentifier { /// Either an intermediate column or a definition. @@ -1048,18 +1051,55 @@ pub enum Identity { PhantomBusInteraction(PhantomBusInteractionIdentity), } +struct NextRefCallback(); +impl ExpressionWalkerCallback for &NextRefCallback { + fn handle_binary_operation( + &self, + left: bool, + _op: &AlgebraicBinaryOperator, + right: bool, + _right_expr: &AlgebraicExpression, + ) -> bool { + left || right + } + + fn handle_unary_operation(&self, _op: &AlgebraicUnaryOperator, arg: bool) -> bool { + arg + } + + fn handle_number(&self, _fe: &T) -> bool { + false + } +} +impl TerminalAccess for &NextRefCallback { + fn get(&self, poly_ref: &AlgebraicReference) -> bool { + poly_ref.next + } + fn get_public(&self, _public: &str) -> bool { + false + } + fn get_challenge(&self, _challenge: &Challenge) -> bool { + false + } +} + impl Identity { - pub fn contains_next_ref(&self) -> bool { - self.children().any(|e| e.contains_next_ref()) + pub fn contains_next_ref( + &self, + intermediate_definitions: &BTreeMap>, + ) -> bool { + let callback = NextRefCallback(); + let mut expression_walker = + ExpressionWalker::new(&callback, intermediate_definitions, &callback); + self.children().any(|e| expression_walker.evaluate(e)) } pub fn degree( &self, - intermediate_polynomials: &BTreeMap>, + intermediate_definitions: &BTreeMap>, ) -> usize { - let mut cache = BTreeMap::new(); self.children() - .map(|e| e.degree_with_cache(intermediate_polynomials, &mut cache)) + .map(|e| e.degree(intermediate_definitions)) .max() .unwrap_or(0) } @@ -1362,6 +1402,44 @@ impl num_traits::One for AlgebraicExpression { } } +struct DegreeCallback(); +impl ExpressionWalkerCallback for &DegreeCallback { + fn handle_binary_operation( + &self, + left: usize, + op: &AlgebraicBinaryOperator, + right: usize, + _right_expr: &AlgebraicExpression, + ) -> usize { + match op { + AlgebraicBinaryOperator::Add | AlgebraicBinaryOperator::Sub => max(left, right), + AlgebraicBinaryOperator::Mul => left + right, + AlgebraicBinaryOperator::Pow => todo!(), + } + } + + fn handle_unary_operation(&self, _op: &AlgebraicUnaryOperator, arg: usize) -> usize { + arg + } + + fn handle_number(&self, _fe: &T) -> usize { + 0 + } +} +impl TerminalAccess for &DegreeCallback { + fn get(&self, _poly_ref: &AlgebraicReference) -> usize { + 1 + } + + fn get_public(&self, _public: &str) -> usize { + 0 + } + + fn get_challenge(&self, _challenge: &Challenge) -> usize { + 0 + } +} + impl AlgebraicExpression { /// Returns an iterator over all (top-level) expressions in this expression. /// This specifically does not implement Children because otherwise it would @@ -1428,42 +1506,14 @@ impl AlgebraicExpression { } /// Returns the degree of the expressions - pub fn degree_with_cache( + pub fn degree( &self, intermediate_definitions: &BTreeMap>, - cache: &mut BTreeMap, ) -> usize { - match self { - AlgebraicExpression::Reference(reference) => match reference.poly_id.ptype { - PolynomialType::Committed | PolynomialType::Constant => 1, - PolynomialType::Intermediate => { - let reference = reference.to_thin(); - cache.get(&reference).cloned().unwrap_or_else(|| { - let def = intermediate_definitions - .get(&reference) - .expect("Intermediate definition not found."); - let result = def.degree_with_cache(intermediate_definitions, cache); - cache.insert(reference, result); - result - }) - } - }, - // Multiplying two expressions adds their degrees - AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { - op: AlgebraicBinaryOperator::Mul, - left, - right, - }) => { - left.degree_with_cache(intermediate_definitions, cache) - + right.degree_with_cache(intermediate_definitions, cache) - } - // In all other cases, we take the maximum of the degrees of the children - _ => self - .children() - .map(|e| e.degree_with_cache(intermediate_definitions, cache)) - .max() - .unwrap_or(0), - } + let callback = DegreeCallback(); + let mut expression_walker = + ExpressionWalker::new(&callback, intermediate_definitions, &callback); + expression_walker.evaluate(self) } } @@ -1731,33 +1781,30 @@ mod tests { // No intermediates let intermediate_definitions = Default::default(); - let mut cache = Default::default(); let expr = one.clone() + one.clone() * one.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 0); let expr = column.clone() + one.clone() * one.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 1); let expr = column.clone() + one.clone() * column.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 1); let expr = column.clone() + column.clone() * column.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 2); let expr = column.clone() + column.clone() * (column.clone() + one.clone()); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 2); let expr = column.clone() * column.clone() * column.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 3); - - assert!(cache.is_empty()); } #[test] @@ -1787,27 +1834,26 @@ mod tests { let intermediate_definitions = [(column_squared_ref.to_thin(), column_squared.clone())] .into_iter() .collect(); - let mut cache = Default::default(); let expr = column_squared_intermediate.clone() + one.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 2); let expr = column_squared_intermediate.clone() + column.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 2); let expr = column_squared_intermediate.clone() * column_squared_intermediate.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 4); let expr = column_squared_intermediate.clone() * (column_squared_intermediate.clone() + one.clone()); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 4); let expr = column_squared_intermediate.clone() * column.clone(); - let degree = expr.degree_with_cache(&intermediate_definitions, &mut cache); + let degree = expr.degree(&intermediate_definitions); assert_eq!(degree, 3); } diff --git a/backend/src/composite/mod.rs b/backend/src/composite/mod.rs index e202c78a7f..c420c0a80d 100644 --- a/backend/src/composite/mod.rs +++ b/backend/src/composite/mod.rs @@ -178,7 +178,10 @@ fn log_machine_stats(machine_name: &str, pil: &Analyzed) { .map(|i| i.degree(&intermediate_definitions)) .max() .unwrap_or(0); - let uses_next_operator = pil.identities.iter().any(|i| i.contains_next_ref()); + let uses_next_operator = pil + .identities + .iter() + .any(|i| i.contains_next_ref(&intermediate_definitions)); // This assumes that we'll always at least once reference the current row let number_of_rotations = 1 + if uses_next_operator { 1 } else { 0 }; let num_identities_by_kind = pil diff --git a/backend/src/halo2/circuit_builder.rs b/backend/src/halo2/circuit_builder.rs index 7cf0d5742a..c739e52236 100644 --- a/backend/src/halo2/circuit_builder.rs +++ b/backend/src/halo2/circuit_builder.rs @@ -15,10 +15,10 @@ use powdr_executor::witgen::WitgenCallback; use powdr_ast::analyzed::Analyzed; use powdr_ast::analyzed::{ + expression_evaluator::{ExpressionEvaluator, TerminalAccess}, AlgebraicExpression, AlgebraicReferenceThin, Identity, PolynomialIdentity, PolynomialType, SelectedExpressions, }; -use powdr_executor_utils::expression_evaluator::{ExpressionEvaluator, GlobalValues, TraceValues}; use powdr_number::FieldElement; const FIRST_STEP_NAME: &str = "__first_step"; @@ -553,8 +553,8 @@ impl<'a, F: PrimeField> Data<'a, '_, '_, F> { fn evaluator( &self, intermediate_definitions: &'a BTreeMap>, - ) -> ExpressionEvaluator<'a, T, Expression, &Self, &Self> { - ExpressionEvaluator::new_with_custom_expr(self, self, intermediate_definitions, |n| { + ) -> ExpressionEvaluator<'a, T, Expression, &Self> { + ExpressionEvaluator::new_with_custom_expr(self, intermediate_definitions, |n| { Expression::Constant(convert_field(*n)) }) } @@ -564,7 +564,7 @@ impl<'a, F: PrimeField> Data<'a, '_, '_, F> { } } -impl TraceValues> for &Data<'_, '_, '_, F> { +impl TerminalAccess> for &Data<'_, '_, '_, F> { fn get(&self, poly_ref: &powdr_ast::analyzed::AlgebraicReference) -> Expression { let rotation = match poly_ref.next { false => Rotation::cur(), @@ -578,9 +578,7 @@ impl TraceValues> for &Data<'_, '_, '_, F> { panic!("Unknown reference: {}", poly_ref.name) } } -} -impl GlobalValues> for &Data<'_, '_, '_, F> { fn get_public(&self, _public: &str) -> Expression { unimplemented!() } diff --git a/backend/src/mock/connection_constraint_checker.rs b/backend/src/mock/connection_constraint_checker.rs index e6e77da6b5..0985a5eb40 100644 --- a/backend/src/mock/connection_constraint_checker.rs +++ b/backend/src/mock/connection_constraint_checker.rs @@ -4,19 +4,14 @@ use std::fmt; use std::ops::ControlFlow; use itertools::Itertools; -use powdr_ast::analyzed::AlgebraicExpression; -use powdr_ast::analyzed::AlgebraicReference; -use powdr_ast::analyzed::Analyzed; use powdr_ast::analyzed::{ - Identity, LookupIdentity, PermutationIdentity, PhantomLookupIdentity, - PhantomPermutationIdentity, SelectedExpressions, + expression_evaluator::{ExpressionEvaluator, TerminalAccess}, + AlgebraicExpression, Analyzed, Identity, LookupIdentity, PermutationIdentity, + PhantomLookupIdentity, PhantomPermutationIdentity, SelectedExpressions, }; use powdr_ast::parsed::visitor::ExpressionVisitable; use powdr_ast::parsed::visitor::VisitOrder; use powdr_backend_utils::referenced_namespaces_algebraic_expression; -use powdr_executor_utils::expression_evaluator::ExpressionEvaluator; -use powdr_executor_utils::expression_evaluator::OwnedGlobalValues; -use powdr_executor_utils::expression_evaluator::TraceValues; use powdr_number::FieldElement; use rayon::iter::IntoParallelIterator; use rayon::iter::ParallelIterator; @@ -152,7 +147,6 @@ impl Connection { pub struct ConnectionConstraintChecker<'a, F: FieldElement> { connections: &'a [Connection], machines: BTreeMap>, - global_values: OwnedGlobalValues, } impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> { @@ -160,17 +154,10 @@ impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> { pub fn new( connections: &'a [Connection], machines: BTreeMap>, - challenges: &'a BTreeMap, ) -> Self { - let global_values = OwnedGlobalValues { - // TODO: Support publics. - public_values: BTreeMap::new(), - challenge_values: challenges.clone(), - }; Self { connections, machines, - global_values, } } } @@ -276,8 +263,7 @@ impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> { .into_par_iter() .filter_map(|row| { let mut evaluator = ExpressionEvaluator::new( - machine.trace_values.row(row), - &self.global_values, + machine.values.row(row), &machine.intermediate_definitions, ); let result = evaluator.evaluate(&selected_expressions.selector); @@ -300,9 +286,7 @@ impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> { None => { let empty_variables = EmptyVariables {}; let empty_definitions = BTreeMap::new(); - let empty_globals = OwnedGlobalValues::default(); - let mut evaluator = - ExpressionEvaluator::new(empty_variables, &empty_globals, &empty_definitions); + let mut evaluator = ExpressionEvaluator::new(empty_variables, &empty_definitions); let selector_value: F = evaluator.evaluate(&selected_expressions.selector); match selector_value.to_degree() { @@ -339,14 +323,7 @@ impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> { struct EmptyVariables; -impl TraceValues for EmptyVariables -where - T: FieldElement, -{ - fn get(&self, _reference: &AlgebraicReference) -> T { - panic!() - } -} +impl TerminalAccess for EmptyVariables {} /// Converts a slice to a multi-set, represented as a map from elements to their count. fn to_multi_set(a: &[T]) -> BTreeMap<&T, usize> { diff --git a/backend/src/mock/machine.rs b/backend/src/mock/machine.rs index b5ce39e1f6..ba6e6acc63 100644 --- a/backend/src/mock/machine.rs +++ b/backend/src/mock/machine.rs @@ -1,17 +1,20 @@ use std::collections::BTreeMap; use itertools::Itertools; -use powdr_ast::analyzed::{AlgebraicExpression, AlgebraicReferenceThin, Analyzed}; +use powdr_ast::analyzed::{ + expression_evaluator::OwnedTerminalValues, AlgebraicExpression, AlgebraicReferenceThin, + Analyzed, +}; use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns}; use powdr_executor::constant_evaluator::VariablySizedColumn; -use powdr_executor_utils::{expression_evaluator::OwnedTraceValues, WitgenCallback}; +use powdr_executor_utils::WitgenCallback; use powdr_number::{DegreeType, FieldElement}; /// A collection of columns with self-contained constraints. pub struct Machine<'a, F> { pub machine_name: String, pub size: usize, - pub trace_values: OwnedTraceValues, + pub values: OwnedTerminalValues, pub pil: &'a Analyzed, pub intermediate_definitions: BTreeMap>, } @@ -55,12 +58,14 @@ impl<'a, F: FieldElement> Machine<'a, F> { let intermediate_definitions = pil.intermediate_definitions(); - let trace_values = OwnedTraceValues::new(pil, witness, fixed); + // TODO: Supports publics. + let values = + OwnedTerminalValues::new(pil, witness, fixed).with_challenges(challenges.clone()); Some(Self { machine_name, size, - trace_values, + values, pil, intermediate_definitions, }) diff --git a/backend/src/mock/mod.rs b/backend/src/mock/mod.rs index 9d58d0036e..9c0a905eba 100644 --- a/backend/src/mock/mod.rs +++ b/backend/src/mock/mod.rs @@ -117,14 +117,13 @@ impl Backend for MockBackend { ); } - let is_ok = - machines.values().all(|machine| { - !PolynomialConstraintChecker::new(machine, &challenges) - .check() - .has_errors() - }) && ConnectionConstraintChecker::new(&self.connections, machines, &challenges) + let is_ok = machines.values().all(|machine| { + !PolynomialConstraintChecker::new(machine) .check() - .is_ok(); + .has_errors() + }) && ConnectionConstraintChecker::new(&self.connections, machines) + .check() + .is_ok(); match is_ok { true => Ok(Vec::new()), diff --git a/backend/src/mock/polynomial_constraint_checker.rs b/backend/src/mock/polynomial_constraint_checker.rs index 325fc6580f..9eec42dabb 100644 --- a/backend/src/mock/polynomial_constraint_checker.rs +++ b/backend/src/mock/polynomial_constraint_checker.rs @@ -1,10 +1,12 @@ use std::{collections::BTreeMap, fmt}; use powdr_ast::{ - analyzed::{AlgebraicExpression, Identity, PolynomialIdentity}, + analyzed::{ + expression_evaluator::ExpressionEvaluator, AlgebraicExpression, Identity, + PolynomialIdentity, + }, parsed::visitor::AllChildren, }; -use powdr_executor_utils::expression_evaluator::{ExpressionEvaluator, OwnedGlobalValues}; use powdr_number::FieldElement; use rayon::iter::{IntoParallelIterator, ParallelIterator}; @@ -12,20 +14,11 @@ use super::machine::Machine; pub struct PolynomialConstraintChecker<'a, F> { machine: &'a Machine<'a, F>, - global_values: OwnedGlobalValues, } impl<'a, F: FieldElement> PolynomialConstraintChecker<'a, F> { - pub fn new(machine: &'a Machine<'a, F>, challenges: &'a BTreeMap) -> Self { - let global_values = OwnedGlobalValues { - // TODO: Support publics - public_values: Default::default(), - challenge_values: challenges.clone(), - }; - Self { - machine, - global_values, - } + pub fn new(machine: &'a Machine<'a, F>) -> Self { + Self { machine } } pub fn check(&self) -> MachineResult<'a, F> { @@ -59,8 +52,7 @@ impl<'a, F: FieldElement> PolynomialConstraintChecker<'a, F> { identities: &[&'a Identity], ) -> Vec> { let mut evaluator = ExpressionEvaluator::new( - self.machine.trace_values.row(row), - &self.global_values, + self.machine.values.row(row), &self.machine.intermediate_definitions, ); identities diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index bf5c0bbd58..1121c62240 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -1,12 +1,14 @@ use core::unreachable; use powdr_ast::parsed::visitor::AllChildren; -use powdr_executor_utils::expression_evaluator::{ExpressionEvaluator, GlobalValues, TraceValues}; use std::collections::HashSet; use std::sync::Arc; extern crate alloc; use alloc::collections::btree_map::BTreeMap; -use powdr_ast::analyzed::{AlgebraicExpression, AlgebraicReference, Analyzed, Challenge, Identity}; +use powdr_ast::analyzed::{ + expression_evaluator::{ExpressionEvaluator, TerminalAccess}, + AlgebraicExpression, AlgebraicReference, Analyzed, Challenge, Identity, +}; use powdr_number::FieldElement; use powdr_ast::analyzed::{PolyID, PolynomialType}; @@ -100,7 +102,7 @@ struct Data<'a, F> { constant_eval: &'a BTreeMap, } -impl TraceValues for &Data<'_, F> { +impl TerminalAccess for &Data<'_, F> { fn get(&self, poly_ref: &AlgebraicReference) -> F { match poly_ref.poly_id.ptype { PolynomialType::Committed => match poly_ref.next { @@ -114,12 +116,11 @@ impl TraceValues for &Data<'_, F> { PolynomialType::Intermediate => unreachable!(), } } -} -impl GlobalValues for &Data<'_, F> { fn get_public(&self, _public: &str) -> F { unimplemented!("Public references are not supported in stwo yet") } + fn get_challenge(&self, _challenge: &Challenge) -> F { unimplemented!("challenges are not supported in stwo yet") } @@ -182,12 +183,10 @@ impl FrameworkEval for PowdrEval { constant_shifted_eval: &constant_shifted_eval, constant_eval: &constant_eval, }; - let mut evaluator = ExpressionEvaluator::new_with_custom_expr( - &data, - &data, - &intermediate_definitions, - |v| E::F::from(M31::from(v.try_into_i32().unwrap())), - ); + let mut evaluator = + ExpressionEvaluator::new_with_custom_expr(&data, &intermediate_definitions, |v| { + E::F::from(M31::from(v.try_into_i32().unwrap())) + }); for id in &self.analyzed.identities { match id { diff --git a/executor-utils/src/expression_evaluator.rs b/executor-utils/src/expression_evaluator.rs deleted file mode 100644 index 2649f56b63..0000000000 --- a/executor-utils/src/expression_evaluator.rs +++ /dev/null @@ -1,189 +0,0 @@ -use core::ops::{Add, Mul, Sub}; -use std::collections::BTreeMap; - -use powdr_ast::analyzed::{ - AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression as Expression, - AlgebraicReference, AlgebraicReferenceThin, AlgebraicUnaryOperation, AlgebraicUnaryOperator, - Analyzed, Challenge, PolyID, PolynomialType, -}; -use powdr_number::FieldElement; -use powdr_number::LargeInt; - -/// Accessor for trace values. -pub trait TraceValues { - fn get(&self, poly_ref: &AlgebraicReference) -> T; -} - -/// Accessor for global values. -pub trait GlobalValues { - fn get_public(&self, public: &str) -> T; - fn get_challenge(&self, challenge: &Challenge) -> T; -} - -/// A simple container for trace values. -pub struct OwnedTraceValues { - pub values: BTreeMap>, -} - -/// A view into the trace values for a single row. -pub struct RowTraceValues<'a, T> { - trace: &'a OwnedTraceValues, - row: usize, -} - -impl OwnedTraceValues { - pub fn new( - pil: &Analyzed, - witness_columns: Vec<(String, Vec)>, - fixed_columns: Vec<(String, Vec)>, - ) -> Self { - let mut columns_by_name = witness_columns - .into_iter() - .chain(fixed_columns) - .collect::>(); - let values = pil - .committed_polys_in_source_order() - .chain(pil.constant_polys_in_source_order()) - .flat_map(|(symbol, _)| symbol.array_elements()) - .filter_map(|(name, poly_id)| { - columns_by_name - .remove(&name) - .map(|column| (poly_id, column)) - }) - .collect(); - Self { values } - } - - pub fn height(&self) -> usize { - self.values.values().next().map(|v| v.len()).unwrap() - } - - pub fn row(&self, row: usize) -> RowTraceValues { - RowTraceValues { trace: self, row } - } -} - -impl TraceValues for RowTraceValues<'_, F> { - fn get(&self, column: &AlgebraicReference) -> F { - match column.poly_id.ptype { - PolynomialType::Committed | PolynomialType::Constant => { - let column_values = self.trace.values.get(&column.poly_id).unwrap(); - let row = (self.row + column.next as usize) % column_values.len(); - column_values[row] - } - PolynomialType::Intermediate => unreachable!( - "Intermediate polynomials should have been handled by ExpressionEvaluator" - ), - } - } -} - -#[derive(Default)] -pub struct OwnedGlobalValues { - pub public_values: BTreeMap, - pub challenge_values: BTreeMap, -} - -impl GlobalValues for &OwnedGlobalValues { - fn get_public(&self, public: &str) -> T { - self.public_values[public].clone() - } - - fn get_challenge(&self, challenge: &Challenge) -> T { - self.challenge_values[&challenge.id].clone() - } -} - -/// Evaluates an algebraic expression to a value. -pub struct ExpressionEvaluator<'a, T, Expr, TV, GV> { - trace_values: TV, - global_values: GV, - intermediate_definitions: &'a BTreeMap>, - /// Maps intermediate reference to their evaluation. Updated throughout the lifetime of the - /// ExpressionEvaluator. - intermediates_cache: BTreeMap, - to_expr: fn(&T) -> Expr, -} - -impl<'a, T, TV, GV> ExpressionEvaluator<'a, T, T, TV, GV> -where - TV: TraceValues, - GV: GlobalValues, - T: FieldElement, -{ - /// Create a new expression evaluator (for the case where Expr = T). - pub fn new( - trace_values: TV, - global_values: GV, - intermediate_definitions: &'a BTreeMap>, - ) -> Self { - Self::new_with_custom_expr(trace_values, global_values, intermediate_definitions, |x| { - *x - }) - } -} - -impl<'a, T, Expr, TV, GV> ExpressionEvaluator<'a, T, Expr, TV, GV> -where - TV: TraceValues, - GV: GlobalValues, - Expr: Clone + Add + Sub + Mul, - T: FieldElement, -{ - /// Create a new expression evaluator with custom expression converters. - pub fn new_with_custom_expr( - trace_values: TV, - global_values: GV, - intermediate_definitions: &'a BTreeMap>, - to_expr: fn(&T) -> Expr, - ) -> Self { - Self { - trace_values, - global_values, - intermediate_definitions, - intermediates_cache: Default::default(), - to_expr, - } - } - - pub fn evaluate(&mut self, expr: &'a Expression) -> Expr { - match expr { - Expression::Reference(reference) => match reference.poly_id.ptype { - PolynomialType::Committed => self.trace_values.get(reference), - PolynomialType::Constant => self.trace_values.get(reference), - PolynomialType::Intermediate => { - let reference = reference.to_thin(); - let value = self.intermediates_cache.get(&reference).cloned(); - match value { - Some(v) => v, - None => { - let definition = self.intermediate_definitions.get(&reference).unwrap(); - let result = self.evaluate(definition); - self.intermediates_cache.insert(reference, result.clone()); - result - } - } - } - }, - Expression::PublicReference(_public) => unimplemented!(), - Expression::Number(n) => (self.to_expr)(n), - Expression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => match op { - AlgebraicBinaryOperator::Add => self.evaluate(left) + self.evaluate(right), - AlgebraicBinaryOperator::Sub => self.evaluate(left) - self.evaluate(right), - AlgebraicBinaryOperator::Mul => self.evaluate(left) * self.evaluate(right), - AlgebraicBinaryOperator::Pow => match &**right { - Expression::Number(n) => { - let left = self.evaluate(left); - (0u32..n.to_integer().try_into_u32().unwrap()) - .fold((self.to_expr)(&T::one()), |acc, _| acc * left.clone()) - } - _ => unimplemented!("pow with non-constant exponent"), - }, - }, - Expression::UnaryOperation(AlgebraicUnaryOperation { op, expr }) => match op { - AlgebraicUnaryOperator::Minus => self.evaluate(expr), - }, - Expression::Challenge(challenge) => self.global_values.get_challenge(challenge), - } - } -} diff --git a/executor-utils/src/lib.rs b/executor-utils/src/lib.rs index d9d6b4e307..3fafdbc35b 100644 --- a/executor-utils/src/lib.rs +++ b/executor-utils/src/lib.rs @@ -6,8 +6,6 @@ use std::sync::Arc; use powdr_number::{DegreeType, FieldElement}; -pub mod expression_evaluator; - /// A callback that computes an updated witness, given: /// - The PIL for the current machine. /// - The current witness. diff --git a/executor/src/witgen/bus_accumulator/mod.rs b/executor/src/witgen/bus_accumulator/mod.rs index 312a243de1..8d5059d30a 100644 --- a/executor/src/witgen/bus_accumulator/mod.rs +++ b/executor/src/witgen/bus_accumulator/mod.rs @@ -3,11 +3,11 @@ use std::collections::{BTreeMap, BTreeSet}; use fp2::Fp2; use itertools::Itertools; use num_traits::{One, Zero}; -use powdr_ast::analyzed::{Analyzed, Identity, PhantomBusInteractionIdentity}; -use powdr_executor_utils::{ - expression_evaluator::{ExpressionEvaluator, OwnedGlobalValues, OwnedTraceValues}, - VariablySizedColumn, +use powdr_ast::analyzed::{ + expression_evaluator::{ExpressionEvaluator, OwnedTerminalValues}, + Analyzed, Identity, PhantomBusInteractionIdentity, }; +use powdr_executor_utils::VariablySizedColumn; use powdr_number::{DegreeType, FieldElement}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; @@ -17,7 +17,7 @@ mod fp2; pub struct BusAccumulatorGenerator<'a, T> { pil: &'a Analyzed, bus_interactions: Vec<&'a PhantomBusInteractionIdentity>, - trace_values: OwnedTraceValues, + values: OwnedTerminalValues, powers_of_alpha: Vec>, beta: Fp2, } @@ -44,14 +44,6 @@ impl<'a, T: FieldElement> BusAccumulatorGenerator<'a, T> { .filter(|(n, _)| fixed_column_names.contains(n)) .map(|(n, v)| (n.clone(), v.get_by_size(size).unwrap())); - let trace_values = OwnedTraceValues::new( - pil, - witness_columns.to_vec(), - fixed_columns - .map(|(name, values)| (name, values.to_vec())) - .collect(), - ); - let bus_interactions = pil .identities .iter() @@ -71,10 +63,19 @@ impl<'a, T: FieldElement> BusAccumulatorGenerator<'a, T> { let beta = Fp2::new(challenges[&3], challenges[&4]); let powers_of_alpha = powers_of_alpha(alpha, max_tuple_size); + let values = OwnedTerminalValues::new( + pil, + witness_columns.to_vec(), + fixed_columns + .map(|(name, values)| (name, values.to_vec())) + .collect(), + ) + .with_challenges(challenges); + Self { pil, bus_interactions, - trace_values, + values, powers_of_alpha, beta, } @@ -100,20 +101,16 @@ impl<'a, T: FieldElement> BusAccumulatorGenerator<'a, T> { bus_interaction: &PhantomBusInteractionIdentity, ) -> Vec> { let intermediate_definitions = self.pil.intermediate_definitions(); - let empty_globals = OwnedGlobalValues::default(); - let size = self.trace_values.height(); + let size = self.values.height(); let mut folded1 = vec![T::zero(); size]; let mut folded2 = vec![T::zero(); size]; let mut acc1 = vec![T::zero(); size]; let mut acc2 = vec![T::zero(); size]; for i in 0..size { - let mut evaluator = ExpressionEvaluator::new( - self.trace_values.row(i), - &empty_globals, - &intermediate_definitions, - ); + let mut evaluator = + ExpressionEvaluator::new(self.values.row(i), &intermediate_definitions); let current_acc = if i == 0 { Fp2::zero() } else { diff --git a/executor/src/witgen/jit/single_step_processor.rs b/executor/src/witgen/jit/single_step_processor.rs index 55a704cd87..5ee98ec5e8 100644 --- a/executor/src/witgen/jit/single_step_processor.rs +++ b/executor/src/witgen/jit/single_step_processor.rs @@ -153,7 +153,12 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> { continue; } // TODO this is wrong if intermediate columns are referenced. - let row_offset = if id.contains_next_ref() { 0 } else { 1 }; + let row_offset = if id.contains_next_ref(&self.fixed_data.intermediate_definitions) + { + 0 + } else { + 1 + }; let result = witgen.process_identity(can_process.clone(), id, row_offset); progress |= result.progress; if result.complete { diff --git a/executor/src/witgen/machines/mod.rs b/executor/src/witgen/machines/mod.rs index 88d3e0ace1..2c662ff38e 100644 --- a/executor/src/witgen/machines/mod.rs +++ b/executor/src/witgen/machines/mod.rs @@ -456,7 +456,11 @@ impl<'a, T: FieldElement> MachineParts<'a, T> { let identities_with_next_reference = self .identities .iter() - .filter_map(|identity| identity.contains_next_ref().then_some(*identity)) + .filter_map(|identity| { + identity + .contains_next_ref(&self.fixed_data.intermediate_definitions) + .then_some(*identity) + }) .collect::>(); Self { identities: identities_with_next_reference, diff --git a/executor/src/witgen/processor.rs b/executor/src/witgen/processor.rs index 32a7fd4eab..55db2a5e47 100644 --- a/executor/src/witgen/processor.rs +++ b/executor/src/witgen/processor.rs @@ -334,7 +334,7 @@ Known values in current row (local: {row_index}, global {global_row_index}): ", self.data[row_index].render_values(false, self.parts) ); - if identity.contains_next_ref() { + if identity.contains_next_ref(&self.fixed_data.intermediate_definitions) { error += &format!( "Known values in next row (local: {}, global {}):\n{}\n", row_index + 1, diff --git a/executor/src/witgen/vm_processor.rs b/executor/src/witgen/vm_processor.rs index adc44c1d23..d64914be75 100644 --- a/executor/src/witgen/vm_processor.rs +++ b/executor/src/witgen/vm_processor.rs @@ -90,7 +90,7 @@ impl<'a, 'c, T: FieldElement, Q: QueryCallback> VmProcessor<'a, 'c, T, Q> { let (identities_with_next, identities_without_next): (Vec<_>, Vec<_>) = parts .identities .iter() - .partition(|identity| identity.contains_next_ref()); + .partition(|identity| identity.contains_next_ref(&fixed_data.intermediate_definitions)); let processor = Processor::new( row_offset, mutable_data, diff --git a/plonky3/src/circuit_builder.rs b/plonky3/src/circuit_builder.rs index 8553b0844e..8110f598eb 100644 --- a/plonky3/src/circuit_builder.rs +++ b/plonky3/src/circuit_builder.rs @@ -23,6 +23,7 @@ use crate::{ use p3_air::{Air, BaseAir, PairBuilder}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use powdr_ast::analyzed::{ + expression_evaluator::{ExpressionEvaluator, TerminalAccess}, AlgebraicExpression, AlgebraicReference, AlgebraicReferenceThin, Analyzed, Challenge, Identity, PolyID, PolynomialType, }; @@ -30,10 +31,7 @@ use powdr_ast::analyzed::{ use crate::{CallbackResult, MultiStageAir, MultistageAirBuilder}; use powdr_ast::parsed::visitor::ExpressionVisitable; -use powdr_executor_utils::{ - expression_evaluator::{ExpressionEvaluator, GlobalValues, TraceValues}, - WitgenCallback, -}; +use powdr_executor_utils::WitgenCallback; use powdr_number::FieldElement; /// A description of the constraint system. @@ -262,7 +260,7 @@ struct Data<'a, T, AB: MultistageAirBuilder> { challenges: &'a [BTreeMap<&'a u64, ::Challenge>], } -impl TraceValues for &Data<'_, T, AB> { +impl TerminalAccess for &Data<'_, T, AB> { fn get(&self, reference: &AlgebraicReference) -> AB::Expr { match reference.poly_id.ptype { PolynomialType::Committed => { @@ -277,9 +275,7 @@ impl TraceValues for &Data<'_, T, AB> { PolynomialType::Intermediate => unreachable!(), } } -} -impl GlobalValues for &Data<'_, T, AB> { fn get_challenge(&self, challenge: &Challenge) -> AB::Expr { self.challenges[challenge.stage as usize][&challenge.id] .clone() @@ -351,7 +347,6 @@ where fixed: &fixed, }; let mut evaluator = ExpressionEvaluator::new_with_custom_expr( - &data, &data, &self.constraint_system.intermediates, |value| AB::Expr::from(value.into_p3_field()),