Skip to content

Commit

Permalink
first pass of witgen; lots of bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
qwang98 committed Feb 6, 2025
1 parent 1af5270 commit 6998967
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
21 changes: 20 additions & 1 deletion executor/src/witgen/bus_accumulator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ impl<'a, T: FieldElement, Ext: ExtensionField<T> + Sync> BusAccumulatorGenerator
.flat_map(|bus_interaction| {
let (folded, acc) = self.interaction_columns(bus_interaction);
collect_folded_columns(bus_interaction, folded)
.chain(collect_acc_columns(bus_interaction, acc))
.chain(collect_acc_columns(bus_interaction, acc.clone()))
.chain(collect_helper_columns(bus_interaction, acc))
.collect::<Vec<_>>()
})
// Each thread builds its own BTreeMap.
Expand Down Expand Up @@ -179,6 +180,13 @@ impl<'a, T: FieldElement, Ext: ExtensionField<T> + Sync> BusAccumulatorGenerator
result
}

/// New version of interaction_columns that “batches” several bus interactions
/// according to bus_multi_interaction_2.
///
/// Returns a triple:
/// - the folded columns (one per bus interaction),
/// - the accumulator column (shared by all interactions),
/// - one helper column per pair of bus interactions.
fn interaction_columns(
&self,
bus_interaction: &PhantomBusInteractionIdentity<T>,
Expand Down Expand Up @@ -279,3 +287,14 @@ fn collect_acc_columns<T>(
.zip_eq(acc)
.map(|(column_reference, column)| (column_reference.poly_id, column))
}

fn collect_helper_columns<T>(
bus_interaction: &PhantomBusInteractionIdentity<T>,
helper: Vec<Vec<T>>,
) -> impl Iterator<Item = (PolyID, Vec<T>)> + '_ {
bus_interaction
.helper_columns
.iter()
.zip_eq(helper)
.map(|(column_reference, column)| (column_reference.poly_id, column))
}
20 changes: 12 additions & 8 deletions std/protocols/bus.asm
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ let bus_multi_interaction_2: expr[], expr[][], expr[], expr[] -> () = constr |id

// Create helper columns to bound degree to 3 for arbitrary number of bus interactions.
// Each helper processes two bus interactions:
// helper_i = multiplicity_{2*i}' / folded_{2*i}' + multiplicity_{2*i+1}' / folded_{2*i+1}'
// helper_i = multiplicity_{2*i} / folded_{2*i} + multiplicity_{2*i+1} / folded_{2*i+1}
// Or equivalently when expanded:
// folded_{2*i}' * folded_{2*i+1}' * helper_i - folded_{2*i+1}' * multiplicity_{2*i}' - folded_{2*i}' * multiplicity_{2*i+1}' = 0
// folded_{2*i} * folded_{2*i+1}' * helper_i - folded_{2*i+1} * multiplicity_{2*i} - folded_{2*i} * multiplicity_{2*i+1} = 0
let helper_arr: expr[][] = array::new(
input_len / 2,
|helper|
Expand All @@ -108,28 +108,32 @@ let bus_multi_interaction_2: expr[], expr[][], expr[], expr[] -> () = constr |id
helper_arr,
|helper| from_array(helper)
);
let helper_ext_next_arr = array::map(
helper_ext_arr,
|helper_ext| next_ext(helper_ext)
);
// The expression to constrain.
let helper_expr_arr = array::new( // Ext<expr>[]
input_len / 2,
|i| sub_ext(
sub_ext(
mul_ext(
mul_ext(folded_next_arr[2 * i], folded_next_arr[2 * i + 1]),
mul_ext(folded_arr[2 * i], folded_arr[2 * i + 1]),
helper_ext_arr[i]
),
mul_ext(folded_next_arr[2 * i + 1], m_ext_next_arr[2 * i])
mul_ext(folded_arr[2 * i + 1], m_ext_arr[2 * i])
),
mul_ext(folded_next_arr[2 * i], m_ext_next_arr[2 * i + 1])
mul_ext(folded_arr[2 * i], m_ext_arr[2 * i + 1])
)
);
// Return a flattened array of constraints. (Must use `array::fold` or the compiler won't allow nested Constr[][].)
array::fold(helper_expr_arr, [], |init, helper_expr| constrain_eq_ext(helper_expr, from_base(0)));

// Update rule:
// acc' = acc * (1 - is_first') + helper_0 + helper_1 + ...
// acc' = acc * (1 - is_first') + helper_0' + helper_1' + ...
// Add up all helper columns.
// Or equivalently:
// acc * (1 - is_first') + helper_0 + helper_1 + ... - acc' = 0
// acc * (1 - is_first') + helper_0' + helper_1' + ... - acc' = 0
let update_expr =
sub_ext(
add_ext(
Expand All @@ -138,7 +142,7 @@ let bus_multi_interaction_2: expr[], expr[][], expr[], expr[] -> () = constr |id
sub_ext(from_base(1), is_first_next)
),
// Sum of all helper columns.
array::fold(helper_ext_arr, from_base(0), |sum, helper_ext| add_ext(sum, helper_ext))
array::fold(helper_ext_next_arr, from_base(0), |sum, helper_ext_next| add_ext(sum, helper_ext_next))
),
next_acc
);
Expand Down

0 comments on commit 6998967

Please sign in to comment.