Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stwo publics #2438

Merged
merged 67 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
db7572b
No sound challenge, function draft
ShuangWu121 Jan 21, 2025
ed8572c
basic function works, need to be simplified.
ShuangWu121 Jan 21, 2025
301811a
clean up a bit
ShuangWu121 Jan 22, 2025
ef60cdf
optimization
ShuangWu121 Jan 22, 2025
dcf7ad9
fix lint
ShuangWu121 Jan 22, 2025
a72df0a
fix test fail, no need to use witgen_callback if only 1 stage
ShuangWu121 Jan 22, 2025
a804766
add challenges to test
ShuangWu121 Jan 22, 2025
ead9c7a
add challenge.pil test file
ShuangWu121 Jan 22, 2025
59b34ee
polish
ShuangWu121 Jan 22, 2025
c920790
Update backend/src/stwo/circuit_builder.rs
ShuangWu121 Jan 23, 2025
0d2d2df
Update backend/src/stwo/circuit_builder.rs
ShuangWu121 Jan 23, 2025
ce51ac4
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 23, 2025
6213276
add field conversation, remove generic type, just use Mersene31
ShuangWu121 Jan 24, 2025
491779a
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 24, 2025
a403c78
avoid clone in building witness evals for stage 1, by using options.
ShuangWu121 Jan 24, 2025
9a744c3
Update test_data/pil/challenges.pil
ShuangWu121 Jan 24, 2025
363ac77
use draw_felt to draw challenges
ShuangWu121 Jan 26, 2025
386dd5e
optimize more
ShuangWu121 Jan 27, 2025
fba030c
fix the order of stage1 witness cols
ShuangWu121 Jan 27, 2025
2749358
clean up
ShuangWu121 Jan 27, 2025
3e648a8
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 27, 2025
bac5be0
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 27, 2025
dd392e3
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 27, 2025
3004711
add comments
ShuangWu121 Jan 27, 2025
244dfc0
remove 'a
ShuangWu121 Jan 28, 2025
4bf55ea
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 28, 2025
6053620
Update backend/src/stwo/params.rs
ShuangWu121 Jan 28, 2025
39ab885
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 28, 2025
0a648b6
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 28, 2025
2c71935
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 28, 2025
da4d041
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 28, 2025
d41cb46
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 28, 2025
f6e2fb3
Update backend/src/stwo/prover.rs
ShuangWu121 Jan 28, 2025
92d948e
modify param.rs
ShuangWu121 Jan 28, 2025
6305ee2
polish
ShuangWu121 Jan 28, 2025
babe6b8
simplify param.rs
ShuangWu121 Jan 28, 2025
e08a01f
simplify challenge draw to be lazily
ShuangWu121 Jan 28, 2025
77debd7
introduce type T=mersene31
ShuangWu121 Jan 28, 2025
220e986
split witness cols in different stage in circuit build
ShuangWu121 Jan 31, 2025
180adee
split stage 0 and stage 1 witness in prover
ShuangWu121 Jan 31, 2025
f7ea58a
Merge remote-tracking branch 'origin/main' into stwo-challenges-Sound
ShuangWu121 Jan 31, 2025
bd257cf
split stage witnesses works
ShuangWu121 Jan 31, 2025
2939e6f
sound challenge works
ShuangWu121 Feb 2, 2025
b6cebb4
simplification, adding more test cases
ShuangWu121 Feb 3, 2025
997cd58
polish, add more tests
ShuangWu121 Feb 3, 2025
9218cc8
fix comments
ShuangWu121 Feb 3, 2025
9c43930
split witness eval of stage0 and stage1 in terminall access
ShuangWu121 Feb 3, 2025
74d2de0
fix comment
ShuangWu121 Feb 3, 2025
6e51a7c
move poly stage map to PowdrEval new function
ShuangWu121 Feb 4, 2025
2186ae1
start adding publics
ShuangWu121 Feb 4, 2025
defbb09
add more publics
ShuangWu121 Feb 5, 2025
bf2998a
Merge branch 'main' into stwo-publics
ShuangWu121 Feb 5, 2025
5cbf0ce
add public selector into setup phase
ShuangWu121 Feb 5, 2025
dc6c86d
add constraints in circuit builder
ShuangWu121 Feb 6, 2025
85e6df5
basic function works. test case needs to be updated. simplification n…
ShuangWu121 Feb 7, 2025
41df6e0
polish
ShuangWu121 Feb 7, 2025
f158647
polish, add one test
ShuangWu121 Feb 7, 2025
dfb9482
remove stage1_public test
ShuangWu121 Feb 8, 2025
c6385b4
fix comments
ShuangWu121 Feb 11, 2025
0400534
add test for stage 1 publics
ShuangWu121 Feb 11, 2025
7b8c51d
fix build
ShuangWu121 Feb 11, 2025
f48f842
fix build
ShuangWu121 Feb 11, 2025
04ee936
no stage (0) in pil, it casue problem that 0 not equal to None, in pi…
ShuangWu121 Feb 11, 2025
d69f3c2
simplification
ShuangWu121 Feb 11, 2025
b4f3b5b
add fail test
ShuangWu121 Feb 12, 2025
bfd5037
split faild test with valid test
ShuangWu121 Feb 12, 2025
7d85c19
Merge branch 'main' into stwo-publics
ShuangWu121 Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions backend/src/stwo/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub struct PowdrEval {
// the pre-processed are indexed in the whole proof, instead of in each component.
// this offset represents the index of the first pre-processed column in this component
preprocess_col_offset: usize,
// The name of the public, the poly-id of the witness poly that this public is related to, the public value
pub(crate) publics_values: Vec<(String, PolyID, M31)>,
stage0_witness_columns: BTreeMap<PolyID, usize>,
stage1_witness_columns: BTreeMap<PolyID, usize>,
constant_shifted: BTreeMap<PolyID, usize>,
Expand All @@ -73,6 +75,7 @@ impl PowdrEval {
preprocess_col_offset: usize,
log_degree: u32,
challenges: BTreeMap<u64, M31>,
public_values: BTreeMap<String, M31>,
) -> Self {
let stage0_witness_columns: BTreeMap<PolyID, usize> = analyzed
.definitions_in_source_order(PolynomialType::Committed)
Expand Down Expand Up @@ -107,6 +110,12 @@ impl PowdrEval {
.map(|(index, (_, id))| (id, index))
.collect();

let publics_values = analyzed
.get_publics()
.into_iter()
.map(|(name, _, id, _, _)| (name.clone(), id, *public_values.get(&name).unwrap()))
.collect();

let poly_stage_map: BTreeMap<PolyID, usize> = stage0_witness_columns
.keys()
.map(|k| (*k, 0))
Expand All @@ -117,6 +126,7 @@ impl PowdrEval {
log_degree,
analyzed,
preprocess_col_offset,
publics_values,
stage0_witness_columns,
stage1_witness_columns,
constant_shifted,
Expand All @@ -132,6 +142,7 @@ struct Data<'a, F> {
stage1_witness_eval: &'a BTreeMap<PolyID, [F; 2]>,
constant_shifted_eval: &'a BTreeMap<PolyID, F>,
constant_eval: &'a BTreeMap<PolyID, F>,
publics_values: &'a BTreeMap<String, F>,
// challenges for stage 1
challenges: &'a BTreeMap<u64, F>,
poly_stage_map: &'a BTreeMap<PolyID, usize>,
Expand All @@ -157,8 +168,11 @@ impl<F: Clone> TerminalAccess<F> for &Data<'_, F> {
}
}

fn get_public(&self, _public: &str) -> F {
unimplemented!("Public references are not supported in stwo yet")
fn get_public(&self, public: &str) -> F {
self.publics_values
.get(public)
.expect("Referenced public value does not exist")
.clone()
}

fn get_challenge(&self, challenge: &Challenge) -> F {
Expand All @@ -174,11 +188,6 @@ impl FrameworkEval for PowdrEval {
self.log_degree + 1
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
assert!(
self.analyzed.publics_count() == 0,
"Error: Expected no public inputs, as they are not supported yet.",
);

let stage0_witness_eval: BTreeMap<PolyID, [<E as EvalAtRow>::F; 2]> = self
.stage0_witness_columns
.keys()
Expand Down Expand Up @@ -235,14 +244,44 @@ impl FrameworkEval for PowdrEval {
.collect();

let intermediate_definitions = self.analyzed.intermediate_definitions();
let public_values_terminal = self
.publics_values
.iter()
.map(|(name, _, value)| (name.clone(), E::F::from(into_stwo_field(value))))
.collect();
let data = Data {
stage0_witness_eval: &stage0_witness_eval,
stage1_witness_eval: &stage1_witness_eval,
publics_values: &public_values_terminal,
constant_shifted_eval: &constant_shifted_eval,
constant_eval: &constant_eval,
challenges: &challenges,
poly_stage_map: &self.poly_stage_map,
};

// build selector columns and constraints for publics
self.publics_values
.iter()
.enumerate()
.for_each(|(index, (_, poly_id, value))| {
let selector = eval.get_preprocessed_column(PreprocessedColumn::Plonk(
index
+ constant_eval.len()
+ self.preprocess_col_offset
+ constant_shifted_eval.len(),
));

let stage = self.poly_stage_map[poly_id];
let witness_col = match stage {
0 => stage0_witness_eval[poly_id][0].clone(),
1 => stage1_witness_eval[poly_id][0].clone(),
_ => unreachable!(),
};

// constraining s(i) * (pub[i] - x(i)) = 0
eval.add_constraint(selector * (E::F::from(into_stwo_field(value)) - witness_col));
});

let mut evaluator =
ExpressionEvaluator::new_with_custom_expr(&data, &intermediate_definitions, |v| {
E::F::from(into_stwo_field(v))
Expand Down
157 changes: 126 additions & 31 deletions backend/src/stwo/prover.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use itertools::Itertools;
use num_traits::Zero;
use num_traits::{One, Zero};
use powdr_ast::analyzed::{AlgebraicExpression, Analyzed, DegreeRange};
use powdr_ast::parsed::visitor::AllChildren;
use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns};
Expand Down Expand Up @@ -30,14 +30,15 @@ use crate::stwo::proof::{
use stwo_prover::constraint_framework::TraceLocationAllocator;

use stwo_prover::core::air::{Component, ComponentProver};
use stwo_prover::core::backend::{Backend, BackendForChannel};
use stwo_prover::core::backend::{Backend, BackendForChannel, Col, Column};
use stwo_prover::core::channel::{Channel, MerkleChannel};
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::fields::qm31::SecureField;
use stwo_prover::core::fri::FriConfig;
use stwo_prover::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig};
use stwo_prover::core::poly::circle::{CanonicCoset, CircleDomain, CircleEvaluation};
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use stwo_prover::core::ColumnVec;

const FRI_LOG_BLOWUP: usize = 1;
Expand Down Expand Up @@ -68,7 +69,7 @@ pub struct StwoProver<B: BackendForChannel<MC> + Send, MC: MerkleChannel, C: Cha

/// Proving key
proving_key: StarkProvingKey<B>,
/// Verifying key placeholder
/// TODO: Add verification key.
_verifying_key: Option<()>,
_channel_marker: PhantomData<C>,
_merkle_channel_marker: PhantomData<MC>,
Expand Down Expand Up @@ -152,7 +153,7 @@ where
.iter()
.filter_map(|(namespace, pil)| {
// if we have no fixed columns, we don't need to commit to anything.
if pil.constant_count() == 0 {
if pil.constant_count() + pil.publics_count() == 0 {
None
} else {
let fixed_columns = machine_fixed_columns(&self.fixed, pil);
Expand All @@ -166,13 +167,14 @@ where
.map(|size| {
//Group the fixed columns by size
let fixed_columns = &fixed_columns[&size];
let log_size = size.ilog2();
let mut constant_trace: ColumnVec<
CircleEvaluation<B, BaseField, BitReversedOrder>,
> = fixed_columns
.iter()
.map(|(_, vec)| {
gen_stwo_circle_column::<_, BaseField>(
*domain_map.get(&(vec.len().ilog2() as usize)).unwrap(),
*domain_map.get(&(log_size as usize)).unwrap(),
vec,
)
})
Expand All @@ -189,16 +191,41 @@ where
let mut rotated_values = values.to_vec();
rotated_values.rotate_left(1);
gen_stwo_circle_column::<_, BaseField>(
*domain_map
.get(&(values.len().ilog2() as usize))
.unwrap(),
*domain_map.get(&(log_size as usize)).unwrap(),
&rotated_values,
)
})
.collect();

constant_trace.extend(constant_shifted_trace);

// get selector columns for the public inputs
let publics_selectors: ColumnVec<
CircleEvaluation<B, BaseField, BitReversedOrder>,
> = pil
.get_publics()
.into_iter()
.map(|(_, _, _, row_id, _)| {
// Create a column with a single 1 at the row_id-th (in circle domain bitreverse order) position
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
col.set(
bit_reverse_index(
coset_index_to_circle_domain_index(
row_id, log_size,
),
log_size,
),
BaseField::one(),
);
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
*domain_map.get(&(log_size as usize)).unwrap(),
col,
)
})
.collect();

constant_trace.extend(publics_selectors);

(
size as usize,
TableProvingKey {
Expand Down Expand Up @@ -301,6 +328,28 @@ where
})
.collect::<BTreeMap<_, _>>();

// get publics of stage0
let publics_by_stage = self.analyzed.get_publics().into_iter().fold(
vec![vec![]; self.analyzed.stage_count()],
|mut acc, (name, column_name, id, row, stage)| {
acc[stage as usize].push((name, column_name, id, row));
acc
},
);

let mut public_values: BTreeMap<String, M31> = publics_by_stage[0]
.iter()
.flat_map(|(name, ref_witness_col_name, _, row)| {
let namespace = ref_witness_col_name.split("::").next().unwrap();
witness_by_machine
.get(namespace)
.unwrap()
.iter()
.filter(move |(witness_col_name, _)| ref_witness_col_name == witness_col_name)
.map(|(_, col)| (name.clone(), col[*row]))
})
.collect();

// Get witness columns in circle domain for stage 0
let stage0_witness_cols_circle_domain_eval: ColumnVec<
CircleEvaluation<B, BaseField, BitReversedOrder>,
Expand Down Expand Up @@ -349,27 +398,56 @@ where
if self.analyzed.stage_count() > 1 {
// Build witness columns for stage 1 using the callback function, with the generated challenges
let span = span!(Level::INFO, "Generate stage 1 witnesses").entered();
let stage1_witness_cols_circle_domain_eval = witness_by_machine
.into_iter()
let stage0_witness_name_list = witness_by_machine
.values()
.map(|machine_witness| {
machine_witness
.iter()
.map(|(k, _)| k.clone())
.collect::<BTreeSet<_>>()
})
.collect_vec();

let stage1_witness_cols = witness_by_machine
.iter()
.map(|(machine_name, machine_witness)| {
(
machine_witness
.iter()
.map(|(k, _)| k.clone())
.collect::<BTreeSet<_>>(),
witgen_callback.next_stage_witness(
&self.split[&machine_name],
&machine_witness,
stage0_challenges.clone(),
1,
),
witgen_callback.next_stage_witness(
&self.split[&machine_name.clone()],
machine_witness,
stage0_challenges.clone(),
1,
)
})
.flat_map(move |(stage0_columns, callback_result)| {
.collect_vec();

// Get publics of stage 1
let public_values_stage1: BTreeMap<String, M31> = stage0_witness_name_list
.iter()
.zip_eq(stage1_witness_cols.iter())
.flat_map(|(stage0_witness_name_list, callback_result)| {
callback_result.iter().filter_map(|(witness_name, vec)| {
if stage0_witness_name_list.contains(witness_name) {
None
} else {
publics_by_stage[1].iter().find_map(
|(name, ref_witness_col_name, _, row)| {
(witness_name == ref_witness_col_name)
.then(|| (name.clone(), vec[*row]))
},
)
}
})
})
.collect();

let stage1_witness_cols_circle_domain_eval = stage0_witness_name_list
.iter()
.zip_eq(stage1_witness_cols.iter())
.flat_map(|(stage0_witness_name_list, callback_result)| {
callback_result
.iter()
.filter_map(|(witness_name, vec)| {
if stage0_columns.contains(witness_name) {
if stage0_witness_name_list.contains(witness_name) {
None
} else {
Some(gen_stwo_circle_column::<B, BaseField>(
Expand All @@ -388,6 +466,7 @@ where
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(stage1_witness_cols_circle_domain_eval);
tree_builder.commit(prover_channel);
public_values.extend(public_values_stage1);
span.exit();
}

Expand All @@ -410,6 +489,7 @@ where
constant_cols_offset_acc,
machine_log_size,
stage0_challenges.clone(),
public_values.clone(),
),
(SecureField::zero(), None),
);
Expand Down Expand Up @@ -445,12 +525,23 @@ where
Ok(bincode::serialize(&proof).unwrap())
}

pub fn verify(&self, proof: &[u8], _instances: &[M31]) -> Result<(), String> {
assert!(
_instances.is_empty(),
"Expected _instances slice to be empty, but it has {} elements.",
_instances.len()
);
pub fn verify(&self, proof: &[u8], instances: &[M31]) -> Result<(), String> {
// get public values
let publics = self.analyzed.get_publics();

if publics.len() != instances.len() {
return Err(format!(
"Instance size mismatch: expected {}, got {}",
publics.len(),
instances.len()
));
};

let public_values: BTreeMap<String, M31> = publics
.iter()
.zip_eq(instances.iter())
.map(|((public_name, _, _, _, _), value)| (public_name.to_string(), *value))
.collect();

let config = get_config();

Expand Down Expand Up @@ -479,8 +570,11 @@ where
let constant_col_log_sizes = iter
.clone()
.flat_map(|(pil, machine_log_size)| {
repeat(machine_log_size)
.take(pil.constant_count() + get_constant_with_next_list(pil).len())
repeat(machine_log_size).take(
pil.constant_count()
+ get_constant_with_next_list(pil).len()
+ pil.publics_count(),
)
})
.collect_vec();

Expand Down Expand Up @@ -524,6 +618,7 @@ where
constant_cols_offset_acc,
machine_log_size,
stage0_challenges.clone(),
public_values.clone(),
),
(SecureField::zero(), None),
);
Expand Down
Loading
Loading