Skip to content

Commit

Permalink
add constraints in circuit builder
Browse files Browse the repository at this point in the history
  • Loading branch information
ShuangWu121 committed Feb 6, 2025
1 parent 5cbf0ce commit dc6c86d
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 10 deletions.
18 changes: 11 additions & 7 deletions backend/src/stwo/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ pub struct PowdrEval {
// the pre-processed are indexed in the whole proof, instead of in each component.
// this offset represents the index of the first pre-processed column in this component
preprocess_col_offset: usize,
// for each stage, for each public input of that stage, the name, the column name, the poly_id, the row index
// for each stage, for each public input of that stage, the name of the public,
// the name of the witness column that this public is related to, the poly_id, the row index and its value
pub(crate) publics_by_stage: Vec<Vec<(String, String, PolyID, usize)>>,
stage0_witness_columns: BTreeMap<PolyID, usize>,
stage1_witness_columns: BTreeMap<PolyID, usize>,
Expand All @@ -67,6 +68,7 @@ pub struct PowdrEval {
// stwo supports maximum 2 stages, challenges are only created after stage 0
pub challenges: BTreeMap<u64, M31>,
poly_stage_map: BTreeMap<PolyID, usize>,
public_values: BTreeMap<String, M31>,
}

impl PowdrEval {
Expand All @@ -75,6 +77,7 @@ impl PowdrEval {
preprocess_col_offset: usize,
log_degree: u32,
challenges: BTreeMap<u64, M31>,
public_values: BTreeMap<String, M31>,
) -> Self {
let stage0_witness_columns: BTreeMap<PolyID, usize> = analyzed
.definitions_in_source_order(PolynomialType::Committed)
Expand Down Expand Up @@ -108,7 +111,7 @@ impl PowdrEval {
.enumerate()
.map(|(index, (_, id))| (id, index))
.collect();

// TODO:maybe only need in the prove function, before creating PowdrEval
let publics_by_stage = analyzed.get_publics().into_iter().fold(
vec![vec![]; analyzed.stage_count()],
|mut acc, (name, column_name, id, row, stage)| {
Expand All @@ -134,6 +137,7 @@ impl PowdrEval {
constant_columns,
challenges,
poly_stage_map,
public_values,
}
}
}
Expand Down Expand Up @@ -264,14 +268,14 @@ impl FrameworkEval for PowdrEval {
+ self.preprocess_col_offset
+ constant_shifted_eval.len(),
));
let stage= self.poly_stage_map[poly_id];
let witness_col= match stage {
0 => self.stage0_witness_columns[poly_id],
1 => self.stage1_witness_columns[poly_id],
let stage = self.poly_stage_map[poly_id];
let witness_col = match stage {
0 => stage0_witness_eval[poly_id][0].clone(),
1 => stage1_witness_eval[poly_id][0].clone(),
_ => unreachable!(),
};
// constraining s(i) * (pub[i] - x(i)) = 0
// eval.add_constraint(selector * (public_value.into() - witness_col));
eval.add_constraint(selector * (E::F::from(into_stwo_field(self.public_values.get(name).unwrap()))) - witness_col);
},
);

Expand Down
72 changes: 69 additions & 3 deletions backend/src/stwo/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,14 @@ where
)
})
.collect();
// TODO:in polonky3, this is built in ConstraintSystem, which is machine-specific, check if it's the same here
let publics_by_stage = self.analyzed.get_publics().into_iter().fold(
vec![vec![]; self.analyzed.stage_count()],
|mut acc, (name, column_name, id, row, stage)| {
acc[stage as usize].push((name, column_name, id, row));
acc
},
);

// Generate witness for stage 0, build constant columns in circle domain at the same time
let mut machine_log_sizes: BTreeMap<String, u32> = BTreeMap::new();
Expand Down Expand Up @@ -318,6 +326,22 @@ where
})
.collect::<BTreeMap<_, _>>();

// get publics of stage0
// TODO:when publics are supplied, the order of the witness might cause problem
let mut public_values: BTreeMap<String, M31> = witness_by_machine
.iter()
.flat_map(|(_, witness_cols)| {
publics_by_stage[0]
.iter()
.filter_map(|(name, ref_witness_col_name, _, row)| {
witness_cols.iter().find_map(|(witness_col_name, v)| {
(ref_witness_col_name == witness_col_name)
.then(|| (name.clone(), v[*row]))
})
})
})
.collect();

// Get witness columns in circle domain for stage 0
let stage0_witness_cols_circle_domain_eval: ColumnVec<
CircleEvaluation<B, BaseField, BitReversedOrder>,
Expand Down Expand Up @@ -365,6 +389,46 @@ where

if self.analyzed.stage_count() > 1 {
// Build witness columns for stage 1 using the callback function, with the generated challenges

let stage0_witness_names_stage1_witness = witness_by_machine
.iter()
.map(|(machine_name, machine_witness)| {
(
machine_witness
.iter()
.map(|(k, _)| k.clone())
.collect::<BTreeSet<_>>(),
witgen_callback.next_stage_witness(
&self.split[&machine_name.clone()],
&machine_witness,
stage0_challenges.clone(),
1,
),
)
})
.collect_vec();

// TODO: previous publics are built with the order in publics_by_stage (find map from witness machine, that matches publics by stage))
// here is with the order in witness_by_machine, (find map from publics by stage that matach the witness. )
// if the orders are different, the publics will be wrong, check
let public_values_stage1: BTreeMap<String, M31> = stage0_witness_names_stage1_witness
.iter()
.flat_map(|(stage0_columns, callback_result)| {
callback_result.iter().filter_map(|(witness_name, vec)| {
if stage0_columns.contains(witness_name) {
None
} else {
publics_by_stage[1].iter().find_map(
|(name, ref_witness_col_name, _, row)| {
(witness_name == ref_witness_col_name)
.then(|| (name.clone(), vec[*row]))
},
)
}
})
})
.collect();

let stage1_witness_cols_circle_domain_eval = witness_by_machine
.into_iter()
.map(|(machine_name, machine_witness)| {
Expand Down Expand Up @@ -402,6 +466,7 @@ where
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(stage1_witness_cols_circle_domain_eval);
tree_builder.commit(prover_channel);
public_values.extend(public_values_stage1);
}

let tree_span_provider = &mut TraceLocationAllocator::default();
Expand All @@ -423,6 +488,7 @@ where
constant_cols_offset_acc,
machine_log_size,
stage0_challenges.clone(),
public_values.clone(),
),
(SecureField::zero(), None),
);
Expand Down Expand Up @@ -497,11 +563,11 @@ where
.get_publics()
.iter()
.zip_eq(instances.iter())
.map(|((_, poly_name, _, _, stage), value)| {
.map(|((public_name, poly_name, _, _, stage), value)| {
let namespace = poly_name.split("::").next().unwrap();
(namespace, stage, value)
(public_name,namespace, stage, value)
})
.for_each(|(namespace, stage, value)| {
.for_each(|(public_name,namespace, stage, value)| {
instance_map.get_mut(namespace).unwrap()[*stage as usize].push(*value);
});

Expand Down

0 comments on commit dc6c86d

Please sign in to comment.