Skip to content

Commit

Permalink
pilopt: optimize until fixpoint (#2225)
Browse files Browse the repository at this point in the history
Co-authored-by: chriseth <[email protected]>
  • Loading branch information
gzanitti and chriseth authored Jan 8, 2025
1 parent 707522d commit 067b633
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 58 deletions.
49 changes: 30 additions & 19 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::parsed::{
TraitDeclaration, TraitImplementation, TypeDeclaration,
};

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
pub enum StatementIdentifier {
/// Either an intermediate column or a definition.
Definition(String),
Expand Down Expand Up @@ -685,7 +685,7 @@ impl DegreeRange {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct Symbol {
pub id: u64,
pub source: SourceRef,
Expand Down Expand Up @@ -745,7 +745,7 @@ impl Symbol {
/// The "kind" of a symbol. In the future, this will be mostly
/// replaced by its type.
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema,
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum SymbolKind {
/// Fixed, witness or intermediate polynomial
Expand Down Expand Up @@ -815,7 +815,7 @@ impl Children<Expression> for NamedType {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PublicDeclaration {
pub id: u64,
pub source: SourceRef,
Expand All @@ -835,7 +835,9 @@ impl PublicDeclaration {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct SelectedExpressions<T> {
pub selector: AlgebraicExpression<T>,
pub expressions: Vec<AlgebraicExpression<T>>,
Expand All @@ -861,7 +863,7 @@ impl<T> Children<AlgebraicExpression<T>> for SelectedExpressions<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PolynomialIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -878,7 +880,7 @@ impl<T> Children<AlgebraicExpression<T>> for PolynomialIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct LookupIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -900,7 +902,7 @@ impl<T> Children<AlgebraicExpression<T>> for LookupIdentity<T> {
///
/// This identity is used as a replacement for a lookup identity which has been turned into challenge-based polynomial identities.
/// This is ignored by the backend.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomLookupIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand Down Expand Up @@ -929,7 +931,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomLookupIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PermutationIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -951,7 +953,7 @@ impl<T> Children<AlgebraicExpression<T>> for PermutationIdentity<T> {
///
/// This identity is used as a replacement for a permutation identity which has been turned into challenge-based polynomial identities.
/// This is ignored by the backend.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomPermutationIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -969,7 +971,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomPermutationIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct ConnectIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -987,7 +989,9 @@ impl<T> Children<AlgebraicExpression<T>> for ConnectIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, PartialOrd, Ord)]
#[derive(
Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, PartialOrd, Ord, Hash,
)]
pub struct ExpressionList<T>(pub Vec<AlgebraicExpression<T>>);

impl<T> Children<AlgebraicExpression<T>> for ExpressionList<T> {
Expand All @@ -999,7 +1003,7 @@ impl<T> Children<AlgebraicExpression<T>> for ExpressionList<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomBusInteractionIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand Down Expand Up @@ -1034,6 +1038,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomBusInteractionIdentity<T> {
Serialize,
Deserialize,
JsonSchema,
Hash,
derive_more::Display,
derive_more::From,
derive_more::TryInto,
Expand Down Expand Up @@ -1235,7 +1240,9 @@ impl Hash for AlgebraicReference {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicExpression<T> {
Reference(AlgebraicReference),
PublicReference(String),
Expand All @@ -1245,7 +1252,9 @@ pub enum AlgebraicExpression<T> {
UnaryOperation(AlgebraicUnaryOperation<T>),
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct AlgebraicBinaryOperation<T> {
pub left: Box<AlgebraicExpression<T>>,
pub op: AlgebraicBinaryOperator,
Expand All @@ -1271,7 +1280,9 @@ impl<T> From<AlgebraicBinaryOperation<T>> for AlgebraicExpression<T> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct AlgebraicUnaryOperation<T> {
pub op: AlgebraicUnaryOperator,
pub expr: Box<AlgebraicExpression<T>>,
Expand Down Expand Up @@ -1468,7 +1479,7 @@ impl<T> AlgebraicExpression<T> {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct Challenge {
/// Challenge ID
Expand All @@ -1477,7 +1488,7 @@ pub struct Challenge {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicBinaryOperator {
Add,
Expand Down Expand Up @@ -1515,7 +1526,7 @@ impl TryFrom<BinaryOperator> for AlgebraicBinaryOperator {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicUnaryOperator {
Minus,
Expand Down
4 changes: 3 additions & 1 deletion ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ impl<R> Children<Expression<R>> for EnumVariant<Expression<R>> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct TraitImplementation<Expr> {
pub name: SymbolPath,
pub source_ref: SourceRef,
Expand Down
23 changes: 0 additions & 23 deletions backend/src/plonky3/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ mod tests {
Commitment<F>: Send,
{
let mut pipeline = Pipeline::<F>::default().from_pil_string(pil.to_string());

let pil = pipeline.compute_optimized_pil().unwrap();
let witness_callback = pipeline.witgen_callback().unwrap();
let witness = &mut pipeline.compute_witness().unwrap();
Expand Down Expand Up @@ -465,28 +464,6 @@ mod tests {
run_test(content);
}

#[test]
fn two_tables() {
// This test is a bit contrived but witgen wouldn't allow a more direct example
let content = r#"
namespace Add(8);
col witness x;
col witness y;
col witness z;
x = 0;
y = 0;
x + y = z;
1 $ [ x, y, z ] in 1 $ [ Mul::x, Mul::y, Mul::z ];
namespace Mul(16);
col witness x;
col witness y;
col witness z;
x * y = z;
"#;
run_test(content);
}

#[test]
fn challenge() {
let content = r#"
Expand Down
79 changes: 66 additions & 13 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::hash::{DefaultHasher, Hash, Hasher};

use itertools::Itertools;
use powdr_ast::analyzed::{
AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, AlgebraicReference,
AlgebraicUnaryOperation, AlgebraicUnaryOperator, Analyzed, ConnectIdentity, Expression,
FunctionValueDefinition, Identity, LookupIdentity, PermutationIdentity, PhantomLookupIdentity,
PhantomPermutationIdentity, PolyID, PolynomialIdentity, PolynomialReference, PolynomialType,
Reference, Symbol, SymbolKind,
Reference, StatementIdentifier, Symbol, SymbolKind,
};
use powdr_ast::parsed::types::Type;
use powdr_ast::parsed::visitor::{AllChildren, Children, ExpressionVisitable};
Expand All @@ -22,17 +23,25 @@ use referenced_symbols::{ReferencedSymbols, SymbolReference};

pub fn optimize<T: FieldElement>(mut pil_file: Analyzed<T>) -> Analyzed<T> {
let col_count_pre = (pil_file.commitment_count(), pil_file.constant_count());
remove_unreferenced_definitions(&mut pil_file);
remove_constant_fixed_columns(&mut pil_file);
deduplicate_fixed_columns(&mut pil_file);
simplify_identities(&mut pil_file);
extract_constant_lookups(&mut pil_file);
remove_constant_witness_columns(&mut pil_file);
remove_constant_intermediate_columns(&mut pil_file);
simplify_identities(&mut pil_file);
remove_trivial_identities(&mut pil_file);
remove_duplicate_identities(&mut pil_file);
remove_unreferenced_definitions(&mut pil_file);
let mut pil_hash = hash_pil_state(&pil_file);
loop {
remove_unreferenced_definitions(&mut pil_file);
remove_constant_fixed_columns(&mut pil_file);
deduplicate_fixed_columns(&mut pil_file);
simplify_identities(&mut pil_file);
extract_constant_lookups(&mut pil_file);
remove_constant_witness_columns(&mut pil_file);
remove_constant_intermediate_columns(&mut pil_file);
simplify_identities(&mut pil_file);
remove_trivial_identities(&mut pil_file);
remove_duplicate_identities(&mut pil_file);

let new_hash = hash_pil_state(&pil_file);
if pil_hash == new_hash {
break;
}
pil_hash = new_hash;
}
let col_count_post = (pil_file.commitment_count(), pil_file.constant_count());
log::info!(
"Removed {} witness and {} fixed columns. Total count now: {} witness and {} fixed columns.",
Expand All @@ -44,6 +53,43 @@ pub fn optimize<T: FieldElement>(mut pil_file: Analyzed<T>) -> Analyzed<T> {
pil_file
}

fn hash_pil_state<T: Hash>(pil_file: &Analyzed<T>) -> u64 {
let mut hasher = DefaultHasher::new();

for so in &pil_file.source_order {
match so {
StatementIdentifier::Definition(d) => {
d.hash(&mut hasher);
if let Some(def) = pil_file.definitions.get(d) {
def.hash(&mut hasher);
} else if let Some(def) = pil_file.intermediate_columns.get(d) {
def.hash(&mut hasher);
} else {
unreachable!("Missing definition for {:?}", d);
}
}
StatementIdentifier::PublicDeclaration(pd) => {
pd.hash(&mut hasher);
pil_file.public_declarations[pd].hash(&mut hasher);
}
StatementIdentifier::ProofItem(pi) => {
pi.hash(&mut hasher);
pil_file.identities[*pi].hash(&mut hasher);
}
StatementIdentifier::ProverFunction(pf) => {
pf.hash(&mut hasher);
pil_file.prover_functions[*pf].hash(&mut hasher);
}
StatementIdentifier::TraitImplementation(ti) => {
ti.hash(&mut hasher);
pil_file.trait_impls[*ti].hash(&mut hasher);
}
}
}

hasher.finish()
}

/// Removes all definitions that are not referenced by an identity, public declaration
/// or witness column hint.
fn remove_unreferenced_definitions<T: FieldElement>(pil_file: &mut Analyzed<T>) {
Expand Down Expand Up @@ -456,10 +502,17 @@ fn remove_constant_witness_columns<T: FieldElement>(pil_file: &mut Analyzed<T>)
})
.filter_map(constrained_to_constant)
.collect::<Vec<((String, PolyID), _)>>();

let in_publics: HashSet<_> = pil_file
.public_declarations
.values()
.map(|pubd| pubd.polynomial.name.clone())
.collect();
// We cannot remove arrays or array elements, so filter them out.
// Also, we filter out columns that are used in public declarations.
let columns = pil_file
.committed_polys_in_source_order()
.filter(|&(s, _)| (!s.is_array()))
.filter(|&(s, _)| !s.is_array() && !in_publics.contains(&s.absolute_name))
.map(|(s, _)| s.into())
.collect::<HashSet<PolyID>>();
constant_polys.retain(|((_, id), _)| columns.contains(id));
Expand Down
3 changes: 1 addition & 2 deletions pilopt/tests/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ fn replace_intermediate() {
"#;
let expectation = r#"namespace N(65536);
col witness X;
col other_intermediate = 0;
N::X' = N::X + 1 + N::other_intermediate;
N::X' = N::X + 1;
"#;
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
Expand Down

0 comments on commit 067b633

Please sign in to comment.