Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pilopt: Optimize associative ADD operations #2419

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,14 @@ fn simplify_expression_single<T: FieldElement>(e: &mut AlgebraicExpression<T>) {
return;
}
}

if let AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) = e {
if let Some(simplified) = try_simplify_associative_operation(left, right, *op) {
*e = simplified;
return;
}
}

match e {
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left,
Expand Down Expand Up @@ -427,6 +435,81 @@ fn simplify_expression_single<T: FieldElement>(e: &mut AlgebraicExpression<T>) {
}
}

fn try_simplify_associative_operation<T: FieldElement>(
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
left: &AlgebraicExpression<T>,
right: &AlgebraicExpression<T>,
op: AlgebraicBinaryOperator,
) -> Option<AlgebraicExpression<T>> {
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
if op != AlgebraicBinaryOperator::Add {
return None;
}

// Find binary operation and other expression, handling both orderings:
// (X + C1) + Other
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
// Other + (X + C1)
let (x1, x2, other_expr) = match (left, right) {
(
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left: x1,
right: x2,
op: AlgebraicBinaryOperator::Add,
}),
other,
) => (x1, x2, other),
(
other,
AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation {
left: x1,
right: x2,
op: AlgebraicBinaryOperator::Add,
}),
) => (x1, x2, other),
_ => return None,
};

// Extract variable and constant from binary operation, handling both orderings:
// (X + C1) -> (X, C1)
// (C1 + X) -> (X, C1)
let (x, c1_val) = if let AlgebraicExpression::Number(val) = x1.as_ref() {
(x2, val)
} else if let AlgebraicExpression::Number(val) = x2.as_ref() {
(x1, val)
} else {
return None;
};

match other_expr {
// Case 1: Combining with a constant
// (X + C1) + C2 -> X + (C1 + C2)
AlgebraicExpression::Number(c2) => {
let result = *c1_val + *c2;
Some(AlgebraicExpression::BinaryOperation(
AlgebraicBinaryOperation {
left: x.clone(),
op: AlgebraicBinaryOperator::Add,
right: Box::new(AlgebraicExpression::Number(result)),
},
))
}

// Case 2: Combining with any non-numeric expression
// (X + C1) + Y -> (X + Y) + C1
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
y => Some(AlgebraicExpression::BinaryOperation(
AlgebraicBinaryOperation {
left: Box::new(AlgebraicExpression::BinaryOperation(
AlgebraicBinaryOperation {
left: x.clone(),
op: AlgebraicBinaryOperator::Add,
right: Box::new(y.clone()),
},
)),
op: AlgebraicBinaryOperator::Add,
right: Box::new(AlgebraicExpression::Number(*c1_val)),
},
)),
}
}

/// Extracts columns from lookups that are matched against constants and turns
/// them into polynomial identities.
fn extract_constant_lookups<T: FieldElement>(pil_file: &mut Analyzed<T>) {
Expand Down
36 changes: 36 additions & 0 deletions pilopt/tests/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,39 @@ fn equal_constrained_transitive() {
let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}
#[test]
fn simplify_associative_operations() {
let input = r#"namespace N(150);
col witness x;
col witness y;
col witness z;
col fixed c1 = [1]*;
col fixed c2 = [2]*;
col fixed c3 = [3]*;

(x + c2) + c1 = y;
(c2 + x) + c3 = y;
(x - c2) + c1 = y;

((x + 3) - y) - 9 = z;
(c3 + (x + 3)) - y = z;
((-x + 3) + y) + 9 = z;
((-x + 3) + c3) + 12 = z;
"#;

let expectation = r#"namespace N(150);
col witness x;
col witness y;
col witness z;
N::x + 3 = N::y;
N::x + 5 = N::y;
N::x - 2 + 1 = N::y;
gzanitti marked this conversation as resolved.
Show resolved Hide resolved
N::x + 3 - N::y - 9 = N::z;
N::x + 6 - N::y = N::z;
-N::x + N::y + 12 = N::z;
-N::x + 18 = N::z;
"#;

let optimized = optimize(analyze_string::<GoldilocksField>(input).unwrap()).to_string();
assert_eq!(optimized, expectation);
}
Loading