diff --git a/backend/src/stwo/circuit_builder.rs b/backend/src/stwo/circuit_builder.rs index 2020328a21..85fc94d22f 100644 --- a/backend/src/stwo/circuit_builder.rs +++ b/backend/src/stwo/circuit_builder.rs @@ -2,7 +2,6 @@ use core::unreachable; use powdr_ast::parsed::visitor::AllChildren; use powdr_executor_utils::expression_evaluator::{ExpressionEvaluator, TerminalAccess}; use std::collections::HashSet; -use std::sync::Arc; extern crate alloc; use alloc::collections::btree_map::BTreeMap; @@ -53,14 +52,18 @@ where } pub struct PowdrEval { - analyzed: Arc>, + log_degree: u32, + analyzed: Analyzed, + // the pre-processed are indexed in the whole proof, instead of in each component. + // this offset represents the index of the first pre-processed column in this component + preprocess_col_offset: usize, witness_columns: BTreeMap, constant_shifted: BTreeMap, constant_columns: BTreeMap, } impl PowdrEval { - pub fn new(analyzed: Arc>) -> Self { + pub fn new(analyzed: Analyzed, preprocess_col_offset: usize, log_degree: u32) -> Self { let witness_columns: BTreeMap = analyzed .definitions_in_source_order(PolynomialType::Committed) .flat_map(|(symbol, _)| symbol.array_elements()) @@ -86,7 +89,9 @@ impl PowdrEval { .collect(); Self { + log_degree, analyzed, + preprocess_col_offset, witness_columns, constant_shifted, constant_columns, @@ -126,10 +131,10 @@ impl TerminalAccess for &Data<'_, F> { impl FrameworkEval for PowdrEval { fn log_size(&self) -> u32 { - self.analyzed.degree().ilog2() + self.log_degree } fn max_constraint_log_degree_bound(&self) -> u32 { - self.analyzed.degree().ilog2() + 1 + self.log_degree + 1 } fn evaluate(&self, mut eval: E) -> E { assert!( @@ -155,8 +160,9 @@ impl FrameworkEval for PowdrEval { .map(|(i, poly_id)| { ( *poly_id, - // PreprocessedColumn::Plonk(i) is unused argument in get_preprocessed_column - eval.get_preprocessed_column(PreprocessedColumn::Plonk(i)), + eval.get_preprocessed_column(PreprocessedColumn::Plonk( + i + self.preprocess_col_offset, + )), ) }) .collect(); @@ -169,7 +175,7 @@ impl FrameworkEval for PowdrEval { ( *poly_id, eval.get_preprocessed_column(PreprocessedColumn::Plonk( - i + constant_eval.len(), + i + constant_eval.len() + self.preprocess_col_offset, )), ) }) diff --git a/backend/src/stwo/mod.rs b/backend/src/stwo/mod.rs index 14630dfc08..3be79b1f8a 100644 --- a/backend/src/stwo/mod.rs +++ b/backend/src/stwo/mod.rs @@ -39,10 +39,6 @@ impl BackendFactory for RestrictedFactory { return Err(Error::BackendError("Proving key unused".to_string())); } - if pil.degrees().len() > 1 { - return Err(Error::NoVariableDegreeAvailable); - } - let mut stwo: Box> = Box::new(StwoProver::new(pil, fixed)?); diff --git a/backend/src/stwo/proof.rs b/backend/src/stwo/proof.rs index 8177511f97..3d9f9fa64d 100644 --- a/backend/src/stwo/proof.rs +++ b/backend/src/stwo/proof.rs @@ -1,13 +1,15 @@ -use serde::Deserialize; -use serde::Serialize; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use stwo_prover::core::backend::Backend; use stwo_prover::core::backend::Column; use stwo_prover::core::backend::ColumnOps; +use stwo_prover::core::channel::MerkleChannel; use stwo_prover::core::fields::m31::BaseField; use stwo_prover::core::fields::m31::M31; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::prover::StarkProof; use stwo_prover::core::ColumnVec; /// For each possible size, the commitment and prover data @@ -124,3 +126,14 @@ impl From> for SerializableStarkProvingKey { Self { preprocessed } } } + +#[derive(Serialize, Deserialize)] +pub struct Proof +where + MC::H: DeserializeOwned + Serialize, +{ + pub stark_proof: StarkProof, + pub constant_col_log_sizes: Vec, + pub witness_col_log_sizes: Vec, + pub machine_log_sizes: Vec, +} diff --git a/backend/src/stwo/prover.rs b/backend/src/stwo/prover.rs index c4e9d7bdb5..3affb7c817 100644 --- a/backend/src/stwo/prover.rs +++ b/backend/src/stwo/prover.rs @@ -1,6 +1,7 @@ +use itertools::Itertools; use num_traits::Zero; -use powdr_ast::analyzed::Analyzed; -use powdr_backend_utils::machine_fixed_columns; +use powdr_ast::analyzed::{Analyzed, DegreeRange}; +use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns}; use powdr_executor::constant_evaluator::VariablySizedColumn; use powdr_number::FieldElement; use serde::de::DeserializeOwned; @@ -14,23 +15,21 @@ use crate::stwo::circuit_builder::{ gen_stwo_circle_column, get_constant_with_next_list, PowdrComponent, PowdrEval, }; use crate::stwo::proof::{ - SerializableStarkProvingKey, StarkProvingKey, TableProvingKey, TableProvingKeyCollection, + Proof, SerializableStarkProvingKey, StarkProvingKey, TableProvingKey, TableProvingKeyCollection, }; use stwo_prover::constraint_framework::{ TraceLocationAllocator, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX, }; -use stwo_prover::core::prover::StarkProof; use stwo_prover::core::air::{Component, ComponentProver}; -use stwo_prover::core::backend::{Backend, BackendForChannel}; +use stwo_prover::core::backend::{Backend, BackendForChannel, Column}; use stwo_prover::core::channel::{Channel, MerkleChannel}; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::fields::qm31::SecureField; use stwo_prover::core::fri::FriConfig; use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation}; -use stwo_prover::core::poly::twiddles::TwiddleTree; use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::ColumnVec; @@ -114,15 +113,28 @@ where } pub fn setup(&mut self) { - // machines with varying sizes are not supported yet, and it is checked in backendfactory create function. - //TODO: support machines with varying sizes - let domain_map: BTreeMap = self - .analyzed - .degrees() + let domain_degree_range = DegreeRange { + min: self + .analyzed + .degree_ranges() + .iter() + .map(|range| range.min) + .min() + .unwrap(), + max: self + .analyzed + .degree_ranges() + .iter() + .map(|range| range.max) + .max() + .unwrap(), + }; + + let domain_map: BTreeMap = domain_degree_range .iter() .map(|size| { ( - (size.ilog2() as usize), + size.ilog2() as usize, CanonicCoset::new(size.ilog2()).circle_domain(), ) }) @@ -145,19 +157,17 @@ where .unwrap() .iter() .map(|size| { + //Group the fixed columns by size + let fixed_columns = &fixed_columns[&size]; let mut constant_trace: ColumnVec< CircleEvaluation, > = fixed_columns - .values() - .flat_map(|vec| { - vec.iter().map(|(_name, values)| { - gen_stwo_circle_column::( - *domain_map - .get(&(values.len().ilog2() as usize)) - .unwrap(), - values, - ) - }) + .iter() + .map(|(_, vec)| { + gen_stwo_circle_column::( + *domain_map.get(&(vec.len().ilog2() as usize)).unwrap(), + vec, + ) }) .collect(); @@ -166,22 +176,17 @@ where let constant_shifted_trace: ColumnVec< CircleEvaluation, > = fixed_columns - .values() - .flat_map(|vec| { - vec.iter() - .filter(|(name, _)| { - constant_with_next_list.contains(name) - }) - .map(|(_, values)| { - let mut rotated_values = values.to_vec(); - rotated_values.rotate_left(1); - gen_stwo_circle_column::( - *domain_map - .get(&(values.len().ilog2() as usize)) - .unwrap(), - &rotated_values, - ) - }) + .iter() + .filter(|(name, _)| constant_with_next_list.contains(name)) + .map(|(_, values)| { + let mut rotated_values = values.to_vec(); + rotated_values.rotate_left(1); + gen_stwo_circle_column::( + *domain_map + .get(&(values.len().ilog2() as usize)) + .unwrap(), + &rotated_values, + ) }) .collect(); @@ -206,106 +211,165 @@ where } pub fn prove(&self, witness: &[(String, Vec)]) -> Result, String> { - assert!( - witness + let config = get_config(); + let domain_degree_range = DegreeRange { + min: self + .analyzed + .degree_ranges() .iter() - .all(|(_name, vec)| vec.len() == witness[0].1.len()), - "All Vec in witness must have the same length. Mismatch found!" - ); + .map(|range| range.min) + .min() + .unwrap(), + max: self + .analyzed + .degree_ranges() + .iter() + .map(|range| range.max) + .max() + .unwrap(), + }; - let config = get_config(); - let domain_map: BTreeMap = self - .analyzed - .degrees() + let domain_map: BTreeMap = domain_degree_range .iter() .map(|size| { ( - (size.ilog2() as usize), + size.ilog2() as usize, CanonicCoset::new(size.ilog2()).circle_domain(), ) }) .collect(); - let twiddles_map: BTreeMap> = self + + let tree_span_provider = &mut TraceLocationAllocator::default(); + //Each column size in machines needs its own component, the components from different machines are stored in this vector + let mut components = Vec::new(); + + //The preprocessed columns needs to be indexed in the whole execution instead of each machine, so we need to keep track of the offset + let mut constant_cols_offset_acc = 0; + let mut machine_log_sizes = Vec::new(); + + let mut constant_cols = Vec::new(); + + let witness_by_machine: ColumnVec> = self .split - .values() - .flat_map(|pil| { - // Precompute twiddles for all sizes in the PIL - pil.committed_polys_in_source_order() - .flat_map(|(s, _)| { - s.degree.iter().flat_map(|range| { - range - .iter() - .filter(|&size| size.is_power_of_two()) - .map(|size| { - let twiddles = B::precompute_twiddles( - CanonicCoset::new(size.ilog2() + 1 + FRI_LOG_BLOWUP as u32) - .circle_domain() - .half_coset, - ); - (size as usize, twiddles) - }) - .collect::>() + .iter() + .filter_map(|(machine, pil)| { + let witness_columns = machine_witness_columns(witness, pil, machine); + if witness_columns[0].1.is_empty() { + //TODO: Empty machines can be removed entirely, but in verification it is not removed, need to be handled + None + } else { + let witness_by_machine = machine_witness_columns(witness, pil, machine); + let machine_length = witness_by_machine[0].1.len(); + assert!( + witness_by_machine + .iter() + .all(|(_, vec)| vec.len() == machine_length), + "All witness columns in a single machine must have the same length" + ); + + if let Some(constant_trace) = self + .proving_key + .preprocessed + .as_ref() + .and_then(|preprocessed| preprocessed.get(machine)) + .and_then(|table_provingkey| table_provingkey.get(&machine_length)) + .map(|table_provingkey_machine_size| { + table_provingkey_machine_size + .constant_trace_circle_domain + .clone() }) - }) - .collect::>() + { + constant_cols.extend(constant_trace) + } + + let component = PowdrComponent::new( + tree_span_provider, + PowdrEval::new( + (*pil).clone(), + constant_cols_offset_acc, + machine_length.ilog2(), + ), + (SecureField::zero(), None), + ); + components.push(component); + + machine_log_sizes.push(machine_length.ilog2()); + + constant_cols_offset_acc += + pil.constant_count() + get_constant_with_next_list(pil).len(); + + Some( + witness_by_machine + .into_iter() + .map(|(_name, vec)| { + gen_stwo_circle_column::( + *domain_map + .get(&(machine_length.ilog2() as usize)) + .expect("Domain not found for given size"), + &vec, + ) + }) + .collect::>(), + ) + } }) + .flatten() .collect(); - // only the first one is used, machines with varying sizes are not supported yet, and it is checked in backendfactory create function. + + let constant_col_log_sizes = constant_cols + .iter() + .map(|eval| eval.len().ilog2()) + .collect::>(); + + let witness_col_log_sizes = witness_by_machine + .iter() + .map(|eval| eval.len().ilog2()) + .collect::>(); + + let twiddles_max_degree = B::precompute_twiddles( + CanonicCoset::new(domain_degree_range.max.ilog2() + 1 + FRI_LOG_BLOWUP as u32) + .circle_domain() + .half_coset, + ); + let prover_channel = &mut ::C::default(); let mut commitment_scheme = - CommitmentSchemeProver::<'_, B, MC>::new(config, twiddles_map.iter().next().unwrap().1); + CommitmentSchemeProver::<'_, B, MC>::new(config, &twiddles_max_degree); let mut tree_builder = commitment_scheme.tree_builder(); - //commit to the constant and shifted constant polynomials - if let Some((_, table_proving_key)) = - self.proving_key - .preprocessed - .as_ref() - .and_then(|preprocessed| { - preprocessed - .iter() - .find_map(|(_, table_collection)| table_collection.iter().next()) - }) - { - tree_builder.extend_evals(table_proving_key.constant_trace_circle_domain.clone()); - } else { - tree_builder.extend_evals([]); - } - tree_builder.commit(prover_channel); + tree_builder.extend_evals(constant_cols); - let trace: ColumnVec> = witness - .iter() - .map(|(_name, values)| { - gen_stwo_circle_column::( - *domain_map.get(&(values.len().ilog2() as usize)).unwrap(), - values, - ) - }) - .collect(); + tree_builder.commit(prover_channel); let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(trace); + tree_builder.extend_evals(witness_by_machine); tree_builder.commit(prover_channel); - let component = PowdrComponent::new( - &mut TraceLocationAllocator::default(), - PowdrEval::new(self.analyzed.clone()), - // This parameter is used for the logup functionality. If logup is not required, this default value should be passed. - (SecureField::zero(), None), - ); + let mut components_slice: Vec<&dyn ComponentProver> = components + .iter_mut() + .map(|component| component as &dyn ComponentProver) + .collect(); + + let components_slice = components_slice.as_mut_slice(); let proof_result = stwo_prover::core::prover::prove::( - &[&component], + components_slice, prover_channel, commitment_scheme, ); - let proof = match proof_result { + let stark_proof = match proof_result { Ok(value) => value, Err(e) => return Err(e.to_string()), // Propagate the error instead of panicking }; + let proof: Proof = Proof { + stark_proof, + constant_col_log_sizes, + witness_col_log_sizes, + machine_log_sizes, + }; Ok(bincode::serialize(&proof).unwrap()) } @@ -317,36 +381,60 @@ where ); let config = get_config(); - let proof: StarkProof = + + let proof: Proof = bincode::deserialize(proof).map_err(|e| format!("Failed to deserialize proof: {e}"))?; let verifier_channel = &mut ::C::default(); let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); //Constraints that are to be proved - let component = PowdrComponent::new( - &mut TraceLocationAllocator::default(), - PowdrEval::new(self.analyzed.clone()), - (SecureField::zero(), None), - ); - // Retrieve the expected column sizes in each commitment interaction, from the AIR. - // the sizes include the degrees of the constant, witness, native lookups. Native lookups are not used yet. - let sizes = component.trace_log_degree_bounds(); + let tree_span_provider = &mut TraceLocationAllocator::default(); + + let mut constant_cols_offset_acc = 0; + + let mut components = self + .split + .iter() + .zip_eq(proof.machine_log_sizes.iter()) + .map(|((_, pil), &machine_size)| { + constant_cols_offset_acc += pil.constant_count(); + + constant_cols_offset_acc += get_constant_with_next_list(pil).len(); + PowdrComponent::new( + tree_span_provider, + PowdrEval::new((*pil).clone(), constant_cols_offset_acc, machine_size), + (SecureField::zero(), None), + ) + }) + .collect::>(); + + let mut components_slice: Vec<&dyn Component> = components + .iter_mut() + .map(|component| component as &dyn Component) + .collect(); + + let components_slice = components_slice.as_mut_slice(); commitment_scheme.commit( - proof.commitments[PREPROCESSED_TRACE_IDX], - &sizes[PREPROCESSED_TRACE_IDX], + proof.stark_proof.commitments[PREPROCESSED_TRACE_IDX], + &proof.constant_col_log_sizes, verifier_channel, ); commitment_scheme.commit( - proof.commitments[ORIGINAL_TRACE_IDX], - &sizes[ORIGINAL_TRACE_IDX], + proof.stark_proof.commitments[ORIGINAL_TRACE_IDX], + &proof.witness_col_log_sizes, verifier_channel, ); - stwo_prover::core::prover::verify(&[&component], verifier_channel, commitment_scheme, proof) - .map_err(|e| e.to_string()) + stwo_prover::core::prover::verify( + components_slice, + verifier_channel, + commitment_scheme, + proof.stark_proof, + ) + .map_err(|e| e.to_string()) } } diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index 59a553ff80..500e0b645b 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -293,7 +293,6 @@ fn stwo_constant_next_test() { let f = "pil/fixed_with_incremental.pil"; test_stwo(f, Default::default()); } - #[test] fn fibonacci_invalid_witness_stwo() { let f = "pil/fibo_no_publics.pil"; @@ -389,6 +388,7 @@ fn different_degrees() { // Because machines have different lengths, this can only be proven // with a composite proof. regular_test_gl(f, Default::default()); + test_stwo(f, Default::default()); } #[test]