Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move assignments to processor #2454

Merged
merged 23 commits into from
Feb 7, 2025
31 changes: 24 additions & 7 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use powdr_ast::analyzed::{ContainsNextRef, PolyID, PolynomialType};
use powdr_number::FieldElement;

use crate::witgen::{
jit::{processor::Processor, prover_function_heuristics::decode_prover_functions},
jit::{
processor::Processor, prover_function_heuristics::decode_prover_functions,
witgen_inference::Assignment,
},
machines::MachineParts,
FixedData,
};
Expand Down Expand Up @@ -61,25 +64,38 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
.enumerate()
.filter_map(|(i, is_input)| is_input.then_some(Variable::Param(i)))
.collect::<HashSet<_>>();
let mut witgen = WitgenInference::new(self.fixed_data, self, known_variables, []);
let witgen = WitgenInference::new(self.fixed_data, self, known_variables, []);

let prover_functions = decode_prover_functions(&self.machine_parts, self.fixed_data)?;

// In the latch row, set the RHS selector to 1.
let mut assignments = vec![];
let selector = &connection.right.selector;
witgen.assign_constant(selector, self.latch_row as i32, T::one());
assignments.push(Assignment::assign_constant(
selector,
self.latch_row as i32,
T::one(),
));

// Set all other selectors to 0 in the latch row.
for other_connection in self.machine_parts.connections.values() {
let other_selector = &other_connection.right.selector;
if other_selector != selector {
witgen.assign_constant(other_selector, self.latch_row as i32, T::zero());
assignments.push(Assignment::assign_constant(
other_selector,
self.latch_row as i32,
T::zero(),
));
}
}

// For each argument, connect the expression on the RHS with the formal parameter.
for (index, expr) in connection.right.expressions.iter().enumerate() {
witgen.assign_variable(expr, self.latch_row as i32, Variable::Param(index));
assignments.push(Assignment::assign_variable(
expr,
self.latch_row as i32,
Variable::Param(index),
));
}

