Skip to content

Commit

Permalink
Support machine calls. (#2241)
Browse files Browse the repository at this point in the history
Depends on #2244
  • Loading branch information
chriseth authored Dec 19, 2024
1 parent 309279a commit a2a67a6
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 106 deletions.
2 changes: 1 addition & 1 deletion executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {

/// Repeatedly processes all identities on all rows, until no progress is made.
/// Fails iff there are incomplete machine calls in the latch row.
fn solve_block(&self, witgen: &mut WitgenInference<T, &Self>) -> Result<(), String> {
fn solve_block(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> {
let mut complete = HashSet::new();
for iteration in 0.. {
let mut progress = false;
Expand Down
212 changes: 156 additions & 56 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<T: FieldElement> WitgenFunction<T> {
/// This function always succeeds (unless it panics).
pub fn call<Q: QueryCallback<T>>(
&self,
_mutable_state: &MutableState<'_, T, Q>,
mutable_state: &MutableState<'_, T, Q>,
params: &mut [LookupCell<T>],
mut data: CompactDataRef<'_, T>,
) {
Expand All @@ -47,6 +47,7 @@ impl<T: FieldElement> WitgenFunction<T> {
known: known.as_mut_ptr(),
row_offset,
params: params.into(),
mutable_state: mutable_state as *const _ as *const c_void,
call_machine: call_machine::<T, Q>,
});
}
Expand Down Expand Up @@ -100,6 +101,8 @@ struct WitgenFunctionParams<'a, T: 'a> {
row_offset: u64,
/// Input and output parameters if this is a machine call.
params: MutSlice<LookupCell<'a, T>>,
/// The pointer to the mutable state.
mutable_state: *const c_void,
/// A callback to call submachines.
call_machine: extern "C" fn(*const c_void, u64, MutSlice<LookupCell<'_, T>>) -> bool,
}
Expand Down Expand Up @@ -147,6 +150,9 @@ fn witgen_code<T: FieldElement>(
format!("get(data, row_offset, {}, {})", c.row_offset, c.id)
}
Variable::Param(i) => format!("get_param(params, {i})"),
Variable::MachineCallReturnValue(_) => {
unreachable!("Machine call return values should not be pre-known.")
}
};
format!(" let {var_name} = {value};")
})
Expand All @@ -158,17 +164,17 @@ fn witgen_code<T: FieldElement>(
.collect_vec();
let store_values = vars_known
.iter()
.map(|var| {
.filter_map(|var| {
let value = variable_to_string(var);
match var {
Variable::Cell(cell) => {
format!(
" set(data, row_offset, {}, {}, {value});",
cell.row_offset, cell.id,
)
}
Variable::Param(i) => {
format!(" set_param(params, {i}, {value});")
Variable::Cell(cell) => Some(format!(
" set(data, row_offset, {}, {}, {value});",
cell.row_offset, cell.id,
)),
Variable::Param(i) => Some(format!(" set_param(params, {i}, {value});")),
Variable::MachineCallReturnValue(_) => {
// This is just an internal variable.
None
}
}
})
Expand All @@ -179,7 +185,7 @@ fn witgen_code<T: FieldElement>(
.iter()
.filter_map(|var| match var {
Variable::Cell(cell) => Some(cell),
Variable::Param(_) => None,
Variable::Param(_) | Variable::MachineCallReturnValue(_) => None,
})
.map(|cell| {
format!(
Expand All @@ -197,6 +203,7 @@ extern "C" fn witgen(
known,
row_offset,
params,
mutable_state,
call_machine
}}: WitgenFunctionParams<FieldElement>,
) {{
Expand Down Expand Up @@ -226,7 +233,7 @@ fn written_vars_in_effect<T: FieldElement>(
Effect::RangeConstraint(..) => unreachable!(),
Effect::Assertion(..) => iter::empty(),
Effect::MachineCall(_, arguments) => arguments.iter().flat_map(|e| match e {
MachineCallArgument::Unknown(e) => Some(e.single_unknown_variable().unwrap()),
MachineCallArgument::Unknown(v) => Some(v),
MachineCallArgument::Known(_) => None,
}),
}
Expand Down Expand Up @@ -254,7 +261,31 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>) -> String {
if *expected_equal { "==" } else { "!=" },
format_expression(rhs)
),
Effect::MachineCall(..) => todo!(),
Effect::MachineCall(id, arguments) => {
let mut result_vars = vec![];
let args = arguments
.iter()
.map(|a| match a {
MachineCallArgument::Unknown(v) => {
let var_name = variable_to_string(v);
result_vars.push(var_name.clone());
format!("LookupCell::Output(&mut {var_name})")
}
MachineCallArgument::Known(v) => {
format!("LookupCell::Input(&{})", format_expression(v))
}
})
.format(", ")
.to_string();
let var_decls = result_vars
.iter()
.map(|var_name| format!(" let mut {var_name} = FieldElement::default();"))
.format("\n");
format!(
"{var_decls}
assert!(call_machine(mutable_state, {id}, MutSlice::from((&mut [{args}]).as_mut_slice())));"
)
}
}
}

Expand Down Expand Up @@ -298,6 +329,14 @@ fn variable_to_string(v: &Variable) -> String {
format_row_offset(cell.row_offset)
),
Variable::Param(i) => format!("p_{i}"),
Variable::MachineCallReturnValue(ret) => {
format!(
"ret_{}_{}_{}",
ret.identity_id,
format_row_offset(ret.row_offset),
ret.index
)
}
}
}

Expand Down Expand Up @@ -382,6 +421,7 @@ mod tests {
use powdr_number::GoldilocksField;

use crate::witgen::jit::variable::Cell;
use crate::witgen::jit::variable::MachineCallReturnVariable;

use super::*;

Expand Down Expand Up @@ -409,6 +449,14 @@ mod tests {
Variable::Param(i)
}

fn ret_val(identity_id: u64, row_offset: i32, index: usize) -> Variable {
Variable::MachineCallReturnValue(MachineCallReturnVariable {
identity_id,
row_offset,
index,
})
}

fn symbol(var: &Variable) -> SymbolicExpression<GoldilocksField, Variable> {
SymbolicExpression::from_symbol(var.clone(), None)
}
Expand All @@ -425,14 +473,22 @@ mod tests {
}

#[test]
fn simple_effects() {
fn code_for_effects() {
let a0 = cell("a", 2, 0);
let x0 = cell("x", 0, 0);
let ym1 = cell("y", 1, -1);
let yp1 = cell("y", 1, 1);
let r1 = ret_val(7, 1, 1);
let effects = vec![
assignment(&x0, number(7) * symbol(&a0)),
assignment(&ym1, symbol(&x0)),
Effect::MachineCall(
7,
vec![
MachineCallArgument::Unknown(r1.clone()),
MachineCallArgument::Known(symbol(&x0)),
],
),
assignment(&ym1, symbol(&r1)),
assignment(&yp1, symbol(&ym1) + symbol(&x0)),
Effect::Assertion(Assertion {
lhs: symbol(&ym1),
Expand All @@ -452,6 +508,7 @@ extern \"C\" fn witgen(
known,
row_offset,
params,
mutable_state,
call_machine
}: WitgenFunctionParams<FieldElement>,
) {
Expand All @@ -462,7 +519,9 @@ extern \"C\" fn witgen(
let c_a_2_0 = get(data, row_offset, 0, 2);
let c_x_0_0 = (FieldElement::from(7) * c_a_2_0);
let c_y_1_m1 = c_x_0_0;
let mut ret_7_1_1 = FieldElement::default();
assert!(call_machine(mutable_state, 7, MutSlice::from((&mut [LookupCell::Output(&mut ret_7_1_1), LookupCell::Input(&c_x_0_0)]).as_mut_slice())));
let c_y_1_m1 = ret_7_1_1;
let c_y_1_1 = (c_y_1_m1 + c_x_0_0);
assert!(c_y_1_m1 == c_x_0_0);
Expand All @@ -486,6 +545,20 @@ extern \"C\" fn witgen(
false
}

fn witgen_fun_params<'a>(
data: &mut [GoldilocksField],
known: &mut [u32],
) -> WitgenFunctionParams<'a, GoldilocksField> {
WitgenFunctionParams {
data: data.into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: Default::default(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
}
}

#[test]
fn load_code() {
let x = cell("x", 0, 0);
Expand All @@ -497,14 +570,7 @@ extern \"C\" fn witgen(
let f = compile_effects(0, 1, &[], &effects).unwrap();
let mut data = vec![GoldilocksField::from(0); 2];
let mut known = vec![0; 1];
let params = WitgenFunctionParams {
data: MutSlice::from(data.as_mut_slice()),
known: known.as_mut_ptr(),
row_offset: 0,
params: Default::default(),
call_machine: no_call_machine,
};
(f.function)(params);
(f.function)(witgen_fun_params(&mut data, &mut known));
assert_eq!(data[0], GoldilocksField::from(7));
assert_eq!(data[1], GoldilocksField::from(9));
assert_eq!(known[0], 3);
Expand All @@ -528,14 +594,7 @@ extern \"C\" fn witgen(
let f2 = compile_effects(0, column_count, &[], &effects2).unwrap();
let mut data = vec![GoldilocksField::from(0); data_len];
let mut known = vec![0; row_count];
let params1 = WitgenFunctionParams {
data: MutSlice::from(data.as_mut_slice()),
known: known.as_mut_ptr(),
row_offset: 0,
params: Default::default(),
call_machine: no_call_machine,
};
(f1.function)(params1);
(f1.function)(witgen_fun_params(&mut data, &mut known));
assert_eq!(data[0], GoldilocksField::from(7));
assert_eq!(data[1], GoldilocksField::from(0));
assert_eq!(data[2], GoldilocksField::from(0));
Expand All @@ -546,6 +605,7 @@ extern \"C\" fn witgen(
known: known.as_mut_ptr(),
row_offset: 1,
params: Default::default(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
};
(f2.function)(params2);
Expand All @@ -568,14 +628,7 @@ extern \"C\" fn witgen(
let f = compile_effects(0, 1, &[], &effects).unwrap();
let mut data = vec![GoldilocksField::from(0); 5];
let mut known = vec![0; 5];
let params = WitgenFunctionParams {
data: data.as_mut_slice().into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: Default::default(),
call_machine: no_call_machine,
};
(f.function)(params);
(f.function)(witgen_fun_params(&mut data, &mut known));
assert_eq!(data[0], GoldilocksField::from(4));
assert_eq!(
data[1],
Expand All @@ -600,14 +653,7 @@ extern \"C\" fn witgen(
-GoldilocksField::from(4),
];
let mut known = vec![0; 1];
let params = WitgenFunctionParams {
data: data.as_mut_slice().into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: Default::default(),
call_machine: no_call_machine,
};
(f.function)(params);
(f.function)(witgen_fun_params(&mut data, &mut known));
assert_eq!(data[0], -GoldilocksField::from(12));
}

Expand All @@ -628,14 +674,7 @@ extern \"C\" fn witgen(
GoldilocksField::from(0),
];
let mut known = vec![0; 1];
let params = WitgenFunctionParams {
data: data.as_mut_slice().into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: Default::default(),
call_machine: no_call_machine,
};
(f.function)(params);
(f.function)(witgen_fun_params(&mut data, &mut known));
assert_eq!(data[0], GoldilocksField::from(23));
assert_eq!(data[1], GoldilocksField::from(2));
assert_eq!(data[2], GoldilocksField::from(0));
Expand All @@ -657,6 +696,7 @@ extern \"C\" fn witgen(
known: known.as_mut_ptr(),
row_offset: 0,
params: params.as_mut_slice().into(),
mutable_state: std::ptr::null(),
call_machine: no_call_machine,
};
(f.function)(params);
Expand All @@ -679,4 +719,64 @@ extern \"C\" fn witgen(
let code = witgen_code(&known_inputs, &effects);
assert!(code.contains(&format!("let c_x_1_0 = (c_a_0_0 & {large_num});")));
}

extern "C" fn mock_call_machine(
_: *const c_void,
id: u64,
params: MutSlice<LookupCell<'_, GoldilocksField>>,
) -> bool {
assert_eq!(id, 7);
assert_eq!(params.len, 3);

let params: &mut [LookupCell<GoldilocksField>] = params.into();
match &params[0] {
LookupCell::Input(x) => assert_eq!(**x, 7.into()),
_ => panic!(),
}
match &mut params[1] {
LookupCell::Output(y) => **y = 9.into(),
_ => panic!(),
}
match &mut params[2] {
LookupCell::Output(z) => **z = 18.into(),
_ => panic!(),
}
true
}

#[test]
fn submachine_calls() {
let x = cell("x", 0, 0);
let y = cell("y", 1, 0);
let r1 = ret_val(7, 0, 1);
let r2 = ret_val(7, 0, 2);
let effects = vec![
Effect::MachineCall(
7,
vec![
MachineCallArgument::Known(number(7)),
MachineCallArgument::Unknown(r1.clone()),
MachineCallArgument::Unknown(r2.clone()),
],
),
Effect::Assignment(x.clone(), symbol(&r1)),
Effect::Assignment(y.clone(), symbol(&r2)),
];
let known_inputs = vec![];
let f = compile_effects(0, 3, &known_inputs, &effects).unwrap();
let mut data = vec![GoldilocksField::from(0); 3];
let mut known = vec![0; 1];
let params = WitgenFunctionParams {
data: data.as_mut_slice().into(),
known: known.as_mut_ptr(),
row_offset: 0,
params: Default::default(),
mutable_state: std::ptr::null(),
call_machine: mock_call_machine,
};
(f.function)(params);
assert_eq!(data[0], GoldilocksField::from(9));
assert_eq!(data[1], GoldilocksField::from(18));
assert_eq!(data[2], GoldilocksField::from(0));
}
}
Loading

0 comments on commit a2a67a6

Please sign in to comment.