Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract assignments from witgen inference, step 1 #2452

Merged
merged 17 commits into from
Feb 7, 2025
Merged
20 changes: 10 additions & 10 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,50 +321,50 @@ main_binary::operation_id[2] = main_binary::operation_id[3];
main_binary::operation_id[1] = main_binary::operation_id[2];
main_binary::operation_id[0] = main_binary::operation_id[1];
main_binary::operation_id_next[-1] = main_binary::operation_id[0];
call_var(9, -1, 0) = main_binary::operation_id_next[-1];
main_binary::operation_id_next[0] = main_binary::operation_id[1];
call_var(9, 0, 0) = main_binary::operation_id_next[0];
main_binary::operation_id_next[1] = main_binary::operation_id[2];
call_var(9, 1, 0) = main_binary::operation_id_next[1];
main_binary::A_byte[2] = ((main_binary::A[3] & 0xff000000) // 16777216);
main_binary::A[2] = (main_binary::A[3] & 0xffffff);
assert (main_binary::A[3] & 0xffffffff00000000) == 0;
call_var(9, 2, 1) = main_binary::A_byte[2];
main_binary::A_byte[1] = ((main_binary::A[2] & 0xff0000) // 65536);
main_binary::A[1] = (main_binary::A[2] & 0xffff);
assert (main_binary::A[2] & 0xffffffffff000000) == 0;
call_var(9, 1, 1) = main_binary::A_byte[1];
main_binary::A_byte[0] = ((main_binary::A[1] & 0xff00) // 256);
main_binary::A[0] = (main_binary::A[1] & 0xff);
assert (main_binary::A[1] & 0xffffffffffff0000) == 0;
call_var(9, 0, 1) = main_binary::A_byte[0];
main_binary::A_byte[-1] = main_binary::A[0];
call_var(9, -1, 1) = main_binary::A_byte[-1];
main_binary::B_byte[2] = ((main_binary::B[3] & 0xff000000) // 16777216);
main_binary::B[2] = (main_binary::B[3] & 0xffffff);
assert (main_binary::B[3] & 0xffffffff00000000) == 0;
call_var(9, 2, 2) = main_binary::B_byte[2];
main_binary::B_byte[1] = ((main_binary::B[2] & 0xff0000) // 65536);
main_binary::B[1] = (main_binary::B[2] & 0xffff);
assert (main_binary::B[2] & 0xffffffffff000000) == 0;
call_var(9, 1, 2) = main_binary::B_byte[1];
main_binary::B_byte[0] = ((main_binary::B[1] & 0xff00) // 256);
main_binary::B[0] = (main_binary::B[1] & 0xff);
assert (main_binary::B[1] & 0xffffffffffff0000) == 0;
call_var(9, 0, 2) = main_binary::B_byte[0];
main_binary::B_byte[-1] = main_binary::B[0];
call_var(9, -1, 0) = main_binary::operation_id_next[-1];
call_var(9, -1, 1) = main_binary::A_byte[-1];
call_var(9, -1, 2) = main_binary::B_byte[-1];
machine_call(9, [Known(call_var(9, -1, 0)), Known(call_var(9, -1, 1)), Known(call_var(9, -1, 2)), Unknown(call_var(9, -1, 3))]);
main_binary::C_byte[-1] = call_var(9, -1, 3);
main_binary::C[0] = main_binary::C_byte[-1];
call_var(9, 0, 0) = main_binary::operation_id_next[0];
call_var(9, 0, 1) = main_binary::A_byte[0];
call_var(9, 0, 2) = main_binary::B_byte[0];
machine_call(9, [Known(call_var(9, 0, 0)), Known(call_var(9, 0, 1)), Known(call_var(9, 0, 2)), Unknown(call_var(9, 0, 3))]);
main_binary::C_byte[0] = call_var(9, 0, 3);
main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256));
call_var(9, 1, 0) = main_binary::operation_id_next[1];
call_var(9, 1, 1) = main_binary::A_byte[1];
call_var(9, 1, 2) = main_binary::B_byte[1];
machine_call(9, [Known(call_var(9, 1, 0)), Known(call_var(9, 1, 1)), Known(call_var(9, 1, 2)), Unknown(call_var(9, 1, 3))]);
main_binary::C_byte[1] = call_var(9, 1, 3);
main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536));
main_binary::operation_id_next[2] = main_binary::operation_id[3];
call_var(9, 2, 0) = main_binary::operation_id_next[2];
call_var(9, 2, 1) = main_binary::A_byte[2];
call_var(9, 2, 2) = main_binary::B_byte[2];
machine_call(9, [Known(call_var(9, 2, 0)), Known(call_var(9, 2, 1)), Known(call_var(9, 2, 2)), Unknown(call_var(9, 2, 3))]);
main_binary::C_byte[2] = call_var(9, 2, 3);
main_binary::C[3] = (main_binary::C[2] + (main_binary::C_byte[2] * 16777216));
Expand Down
47 changes: 39 additions & 8 deletions executor/src/witgen/jit/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@ use std::{
};

use itertools::Itertools;
use powdr_ast::analyzed::{PolyID, PolynomialType};
use powdr_ast::analyzed::{PolyID, PolynomialIdentity, PolynomialType};
use powdr_number::FieldElement;

use crate::witgen::{
data_structures::identity::Identity, jit::debug_formatter::format_identities,
range_constraints::RangeConstraint, FixedData,
data_structures::identity::{BusSend, Identity},
jit::debug_formatter::format_identities,
range_constraints::RangeConstraint,
FixedData,
};

use super::{
affine_symbolic_expression,
effect::{format_code, Effect},
identity_queue::IdentityQueue,
prover_function_heuristics::ProverFunction,
variable::{Cell, Variable},
variable::{Cell, MachineCallVariable, Variable},
witgen_inference::{BranchResult, CanProcessCall, FixedEvaluator, Value, WitgenInference},
};

Expand Down Expand Up @@ -109,8 +111,21 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
pub fn generate_code(
self,
can_process: impl CanProcessCall<T>,
witgen: WitgenInference<'a, T, FixedEval>,
mut witgen: WitgenInference<'a, T, FixedEval>,
) -> Result<ProcessorResult<T>, Error<'a, T, FixedEval>> {
// Create variables for bus send arguments.
for (id, row_offset) in &self.identities {
if let Identity::BusSend(bus_send) = id {
for (index, arg) in bus_send.selected_payload.expressions.iter().enumerate() {
let var = Variable::MachineCallParam(MachineCallVariable {
identity_id: bus_send.identity_id,
row_offset: *row_offset,
index,
});
witgen.assign_variable(arg, *row_offset, var.clone());
}
}
}
let branch_depth = 0;
let identity_queue = IdentityQueue::new(self.fixed_data, &self.identities);
self.generate_code_for_branch(can_process, witgen, identity_queue, branch_depth)
Expand Down Expand Up @@ -283,9 +298,25 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
loop {
let identity = identity_queue.next();
let updated_vars = match identity {
Some((identity, row_offset)) => {
witgen.process_identity(can_process.clone(), identity, row_offset)
}
Some((identity, row_offset)) => match identity {
Identity::Polynomial(PolynomialIdentity { id, expression, .. }) => {
witgen.process_polynomial_identity(*id, expression, row_offset)
}
Identity::BusSend(BusSend {
bus_id: _,
identity_id,
selected_payload,
}) => witgen.process_call(
can_process.clone(),
*identity_id,
&selected_payload.selector,
selected_payload.expressions.len(),
Comment on lines +309 to +313
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just pass a reference to the BusSend struct? Similarly I think that would be good for the Polynomial case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass a whole BusSend, it gives the impression that it processes the arguments instead of working with the variables.

row_offset,
),
Identity::Connect(..) => Ok(vec![]),
},
// TODO Also add prover functions to the queue (activated by their variables)
// and sort them so that they are always last.
None => self.process_prover_functions(witgen),
}?;
if updated_vars.is_empty() && identity.is_none() {
Expand Down
10 changes: 5 additions & 5 deletions executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ namespace M(256);
assert_eq!(
format_code(&code),
"\
VM::pc[1] = (VM::pc[0] + 1);
call_var(1, 0, 0) = VM::pc[0];
call_var(1, 0, 1) = VM::instr_add[0];
call_var(1, 0, 2) = VM::instr_mul[0];
VM::B[1] = VM::B[0];
VM::pc[1] = (VM::pc[0] + 1);
call_var(1, 1, 0) = VM::pc[1];
VM::B[1] = VM::B[0];
machine_call(1, [Known(call_var(1, 1, 0)), Unknown(call_var(1, 1, 1)), Unknown(call_var(1, 1, 2))]);
VM::instr_add[1] = call_var(1, 1, 1);
VM::instr_mul[1] = call_var(1, 1, 2);
Expand Down Expand Up @@ -280,12 +280,12 @@ if (VM::instr_add[0] == 1) {
assert_eq!(
format_code(&code),
"\
VM::pc[1] = VM::pc[0];
call_var(2, 0, 0) = VM::pc[0];
call_var(2, 0, 1) = 0;
call_var(2, 0, 1) = VM::instr_add[0];
call_var(2, 0, 2) = VM::instr_mul[0];
VM::instr_add[1] = 0;
VM::pc[1] = VM::pc[0];
call_var(2, 1, 0) = VM::pc[1];
VM::instr_add[1] = 0;
call_var(2, 1, 1) = 0;
call_var(2, 1, 2) = 1;
machine_call(2, [Known(call_var(2, 1, 0)), Known(call_var(2, 1, 1)), Unknown(call_var(2, 1, 2))]);
Expand Down
Loading
Loading