Skip to content

Commit

Permalink
Fully implement irregular block shapes (#2281)
Browse files Browse the repository at this point in the history
With #2304, JITed block machine witgen functions can already access the
previous row in the last block. With this PR, functions that actually do
that are no longer filtered out.

Changes:
- The accessed cells are now checked properly (not just looking at
assignments)
- Instead of filtering out functions that access outside the bound
`0..block_size`, we *assert* that the accessed cells are in the bounds
`-1..block_size`. This is sufficient in practice (all the tests pass),
although in theory the block machine processor could generate generate
accesses to row `block_size` and even `block_size + 1` (via next
references). I'm not sure if that's worth handling though if it never
happens in practice, so I figured it's best to fail hard if it does
happen, so we know when we do need to handle it.

As a result, we now use the JIT path for 2 machines in the Keccak RISC-V
example:
```
Secondary machine 0: main_binary (BlockMachine): 50462 / 50462 blocks computed via JIT.
201848 of 262144 rows are used in machine 'Secondary machine 0: main_binary (BlockMachine)'.
40128 of 65536 rows are used in machine 'Secondary machine 1: main_memory (DoubleSortedWitnesses32)'.
Secondary machine 2: main_poseidon_gl (BlockMachine): 0 / 1 blocks computed via JIT.
31 of 32 rows are used in machine 'Secondary machine 2: main_poseidon_gl (BlockMachine)'.
361974 of 524288 rows are used in machine 'Secondary machine 4: main_regs (DoubleSortedWitnesses32)'.
Secondary machine 5: main_shift (BlockMachine): 11734 / 11734 blocks computed via JIT.
46936 of 65536 rows are used in machine 'Secondary machine 5: main_shift (BlockMachine)'.
Secondary machine 6: main_split_gl (BlockMachine): 0 / 7 blocks computed via JIT.
56 of 64 rows are used in machine 'Secondary machine 6: main_split_gl (BlockMachine)'.
```

---------

Co-authored-by: chriseth <[email protected]>
  • Loading branch information
georgwiese and chriseth authored Jan 9, 2025
1 parent 40aa582 commit 725fceb
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 22 deletions.
11 changes: 10 additions & 1 deletion executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
let mut witgen = WitgenInference::new(self.fixed_data, self, known_variables);

// In the latch row, set the RHS selector to 1.
witgen.assign_constant(&connection.right.selector, self.latch_row as i32, T::one());
let selector = &connection.right.selector;
witgen.assign_constant(selector, self.latch_row as i32, T::one());

// Set all other selectors to 0 in the latch row.
for other_connection in self.machine_parts.connections.values() {
let other_selector = &other_connection.right.selector;
if other_selector != selector {
witgen.assign_constant(other_selector, self.latch_row as i32, T::zero());
}
}

// For each argument, connect the expression on the RHS with the formal parameter.
for (index, expr) in connection.right.expressions.iter().enumerate() {
Expand Down
26 changes: 26 additions & 0 deletions executor/src/witgen/jit/effect.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use std::cmp::Ordering;

use std::iter;

use bit_vec::BitVec;
use itertools::Itertools;
use powdr_ast::indent;
use powdr_number::FieldElement;
use std::hash::Hash;

use crate::witgen::range_constraints::RangeConstraint;

Expand All @@ -24,6 +27,29 @@ pub enum Effect<T: FieldElement, V> {
Branch(BranchCondition<T, V>, Vec<Effect<T, V>>, Vec<Effect<T, V>>),
}

impl<T: FieldElement, V: Hash + Eq> Effect<T, V> {
pub fn referenced_variables(&self) -> Box<dyn Iterator<Item = &V> + '_> {
match self {
Effect::Assignment(v, expr) => {
Box::new(iter::once(v).chain(expr.referenced_symbols()).unique())
}
Effect::RangeConstraint(v, _) => Box::new(iter::once(v)),
Effect::Assertion(Assertion { lhs, rhs, .. }) => Box::new(
lhs.referenced_symbols()
.chain(rhs.referenced_symbols())
.unique(),
),
Effect::MachineCall(_, _, args) => Box::new(args.iter().unique()),
Effect::Branch(branch_condition, vec, vec1) => Box::new(
iter::once(&branch_condition.variable)
.chain(vec.iter().flat_map(|effect| effect.referenced_variables()))
.chain(vec1.iter().flat_map(|effect| effect.referenced_variables()))
.unique(),
),
}
}
}

/// A run-time assertion. If this fails, we have conflicting constraints.
#[derive(Clone, PartialEq, Eq)]
pub struct Assertion<T: FieldElement, V> {
Expand Down
24 changes: 7 additions & 17 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use powdr_number::{FieldElement, KnownField};

use crate::witgen::{
data_structures::finalizable_data::{ColumnLayout, CompactDataRef},
jit::effect::Effect,
machines::{LookupCell, MachineParts},
EvalError, FixedData, MutableState, QueryCallback,
};
Expand Down Expand Up @@ -96,26 +95,17 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
self.processor
.generate_code(mutable_state, cache_key.identity_id, &cache_key.known_args)
.ok()
.and_then(|code| {
// TODO: Remove this once BlockMachine passes the right amount of context for machines with
// non-rectangular block shapes.
let is_rectangular = code
.map(|code| {
let is_in_bounds = code
.iter()
.filter_map(|effect| match effect {
Effect::Assignment(v, _) => Some(v),
_ => None,
})
.filter_map(|assigned_variable| match assigned_variable {
.flat_map(|effect| effect.referenced_variables())
.filter_map(|var| match var {
Variable::Cell(cell) => Some(cell.row_offset),
_ => None,
})
.all(|row_offset| row_offset >= 0 && row_offset < self.block_size as i32);
if !is_rectangular {
log::debug!("Filtering out code for non-rectangular block shape");
}
is_rectangular.then_some(code)
})
.map(|code| {
.all(|row_offset| row_offset >= -1 && row_offset < self.block_size as i32);
assert!(is_in_bounds, "Expected JITed code to only reference cells in the block + the last row of the previous block.");

log::trace!("Generated code ({} steps)", code.len());
let known_inputs = cache_key
.known_args
Expand Down
41 changes: 38 additions & 3 deletions executor/src/witgen/jit/symbolic_expression.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use itertools::Itertools;
use num_traits::Zero;
use powdr_ast::parsed::visitor::Children;
use powdr_number::FieldElement;
use std::hash::Hash;
use std::{
fmt::{self, Display, Formatter},
iter,
ops::{Add, BitAnd, Mul, Neg},
rc::Rc,
};

use num_traits::Zero;
use powdr_number::FieldElement;

use crate::witgen::range_constraints::RangeConstraint;

/// A value that is known at run-time, defined through a complex expression
Expand Down Expand Up @@ -45,6 +48,25 @@ pub enum UnaryOperator {
Neg,
}

impl<T: FieldElement, S> Children<SymbolicExpression<T, S>> for SymbolicExpression<T, S> {
fn children(&self) -> Box<dyn Iterator<Item = &SymbolicExpression<T, S>> + '_> {
match self {
SymbolicExpression::BinaryOperation(lhs, _, rhs, _) => {
Box::new(iter::once(lhs.as_ref()).chain(iter::once(rhs.as_ref())))
}
SymbolicExpression::UnaryOperation(_, expr, _) => Box::new(iter::once(expr.as_ref())),
SymbolicExpression::BitOperation(expr, _, _, _) => Box::new(iter::once(expr.as_ref())),
SymbolicExpression::Concrete(_) | SymbolicExpression::Symbol(..) => {
Box::new(iter::empty())
}
}
}

fn children_mut(&mut self) -> Box<dyn Iterator<Item = &mut SymbolicExpression<T, S>> + '_> {
unimplemented!()
}
}

impl<T: FieldElement, S> SymbolicExpression<T, S> {
pub fn from_symbol(symbol: S, rc: RangeConstraint<T>) -> Self {
SymbolicExpression::Symbol(symbol, rc)
Expand Down Expand Up @@ -89,6 +111,19 @@ impl<T: FieldElement, S> SymbolicExpression<T, S> {
}
}

impl<T: FieldElement, S: Hash + Eq> SymbolicExpression<T, S> {
pub fn referenced_symbols(&self) -> Box<dyn Iterator<Item = &S> + '_> {
match self {
SymbolicExpression::Symbol(s, _) => Box::new(iter::once(s)),
_ => Box::new(
self.children()
.flat_map(|c| c.referenced_symbols())
.unique(),
),
}
}
}

/// Display for affine symbolic expressions, for informational purposes only.
impl<T: FieldElement, V: Display> Display for SymbolicExpression<T, V> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Expand Down
1 change: 0 additions & 1 deletion executor/src/witgen/machines/block_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,6 @@ impl<'a, T: FieldElement> BlockMachine<'a, T> {
"Block machine is full (this should have been checked before)"
);
self.data.finalize_all();
//TODO can we properly access the last row of the dummy block?
let data = self.data.append_new_finalized_rows(self.block_size);

let success =
Expand Down
19 changes: 19 additions & 0 deletions executor/src/witgen/processor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::BTreeMap;

use itertools::Itertools;
use num_traits::One;
use powdr_ast::analyzed::PolynomialType;
use powdr_ast::analyzed::{AlgebraicExpression as Expression, AlgebraicReference, PolyID};

Expand Down Expand Up @@ -639,6 +640,24 @@ Known values in current row (local: {row_index}, global {global_row_index}):
),
};

if let Ok(connection) = Connection::try_from(identity) {
// JITed submachines would panic if passed a wrong input / output pair.
// Therefore, if any machine call is activated, we resort to the full
// solving routine.
// An to this is when the call is always active (e.g. the PC lookup).
// In that case, we know that the call has been active before with the
// same input / output pair, so we can be sure that it will succeed.
let selector = &connection.left.selector;
if selector != &Expression::one() {
let selector_value = row_pair
.evaluate(selector)
.unwrap()
.constant_value()
.unwrap();
return selector_value.is_zero();
}
}

if identity_processor
.process_identity(identity, &row_pair)
.is_err()
Expand Down

1 comment on commit 725fceb

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Benchmarks'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.20.

Benchmark suite Current: 725fceb Previous: 1e097e8 Ratio
executor-benchmark/keccak 11995307129 ns/iter (± 33172172) 8898077628 ns/iter (± 54759768) 1.35

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.