Skip to content

Commit

Permalink
JIT for block machines with non-rectangular shapes (#2275)
Browse files Browse the repository at this point in the history
Depends on #2279

This PR implements JIT code generation for block machines with irregular
block shape, such as `std::machines::large_field::binary::Binary`. This
is achieved as follows:
- Instead of solving rows `0..block_size`, we run the solver for rows
`-1..(block_size + 1)`. This way, the solver is able to generate code
that writes to the previous row of the last block or the first row of
the next block.
- At the end, we check whether the generated code is actually
consistent: For example, if the code writes to the last row of the
previous block, it can't have a unknown value in the same cell of the
current block (unless it's known to be the same).

Note that the generated code is still not used in practice, because we
don't call the JIT with the right amount of context. I started fixing
this in #2281, but it is still WIP.

---------

Co-authored-by: chriseth <[email protected]>
  • Loading branch information
georgwiese and chriseth authored Jan 3, 2025
1 parent 91202ee commit 4bb99be
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 46 deletions.
229 changes: 204 additions & 25 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::collections::HashSet;
use std::collections::{BTreeSet, HashSet};

use bit_vec::BitVec;
use itertools::Itertools;
use powdr_ast::analyzed::{AlgebraicReference, Identity, SelectedExpressions};
use powdr_ast::analyzed::{
AlgebraicReference, Identity, PolyID, PolynomialType, SelectedExpressions,
};
use powdr_number::FieldElement;

use crate::witgen::{jit::effect::format_code, machines::MachineParts, FixedData};

use super::{
effect::Effect,
variable::Variable,
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
variable::{Cell, Variable},
witgen_inference::{CanProcessCall, FixedEvaluator, Value, WitgenInference},
};

/// A processor for generating JIT code for a block machine.
Expand Down Expand Up @@ -85,6 +87,11 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
}
}

fn row_range(&self) -> std::ops::Range<i32> {
// We iterate over all rows of the block +/- one row, so that we can also solve for non-rectangular blocks.
-1..(self.block_size + 1) as i32
}

