diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 8a1b2e2e13..c4e8da4385 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -58,7 +58,8 @@ pub struct PowdrEval { // the pre-processed are indexed in the whole proof, instead of in each component. // this offset represents the index of the first pre-processed column in this component preprocess_col_offset: usize, - // for each stage, for each public input of that stage, the name, the column name, the poly_id, the row index + // for each stage, for each public input of that stage, the name of the public, + // the name of the witness column that this public is related to, the poly_id, the row index and its value pub(crate) publics_by_stage: Vec>, stage0_witness_columns: BTreeMap, stage1_witness_columns: BTreeMap, @@ -67,6 +68,7 @@ pub struct PowdrEval { // stwo supports maximum 2 stages, challenges are only created after stage 0 pub challenges: BTreeMap, poly_stage_map: BTreeMap, + public_values: BTreeMap, } impl PowdrEval { @@ -75,6 +77,7 @@ impl PowdrEval { preprocess_col_offset: usize, log_degree: u32, challenges: BTreeMap, + public_values: BTreeMap, ) -> Self { let stage0_witness_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Committed) @@ -108,7 +111,7 @@ impl PowdrEval { .enumerate() .map(|(index, (_, id))| (id, index)) .collect(); - + // TODO:maybe only need in the prove function, before creating PowdrEval let publics_by_stage = analyzed.get_publics().into_iter().fold( vec![vec![]; analyzed.stage_count()], |mut acc, (name, column_name, id, row, stage)| { @@ -134,6 +137,7 @@ impl PowdrEval { constant_columns, challenges, poly_stage_map, + public_values, } } } @@ -264,14 +268,14 @@ impl FrameworkEval for PowdrEval { + self.preprocess_col_offset + constant_shifted_eval.len(), )); - let stage= self.poly_stage_map[poly_id]; - let witness_col= match stage { - 0 => self.stage0_witness_columns[poly_id], - 1 => self.stage1_witness_columns[poly_id], + let stage = self.poly_stage_map[poly_id]; + let witness_col = match stage { + 0 => stage0_witness_eval[poly_id][0].clone(), + 1 => stage1_witness_eval[poly_id][0].clone(), _ => unreachable!(), }; // constraining s(i) * (pub[i] - x(i)) = 0 - // eval.add_constraint(selector * (public_value.into() - witness_col)); + eval.add_constraint(selector * (E::F::from(into_stwo_field(self.public_values.get(name).unwrap()))) - witness_col); }, ); diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index e470dc6092..e933775c5c 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -275,6 +275,14 @@ where ) }) .collect(); + // TODO:in polonky3, this is built in ConstraintSystem, which is machine-specific, check if it's the same here + let publics_by_stage = self.analyzed.get_publics().into_iter().fold( + vec![vec![]; self.analyzed.stage_count()], + |mut acc, (name, column_name, id, row, stage)| { + acc[stage as usize].push((name, column_name, id, row)); + acc + }, + ); // Generate witness for stage 0, build constant columns in circle domain at the same time let mut machine_log_sizes: BTreeMap = BTreeMap::new(); @@ -318,6 +326,22 @@ where }) .collect::>(); + // get publics of stage0 + // TODO:when publics are supplied, the order of the witness might cause problem + let mut public_values: BTreeMap = witness_by_machine + .iter() + .flat_map(|(_, witness_cols)| { + publics_by_stage[0] + .iter() + .filter_map(|(name, ref_witness_col_name, _, row)| { + witness_cols.iter().find_map(|(witness_col_name, v)| { + (ref_witness_col_name == witness_col_name) + .then(|| (name.clone(), v[*row])) + }) + }) + }) + .collect(); + // Get witness columns in circle domain for stage 0 let stage0_witness_cols_circle_domain_eval: ColumnVec< CircleEvaluation, @@ -365,6 +389,46 @@ where if self.analyzed.stage_count() > 1 { // Build witness columns for stage 1 using the callback function, with the generated challenges + + let stage0_witness_names_stage1_witness = witness_by_machine + .iter() + .map(|(machine_name, machine_witness)| { + ( + machine_witness + .iter() + .map(|(k, _)| k.clone()) + .collect::>(), + witgen_callback.next_stage_witness( + &self.split[&machine_name.clone()], + &machine_witness, + stage0_challenges.clone(), + 1, + ), + ) + }) + .collect_vec(); + + // TODO: previous publics are built with the order in publics_by_stage (find map from witness machine, that matches publics by stage)) + // here is with the order in witness_by_machine, (find map from publics by stage that matach the witness. ) + // if the orders are different, the publics will be wrong, check + let public_values_stage1: BTreeMap = stage0_witness_names_stage1_witness + .iter() + .flat_map(|(stage0_columns, callback_result)| { + callback_result.iter().filter_map(|(witness_name, vec)| { + if stage0_columns.contains(witness_name) { + None + } else { + publics_by_stage[1].iter().find_map( + |(name, ref_witness_col_name, _, row)| { + (witness_name == ref_witness_col_name) + .then(|| (name.clone(), vec[*row])) + }, + ) + } + }) + }) + .collect(); + let stage1_witness_cols_circle_domain_eval = witness_by_machine .into_iter() .map(|(machine_name, machine_witness)| { @@ -402,6 +466,7 @@ where let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals(stage1_witness_cols_circle_domain_eval); tree_builder.commit(prover_channel); + public_values.extend(public_values_stage1); } let tree_span_provider = &mut TraceLocationAllocator::default(); @@ -423,6 +488,7 @@ where constant_cols_offset_acc, machine_log_size, stage0_challenges.clone(), + public_values.clone(), ), (SecureField::zero(), None), ); @@ -497,11 +563,11 @@ where .get_publics() .iter() .zip_eq(instances.iter()) - .map(|((_, poly_name, _, _, stage), value)| { + .map(|((public_name, poly_name, _, _, stage), value)| { let namespace = poly_name.split("::").next().unwrap(); - (namespace, stage, value) + (public_name,namespace, stage, value) }) - .for_each(|(namespace, stage, value)| { + .for_each(|(public_name,namespace, stage, value)| { instance_map.get_mut(namespace).unwrap()[*stage as usize].push(*value); });