let intermediate_definitions = self.fixed_data.analyzed.intermediate_definitions();
Expand Down Expand Up @@ -124,6 +140,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
self.fixed_data,
self,
identities,
assignments,
requested_known,
BLOCK_MACHINE_MAX_BRANCH_DEPTH,
)
Expand Down Expand Up @@ -313,10 +330,10 @@ params[2] = Add::c[0];"
assert_eq!(c_rc, &RangeConstraint::from_mask(0xffffffffu64));
assert_eq!(
format_code(&result.code),
"main_binary::sel[0][3] = 1;
main_binary::operation_id[3] = params[0];
"main_binary::operation_id[3] = params[0];
main_binary::A[3] = params[1];
main_binary::B[3] = params[2];
main_binary::sel[0][3] = 1;
main_binary::operation_id[2] = main_binary::operation_id[3];
main_binary::operation_id[1] = main_binary::operation_id[2];
main_binary::operation_id[0] = main_binary::operation_id[1];
Expand Down
37 changes: 31 additions & 6 deletions executor/src/witgen/jit/identity_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use powdr_ast::{
};
use powdr_number::FieldElement;

use crate::witgen::{data_structures::identity::Identity, FixedData};
use crate::witgen::{
data_structures::identity::Identity, jit::variable::MachineCallVariable, FixedData,
};

use super::{
variable::Variable,
Expand Down Expand Up @@ -128,13 +130,25 @@ fn compute_occurrences_map<'b, 'a: 'b, T: FieldElement>(
.flat_map(|item| {
let variables = match item {
QueueItem::Identity(id, row) => {
references_in_identity(id, fixed_data, &mut intermediate_cache)
.into_iter()
let mut variables = references_per_identity[&id.id()]
.iter()
.map(|r| {
let name = fixed_data.column_name(&r.poly_id).to_string();
Variable::from_reference(&r.with_name(name), *row)
})
.collect_vec()
.collect_vec();
if let Identity::BusSend(bus_send) = id {
variables.extend((0..bus_send.selected_payload.expressions.len()).map(
|index| {
Variable::MachineCallParam(MachineCallVariable {
identity_id: id.id(),
row_offset: *row,
index,
})
},
));
};
variables
}
QueueItem::Assignment(a) => {
variables_in_assignment(a, fixed_data, &mut intermediate_cache)
Expand All @@ -152,9 +166,20 @@ fn references_in_identity<T: FieldElement>(
intermediate_cache: &mut HashMap<AlgebraicReferenceThin, Vec<AlgebraicReferenceThin>>,
) -> Vec<AlgebraicReferenceThin> {
let mut result = BTreeSet::new();
for e in identity.children() {
result.extend(references_in_expression(e, fixed_data, intermediate_cache));

match identity {
Identity::BusSend(bus_send) => result.extend(references_in_expression(
&bus_send.selected_payload.selector,
fixed_data,
intermediate_cache,
)),
_ => {
for e in identity.children() {
result.extend(references_in_expression(e, fixed_data, intermediate_cache));
}
}
}

result.into_iter().collect()
}

Expand Down
68 changes: 43 additions & 25 deletions executor/src/witgen/jit/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use super::{
identity_queue::{IdentityQueue, QueueItem},
prover_function_heuristics::ProverFunction,
variable::{Cell, MachineCallVariable, Variable},
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference},
witgen_inference::{
Assignment, BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference,
},
};

/// A generic processor for generating JIT code.
Expand All @@ -31,6 +33,8 @@ pub struct Processor<'a, T: FieldElement, FixedEval> {
fixed_evaluator: FixedEval,
/// List of identities and row offsets to process them on.
identities: Vec<(&'a Identity<T>, i32)>,
/// List of assignments provided from outside.
initial_assignments: Vec<Assignment<'a, T>>,
/// The prover functions, i.e. helpers to compute certain values that
/// we cannot easily determine.
prover_functions: Vec<(ProverFunction<'a, T>, i32)>,
Expand Down Expand Up @@ -60,6 +64,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
fixed_data: &'a FixedData<'a, T>,
fixed_evaluator: FixedEval,
identities: impl IntoIterator<Item = (&'a Identity<T>, i32)>,
assignments: Vec<Assignment<'a, T>>,
requested_known_vars: impl IntoIterator<Item = Variable>,
max_branch_depth: usize,
) -> Self {
Expand All @@ -68,6 +73,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
fixed_data,
fixed_evaluator,
identities,
initial_assignments: assignments,
prover_functions: vec![],
block_size: 1,
check_block_shape: false,
Expand Down Expand Up @@ -111,23 +117,37 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
pub fn generate_code(
self,
can_process: impl CanProcessCall<T>,
mut witgen: WitgenInference<'a, T, FixedEval>,
witgen: WitgenInference<'a, T, FixedEval>,
) -> Result<ProcessorResult<T>, Error<'a, T, FixedEval>> {
// Create variables for bus send arguments.
for (id, row_offset) in &self.identities {
if let Identity::BusSend(bus_send) = id {
for (index, arg) in bus_send.selected_payload.expressions.iter().enumerate() {
let var = Variable::MachineCallParam(MachineCallVariable {
identity_id: bus_send.identity_id,
row_offset: *row_offset,
index,
});
witgen.assign_variable(arg, *row_offset, var.clone());
}
}
}
// Create variable assignments for bus send arguments.
let mut assignments = self.initial_assignments.clone();
assignments.extend(
self.identities
.iter()
.filter_map(|(id, row_offset)| {
if let Identity::BusSend(bus_send) = id {
Some((
bus_send.identity_id,
&bus_send.selected_payload.expressions,
*row_offset,
))
} else {
None
}
})
.flat_map(|(identity_id, arguments, row_offset)| {
arguments.iter().enumerate().map(move |(index, arg)| {
let var = Variable::MachineCallParam(MachineCallVariable {
identity_id,
row_offset,
index,
});
Assignment::assign_variable(arg, row_offset, var)
})
}),
);
let branch_depth = 0;
let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities, &[]);
let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities, &assignments);
self.generate_code_for_branch(can_process, witgen, identity_queue, branch_depth)
}

Expand Down Expand Up @@ -296,11 +316,11 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
identity_queue: &mut IdentityQueue<'a, T>,
) -> Result<(), affine_symbolic_expression::Error> {
loop {
let identity = identity_queue.next();
let updated_vars = match identity {
let item = identity_queue.next();
let updated_vars = match &item {
Some(QueueItem::Identity(identity, row_offset)) => match identity {
Identity::Polynomial(PolynomialIdentity { id, expression, .. }) => {
witgen.process_polynomial_identity(*id, expression, row_offset)
witgen.process_polynomial_identity(*id, expression, *row_offset)
}
Identity::BusSend(BusSend {
bus_id: _,
Expand All @@ -311,23 +331,21 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
*identity_id,
&selected_payload.selector,
selected_payload.expressions.len(),
row_offset,
*row_offset,
),
Identity::Connect(..) => Ok(vec![]),
},
Some(QueueItem::Assignment(assignment)) => witgen.process_assignment(assignment),
// TODO Also add prover functions to the queue (activated by their variables)
// and sort them so that they are always last.
Some(QueueItem::Assignment(_assignment)) => {
todo!()
}
None => self.process_prover_functions(witgen),
}?;
if updated_vars.is_empty() && identity.is_none() {
if updated_vars.is_empty() && item.is_none() {
// No identities to process and prover functions did not make any progress,
// we are done.
return Ok(());
}
identity_queue.variables_updated(updated_vars, identity);
identity_queue.variables_updated(updated_vars, item);
}
}

Expand Down
5 changes: 3 additions & 2 deletions executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
self.fixed_data,
self,
identities,
vec![],
requested_known,
SINGLE_STEP_MACHINE_MAX_BRANCH_DEPTH,
)
Expand Down Expand Up @@ -237,9 +238,9 @@ namespace M(256);
assert_eq!(
format_code(&code),
"\
call_var(1, 0, 0) = VM::pc[0];
call_var(1, 0, 1) = VM::instr_add[0];
call_var(1, 0, 2) = VM::instr_mul[0];
call_var(1, 0, 0) = VM::pc[0];
VM::pc[1] = (VM::pc[0] + 1);
call_var(1, 1, 0) = VM::pc[1];
VM::B[1] = VM::B[0];
Expand Down Expand Up @@ -280,9 +281,9 @@ if (VM::instr_add[0] == 1) {
assert_eq!(
format_code(&code),
"\
call_var(2, 0, 0) = VM::pc[0];
call_var(2, 0, 1) = VM::instr_add[0];
call_var(2, 0, 2) = VM::instr_mul[0];
call_var(2, 0, 0) = VM::pc[0];
VM::pc[1] = VM::pc[0];
call_var(2, 1, 0) = VM::pc[1];
VM::instr_add[1] = 0;
Expand Down
Loading
Loading