/// Repeatedly processes all identities on all rows, until no progress is made.
/// Fails iff there are incomplete machine calls in the latch row.
fn solve_block<CanProcess: CanProcessCall<T> + Clone>(
Expand All @@ -97,11 +104,10 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
for iteration in 0.. {
let mut progress = false;

// TODO: This algorithm is assuming a rectangular block shape.
for row in 0..self.block_size {
for row in self.row_range() {
for id in &self.machine_parts.identities {
if !complete.contains(&(id.id(), row)) {
let result = witgen.process_identity(can_process.clone(), id, row as i32);
let result = witgen.process_identity(can_process.clone(), id, row);
if result.complete {
complete.insert((id.id(), row));
}
Expand All @@ -125,22 +131,121 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
}
}

// If any machine call could not be completed, that's bad because machine calls typically have side effects.
// So, the underlying lookup / permutation / bus argument likely does not hold.
// TODO: This assumes a rectangular block shape.
let has_incomplete_machine_calls = (0..self.block_size)
.flat_map(|row| {
self.machine_parts
.identities
.iter()
.filter(|id| is_machine_call(id))
.map(move |id| (id, row))
// TODO: Fail hard (or return a different error), as this should never
// happen for valid block machines. Currently fails in:
// powdr-pipeline::powdr_std arith256_memory_large_test
self.check_block_shape(witgen)?;
self.check_incomplete_machine_calls(&complete)?;

Ok(())
}

/// After solving, the known values should be such that we can stack different blocks.
fn check_block_shape(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> {
let known_columns = witgen
.known_variables()
.iter()
.filter_map(|var| match var {
Variable::Cell(cell) => Some(cell.id),
_ => None,
})
.any(|(identity, row)| !complete.contains(&(identity.id(), row)));
.collect::<BTreeSet<_>>();

let can_stack = known_columns.iter().all(|column_id| {
// Increase the range by 1, because in row <block_size>,
// we might have processed an identity with next references.
let row_range = self.row_range();
let values = (row_range.start..(row_range.end + 1))
.map(|row| {
witgen.value(&Variable::Cell(Cell {
id: *column_id,
row_offset: row,
// Dummy value, the column name is ignored in the implementation
// of Cell::eq, etc.
column_name: "".to_string(),
}))
})
.collect::<Vec<_>>();

// Two values that refer to the same row (modulo block size) are compatible if:
// - One of them is unknown, or
// - Both are concrete and equal
let is_compatible = |v1: Value<T>, v2: Value<T>| match (v1, v2) {
(Value::Unknown, _) | (_, Value::Unknown) => true,
(Value::Concrete(a), Value::Concrete(b)) => a == b,
_ => false,
};
// A column is stackable if all rows equal to each other modulo
// the block size are compatible.
let stackable = (0..(values.len() - self.block_size))
.all(|i| is_compatible(values[i], values[i + self.block_size]));

match has_incomplete_machine_calls {
true => Err("Incomplete machine calls".to_string()),
false => Ok(()),
if !stackable {
let column_name = self.fixed_data.column_name(&PolyID {
id: *column_id,
ptype: PolynomialType::Committed,
});
let block_list = values.iter().skip(1).take(self.block_size).join(", ");
let column_str = format!(
"... {} | {} | {} ...",
values[0],
block_list,
values[self.block_size + 1]
);
log::debug!("Column {column_name} is not stackable:\n{column_str}");
}

stackable
});

match can_stack {
true => Ok(()),
false => Err("Block machine shape does not allow stacking".to_string()),
}
}

/// If any machine call could not be completed, that's bad because machine calls typically have side effects.
/// So, the underlying lookup / permutation / bus argument likely does not hold.
/// This function checks that all machine calls are complete, at least for a window of <block_size> rows.
fn check_incomplete_machine_calls(&self, complete: &HashSet<(u64, i32)>) -> Result<(), String> {
let machine_calls = self
.machine_parts
.identities
.iter()
.filter(|id| is_machine_call(id));

let incomplete_machine_calls = machine_calls
.flat_map(|call| {
let complete_rows = self
.row_range()
.filter(|row| complete.contains(&(call.id(), *row)))
.collect::<Vec<_>>();
// Because we process rows -1..block_size+1, it is fine to have two incomplete machine calls,
// as long as <block_size> consecutive rows are complete.
if complete_rows.len() >= self.block_size {
let (min, max) = complete_rows.iter().minmax().into_option().unwrap();
let is_consecutive = max - min == complete_rows.len() as i32 - 1;
if is_consecutive {
return vec![];
}
}
self.row_range()
.filter(|row| !complete.contains(&(call.id(), *row)))
.map(|row| (call, row))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();

if !incomplete_machine_calls.is_empty() {
Err(format!(
"Incomplete machine calls:\n {}",
incomplete_machine_calls
.iter()
.map(|(identity, row)| format!("{identity} (row {row})"))
.join("\n ")
))
} else {
Ok(())
}
}
}
Expand All @@ -160,7 +265,22 @@ impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
fn evaluate(&self, var: &AlgebraicReference, row_offset: i32) -> Option<T> {
assert!(var.is_fixed());
let values = self.fixed_data.fixed_cols[&var.poly_id].values_max_size();
let row = (row_offset + var.next as i32 + values.len() as i32) as usize % values.len();

// By assumption of the block machine, all fixed columns are cyclic with a period of <block_size>.
// An exception might be the first and last row.
assert!(row_offset >= -1);
assert!(self.block_size >= 1);
// The current row is guaranteed to be at least 1.
let current_row = (2 * self.block_size as i32 + row_offset) as usize;
let row = current_row + var.next as usize;

assert!(values.len() >= self.block_size * 4);

// Fixed columns are assumed to be cyclic, except in the first and last row.
// The code above should ensure that we never access the first or last row.
assert!(row > 0);
assert!(row < values.len() - 1);

Some(values[row])
}
}
Expand Down Expand Up @@ -265,11 +385,70 @@ params[2] = Add::c[0];"
}

#[test]
// TODO: Currently fails, because the machine has a non-rectangular block shape.
#[should_panic = "Unable to derive algorithm to compute output value \\\"main_binary::C\\\""]
#[should_panic = "Block machine shape does not allow stacking"]
fn not_stackable() {
let input = "
namespace Main(256);
col witness a, b, c;
[a] is NotStackable.sel $ [NotStackable.a];
namespace NotStackable(256);
col witness sel, a;
a = a';
";
generate_for_block_machine(input, "NotStackable", 1, 0).unwrap();
}

#[test]
fn binary() {
let input = read_to_string("../test_data/pil/binary.pil").unwrap();
generate_for_block_machine(&input, "main_binary", 3, 1).unwrap();
let code = generate_for_block_machine(&input, "main_binary", 3, 1).unwrap();
assert_eq!(
format_code(&code),
"main_binary::sel[0][3] = 1;
main_binary::operation_id[3] = params[0];
main_binary::A[3] = params[1];
main_binary::B[3] = params[2];
main_binary::operation_id[2] = main_binary::operation_id[3];
main_binary::A_byte[2] = ((main_binary::A[3] & 4278190080) // 16777216);
main_binary::A[2] = (main_binary::A[3] & 16777215);
assert (main_binary::A[3] & 18446744069414584320) == 0;
main_binary::B_byte[2] = ((main_binary::B[3] & 4278190080) // 16777216);
main_binary::B[2] = (main_binary::B[3] & 16777215);
assert (main_binary::B[3] & 18446744069414584320) == 0;
main_binary::operation_id_next[2] = main_binary::operation_id[3];
machine_call(9, [Known(main_binary::operation_id_next[2]), Known(main_binary::A_byte[2]), Known(main_binary::B_byte[2]), Unknown(ret(9, 2, 3))]);
main_binary::C_byte[2] = ret(9, 2, 3);
main_binary::operation_id[1] = main_binary::operation_id[2];
main_binary::A_byte[1] = ((main_binary::A[2] & 16711680) // 65536);
main_binary::A[1] = (main_binary::A[2] & 65535);
assert (main_binary::A[2] & 18446744073692774400) == 0;
main_binary::B_byte[1] = ((main_binary::B[2] & 16711680) // 65536);
main_binary::B[1] = (main_binary::B[2] & 65535);
assert (main_binary::B[2] & 18446744073692774400) == 0;
main_binary::operation_id_next[1] = main_binary::operation_id[2];
machine_call(9, [Known(main_binary::operation_id_next[1]), Known(main_binary::A_byte[1]), Known(main_binary::B_byte[1]), Unknown(ret(9, 1, 3))]);
main_binary::C_byte[1] = ret(9, 1, 3);
main_binary::operation_id[0] = main_binary::operation_id[1];
main_binary::A_byte[0] = ((main_binary::A[1] & 65280) // 256);
main_binary::A[0] = (main_binary::A[1] & 255);
assert (main_binary::A[1] & 18446744073709486080) == 0;
main_binary::B_byte[0] = ((main_binary::B[1] & 65280) // 256);
main_binary::B[0] = (main_binary::B[1] & 255);
assert (main_binary::B[1] & 18446744073709486080) == 0;
main_binary::operation_id_next[0] = main_binary::operation_id[1];
machine_call(9, [Known(main_binary::operation_id_next[0]), Known(main_binary::A_byte[0]), Known(main_binary::B_byte[0]), Unknown(ret(9, 0, 3))]);
main_binary::C_byte[0] = ret(9, 0, 3);
main_binary::A_byte[-1] = main_binary::A[0];
main_binary::B_byte[-1] = main_binary::B[0];
main_binary::operation_id_next[-1] = main_binary::operation_id[0];
machine_call(9, [Known(main_binary::operation_id_next[-1]), Known(main_binary::A_byte[-1]), Known(main_binary::B_byte[-1]), Unknown(ret(9, -1, 3))]);
main_binary::C_byte[-1] = ret(9, -1, 3);
main_binary::C[0] = main_binary::C_byte[-1];
main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256));
main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536));
main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216));
params[3] = main_binary::C[3];"
)
}

#[test]
Expand Down
23 changes: 23 additions & 0 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use powdr_number::{FieldElement, KnownField};

use crate::witgen::{
data_structures::finalizable_data::{ColumnLayout, CompactDataRef},
jit::effect::Effect,
machines::{LookupCell, MachineParts},
EvalError, FixedData, MutableState, QueryCallback,
};
Expand All @@ -28,6 +29,7 @@ pub struct FunctionCache<'a, T: FieldElement> {
/// but failed.
witgen_functions: HashMap<CacheKey, Option<WitgenFunction<T>>>,
column_layout: ColumnLayout,
block_size: usize,
}

impl<'a, T: FieldElement> FunctionCache<'a, T> {
Expand All @@ -45,6 +47,7 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
processor,
column_layout: metadata,
witgen_functions: HashMap::new(),
block_size,
}
}

Expand Down Expand Up @@ -89,9 +92,29 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
cache_key: &CacheKey,
) -> Option<WitgenFunction<T>> {
log::trace!("Compiling JIT function for {:?}", cache_key);

self.processor
.generate_code(mutable_state, cache_key.identity_id, &cache_key.known_args)
.ok()
.and_then(|code| {
// TODO: Remove this once BlockMachine passes the right amount of context for machines with
// non-rectangular block shapes.
let is_rectangular = code
.iter()
.filter_map(|effect| match effect {
Effect::Assignment(v, _) => Some(v),
_ => None,
})
.filter_map(|assigned_variable| match assigned_variable {
Variable::Cell(cell) => Some(cell.row_offset),
_ => None,
})
.all(|row_offset| row_offset >= 0 && row_offset < self.block_size as i32);
if !is_rectangular {
log::debug!("Filtering out code for non-rectangular block shape");
}
is_rectangular.then_some(code)
})
.map(|code| {
log::trace!("Generated code ({} steps)", code.len());
let known_inputs = cache_key
Expand Down
Loading

0 comments on commit 4bb99be

Please sign in to comment.