diff --git a/Cargo.lock b/Cargo.lock index d0ee30820..ea6e30885 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -410,7 +410,7 @@ dependencies = [ "tempfile", "tracing", "tracing-forest", - "tracing-subscriber 0.3.19", + "tracing-subscriber", "vergen-git2", ] @@ -1910,6 +1910,7 @@ dependencies = [ "p3-commit", "p3-dft", "p3-field", + "p3-fri", "p3-goldilocks", "p3-matrix", "p3-maybe-rayon", @@ -1993,6 +1994,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "p3-fri" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/plonky3?rev=20c3cb9#20c3cb9119da0d5196d7a2e7b2e4f1a62dc8cae7" +dependencies = [ + "itertools 0.14.0", + "p3-challenger", + "p3-commit", + "p3-dft", + "p3-field", + "p3-interpolation", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "rand", + "serde", + "tracing", +] + [[package]] name = "p3-goldilocks" version = "0.1.0" @@ -2011,6 +2031,17 @@ dependencies = [ "serde", ] +[[package]] +name = "p3-interpolation" +version = "0.1.0" +source = "git+https://github.com/scroll-tech/plonky3?rev=20c3cb9#20c3cb9119da0d5196d7a2e7b2e4f1a62dc8cae7" +dependencies = [ + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", +] + [[package]] name = "p3-matrix" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 4424d9eba..cbeec6fbc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ p3-challenger = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9 p3-commit = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9" } p3-dft = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9" } p3-field = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9" } +p3-fri = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9" } p3-goldilocks = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9" } p3-matrix = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9" } p3-maybe-rayon = { git = "https://github.com/scroll-tech/plonky3", rev = "20c3cb9" } diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 76436bc2f..d1e2c59af 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{collections::BTreeMap, time::Duration}; use ceno_zkvm::{ self, @@ -73,14 +73,16 @@ fn bench_add(c: &mut Criterion) { for _ in 0..iters { // generate mock witness let num_instances = 1 << instance_num_vars; - let rmm = - RowMajorMatrix::rand(&mut OsRng, num_instances, num_witin as usize); + let rmms = BTreeMap::from([( + 0, + RowMajorMatrix::rand(&mut OsRng, num_instances, num_witin as usize), + )]); let instant = std::time::Instant::now(); let num_instances = 1 << instance_num_vars; let mut transcript = BasicTranscript::new(b"riscv"); let commit = - Pcs::batch_commit_and_write(&prover.pk.pp, rmm, &mut transcript) + Pcs::batch_commit_and_write(&prover.pk.pp, rmms, &mut transcript) .unwrap(); let polys = Pcs::get_arc_mle_witness_from_commitment(&commit); let challenges = [ @@ -91,10 +93,8 @@ fn bench_add(c: &mut Criterion) { let _ = prover .create_opcode_proof( "ADD", - &prover.pk.pp, circuit_pk, polys, - commit, &[], num_instances, &mut transcript, diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 42c3fd21f..59f307a01 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -3,7 +3,6 @@ use serde::de::DeserializeOwned; use std::{collections::HashMap, iter::once, marker::PhantomData}; use ff_ext::ExtensionField; -use mpcs::PolynomialCommitmentScheme; use crate::{ ROMType, @@ -13,7 +12,6 @@ use crate::{ structs::{ProgramParams, ProvingKey, RAMType, VerifyingKey, WitnessId}, }; use p3::field::PrimeCharacteristicRing; -use witness::RowMajorMatrix; /// namespace used for annotation, preserve meta info during circuit construction #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)] @@ -178,24 +176,9 @@ impl ConstraintSystem { } } - pub fn key_gen>( - self, - pp: &PCS::ProverParam, - fixed_traces: Option>, - ) -> ProvingKey { - // transpose from row-major to column-major - let fixed_traces_polys = fixed_traces.as_ref().map(|rmm| rmm.to_mles()); - - let fixed_commit_wd = fixed_traces.map(|traces| PCS::batch_commit(pp, traces).unwrap()); - let fixed_commit = fixed_commit_wd.as_ref().map(PCS::get_pure_commitment); - + pub fn key_gen(self) -> ProvingKey { ProvingKey { - fixed_traces: fixed_traces_polys, - fixed_commit_wd, - vk: VerifyingKey { - cs: self, - fixed_commit, - }, + vk: VerifyingKey { cs: self }, } } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index e377442d4..e5b5621f3 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -675,7 +675,11 @@ pub fn verify( // print verification statistics like proof size and hash count let stat_recorder = StatisticRecorder::default(); let transcript = BasicTranscriptWithStat::new(&stat_recorder, b"riscv"); - verifier.verify_proof_halt(zkvm_proof.clone(), transcript, zkvm_proof.has_halt())?; + verifier.verify_proof_halt( + zkvm_proof.clone(), + transcript, + zkvm_proof.has_halt(&verifier.vk), + )?; info!("e2e proof stat: {}", zkvm_proof); info!( "hashes count = {}", diff --git a/ceno_zkvm/src/instructions/riscv/test.rs b/ceno_zkvm/src/instructions/riscv/test.rs index 905438c48..00cabba85 100644 --- a/ceno_zkvm/src/instructions/riscv/test.rs +++ b/ceno_zkvm/src/instructions/riscv/test.rs @@ -23,6 +23,6 @@ fn test_multiple_opcode() { |cs| SubInstruction::construct_circuit(&mut CircuitBuilder::::new(cs)), ); let param = Pcs::setup(1 << 10).unwrap(); - let (pp, _) = Pcs::trim(param, 1 << 10).unwrap(); - cs.key_gen::(&pp, None); + let (_, _) = Pcs::trim(param, 1 << 10).unwrap(); + cs.key_gen(); } diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index f97a1886d..473ea339e 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use crate::{ error::ZKVMError, structs::{ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMProvingKey}, @@ -12,24 +14,27 @@ impl ZKVMConstraintSystem { vp: PCS::VerifierParam, mut vm_fixed_traces: ZKVMFixedTraces, ) -> Result, ZKVMError> { - let mut vm_pk = ZKVMProvingKey::new(pp, vp); + let mut vm_pk = ZKVMProvingKey::new(pp.clone(), vp); + let mut fixed_traces = BTreeMap::new(); - for (c_name, cs) in self.circuit_css { + for (circuit_index, (c_name, cs)) in self.circuit_css.into_iter().enumerate() { // fixed_traces is optional // verifier will check it existent if cs.num_fixed > 0 - let fixed_traces = if cs.num_fixed > 0 { - vm_fixed_traces + if cs.num_fixed > 0 { + let fixed_trace_rmm = vm_fixed_traces .circuit_fixed_traces .remove(&c_name) - .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))? - } else { - None + .flatten() + .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone()))?; + fixed_traces.insert(circuit_index, fixed_trace_rmm); }; - let circuit_pk = cs.key_gen(&vm_pk.pp, fixed_traces); + let circuit_pk = cs.key_gen(); assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); } + vm_pk.commit_fixed(fixed_traces)?; + vm_pk.initial_global_state_expr = self.initial_global_state_expr; vm_pk.finalize_global_state_expr = self.finalize_global_state_expr; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index d47ed5108..38fc49aa1 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -12,7 +12,7 @@ use sumcheck::structs::IOPProverMessage; use crate::{ instructions::{Instruction, riscv::ecall::HaltInstruction}, - structs::TowerProofs, + structs::{TowerProofs, ZKVMVerifyingKey}, }; pub mod constants; @@ -29,10 +29,7 @@ mod tests; serialize = "E::BaseField: Serialize", deserialize = "E::BaseField: DeserializeOwned" ))] -pub struct ZKVMOpcodeProof> { - // TODO support >1 opcodes - pub num_instances: usize, - +pub struct ZKVMOpcodeProof { // product constraints pub record_r_out_evals: Vec, pub record_w_out_evals: Vec, @@ -51,8 +48,6 @@ pub struct ZKVMOpcodeProof pub w_records_in_evals: Vec, pub lk_records_in_evals: Vec, - pub wits_commit: PCS::Commitment, - pub wits_opening_proof: PCS::Proof, pub wits_in_evals: Vec, } @@ -61,7 +56,7 @@ pub struct ZKVMOpcodeProof serialize = "E::BaseField: Serialize", deserialize = "E::BaseField: DeserializeOwned" ))] -pub struct ZKVMTableProof> { +pub struct ZKVMTableProof { // tower evaluation at layer 1 pub r_out_evals: Vec<[E; 2]>, pub w_out_evals: Vec<[E; 2]>, @@ -73,14 +68,8 @@ pub struct ZKVMTableProof> pub tower_proof: TowerProofs, - // num_vars hint for rw dynamic address to work - pub rw_hints_num_vars: Vec, - pub fixed_in_evals: Vec, - pub fixed_opening_proof: Option, - pub wits_commit: PCS::Commitment, pub wits_in_evals: Vec, - pub wits_opening_proof: PCS::Proof, } /// each field will be interpret to (constant) polynomial @@ -144,14 +133,37 @@ pub struct ZKVMProof> { pub raw_pi: Vec>, // the evaluation of raw_pi. pub pi_evals: Vec, - opcode_proofs: BTreeMap)>, - table_proofs: BTreeMap)>, + // circuit size -> instance mapping + pub num_instances: Vec<(usize, usize)>, + opcode_proofs: BTreeMap>, + table_proofs: BTreeMap>, + witin_commit: >::Commitment, + pub fixed_witin_opening_proof: PCS::Proof, } impl> ZKVMProof { - pub fn empty(pv: PublicValues) -> Self { - let raw_pi = pv.to_vec::(); - let pi_evals = raw_pi + pub fn new( + raw_pi: Vec>, + pi_evals: Vec, + opcode_proofs: BTreeMap>, + table_proofs: BTreeMap>, + witin_commit: >::Commitment, + fixed_witin_opening_proof: PCS::Proof, + num_instances: Vec<(usize, usize)>, + ) -> Self { + Self { + raw_pi, + pi_evals, + opcode_proofs, + table_proofs, + witin_commit, + fixed_witin_opening_proof, + num_instances, + } + } + + pub fn pi_evals(raw_pi: &[Vec]) -> Vec { + raw_pi .iter() .map(|pv| { if pv.len() == 1 { @@ -163,13 +175,7 @@ impl> ZKVMProof { E::ZERO } }) - .collect_vec(); - Self { - raw_pi, - pi_evals, - opcode_proofs: BTreeMap::new(), - table_proofs: BTreeMap::new(), - } + .collect_vec() } pub fn update_pi_eval(&mut self, idx: usize, v: E) { @@ -180,11 +186,19 @@ impl> ZKVMProof { self.opcode_proofs.len() + self.table_proofs.len() } - pub fn has_halt(&self) -> bool { + pub fn has_halt(&self, vk: &ZKVMVerifyingKey) -> bool { let halt_instance_count = self - .opcode_proofs - .get(&HaltInstruction::::name()) - .map(|(_, p)| p.num_instances) + .num_instances + .iter() + .find_map(|(circuit_index, num_instances)| { + (*circuit_index + == vk + .circuit_vks + .keys() + .position(|circuit_name| *circuit_name == HaltInstruction::::name()) + .expect("halt circuit not exist")) + .then_some(*num_instances) + }) .unwrap_or(0); if halt_instance_count > 0 { assert_eq!( @@ -205,40 +219,19 @@ impl + Serialize> fmt::Dis // also provide by-circuit stats let mut by_circuitname_stats = HashMap::new(); // opcode circuit mpcs size - let mpcs_opcode_commitment = self - .opcode_proofs - .iter() - .map(|(circuit_name, (_, proof))| { - let size = bincode::serialized_size(&proof.wits_commit); - size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; - }) - }) - .collect::, _>>() - .expect("serialization error") - .iter() - .sum::(); - let mpcs_opcode_opening = self - .opcode_proofs - .iter() - .map(|(circuit_name, (_, proof))| { - let size = bincode::serialized_size(&proof.wits_opening_proof); - size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; - }) - }) - .collect::, _>>() - .expect("serialization error") - .iter() - .sum::(); + let mpcs_opcode_commitment = + bincode::serialized_size(&self.witin_commit).expect("serialization error"); + let mpcs_opcode_opening = + bincode::serialized_size(&self.fixed_witin_opening_proof).expect("serialization error"); + // opcode circuit for tower proof size let tower_proof_opcode = self .opcode_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.tower_proof); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -249,50 +242,10 @@ impl + Serialize> fmt::Dis let main_sumcheck_opcode = self .opcode_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.main_sel_sumcheck_proofs); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; - }) - }) - .collect::, _>>() - .expect("serialization error") - .iter() - .sum::(); - // table circuit mpcs size - let mpcs_table_commitment = self - .table_proofs - .iter() - .map(|(circuit_name, (_, proof))| { - let size = bincode::serialized_size(&proof.wits_commit); - size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; - }) - }) - .collect::, _>>() - .expect("serialization error") - .iter() - .sum::(); - let mpcs_table_opening = self - .table_proofs - .iter() - .map(|(circuit_name, (_, proof))| { - let size = bincode::serialized_size(&proof.wits_opening_proof); - size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; - }) - }) - .collect::, _>>() - .expect("serialization error") - .iter() - .sum::(); - let mpcs_table_fixed_opening = self - .table_proofs - .iter() - .map(|(circuit_name, (_, proof))| { - let size = bincode::serialized_size(&proof.fixed_opening_proof); - size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -303,10 +256,10 @@ impl + Serialize> fmt::Dis let tower_proof_table = self .table_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.tower_proof); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -317,10 +270,10 @@ impl + Serialize> fmt::Dis let same_r_sumcheck_table = self .table_proofs .iter() - .map(|(circuit_name, (_, proof))| { + .map(|(circuit_index, proof)| { let size = bincode::serialized_size(&proof.same_r_sumcheck_proofs); size.inspect(|size| { - *by_circuitname_stats.entry(circuit_name).or_insert(0) += size; + *by_circuitname_stats.entry(circuit_index).or_insert(0) += size; }) }) .collect::, _>>() @@ -350,13 +303,10 @@ impl + Serialize> fmt::Dis write!( f, "overall_size {:.2}mb. \n\ - opcode mpcs commitment {:?}% \n\ - opcode mpcs opening {:?}% \n\ + mpcs commitment {:?}% \n\ + mpcs opening {:?}% \n\ opcode tower proof {:?}% \n\ opcode main sumcheck proof {:?}% \n\ - table mpcs commitment {:?}% \n\ - table mpcs opening {:?}% \n\ - table mpcs fixed opening {:?}% \n\ table tower proof {:?}% \n\ table same r sumcheck proof {:?}% \n\n\ by circuit_name break down: \n\ @@ -367,9 +317,6 @@ impl + Serialize> fmt::Dis (mpcs_opcode_opening * 100).div(overall_size), (tower_proof_opcode * 100).div(overall_size), (main_sumcheck_opcode * 100).div(overall_size), - (mpcs_table_commitment * 100).div(overall_size), - (mpcs_table_opening * 100).div(overall_size), - (mpcs_table_fixed_opening * 100).div(overall_size), (tower_proof_table * 100).div(overall_size), (same_r_sumcheck_table * 100).div(overall_size), by_circuitname_stats, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index c63cf2759..cea2cb9aa 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,13 +1,10 @@ use ff_ext::ExtensionField; -use std::{ - collections::{BTreeMap, BTreeSet, HashMap}, - sync::Arc, -}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; use itertools::{Itertools, enumerate, izip}; -use mpcs::PolynomialCommitmentScheme; +use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - mle::{IntoMLE, MultilinearExtension}, + mle::IntoMLE, util::ceil_log2, virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec}, virtual_polys::VirtualPolynomials, @@ -33,14 +30,14 @@ use crate::{ }, }, structs::{ - Point, ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, + ProvingKey, TowerProofs, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, utils::{add_mle_list_by_expr, get_challenge_pows}, }; use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; -type ResultCreateTableProof = (ZKVMTableProof, HashMap); +type CreateTableProof = (ZKVMTableProof, HashMap, Point); pub struct ZKVMProver> { pub pk: ZKVMProvingKey, @@ -64,16 +61,19 @@ impl> ZKVMProver { pi: PublicValues, mut transcript: impl Transcript, ) -> Result, ZKVMError> { - let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); - let mut vm_proof = ZKVMProof::empty(pi); + let raw_pi = pi.to_vec::(); + let mut pi_evals = ZKVMProof::::pi_evals(&raw_pi); + let mut opcode_proofs: BTreeMap> = BTreeMap::new(); + let mut table_proofs: BTreeMap> = BTreeMap::new(); + let span = entered_span!("commit_to_pi", profiling_1 = true); // including raw public input to transcript - for v in vm_proof.raw_pi.iter().flatten() { + for v in raw_pi.iter().flatten() { transcript.append_field_element(v); } + exit_span!(span); - let pi: Vec> = vm_proof - .raw_pi + let pi: Vec> = raw_pi .iter() .map(|p| { let pi_mle: ArcMultilinearExtension = p.to_vec().into_mle().into(); @@ -82,19 +82,52 @@ impl> ZKVMProver { .collect(); // commit to fixed commitment - for pk in self.pk.circuit_pks.values() { - if let Some(fixed_commit) = &pk.vk.fixed_commit { - PCS::write_commitment(fixed_commit, &mut transcript) - .map_err(ZKVMError::PCSError)?; - } + let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); + if let Some(fixed_commit) = &self.pk.get_vk().fixed_commit { + PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; } exit_span!(span); // commit to main traces - let mut commitments = BTreeMap::new(); - let mut wits = BTreeMap::new(); + let circuit_name_index_mapping = self + .pk + .circuit_pks + .keys() + .enumerate() + .map(|(k, v)| (v, k)) + .collect::>(); + let mut wits_instances = BTreeMap::new(); + let mut wits_rmms = BTreeMap::new(); let mut structural_wits = BTreeMap::new(); + let mut num_instances = Vec::with_capacity(self.pk.circuit_pks.len()); + for (index, (circuit_name, _)) in self.pk.circuit_pks.iter().enumerate() { + if let Some(num_instance) = witnesses + .get_opcode_witness(circuit_name) + .or_else(|| { + witnesses + .get_table_witness(circuit_name) + .map(|rmms| &rmms[0]) + }) + .map(|rmm| rmm.num_instances()) + .and_then(|num_instance| { + if num_instance > 0 { + Some(num_instance) + } else { + None + } + }) + { + num_instances.push((index, num_instance)); + } + } + + // write (circuit_size, num_var) to transcript + for (circuit_size, num_var) in &num_instances { + transcript.append_message(&circuit_size.to_le_bytes()); + transcript.append_message(&num_var.to_le_bytes()); + } + let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, mut rmm) in witnesses.into_iter_sorted() { @@ -106,26 +139,13 @@ impl> ZKVMProver { RowMajorMatrix::empty() }; let num_instances = witness_rmm.num_instances(); - let span = entered_span!( - "commit to iteration", - circuit_name = circuit_name, - profiling_2 = true - ); + wits_instances.insert(circuit_name.clone(), num_instances); + if num_instances == 0 { + continue; + } - let (witness, structural_witness) = match num_instances { - 0 => (vec![], vec![]), - _ => { - let structural_witness = structural_witness_rmm.to_mles(); - let commit = - PCS::batch_commit_and_write(&self.pk.pp, witness_rmm, &mut transcript) - .map_err(ZKVMError::PCSError)?; - let witness = PCS::get_arc_mle_witness_from_commitment(&commit); - commitments.insert(circuit_name.clone(), commit); - (witness, structural_witness) - } - }; - exit_span!(span); - wits.insert(circuit_name.clone(), (witness, num_instances)); + let structural_witness = structural_witness_rmm.to_mles(); + wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); structural_wits.insert( circuit_name, ( @@ -137,8 +157,26 @@ impl> ZKVMProver { ), ); } + + debug_assert_eq!(num_instances.len(), wits_rmms.len()); + + // batch commit witness + let span = entered_span!("batch commit to witness", profiling_2 = true); + let witin_commit_with_witness = + PCS::batch_commit_and_write(&self.pk.pp, wits_rmms, &mut transcript) + .map_err(ZKVMError::PCSError)?; + exit_span!(span); + // retrieve mle from commitment + let mut witness_mles = PCS::get_arc_mle_witness_from_commitment(&witin_commit_with_witness); + let witin_commit = PCS::get_pure_commitment(&witin_commit_with_witness); exit_span!(commit_to_traces_span); + // retrive fixed mle from pk + let mut fixed_mles = + PCS::get_arc_mle_witness_from_commitment(self.pk.fixed_commit_wd.as_ref().ok_or( + ZKVMError::FixedTraceNotFound("there is no fixed trace witness".to_string()), + )?); + // squeeze two challenges from transcript let challenges = [ transcript.read_challenge().elements, @@ -147,17 +185,23 @@ impl> ZKVMProver { tracing::debug!("challenges in prover: {:?}", challenges); let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); - for (index, (circuit_name, pk)) in self.pk.circuit_pks.iter().enumerate() { - let (witness, num_instances) = wits - .remove(circuit_name) + let (points, evaluations) = self + .pk + .circuit_pks + .iter() + .enumerate() + .try_fold((vec![], vec![]), |(mut points, mut evaluations), (index, (circuit_name, pk))| { + let num_instances = *wits_instances + .get(circuit_name) .ok_or(ZKVMError::WitnessNotFound(circuit_name.to_string()))?; - if witness.is_empty() { - continue; + if num_instances == 0 { + // do nothing without point and evaluation insertion + return Ok::<(Vec<_>, Vec>), ZKVMError>((points,evaluations)); } transcript.append_field_element(&E::BaseField::from_u64(index as u64)); - let wits_commit = commitments.remove(circuit_name).unwrap(); // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); + let witness_mle = witness_mles.drain(..cs.num_witin as usize).collect_vec(); let is_opcode_circuit = cs.lk_table_expressions.is_empty() && cs.r_table_expressions.is_empty() && cs.w_table_expressions.is_empty(); @@ -171,12 +215,10 @@ impl> ZKVMProver { cs.w_expressions.len(), cs.lk_expressions.len(), ); - let opcode_proof = self.create_opcode_proof( + let (opcode_proof, point) = self.create_opcode_proof( circuit_name, - &self.pk.pp, pk, - witness, - wits_commit, + witness_mle, &pi, num_instances, &mut transcript, @@ -187,38 +229,76 @@ impl> ZKVMProver { circuit_name, num_instances ); - vm_proof - .opcode_proofs - .insert(circuit_name.clone(), (index, opcode_proof)); + points.push(point); + evaluations.push(opcode_proof.wits_in_evals.clone()); + opcode_proofs + .insert(index, opcode_proof); } else { + let fixed_mle = fixed_mles.drain(..cs.num_fixed).collect_vec(); let (structural_witness, structural_num_instances) = structural_wits .remove(circuit_name) .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; - let (table_proof, pi_in_evals) = self.create_table_proof( + let (table_proof, pi_in_evals, point) = self.create_table_proof( circuit_name, - &self.pk.pp, pk, - witness, - wits_commit, + fixed_mle, + witness_mle, structural_witness, &pi, &mut transcript, &challenges, )?; + points.push(point); + evaluations.push(table_proof.wits_in_evals.clone()); + if cs.num_fixed > 0 { + evaluations.push(table_proof.fixed_in_evals.clone()); + } tracing::info!( "generated proof for table {} with num_instances={}, structural_num_instances={}", circuit_name, num_instances, structural_num_instances ); - vm_proof - .table_proofs - .insert(circuit_name.clone(), (index, table_proof)); + table_proofs.insert(index, table_proof); for (idx, eval) in pi_in_evals { - vm_proof.update_pi_eval(idx, eval); + pi_evals[idx]= eval; } - } - } + }; + Ok((points,evaluations)) + })?; + + // batch opening pcs + // generate static info from prover key for expected num variable + let circuit_num_polys = self + .pk + .circuit_pks + .values() + .map(|pk| (pk.get_cs().num_witin as usize, pk.get_cs().num_fixed)) + .collect_vec(); + let pcs_opening = entered_span!("pcs_opening"); + let mpcs_opening_proof = PCS::batch_open( + &self.pk.pp, + &num_instances, + self.pk.fixed_commit_wd.as_ref(), + &witin_commit_with_witness, + &points, + &evaluations, + &circuit_num_polys, + &mut transcript, + ) + .map_err(ZKVMError::PCSError)?; + exit_span!(pcs_opening); + + let vm_proof = ZKVMProof::new( + raw_pi, + pi_evals, + opcode_proofs, + table_proofs, + witin_commit, + mpcs_opening_proof, + // verifier need this information from prover to achieve non-uniform design. + num_instances, + ); exit_span!(main_proofs_span); Ok(vm_proof) @@ -232,15 +312,13 @@ impl> ZKVMProver { pub fn create_opcode_proof( &self, name: &str, - pp: &PCS::ProverParam, - circuit_pk: &ProvingKey, + circuit_pk: &ProvingKey, witnesses: Vec>, - wits_commit: PCS::CommitmentWithWitness, pi: &[ArcMultilinearExtension<'_, E>], num_instances: usize, transcript: &mut impl Transcript, challenges: &[E; 2], - ) -> Result, ZKVMError> { + ) -> Result<(ZKVMOpcodeProof, Point), ZKVMError> { let cs = circuit_pk.get_cs(); let next_pow2_instances = next_pow2_instance_padding(num_instances); let log2_num_instances = ceil_log2(next_pow2_instances); @@ -575,7 +653,7 @@ impl> ZKVMProver { let (main_sel_sumcheck_proofs, state) = IOPProverState::prove(virtual_polys, transcript); tracing::debug!("main sel sumcheck end"); - let main_sel_evals = state.get_mle_final_evaluations(); + let main_sel_evals = state.get_mle_flatten_final_evaluations(); assert_eq!( main_sel_evals.len(), r_counts_per_instance @@ -630,40 +708,29 @@ impl> ZKVMProver { name, witnesses.len() ); - let wits_opening_proof = PCS::simple_batch_open( - pp, - &witnesses, - &wits_commit, - &input_open_point, - wits_in_evals.as_slice(), - transcript, - ) - .map_err(ZKVMError::PCSError)?; tracing::info!( "[opcode {}] build opening proof took {:?}", name, opening_dur.elapsed(), ); exit_span!(pcs_open_span); - let wits_commit = PCS::get_pure_commitment(&wits_commit); - - Ok(ZKVMOpcodeProof { - num_instances, - record_r_out_evals, - record_w_out_evals, - lk_p1_out_eval, - lk_p2_out_eval, - lk_q1_out_eval, - lk_q2_out_eval, - tower_proof, - main_sel_sumcheck_proofs: main_sel_sumcheck_proofs.proofs, - r_records_in_evals, - w_records_in_evals, - lk_records_in_evals, - wits_commit, - wits_opening_proof, - wits_in_evals, - }) + Ok(( + ZKVMOpcodeProof { + record_r_out_evals, + record_w_out_evals, + lk_p1_out_eval, + lk_p2_out_eval, + lk_q1_out_eval, + lk_q2_out_eval, + tower_proof, + main_sel_sumcheck_proofs: main_sel_sumcheck_proofs.proofs, + r_records_in_evals, + w_records_in_evals, + lk_records_in_evals, + wits_in_evals, + }, + input_open_point, + )) } #[allow(clippy::too_many_arguments)] @@ -674,26 +741,15 @@ impl> ZKVMProver { pub fn create_table_proof( &self, name: &str, - pp: &PCS::ProverParam, - circuit_pk: &ProvingKey, + circuit_pk: &ProvingKey, + fixed: Vec>, witnesses: Vec>, - wits_commit: PCS::CommitmentWithWitness, structural_witnesses: Vec>, pi: &[ArcMultilinearExtension<'_, E>], transcript: &mut impl Transcript, challenges: &[E; 2], - ) -> Result, ZKVMError> { + ) -> Result, ZKVMError> { let cs = circuit_pk.get_cs(); - let fixed = circuit_pk - .fixed_traces - .as_ref() - .map(|fixed_traces| { - fixed_traces - .iter() - .map(|f| -> ArcMultilinearExtension { Arc::new(f.get_ranged_mle(1, 0)) }) - .collect::>>() - }) - .unwrap_or_default(); // sanity check assert_eq!(witnesses.len(), cs.num_witin as usize); assert_eq!(structural_witnesses.len(), cs.num_structural_witin as usize); @@ -915,15 +971,6 @@ impl> ZKVMProver { }) .collect_vec(); - // (non uniform) collect dynamic address hints as witness for verifier - let rw_hints_num_vars = structural_witnesses - .iter() - .map(|mle| mle.num_vars()) - .collect_vec(); - for var in rw_hints_num_vars.iter() { - transcript.append_message(&var.to_le_bytes()); - } - let (rt_tower, tower_proof) = TowerProver::create_proof( // pattern [r1, w1, r2, w2, ...] same pair are chain together r_wit_layers @@ -1013,7 +1060,7 @@ impl> ZKVMProver { let (same_r_sumcheck_proofs, state) = IOPProverState::prove(virtual_polys, transcript); - let evals = state.get_mle_final_evaluations(); + let evals = state.get_mle_flatten_final_evaluations(); let mut evals_iter = evals.into_iter(); let rw_in_evals = cs // r, w table len are identical @@ -1068,44 +1115,6 @@ impl> ZKVMProver { } exit_span!(span); - let pcs_opening = entered_span!("pcs_opening"); - let (fixed_opening_proof, _fixed_commit) = if !fixed.is_empty() { - ( - Some( - PCS::simple_batch_open( - pp, - &fixed, - circuit_pk.fixed_commit_wd.as_ref().unwrap(), - &input_open_point, - fixed_in_evals.as_slice(), - transcript, - ) - .map_err(ZKVMError::PCSError)?, - ), - Some(PCS::get_pure_commitment( - circuit_pk.fixed_commit_wd.as_ref().unwrap(), - )), - ) - } else { - (None, None) - }; - - tracing::debug!( - "[table {}] build opening proof for {} fixed polys", - name, - fixed.len() - ); - let wits_opening_proof = PCS::simple_batch_open( - pp, - &witnesses, - &wits_commit, - &input_open_point, - wits_in_evals.as_slice(), - transcript, - ) - .map_err(ZKVMError::PCSError)?; - exit_span!(pcs_opening); - let wits_commit = PCS::get_pure_commitment(&wits_commit); tracing::debug!( "[table {}] build opening proof for {} polys", name, @@ -1122,13 +1131,10 @@ impl> ZKVMProver { lk_in_evals, tower_proof, fixed_in_evals, - fixed_opening_proof, - rw_hints_num_vars, wits_in_evals, - wits_commit, - wits_opening_proof, }, pi_in_evals, + input_open_point, )) } } @@ -1281,7 +1287,7 @@ impl TowerProver { prod_specs.len() +logup_specs.len() * 2, // logup occupy 2 sumcheck: numerator and denominator transcript, ); - let evals = state.get_mle_final_evaluations(); + let evals = state.get_mle_flatten_final_evaluations(); let mut evals_iter = evals.iter(); evals_iter.next(); // skip first eq for (i, s) in enumerate(&prod_specs) { diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 190ab3ede..b9b516c13 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{collections::BTreeMap, marker::PhantomData}; use crate::{ circuit_builder::CircuitBuilder, @@ -127,20 +127,25 @@ fn test_rw_lk_expression_combination() { let rmm = zkvm_witness.into_iter_sorted().next().unwrap().1.remove(0); let wits_in = rmm.to_mles(); // commit to main traces - let commit = Pcs::batch_commit_and_write(&prover.pk.pp, rmm, &mut transcript).unwrap(); + let commit_with_witness = Pcs::batch_commit_and_write( + &prover.pk.pp, + vec![(0, rmm)].into_iter().collect::>(), + &mut transcript, + ) + .unwrap(); + let witin_commit = Pcs::get_pure_commitment(&commit_with_witness); + let wits_in = wits_in.into_iter().map(|v| v.into()).collect_vec(); let prover_challenges = [ transcript.read_challenge().elements, transcript.read_challenge().elements, ]; - let proof = prover + let (proof, _) = prover .create_opcode_proof( name.as_str(), - &prover.pk.pp, prover.pk.circuit_pks.get(&name).unwrap(), wits_in, - commit, &[], num_instances, &mut transcript, @@ -153,7 +158,7 @@ fn test_rw_lk_expression_combination() { let verifier = ZKVMVerifier::new(vk.clone()); let mut v_transcript = BasicTranscriptWithStat::new(&stat_recorder, b"test"); // write commitment into transcript and derive challenges from it - Pcs::write_commitment(&proof.wits_commit, &mut v_transcript).unwrap(); + Pcs::write_commitment(&witin_commit, &mut v_transcript).unwrap(); let verifier_challenges = [ v_transcript.read_challenge().elements, v_transcript.read_challenge().elements, @@ -163,9 +168,9 @@ fn test_rw_lk_expression_combination() { let _rt_input = verifier .verify_opcode_proof( name.as_str(), - &vk.vp, verifier.vk.circuit_vks.get(&name).unwrap(), &proof, + num_instances, &[], &mut v_transcript, NUM_FANIN, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 7de1b0a64..d1ee3c24c 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -4,13 +4,14 @@ use ark_std::iterable::Iterable; use ff_ext::ExtensionField; use itertools::{Itertools, chain, interleave, izip}; -use mpcs::PolynomialCommitmentScheme; +use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, }; use p3::field::PrimeCharacteristicRing; +use std::collections::HashSet; use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; @@ -22,7 +23,7 @@ use crate::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, utils::eval_by_expr_with_instance, }, - structs::{Point, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, + structs::{PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{eq_eval_less_or_equal_than, eval_wellform_address_vec, get_challenge_pows}, }; @@ -61,7 +62,7 @@ impl> ZKVMVerifier expect_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. - let has_halt = vm_proof.has_halt(); + let has_halt = vm_proof.has_halt(&self.vk); if has_halt != expect_halt { return Err(ZKVMError::VerifyError(format!( "ecall/halt mismatch: expected {expect_halt} != {has_halt}", @@ -83,6 +84,45 @@ impl> ZKVMVerifier let pi_evals = &vm_proof.pi_evals; + // make sure circuit index are + // 1. unique + // 2. less than self.vk.circuit_vks.len() + assert!( + vm_proof + .num_instances + .iter() + .fold(None, |prev, &(circuit_index, _)| { + (circuit_index < self.vk.circuit_vks.len() + && prev.is_none_or(|p| p < circuit_index)) + .then_some(circuit_index) + }) + .is_some(), + "num_instances validity check failed" + ); + + assert_eq!( + vm_proof + .num_instances + .iter() + .map(|(x, _)| x) + .collect::>(), + vm_proof + .opcode_proofs + .keys() + .chain(vm_proof.table_proofs.keys()) + .collect::>(), + "num_instance circuit index exactly equal with provided proofs" + ); + + assert!( + vm_proof + .opcode_proofs + .keys() + .collect::>() + .is_disjoint(&vm_proof.table_proofs.keys().collect::>()), + "there is duplicated circuit index" + ); + // TODO fix soundness: construct raw public input by ourself and trustless from proof // including raw public input to transcript vm_proof @@ -104,28 +144,26 @@ impl> ZKVMVerifier Ok(()) } })?; + // write fixed commitment to transcript - for (_, vk) in self.vk.circuit_vks.iter() { - if let Some(fixed_commit) = vk.fixed_commit.as_ref() { - PCS::write_commitment(fixed_commit, &mut transcript) - .map_err(ZKVMError::PCSError)?; - } + // TODO check soundness if there is no fixed_commit but got fixed proof? + if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() { + PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; } - for (circuit_name, _) in self.vk.circuit_vks.iter() { - if let Some((_, opcode_proof)) = vm_proof.opcode_proofs.get(circuit_name) { - tracing::debug!("read {}'s commit", circuit_name); - PCS::write_commitment(&opcode_proof.wits_commit, &mut transcript) - .map_err(ZKVMError::PCSError)?; - } else if let Some((_, table_proof)) = vm_proof.table_proofs.get(circuit_name) { - tracing::debug!("read {}'s commit", circuit_name); - PCS::write_commitment(&table_proof.wits_commit, &mut transcript) - .map_err(ZKVMError::PCSError)?; - } else { - // all proof are optional - } + // write (circuit_size, num_var) to transcript + for (circuit_size, num_var) in &vm_proof.num_instances { + transcript.append_message(&circuit_size.to_le_bytes()); + transcript.append_message(&num_var.to_le_bytes()); } + let circuit_vks: Vec<&VerifyingKey> = self.vk.circuit_vks.values().collect_vec(); + let circuit_names: Vec<&String> = self.vk.circuit_vks.keys().collect_vec(); + + // write witin commitment to transcript + PCS::write_commitment(&vm_proof.witin_commit, &mut transcript) + .map_err(ZKVMError::PCSError)?; + // alpha, beta let challenges = [ transcript.read_challenge().elements, @@ -136,30 +174,33 @@ impl> ZKVMVerifier let dummy_table_item = challenges[0]; let mut dummy_table_item_multiplicity = 0; let point_eval = PointAndEval::default(); - for (index, (circuit_name, circuit_vk)) in self.vk.circuit_vks.iter().enumerate() { - if let Some((_, opcode_proof)) = vm_proof.opcode_proofs.get(circuit_name) { - transcript.append_field_element(&E::BaseField::from_u64(index as u64)); - let name = circuit_name; - let _rand_point = self.verify_opcode_proof( + let mut rt_points = Vec::with_capacity(vm_proof.num_instances.len()); + let mut evaluations = Vec::with_capacity(2 * vm_proof.num_instances.len()); // witin + fixed thus *2 + for (index, num_instances) in &vm_proof.num_instances { + let circuit_vk = circuit_vks[*index]; + let name = circuit_names[*index]; + if let Some(opcode_proof) = vm_proof.opcode_proofs.get(index) { + transcript.append_field_element(&E::BaseField::from_u64(*index as u64)); + rt_points.push(self.verify_opcode_proof( name, - &self.vk.vp, circuit_vk, opcode_proof, + *num_instances, pi_evals, &mut transcript, NUM_FANIN, &point_eval, &challenges, - )?; + )?); + evaluations.push(opcode_proof.wits_in_evals.clone()); tracing::info!("verified proof for opcode {}", name); // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().lk_expressions.len(); let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; - let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances) - - opcode_proof.num_instances; - dummy_table_item_multiplicity += num_padded_lks_per_instance - * opcode_proof.num_instances + let num_padded_instance = + next_pow2_instance_padding(*num_instances) - num_instances; + dummy_table_item_multiplicity += num_padded_lks_per_instance * num_instances + num_lks.next_power_of_two() * num_padded_instance; prod_r *= opcode_proof @@ -175,21 +216,24 @@ impl> ZKVMVerifier logup_sum += opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.inverse(); logup_sum += opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.inverse(); - } else if let Some((_, table_proof)) = vm_proof.table_proofs.get(circuit_name) { - transcript.append_field_element(&E::BaseField::from_u64(index as u64)); - let name = circuit_name; - let _rand_point = self.verify_table_proof( + } else if let Some(table_proof) = vm_proof.table_proofs.get(index) { + transcript.append_field_element(&E::BaseField::from_u64(*index as u64)); + rt_points.push(self.verify_table_proof( name, - &self.vk.vp, circuit_vk, table_proof, + *num_instances, &vm_proof.raw_pi, &vm_proof.pi_evals, &mut transcript, NUM_FANIN_LOGUP, &point_eval, &challenges, - )?; + )?); + evaluations.push(table_proof.wits_in_evals.clone()); + if circuit_vk.cs.num_fixed > 0 { + evaluations.push(table_proof.fixed_in_evals.clone()); + } tracing::info!("verified proof for table {}", name); logup_sum = table_proof @@ -212,7 +256,7 @@ impl> ZKVMVerifier .copied() .product::(); } else { - // all proof are optional + unreachable!("respective proof of index {} should exist", index) } } logup_sum -= E::from_u64(dummy_table_item_multiplicity as u64) * dummy_table_item.inverse(); @@ -225,6 +269,20 @@ impl> ZKVMVerifier ))); } + // verify mpcs + PCS::batch_verify( + &self.vk.vp, + &vm_proof.num_instances, + &rt_points, + self.vk.fixed_commit.as_ref(), + &vm_proof.witin_commit, + &evaluations, + &vm_proof.fixed_witin_opening_proof, + &self.vk.circuit_num_polys, + &mut transcript, + ) + .map_err(ZKVMError::PCSError)?; + let initial_global_state = eval_by_expr_with_instance( &[], &[], @@ -255,10 +313,10 @@ impl> ZKVMVerifier #[allow(clippy::too_many_arguments)] pub fn verify_opcode_proof( &self, - name: &str, - vp: &PCS::VerifierParam, - circuit_vk: &VerifyingKey, - proof: &ZKVMOpcodeProof, + _name: &str, + circuit_vk: &VerifyingKey, + proof: &ZKVMOpcodeProof, + num_instances: usize, pi: &[E], transcript: &mut impl Transcript, num_product_fanin: usize, @@ -278,7 +336,6 @@ impl> ZKVMVerifier ); let (chip_record_alpha, _) = (challenges[0], challenges[1]); - let num_instances = proof.num_instances; let next_pow2_instance = next_pow2_instance_padding(num_instances); let log2_num_instances = ceil_log2(next_pow2_instance); @@ -480,21 +537,6 @@ impl> ZKVMVerifier return Err(ZKVMError::VerifyError("zero expression != 0".into())); } - tracing::debug!( - "[opcode {}] verify opening proof for {} polys", - name, - proof.wits_in_evals.len(), - ); - PCS::simple_batch_verify( - vp, - &proof.wits_commit, - &input_opening_point, - &proof.wits_in_evals, - &proof.wits_opening_proof, - transcript, - ) - .map_err(ZKVMError::PCSError)?; - Ok(input_opening_point) } @@ -502,9 +544,9 @@ impl> ZKVMVerifier pub fn verify_table_proof( &self, name: &str, - vp: &PCS::VerifierParam, - circuit_vk: &VerifyingKey, - proof: &ZKVMTableProof, + circuit_vk: &VerifyingKey, + proof: &ZKVMTableProof, + num_instances: usize, raw_pi: &[Vec], pi: &[E], transcript: &mut impl Transcript, @@ -519,6 +561,9 @@ impl> ZKVMVerifier .zip_eq(cs.w_table_expressions.iter()) .all(|(r, w)| r.table_spec.len == w.table_spec.len) ); + + let log2_num_instances = ceil_log2(num_instances); + // in table proof, we always skip same point sumcheck for now // as tower sumcheck batch product argument/logup in same length let is_skip_same_point_sumcheck = true; @@ -526,6 +571,7 @@ impl> ZKVMVerifier // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; + // NOTE: for all structural witness within same constrain system should got same hints num variable via `log2_num_instances` let expected_rounds = cs // only iterate r set, as read/write set round should match .r_table_expressions @@ -536,14 +582,15 @@ impl> ZKVMVerifier r.table_spec .structural_witins .iter() - .map(|StructuralWitIn { id, max_len, .. }| { - let hint_num_vars = proof.rw_hints_num_vars[*id as usize]; + .map(|StructuralWitIn { max_len, .. }| { + let hint_num_vars = log2_num_instances; assert!((1 << hint_num_vars) <= *max_len); hint_num_vars }) .max() .unwrap() }); + assert_eq!(num_vars, log2_num_instances); [num_vars, num_vars] // format: [read_round, write_round] }) .chain(cs.lk_table_expressions.iter().map(|l| { @@ -552,22 +599,19 @@ impl> ZKVMVerifier l.table_spec .structural_witins .iter() - .map(|StructuralWitIn { id, max_len, .. }| { - let hint_num_vars = proof.rw_hints_num_vars[*id as usize]; + .map(|StructuralWitIn { max_len, .. }| { + let hint_num_vars = log2_num_instances; assert!((1 << hint_num_vars) <= *max_len); hint_num_vars }) .max() .unwrap() }); + assert_eq!(num_vars, log2_num_instances); num_vars })) .collect_vec(); - for var in proof.rw_hints_num_vars.iter() { - transcript.append_message(&var.to_le_bytes()); - } - let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = TowerVerify::verify( @@ -777,45 +821,6 @@ impl> ZKVMVerifier ); } - // do optional check of fixed_commitment openings by vk - if circuit_vk.fixed_commit.is_some() { - let Some(fixed_opening_proof) = &proof.fixed_opening_proof else { - return Err(ZKVMError::VerifyError( - "fixed openning proof shoudn't be none".into(), - )); - }; - PCS::simple_batch_verify( - vp, - circuit_vk.fixed_commit.as_ref().unwrap(), - &input_opening_point, - &proof.fixed_in_evals, - fixed_opening_proof, - transcript, - ) - .map_err(ZKVMError::PCSError)?; - } - - tracing::debug!( - "[table {}] verified opening proof for {} fixed polys", - name, - proof.fixed_in_evals.len(), - ); - - PCS::simple_batch_verify( - vp, - &proof.wits_commit, - &input_opening_point, - &proof.wits_in_evals, - &proof.wits_opening_proof, - transcript, - ) - .map_err(ZKVMError::PCSError)?; - tracing::debug!( - "[table {}] verified opening proof for {} polys", - name, - proof.wits_in_evals.len(), - ); - Ok(input_opening_point) } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 26959961b..978a67b77 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -10,10 +10,8 @@ use crate::{ use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; use ff_ext::ExtensionField; use itertools::Itertools; -use mpcs::PolynomialCommitmentScheme; -use multilinear_extensions::{ - mle::DenseMultilinearExtension, virtual_poly::ArcMultilinearExtension, -}; +use mpcs::{Point, PolynomialCommitmentScheme}; +use multilinear_extensions::virtual_poly::ArcMultilinearExtension; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::collections::{BTreeMap, HashMap}; use strum_macros::EnumIter; @@ -71,9 +69,6 @@ pub enum RAMType { Memory, } -/// A point is a vector of num_var length -pub type Point = Vec; - /// A point and the evaluation of this point. #[derive(Clone, Debug, PartialEq)] pub struct PointAndEval { @@ -108,13 +103,11 @@ impl PointAndEval { } #[derive(Clone)] -pub struct ProvingKey> { - pub fixed_traces: Option>>, - pub fixed_commit_wd: Option, - pub vk: VerifyingKey, +pub struct ProvingKey { + pub vk: VerifyingKey, } -impl> ProvingKey { +impl ProvingKey { pub fn get_cs(&self) -> &ConstraintSystem { self.vk.get_cs() } @@ -122,12 +115,11 @@ impl> ProvingKey { #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[serde(bound = "E: ExtensionField + DeserializeOwned")] -pub struct VerifyingKey> { +pub struct VerifyingKey { pub(crate) cs: ConstraintSystem, - pub fixed_commit: Option, } -impl> VerifyingKey { +impl VerifyingKey { pub fn get_cs(&self) -> &ConstraintSystem { &self.cs } @@ -346,7 +338,7 @@ impl ZKVMWitnesses { Ok(()) } - /// Iterate opcode circuits, then table circuits, sorted by name. + /// Iterate opcode/table circuits, sorted by alphabetical order. pub fn into_iter_sorted( self, ) -> impl Iterator>)> { @@ -366,7 +358,9 @@ pub struct ZKVMProvingKey> pub pp: PCS::ProverParam, pub vp: PCS::VerifierParam, // pk for opcode and table circuits - pub circuit_pks: BTreeMap>, + pub circuit_pks: BTreeMap>, + pub fixed_commit_wd: Option<>::CommitmentWithWitness>, + pub fixed_commit: Option<>::Commitment>, // expression for global state in/out pub initial_global_state_expr: Expression, @@ -381,8 +375,27 @@ impl> ZKVMProvingKey::BaseField>>, + ) -> Result<(), ZKVMError> { + if !fixed_traces.is_empty() { + let fixed_commit_wd = + PCS::batch_commit(&self.pp, fixed_traces).map_err(ZKVMError::PCSError)?; + let fixed_commit = PCS::get_pure_commitment(&fixed_commit_wd); + self.fixed_commit_wd = Some(fixed_commit_wd); + self.fixed_commit = Some(fixed_commit); + } else { + self.fixed_commit_wd = None; + self.fixed_commit = None; + } + Ok(()) + } } impl> ZKVMProvingKey { @@ -394,9 +407,15 @@ impl> ZKVMProvingKey> ZKVMProvingKey> { pub vp: PCS::VerifierParam, // vk for opcode and table circuits - pub circuit_vks: BTreeMap>, + pub circuit_vks: BTreeMap>, + pub fixed_commit: Option<>::Commitment>, // expression for global state in/out pub initial_global_state_expr: Expression, pub finalize_global_state_expr: Expression, + // circuit index -> (witin num_polys, fixed_num_polys) + pub circuit_num_polys: Vec<(usize, usize)>, } diff --git a/mpcs/benches/basefold.rs b/mpcs/benches/basefold.rs index adfaaadfe..fa0686ce9 100644 --- a/mpcs/benches/basefold.rs +++ b/mpcs/benches/basefold.rs @@ -8,8 +8,9 @@ use mpcs::{ Basefold, BasefoldRSParams, PolynomialCommitmentScheme, test_util::{get_point_from_challenge, setup_pcs}, }; +use std::collections::BTreeMap; -use multilinear_extensions::{mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension}; +use multilinear_extensions::mle::MultilinearExtension; use transcript::{BasicTranscript, Transcript}; use witness::RowMajorMatrix; @@ -45,17 +46,18 @@ fn bench_commit_open_verify_goldilocks>( let mut transcript = T::new(b"BaseFold"); let mut rng = rand::thread_rng(); - let rmm = RowMajorMatrix::rand(&mut rng, 1 << num_vars, 1); - let comm = Pcs::commit_and_write(&pp, rmm, &mut transcript).unwrap(); + let rmms = BTreeMap::from([(0, RowMajorMatrix::rand(&mut rng, 1 << num_vars, 1))]); + let comm = Pcs::batch_commit_and_write(&pp, rmms.clone(), &mut transcript).unwrap(); let poly = Pcs::get_arc_mle_witness_from_commitment(&comm).remove(0); group.bench_function(BenchmarkId::new("commit", format!("{}", num_vars)), |b| { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let rmm = RowMajorMatrix::rand(&mut rng, 1 << num_vars, 1); + let mut transcript = T::new(b"BaseFold"); + let rmms = rmms.clone(); let instant = std::time::Instant::now(); - Pcs::commit(&pp, rmm).unwrap(); + Pcs::batch_commit_and_write(&pp, rmms, &mut transcript).unwrap(); let elapsed = instant.elapsed(); time += elapsed; } @@ -98,34 +100,37 @@ fn bench_commit_open_verify_goldilocks>( } } -fn bench_simple_batch_commit_open_verify_goldilocks>( +fn bench_batch_commit_open_verify_goldilocks>( c: &mut Criterion, id: &str, ) { - let mut group = - c.benchmark_group(format!("simple_batch_commit_open_verify_goldilocks_{}", id,)); + let mut group = c.benchmark_group(format!("bench_batch_commit_open_verify_goldilocks{}", id,)); group.sample_size(NUM_SAMPLES); // Challenge is over extension field, poly over the base field for num_vars in NUM_VARS_START..=NUM_VARS_END { for batch_size_log in BATCH_SIZE_LOG_START..=BATCH_SIZE_LOG_END { let batch_size = 1 << batch_size_log; + let num_instances = vec![(0, 1 << num_vars)]; + let circuit_num_polys = vec![(batch_size, 0)]; let (pp, vp) = setup_pcs::(num_vars); let mut transcript = T::new(b"BaseFold"); let mut rng = rand::thread_rng(); - let rmm = RowMajorMatrix::rand(&mut rng, 1 << num_vars, batch_size); - let polys = rmm.to_mles(); - let mut transcript_for_bench = transcript.clone(); - let comm = Pcs::batch_commit_and_write(&pp, rmm, &mut transcript).unwrap(); + let rmms = + BTreeMap::from([(0, RowMajorMatrix::rand(&mut rng, 1 << num_vars, batch_size))]); + + let polys = rmms[&0].to_mles(); + let comm = Pcs::batch_commit_and_write(&pp, rmms.clone(), &mut transcript).unwrap(); + group.bench_function( BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), |b| { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let rmm = RowMajorMatrix::rand(&mut rng, 1 << num_vars, batch_size); + let mut transcript = T::new(b"BaseFold"); + let rmms = rmms.clone(); let instant = std::time::Instant::now(); - Pcs::batch_commit_and_write(&pp, rmm, &mut transcript_for_bench) - .unwrap(); + Pcs::batch_commit_and_write(&pp, rmms, &mut transcript).unwrap(); let elapsed = instant.elapsed(); time += elapsed; } @@ -138,12 +143,17 @@ fn bench_simple_batch_commit_open_verify_goldilocks>(); - let proof = Pcs::simple_batch_open(&pp, &polys, &comm, &point, &evals, &mut transcript) - .unwrap(); + let proof = Pcs::batch_open( + &pp, + &num_instances, + None, + &comm, + &[point.clone()], + &[evals.clone()], + &circuit_num_polys, + &mut transcript, + ) + .unwrap(); group.bench_function( BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), @@ -151,12 +161,14 @@ fn bench_simple_batch_commit_open_verify_goldilocks(c, "rs"); } -fn bench_simple_batch_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { - bench_simple_batch_commit_open_verify_goldilocks::(c, "rs"); +fn bench_batch_commit_open_verify_goldilocks_base_rs(c: &mut Criterion) { + bench_batch_commit_open_verify_goldilocks::(c, "rs"); } criterion_group! { name = bench_basefold; config = Criterion::default().warm_up_time(Duration::from_millis(3000)); targets = - bench_simple_batch_commit_open_verify_goldilocks_base_rs, + bench_batch_commit_open_verify_goldilocks_base_rs, bench_commit_open_verify_goldilocks_base_rs, } diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs index d76e924cd..36a7defef 100644 --- a/mpcs/benches/rscode.rs +++ b/mpcs/benches/rscode.rs @@ -41,7 +41,7 @@ fn bench_encoding(c: &mut Criterion) { for _ in 0..iters { let rmm = RowMajorMatrix::rand(&mut OsRng, 1 << num_vars, batch_size); let instant = std::time::Instant::now(); - <>::EncodingScheme as EncodingScheme>::encode(&pp.encoding_params, rmm); + <>::EncodingScheme as EncodingScheme>::encode(&pp.encoding_params, rmm).expect("encode error"); let elapsed = instant.elapsed(); time += elapsed; } diff --git a/mpcs/benches/whir.rs b/mpcs/benches/whir.rs index 09f86cd31..ede7d48f2 100644 --- a/mpcs/benches/whir.rs +++ b/mpcs/benches/whir.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{collections::BTreeMap, time::Duration}; use criterion::*; use ff_ext::GoldilocksExt2; @@ -107,7 +107,8 @@ fn bench_simple_batch_commit_open_verify_goldilocks(num_vars); - let rmm = RowMajorMatrix::rand(&mut rng, 1 << num_vars, batch_size); + let rmms = + BTreeMap::from([(0, RowMajorMatrix::rand(&mut rng, 1 << num_vars, batch_size))]); group.bench_function( BenchmarkId::new("batch_commit", format!("{}-{}", num_vars, batch_size)), @@ -116,8 +117,9 @@ fn bench_simple_batch_commit_open_verify_goldilocks { - Normal(Box>), - TooSmall(Box>), // The polynomial is too small to apply FRI - TooBig(usize), -} impl> Basefold where @@ -213,91 +206,89 @@ where fn batch_commit( pp: &Self::ProverParam, - rmm: witness::RowMajorMatrix<::BaseField>, + rmms: BTreeMap::BaseField>>, ) -> Result { - let span = entered_span!("to_mles", profiling_3 = true); - let polys = rmm.to_mles(); - exit_span!(span); - // assumptions - // 1. there must be at least one polynomial - // 2. all polynomials must exist in the same field type - // (TODO: eliminate this assumption by supporting commiting - // and opening mixed-type polys) - // 3. all polynomials must have the same number of variables - - if polys.is_empty() { + if rmms.is_empty() { return Err(Error::InvalidPcsParam( "cannot batch commit to zero polynomials".to_string(), )); } - let num_vars = polys[0].num_vars(); - let num_polys = polys.len(); + let mmcs = poseidon2_merkle_tree::(); - let is_base = match polys[0].evaluations() { - FieldType::Ext(_) => false, - FieldType::Base(_) => true, - _ => unreachable!(), - }; + let span = entered_span!("to_mles", profiling_3 = true); + let (polys, rmm_to_batch_commit, trivial_proofdata, circuit_codeword_index): ( + BTreeMap>>, + Vec<_>, + _, + _, + ) = rmms.into_iter().fold( + (BTreeMap::new(), vec![], BTreeMap::new(), BTreeMap::new()), + |( + mut polys, + mut rmm_to_batch_commit, + mut trivial_proofdata, + mut circuit_codeword_index, + ), + (index, rmm)| { + // attach column-based poly + polys.insert( + index, + rmm.to_mles().into_iter().map(|p| p.into()).collect_vec(), + ); + + // for smaller polys, we commit it separately as prover will send it as a whole without blowing factor + if rmm.num_vars() <= Spec::get_basecode_msg_size_log() { + let rmm = rmm.into_default_padded_p3_rmm(); + trivial_proofdata.insert(index, mmcs.commit_matrix(rmm)); + } else { + rmm_to_batch_commit.push(rmm); + circuit_codeword_index.insert(index, rmm_to_batch_commit.len() - 1); + } + ( + polys, + rmm_to_batch_commit, + trivial_proofdata, + circuit_codeword_index, + ) + }, + ); + exit_span!(span); - if !polys.iter().map(|poly| poly.num_vars()).all_equal() { - return Err(Error::InvalidPcsParam( - "cannot batch commit to polynomials with different number of variables".to_string(), - )); + if rmm_to_batch_commit.is_empty() { + todo!("support all are trivial commitment") } - let timer = start_timer!(|| "Basefold::batch commit"); - let encode_timer = start_timer!(|| "Basefold::batch commit::encoding and interpolations"); - let span = entered_span!("encode_codeword_and_mle", profiling_3 = true); - - let evals_codewords = Spec::EncodingScheme::encode(&pp.encoding_params, rmm); + let span = entered_span!("encode_codeword_and_mle", profiling_3 = true); + let evals_codewords = rmm_to_batch_commit + .into_iter() + .map(|rmm| Spec::EncodingScheme::encode(&pp.encoding_params, rmm)) + .collect::>, _>>()?; exit_span!(span); - end_timer!(encode_timer); let span = entered_span!("build mt", profiling_3 = true); - // build merkle tree - let ret = match evals_codewords { - PolyEvalsCodeword::Normal(codewords) => { - let mmcs = poseidon2_merkle_tree::(); - let (comm, codeword) = mmcs.commit_matrix(*codewords); - let polys: Vec> = - polys.into_iter().map(|poly| poly.into()).collect_vec(); - Self::CommitmentWithWitness { - pi_d_digest: comm, - codeword, - polynomials_bh_evals: polys, - num_vars, - is_base, - num_polys, - } - } - PolyEvalsCodeword::TooSmall(rmm_padded) => { - let mmcs = poseidon2_merkle_tree::(); - let (comm, codeword) = mmcs.commit_matrix(*rmm_padded); - let polys: Vec> = - polys.into_iter().map(|poly| poly.into()).collect_vec(); - Self::CommitmentWithWitness { - pi_d_digest: comm, - codeword, - polynomials_bh_evals: polys, - num_vars, - is_base, - num_polys, - } - } - PolyEvalsCodeword::TooBig(_) => return Err(Error::PolynomialTooLarge(num_vars)), - }; + let (comm, codeword) = mmcs.commit(evals_codewords); exit_span!(span); - end_timer!(timer); - - Ok(ret) + Ok(BasefoldCommitmentWithWitness::new( + comm, + codeword, + polys, + trivial_proofdata, + circuit_codeword_index, + )) } fn write_commitment( comm: &Self::Commitment, transcript: &mut impl Transcript, ) -> Result<(), Error> { - write_digest_to_transcript(&comm.pi_d_digest(), transcript); + write_digest_to_transcript(&comm.commit(), transcript); + // write trivial_commits to transcript + for trivial_commit in &comm.trivial_commits { + write_digest_to_transcript(trivial_commit, transcript); + } + transcript + .append_field_element(&E::BaseField::from_u64(comm.log2_max_codeword_size as u64)); Ok(()) } @@ -320,158 +311,373 @@ where } /// Open a batch of polynomial commitments at several points. - /// The current version only supports one polynomial per commitment. - /// Because otherwise it is complex to match the polynomials and - /// the commitments, and because currently this high flexibility is - /// not very useful in ceno. fn batch_open( - _pp: &Self::ProverParam, - _polys: &[ArcMultilinearExtension], - _comms: &[Self::CommitmentWithWitness], - _points: &[Vec], - _evals: &[Evaluation], - _transcript: &mut impl Transcript, - ) -> Result { - unimplemented!() - } - - /// This is a simple version of batch open: - /// 1. Open at one point - /// 2. All the polynomials share the same commitment and have the same - /// number of variables. - /// 3. The point is already a random point generated by a sum-check. - fn simple_batch_open( pp: &Self::ProverParam, - polys: &[ArcMultilinearExtension], - comm: &Self::CommitmentWithWitness, - point: &[E], - evals: &[E], + num_instances: &[(usize, usize)], + fixed_comms: Option<&Self::CommitmentWithWitness>, + witin_comms: &Self::CommitmentWithWitness, + points: &[Point], + // TODO this is only for debug purpose + evals: &[Vec], + circuit_num_polys: &[(usize, usize)], transcript: &mut impl Transcript, ) -> Result { - let timer = start_timer!(|| "Basefold::batch_open"); - let num_vars = polys[0].num_vars(); + let span = entered_span!("Basefold::batch_open"); + + // sanity check + // number of point match with commitment length, assuming each commitment are opening under same point + assert!([points.len(), witin_comms.polys.len(),].iter().all_equal()); + assert!(izip!(&witin_comms.polys, num_instances).all( + |((circuit_index, witin_polys), (_, num_instance))| { + // check num_vars & num_poly match + let (expected_witin_num_polys, expected_fixed_num_polys) = + circuit_num_polys[*circuit_index]; + let num_var = next_pow2_instance_padding(*num_instance).ilog2() as usize; + witin_polys + .iter() + .chain( + fixed_comms + .and_then(|fixed_comms| fixed_comms.polys.get(circuit_index)) + .into_iter() + .flatten(), + ) + .all(|p| p.num_vars() == num_var) + && witin_polys.len() == expected_witin_num_polys + && fixed_comms + .and_then(|fixed_comms| fixed_comms.polys.get(circuit_index)) + .map(|fixed_polys| fixed_polys.len() == expected_fixed_num_polys) + .unwrap_or(true) + }, + )); + assert_eq!(num_instances.len(), witin_comms.polys.len()); - if comm.is_trivial::() { - let mmcs = poseidon2_merkle_tree::(); - return Ok(Self::Proof::trivial( - mmcs.get_matrices(&comm.codeword)[0].clone(), - )); + if cfg!(feature = "sanity-check") { + // check poly evaluation on point equal eval + let point_poly_pair = witin_comms + .polys + .iter() + .zip_eq(points) + .flat_map(|((circuit_index, witin_polys), point)| { + let fixed_iter = fixed_comms + .and_then(|fixed_comms| fixed_comms.polys.get(circuit_index)) + .into_iter() + .flatten() + .cloned(); + let flatten_polys = witin_polys + .iter() + .cloned() + .chain(fixed_iter) + .collect::>>(); + let points = vec![point.clone(); flatten_polys.len()]; + izip!(points, flatten_polys) + }) + .collect::, ArcMultilinearExtension)>>(); + assert!( + point_poly_pair + .into_iter() + .zip_eq(evals.iter().flatten()) + .all(|((point, poly), eval)| poly.evaluate(&point) == *eval) + ); } - polys + // identify trivial/non trivial based on size + let (trivial_witin_polys_and_meta, witin_polys_and_meta): (Vec<_>, _) = + izip!(points, &witin_comms.polys) + .map(|(point, (circuit_index, polys))| (point, (*circuit_index, polys))) + .partition(|(_, (_, polys))| { + polys[0].num_vars() <= Spec::get_basecode_msg_size_log() + }); + + let max_num_vars = witin_polys_and_meta .iter() - .for_each(|poly| assert_eq!(poly.num_vars(), num_vars)); - assert!(num_vars >= Spec::get_basecode_msg_size_log()); - assert_eq!(comm.num_polys, polys.len()); - assert_eq!(comm.num_polys, evals.len()); + .map(|(_, (_, polys))| polys[0].num_vars()) + .max() + .unwrap(); - if cfg!(feature = "sanity-check") { - evals - .iter() - .zip(polys) - .for_each(|(eval, poly)| assert_eq!(&poly.evaluate(point), eval)) - } - // evals.len() is the batch size, i.e., how many polynomials are being opened together - let batch_coeffs = &transcript - .sample_and_append_challenge_pows(evals.len(), b"batch coeffs")[0..evals.len()]; - let _target_sum = inner_product(evals, batch_coeffs); - - // Now the verifier has obtained the new target sum, and is able to compute the random - // linear coefficients. - // The remaining tasks for the prover is to prove that - // sum_i coeffs[i] poly_evals[i] is equal to - // the new target sum, where coeffs is computed as follows - let (trees, commit_phase_proof) = simple_batch_commit_phase::( + // Basefold IOP commit phase + let commit_phase_span = entered_span!("Basefold::open::commit_phase"); + let (trees, commit_phase_proof) = batch_commit_phase::( &pp.encoding_params, - point, - batch_coeffs, - comm, + fixed_comms, + &witin_comms.codeword, + witin_polys_and_meta, transcript, - num_vars, - num_vars - Spec::get_basecode_msg_size_log(), + max_num_vars, + max_num_vars - Spec::get_basecode_msg_size_log(), + circuit_num_polys, ); + exit_span!(commit_phase_span); - let query_timer = start_timer!(|| "Basefold::open::query_phase"); + // for smaller poly, we pass their merkle tree leafs directly + let commit_trivial_span = entered_span!("Basefold::open::commit_trivial"); + let trivial_proof = if !trivial_witin_polys_and_meta.is_empty() { + let mmcs = poseidon2_merkle_tree::(); + Some( + trivial_witin_polys_and_meta + .iter() + .map(|(_, (circuit_index, _))| { + mmcs.get_matrices(&witin_comms.trivial_proofdata[circuit_index].1) + .into_iter() + .take(1) + .chain( + // fixed proof is optional + fixed_comms + .and_then(|fixed_comms| { + fixed_comms.trivial_proofdata.get(circuit_index) + }) + .iter() + .flat_map(|(_, proof_data)| { + mmcs.get_matrices(proof_data).into_iter().take(1) + }), + ) + .cloned() + .collect_vec() + }) + .collect_vec(), + ) + } else { + None + }; + exit_span!(commit_trivial_span); + + let query_span = entered_span!("Basefold::open::query_phase"); // Each entry in queried_els stores a list of triples (F, F, i) indicating the // position opened at each round and the two values at that round - let query_opening_proof = - simple_batch_prover_query_phase(transcript, comm, &trees, Spec::get_number_queries()); - end_timer!(query_timer); + let query_opening_proof = batch_query_phase( + transcript, + fixed_comms, + witin_comms, + &trees, + Spec::get_number_queries(), + ); + exit_span!(query_span); - end_timer!(timer); + exit_span!(span); Ok(Self::Proof { - sumcheck_messages: commit_phase_proof.sumcheck_messages, commits: commit_phase_proof.commits, final_message: commit_phase_proof.final_message, query_opening_proof, - sumcheck_proof: None, - trivial_proof: None, + sumcheck_proof: Some(commit_phase_proof.sumcheck_messages), + trivial_proof, }) } - fn verify( - _vp: &Self::VerifierParam, - _comm: &Self::Commitment, + /// This is a simple version of batch open: + /// 1. Open at one point + /// 2. All the polynomials share the same commitment and have the same + /// number of variables. + /// 3. The point is already a random point generated by a sum-check. + fn simple_batch_open( + _pp: &Self::ProverParam, + _polys: &[ArcMultilinearExtension], + _comm: &Self::CommitmentWithWitness, _point: &[E], - _eval: &E, - _proof: &Self::Proof, + _evals: &[E], _transcript: &mut impl Transcript, - ) -> Result<(), Error> { + ) -> Result { unimplemented!() } - fn batch_verify( + fn verify( _vp: &Self::VerifierParam, - _comms: &[Self::Commitment], - _points: &[Vec], - _evals: &[Evaluation], + _comm: &Self::Commitment, + _point: &[E], + _eval: &E, _proof: &Self::Proof, _transcript: &mut impl Transcript, ) -> Result<(), Error> { unimplemented!() } - fn simple_batch_verify( + fn batch_verify( vp: &Self::VerifierParam, - comm: &Self::Commitment, - point: &[E], - evals: &[E], + num_instances: &[(usize, usize)], + points: &[Point], + fixed_comms: Option<&Self::Commitment>, + witin_comms: &Self::Commitment, + evals: &[Vec], proof: &Self::Proof, + circuit_num_polys: &[(usize, usize)], transcript: &mut impl Transcript, ) -> Result<(), Error> { - let timer = start_timer!(|| "Basefold::simple batch verify"); - let batch_size = evals.len(); - if let Some(num_polys) = comm.num_polys { - assert_eq!(num_polys, batch_size); - } + let mmcs = poseidon2_merkle_tree::(); - if proof.is_trivial() { - let trivial_proof = proof.trivial_proof.as_ref().unwrap(); - let mmcs = poseidon2_merkle_tree::(); - let (root, _) = mmcs.commit_matrix(trivial_proof.clone()); - if comm.pi_d_digest() == root { - return Ok(()); - } else { - return Err(Error::MerkleRootMismatch); - } + assert_eq!(num_instances.len(), points.len()); + + let circuit_num_vars = num_instances + .iter() + .map(|(index, num_instance)| { + ( + *index, + next_pow2_instance_padding(*num_instance).ilog2() as usize, + ) + }) + .collect_vec(); + + assert!( + izip!(&circuit_num_vars, points) + .all(|((_, circuit_num_var), point)| point.len() == *circuit_num_var) + ); + + // preprocess data into respective group, in particularly, trivials vs non-trivials + let mut circuit_metas = vec![]; + let mut circuit_trivial_metas = vec![]; + let mut evals_iter = evals.iter().cloned(); + let (trivial_point_evals, point_evals) = izip!(&circuit_num_vars, points).fold( + (vec![], vec![]), + |(mut trivial_point_evals, mut point_evals), ((circuit_index, num_var), point)| { + let (expected_witins_num_poly, expected_fixed_num_poly) = + &circuit_num_polys[*circuit_index]; + let mut circuit_meta = CircuitIndexMeta { + witin_num_vars: *num_var, + witin_num_polys: *expected_witins_num_poly, + ..Default::default() + }; + // NOTE: for evals, we concat witin with fixed to make process easier + if *num_var <= Spec::get_basecode_msg_size_log() { + trivial_point_evals.push(( + point.clone(), + evals_iter.next().unwrap()[0..*expected_witins_num_poly].to_vec(), + )); + if *expected_fixed_num_poly > 0 { + circuit_meta.fixed_num_vars = *num_var; + circuit_meta.fixed_num_polys = *expected_fixed_num_poly; + trivial_point_evals.last_mut().unwrap().1.extend( + evals_iter.next().unwrap()[0..*expected_fixed_num_poly].to_vec(), + ) + } + circuit_trivial_metas.push(circuit_meta); + } else { + point_evals.push(( + point.clone(), + evals_iter.next().unwrap()[0..*expected_witins_num_poly].to_vec(), + )); + if *expected_fixed_num_poly > 0 { + circuit_meta.fixed_num_vars = *num_var; + circuit_meta.fixed_num_polys = *expected_fixed_num_poly; + point_evals.last_mut().unwrap().1.extend( + evals_iter.next().unwrap()[0..*expected_fixed_num_poly].to_vec(), + ); + } + circuit_metas.push(circuit_meta); + } + + (trivial_point_evals, point_evals) + }, + ); + assert!(evals_iter.next().is_none()); + + // check trivial proofs + if !circuit_trivial_metas.is_empty() { + let mut trivial_fixed_commit = fixed_comms + .as_ref() + .map(|fc| fc.trivial_commits.iter().cloned()) + .unwrap_or_default(); + assert!(proof.trivial_proof.is_some()); + assert!( + [ + circuit_trivial_metas.len(), + proof.trivial_proof.as_ref().unwrap().len(), + witin_comms.trivial_commits.len() + ] + .iter() + .all_equal() + ); + + // 1. check mmcs verify opening + // 2. check mle.evaluate(point) == evals + circuit_trivial_metas + .iter() + .zip_eq(proof.trivial_proof.as_ref().unwrap()) + .zip_eq(&trivial_point_evals) + .zip_eq(&witin_comms.trivial_commits) + .try_for_each( + |( + ( + ( + CircuitIndexMeta { + fixed_num_polys, .. + }, + proof, + ), + (point, witin_fixed_evals), + ), + witin_commit, + )| { + let witin_rmm = proof[0].clone(); + let (commit, _) = mmcs.commit_matrix(witin_rmm.clone()); + if commit != *witin_commit { + Err(Error::MerkleRootMismatch)?; + } + let mut mles = RowMajorMatrix::new_by_inner_matrix( + witin_rmm, + InstancePaddingStrategy::Default, + ) + .to_mles(); + + if *fixed_num_polys > 0 { + let fixed_rmm = proof[1].clone(); + let fixed_commit = + trivial_fixed_commit.next().expect("proof must exist"); + // NOTE rmm clone here is ok since trivial proof is relatively small + let (commit, _) = mmcs.commit_matrix(fixed_rmm.clone()); + if commit != fixed_commit { + Err(Error::MerkleRootMismatch)?; + } + mles.extend( + RowMajorMatrix::new_by_inner_matrix( + fixed_rmm, + InstancePaddingStrategy::Default, + ) + .to_mles(), + ); + } + + mles.iter() + .zip_eq(witin_fixed_evals) + .all(|(mle, eval)| mle.evaluate(point) == *eval) + .then_some(()) + .ok_or_else(|| { + Error::PointEvalMismatch("trivial point eval mismatch".to_string()) + }) + }, + )?; } - let num_vars = point.len(); - if let Some(comm_num_vars) = comm.num_vars() { - assert_eq!(num_vars, comm_num_vars); - assert!(num_vars >= Spec::get_basecode_msg_size_log()); + if circuit_metas.is_empty() { + return Ok(()); + } else { + assert!( + !proof.final_message.is_empty() + && proof + .final_message + .iter() + .map(|final_message| { final_message.len() }) + .chain(std::iter::once(1 << Spec::get_basecode_msg_size_log())) + .all_equal(), + "final message size should be equal to 1 << Spec::get_basecode_msg_size_log()" + ); + assert!(proof.sumcheck_proof.is_some(), "sumcheck proof must exist"); + assert_eq!(proof.query_opening_proof.len(), Spec::get_number_queries()) } - let num_rounds = num_vars - Spec::get_basecode_msg_size_log(); - // evals.len() is the batch size, i.e., how many polynomials are being opened together + // verify non trivial proof + let total_num_polys = circuit_metas + .iter() + .map(|circuit_meta| circuit_meta.witin_num_polys + circuit_meta.fixed_num_polys) + .sum(); let batch_coeffs = - transcript.sample_and_append_challenge_pows(evals.len(), b"batch coeffs"); + &transcript.sample_and_append_challenge_pows(total_num_polys, b"batch coeffs"); + + let max_num_var = *circuit_num_vars.iter().map(|(_, n)| n).max().unwrap(); + let num_rounds = max_num_var - Spec::get_basecode_msg_size_log(); - let mut fold_challenges: Vec = Vec::with_capacity(num_vars); + // prepare folding challenges via sumcheck round msg + FRI commitment + let mut fold_challenges: Vec = Vec::with_capacity(max_num_var); let commits = &proof.commits; - let sumcheck_messages = &proof.sumcheck_messages; + let sumcheck_messages = proof.sumcheck_proof.as_ref().unwrap(); for i in 0..num_rounds { - transcript.append_field_element_exts(sumcheck_messages[i].as_slice()); + transcript.append_field_element_exts(sumcheck_messages[i].evaluations.as_slice()); fold_challenges.push( transcript .sample_and_append_challenge(b"commit round") @@ -482,44 +688,49 @@ where } } let final_message = &proof.final_message; - transcript.append_field_element_exts(final_message.as_slice()); + transcript.append_field_element_exts_iter(proof.final_message.iter().flatten()); - let queries: Vec<_> = transcript - .sample_and_append_vec(b"query indices", Spec::get_number_queries()) - .into_iter() - .map(|r| ext_to_usize(&r) % (1 << (num_vars + Spec::get_rate_log()))) - .collect(); - - // coeff is the eq polynomial evaluated at the first challenge.len() variables - let coeff = eq_eval(&point[..fold_challenges.len()], &fold_challenges); - // Compute eq as the partially evaluated eq polynomial - let mut eq = build_eq_x_r_vec(&point[fold_challenges.len()..]); - eq.par_iter_mut().for_each(|e| *e *= coeff); + let queries: Vec<_> = transcript.sample_bits_and_append_vec( + b"query indices", + Spec::get_number_queries(), + max_num_var + Spec::get_rate_log(), + ); - simple_batch_verifier_query_phase::( - queries.as_slice(), + // verify basefold sumcheck + FRI codeword query + batch_verifier_query_phase::( + max_num_var, + &queries, &vp.encoding_params, + final_message, + batch_coeffs, &proof.query_opening_proof, - sumcheck_messages, + fixed_comms, + witin_comms, + &circuit_metas, + &proof.commits, &fold_challenges, - &batch_coeffs, - num_rounds, - num_vars, - final_message, - commits, - comm, - eq.as_slice(), - evals, + sumcheck_messages, + &point_evals, ); - end_timer!(timer); Ok(()) } + fn simple_batch_verify( + _vp: &Self::VerifierParam, + _comm: &Self::Commitment, + _point: &[E], + _evals: &[E], + _proof: &Self::Proof, + _transcript: &mut impl Transcript, + ) -> Result<(), Error> { + unimplemented!() + } + fn get_arc_mle_witness_from_commitment( commitment: &Self::CommitmentWithWitness, ) -> Vec> { - commitment.polynomials_bh_evals.clone() + commitment.polys.values().flatten().cloned().collect_vec() } } @@ -529,7 +740,7 @@ mod test { use crate::{ basefold::Basefold, - test_util::{run_commit_open_verify, run_simple_batch_commit_open_verify}, + test_util::{run_batch_commit_open_verify, run_commit_open_verify}, }; use super::BasefoldRSParams; @@ -537,20 +748,19 @@ mod test { type PcsGoldilocksRSCode = Basefold; #[test] - fn simple_batch_commit_open_verify_goldilocks() { + fn batch_commit_open_verify_goldilocks() { // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::(10, 11, 1); - run_simple_batch_commit_open_verify::(10, 11, 4); - // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::(4, 6, 4); + run_batch_commit_open_verify::(10, 11, 1); + run_batch_commit_open_verify::(10, 11, 4); + // TODO support all trivial proof } #[test] #[ignore = "For benchmarking and profiling only"] - fn bench_basefold_simple_batch_commit_open_verify_goldilocks() { + fn bench_basefold_batch_commit_open_verify_goldilocks() { { run_commit_open_verify::(20, 21); - run_simple_batch_commit_open_verify::(20, 21, 64); + run_batch_commit_open_verify::(20, 21, 64); } } } diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index f7767f3d6..ee4c300b7 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -2,107 +2,222 @@ use std::sync::Arc; use super::{ encoding::EncodingScheme, - structure::{BasefoldCommitPhaseProof, BasefoldSpec, MerkleTreeExt}, + structure::{BasefoldCommitPhaseProof, BasefoldSpec, MerkleTree, MerkleTreeExt}, }; -use crate::util::{ - hash::write_digest_to_transcript, - merkle_tree::{Poseidon2ExtMerkleMmcs, poseidon2_merkle_tree}, +use crate::{ + Point, + util::{ + codeword_fold_with_challenge, + hash::write_digest_to_transcript, + merkle_tree::{Poseidon2ExtMerkleMmcs, poseidon2_merkle_tree}, + pop_front_while, split_by_sizes, + }, }; -use ark_std::{end_timer, start_timer}; use ff_ext::{ExtensionField, PoseidonField}; -use itertools::izip; +use itertools::{Itertools, izip}; use p3::{ commit::{ExtensionMmcs, Mmcs}, field::{Field, PrimeCharacteristicRing, dot_product}, - matrix::dense::RowMajorMatrix, + matrix::{ + Matrix, + dense::{DenseMatrix, RowMajorMatrix}, + }, util::log2_strict_usize, }; use serde::{Serialize, de::DeserializeOwned}; +use std::collections::VecDeque; use sumcheck::{ macros::{entered_span, exit_span}, - structs::IOPProverState, + structs::{IOPProverMessage, IOPProverState}, util::{AdditiveVec, merge_sumcheck_prover_state, optimal_sumcheck_threads}, }; use transcript::{Challenge, Transcript}; use multilinear_extensions::{ - commutative_op_mle_pair, mle::{DenseMultilinearExtension, IntoMLE}, - util::ceil_log2, virtual_poly::{ArcMultilinearExtension, build_eq_x_r_vec}, virtual_polys::VirtualPolynomials, }; use rayon::{ - iter::{IntoParallelRefIterator, IntoParallelRefMutIterator}, - prelude::{IndexedParallelIterator, ParallelIterator, ParallelSlice}, + iter::{IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}, + prelude::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, }; use super::structure::BasefoldCommitmentWithWitness; -// outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) #[allow(clippy::too_many_arguments)] -pub fn simple_batch_commit_phase>( +#[allow(clippy::type_complexity)] +pub fn batch_commit_phase>( pp: &>::ProverParameters, - point: &[E], - batch_coeffs: &[E], - comm: &BasefoldCommitmentWithWitness, + fixed_comms: Option<&BasefoldCommitmentWithWitness>, + witin_commitment_with_witness: &MerkleTree, + witin_polys_and_meta: Vec<( + &Point, + (usize, &Vec>), + )>, transcript: &mut impl Transcript, - num_vars: usize, + max_num_vars: usize, num_rounds: usize, + circuit_num_polys: &[(usize, usize)], ) -> (Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, as Mmcs>::Commitment: IntoIterator + PartialEq, { - assert_eq!(point.len(), num_vars); - assert_eq!(comm.num_polys, batch_coeffs.len()); - let prepare_timer = start_timer!(|| "Prepare"); + let prepare_span = entered_span!("Prepare"); let mmcs_ext = ExtensionMmcs::::new(poseidon2_merkle_tree::()); let mmcs = poseidon2_merkle_tree::(); - let mut trees: Vec> = Vec::with_capacity(num_vars); + let mut trees: Vec> = Vec::with_capacity(max_num_vars); - let batch_codewords_timer = start_timer!(|| "Batch codewords"); - let initial_oracle = mmcs.get_matrices(&comm.codeword)[0] - .values - .par_chunks(comm.num_polys) - .map(|row| dot_product(batch_coeffs.iter().copied(), row.iter().copied())) + // concat witin mle with fixed mle under same circuit index + let witin_concat_with_fixed_polys: Vec>> = witin_polys_and_meta + .iter() + .map(|(_, (circuit_index, witin_polys))| { + let fixed_iter = fixed_comms + .and_then(|fixed_comms| fixed_comms.polys.get(circuit_index)) + .into_iter() + .flatten() + .cloned(); + witin_polys.iter().cloned().chain(fixed_iter).collect() + }) + .collect::>>(); + let batch_group_size = witin_concat_with_fixed_polys + .iter() + .map(|v| v.len()) + .collect_vec(); + let total_num_polys = batch_group_size.iter().sum(); + + let batch_coeffs = + &transcript.sample_and_append_challenge_pows(total_num_polys, b"batch coeffs"); + // split batch coeffs to match with batch group for easier handling + let batch_coeffs_splitted = split_by_sizes(batch_coeffs, &batch_group_size); + + // prepare + // - codeword oracle => for FRI + // - evals => for sumcheck + let witins_codeword = mmcs.get_matrices(witin_commitment_with_witness); + let fixed_codeword = fixed_comms + .map(|fixed_comms| mmcs.get_matrices(&fixed_comms.codeword)) + .unwrap_or_default(); + + let batch_codeword_span = entered_span!("batch_codeword"); + // we random linear combination of rmm under same circuit into single codeword, as they shared same height + let batched_codewords: Vec> = witins_codeword + .iter() + .zip_eq(&batch_coeffs_splitted) + .zip_eq(&witin_polys_and_meta) + .map( + |((witin_codewords, batch_coeffs), (_, (circuit_index, _)))| { + let (expected_witins_num_poly, expected_fixed_num_poly) = + circuit_num_polys[*circuit_index]; + // batch_coeffs concat witin follow by fixed, where fixed is optional + let witin_fixed_concated_codeword: Vec<(_, usize)> = + std::iter::once((witin_codewords, expected_witins_num_poly)) + .chain( + fixed_comms + .and_then(|fixed_comms| { + fixed_comms.circuit_codeword_index.get(circuit_index) + }) + .and_then(|idx| { + fixed_codeword + .get(*idx) + .map(|rmm| (rmm, expected_fixed_num_poly)) + }), + ) + .collect_vec(); + // final poly size is 2 * height because we commit left: poly[j] and right: poly[j + ni] under same mk path (due to bit-reverse) + let size = witin_fixed_concated_codeword[0].0.height() * 2; + RowMajorMatrix::new( + (0..size) + .into_par_iter() + .map(|j| { + witin_fixed_concated_codeword + .iter() + .scan(0, |start_index, (rmm, num_polys)| { + let batch_coeffs = batch_coeffs + [*start_index..*start_index + num_polys] + .iter() + .copied(); + *start_index += num_polys; + Some(dot_product( + batch_coeffs, + rmm.values[j * num_polys..(j + 1) * num_polys] + .iter() + .copied(), + )) + }) + .sum::() + }) + .collect::>(), + 2, + ) + }, + ) + .collect_vec(); + assert!( + [witin_polys_and_meta.len(), batched_codewords.len(),] + .iter() + .all_equal() + ); + // sorted batch codewords by height in descending order + let mut batched_codewords = VecDeque::from( + batched_codewords + .into_iter() + .sorted_by_key(|codeword| std::cmp::Reverse(codeword.height())) + .collect_vec(), + ); + exit_span!(batch_codeword_span); + + let batched_evals = entered_span!("batched_evals"); + let initial_rlc_evals: Vec> = witin_concat_with_fixed_polys + .par_iter() + .zip_eq(batch_coeffs_splitted.par_iter()) + .map(|(witin_fixed_mle, batch_coeffs)| { + assert_eq!(witin_fixed_mle.len(), batch_coeffs.len()); + let num_vars = witin_fixed_mle[0].num_vars(); + let mle_base_vec = witin_fixed_mle + .iter() + .map(|mle| mle.get_base_field_vec()) + .collect_vec(); + let running_evals: ArcMultilinearExtension<_> = + Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + num_vars, + (0..witin_fixed_mle[0].evaluations().len()) + .into_par_iter() + .map(|j| { + dot_product( + batch_coeffs.iter().copied(), + mle_base_vec.iter().map(|mle| mle[j]), + ) + }) + .collect::>(), + )); + running_evals + }) .collect::>(); - end_timer!(batch_codewords_timer); - - let Some((running_evals, _)): Option<(ArcMultilinearExtension, E)> = izip!( - comm.polynomials_bh_evals.iter().cloned(), - batch_coeffs.iter().cloned() - ) - .reduce(|(poly_a, coeff_a), (poly_b, coeff_b)| { - let next_poly = commutative_op_mle_pair!(|poly_a, poly_b| { - // TODO we can save a bit cost if first batch_coeffs is E::ONE so we can skip the first base * ext operation - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - num_vars, - poly_a - .par_iter() - .zip(poly_b.par_iter()) - .map(|(a, b)| coeff_a * *a + coeff_b * *b) - .collect(), - )) - }); - (next_poly, E::ONE) - }) else { - unimplemented!() - }; - end_timer!(prepare_timer); + exit_span!(batched_evals); + exit_span!(prepare_span); // eq is the evaluation representation of the eq(X,r) polynomial over the hypercube - let build_eq_timer = start_timer!(|| "Basefold::build eq"); - let eq = build_eq_x_r_vec(point).into_mle().into(); - end_timer!(build_eq_timer); + let build_eq_span = entered_span!("Basefold::build eq"); + let eq: Vec> = witin_polys_and_meta + .par_iter() + .map(|(point, _)| build_eq_x_r_vec(point).into_mle().into()) + .collect::>(); + exit_span!(build_eq_span); + + let num_threads = optimal_sumcheck_threads(max_num_vars); + let log2_num_threads = log2_strict_usize(num_threads); + + // sumcheck formula: \sum_i \sum_b eq[point_i; b_i] * running_eval_i[b_i], |b_i| <= b and aligned on suffix + let mut polys = VirtualPolynomials::new(num_threads, max_num_vars); - let num_threads = optimal_sumcheck_threads(num_vars); - let log_num_threads = ceil_log2(num_threads); + izip!(&eq, &initial_rlc_evals) + .for_each(|(eq, running_evals)| polys.add_mle_list(vec![&eq, &running_evals], E::ONE)); - let mut polys = VirtualPolynomials::new(num_threads, num_vars); - polys.add_mle_list(vec![&eq, &running_evals], E::ONE); let (batched_polys, poly_meta) = polys.get_batched_polys(); let mut prover_states = batched_polys @@ -113,26 +228,25 @@ where thread_id == 0, // set thread_id 0 to be main worker poly, vec![(vec![], vec![])], - Some(log_num_threads), + Some(log2_num_threads), Some(poly_meta.clone()), ) }) .collect::>(); - - let mut challenge = None; - let mut sumcheck_messages = Vec::with_capacity(num_rounds); let mut commits = Vec::with_capacity(num_rounds - 1); + let mut challenge = None; let sumcheck_phase1 = entered_span!("sumcheck_phase1"); - let phase1_rounds = num_rounds.min(num_vars - log2_strict_usize(num_threads)); + let phase1_rounds = num_rounds.min(max_num_vars - log2_num_threads); + for i in 0..phase1_rounds { challenge = basefold_one_round::( pp, &mut prover_states, challenge, &mut sumcheck_messages, - &initial_oracle, + &mut batched_codewords, transcript, &mut trees, &mut commits, @@ -140,6 +254,7 @@ where i == num_rounds - 1, ); } + exit_span!(sumcheck_phase1); if let Some(p) = challenge { @@ -150,7 +265,7 @@ where // deal with log(#thread) basefold rounds let merge_sumcheck_prover_state_span = entered_span!("merge_sumcheck_prover_state"); - let poly = merge_sumcheck_prover_state(prover_states); + let poly = merge_sumcheck_prover_state(&prover_states); let mut prover_states = vec![IOPProverState::prover_init_with_extrapolation_aux( true, poly, @@ -163,36 +278,46 @@ where let mut challenge = None; let sumcheck_phase2 = entered_span!("sumcheck_phase2"); - let remaining_rounds = num_rounds.saturating_sub(num_vars - log2_strict_usize(num_threads)); + let remaining_rounds = num_rounds.saturating_sub(max_num_vars - log2_num_threads); + for i in 0..remaining_rounds { challenge = basefold_one_round::( pp, &mut prover_states, challenge, &mut sumcheck_messages, - &initial_oracle, + &mut batched_codewords, transcript, &mut trees, &mut commits, &mmcs_ext, - i == num_rounds.saturating_sub(num_vars - log2_strict_usize(num_threads)) - 1, + i == remaining_rounds - 1, ); } + exit_span!(sumcheck_phase2); if let Some(p) = challenge { prover_states[0].fix_var(p.elements); } - let mut final_message = prover_states[0].get_mle_final_evaluations(); - // skip first half which is eq evaluations - let final_message = final_message.split_off(final_message.len() / 2); + let final_message = prover_states[0].get_mle_final_evaluations(); + // skip even index which is eq evaluations + let final_message = final_message + .into_iter() + .enumerate() + .filter(|(i, _)| i % 2 == 1) + .map(|(_, m)| m) + .collect_vec(); if cfg!(feature = "sanity-check") { - // If the prover is honest, in the last round, the running oracle - // on the prover side should be exactly the encoding of the folded polynomial. + assert!(final_message.iter().map(|m| m.len()).all_equal()); + let final_message_agg: Vec = (0..final_message[0].len()) + .map(|j| final_message.iter().map(|row| row[j]).sum()) + .collect_vec(); + // last round running oracle should be exactly the encoding of the folded polynomial. let basecode = >::encode_slow_ext( - p3::matrix::dense::DenseMatrix::new(final_message.clone(), 1), + p3::matrix::dense::DenseMatrix::new(final_message_agg, 1), ); assert_eq!( basecode.values, @@ -201,7 +326,7 @@ where // remove last tree/commmitment which is only for debug purpose let _ = (trees.pop(), commits.pop()); } - transcript.append_field_element_exts(&final_message); + transcript.append_field_element_exts_iter(final_message.iter().flatten()); (trees, BasefoldCommitPhaseProof { sumcheck_messages, commits, @@ -209,46 +334,132 @@ where }) } -// TODO define it within codeword -pub(crate) fn basefold_one_round_by_interpolation_weights< - E: ExtensionField, - Spec: BasefoldSpec, ->( +/// basefold fri round to fold codewords +#[allow(clippy::too_many_arguments)] +pub(crate) fn basefold_fri_round>( pp: &>::ProverParameters, - level: usize, - values: &[E], + codewords: &mut VecDeque>, + trees: &mut Vec>, + commits: &mut Vec< as Mmcs>::Commitment>, + mmcs_ext: &ExtensionMmcs< + E::BaseField, + E, + <::BaseField as PoseidonField>::MMCS, + >, challenge: E, -) -> RowMajorMatrix { - // assume values in bit_reverse_format - // thus chunks(2) is equivalent to left, right traverse + is_last_round: bool, + transcript: &mut impl Transcript, +) where + E::BaseField: Serialize + DeserializeOwned, + as Mmcs>::Commitment: + IntoIterator + PartialEq, +{ + let running_codeword_opt = trees + .last() + .and_then(|mktree| mmcs_ext.get_matrices(mktree).pop()) + .map(|m| m.as_view()); + let target_len = running_codeword_opt + .map(|running_codeword| running_codeword.values.len()) + .unwrap_or_else(|| { + codewords + .iter() + .map(|v| v.values.len()) + .max() + .expect("empty codeword") + }); + let next_level_target_len = target_len >> 1; + let level = log2_strict_usize(target_len) - 1; let folding_coeffs = >::prover_folding_coeffs_level(pp, level); - debug_assert_eq!(folding_coeffs.len(), 1 << level); let inv_2 = E::BaseField::from_u64(2).inverse(); - RowMajorMatrix::new( - values - .par_chunks_exact(2) - .zip(folding_coeffs) - .map(|(ys, coeff)| { - let (left, right) = (ys[0], ys[1]); - // original (left, right) = (lo + hi*x, lo - hi*x), lo, hi are codeword, but after times x it's not codeword - // recover left & right codeword via (lo, hi) = ((left + right) / 2, (left - right) / 2x) - let (lo, hi) = ((left + right) * inv_2, (left - right) * *coeff); // e.g. coeff = (2 * dit_butterfly)^(-1) in rs code - // we do fold on folded = (1-r) * left_codeword + r * right_codeword, as it match perfectly with raw message in lagrange domain fixed variable - lo + (hi - lo) * challenge + debug_assert_eq!(folding_coeffs.len(), 1 << level); + + // take codewords match with target length then fold + let codewords_matched = + pop_front_while(codewords, |codeword| codeword.values.len() == target_len); + // take codewords match next target length in preparation of being committed together + let codewords_next_level_matched = pop_front_while(codewords, |codeword| { + codeword.values.len() == next_level_target_len + }); + + // optimize for single codeword match + let folded_codeword = if (usize::from(running_codeword_opt.is_some()) + codewords_matched.len()) + == 1 + && codewords_next_level_matched.is_empty() + { + RowMajorMatrix::new( + running_codeword_opt + .or_else(|| codewords_matched.first().map(|m| m.as_view())) + .unwrap() + .values + .par_chunks_exact(2) + .zip(folding_coeffs) + .map(|(ys, coeff)| codeword_fold_with_challenge(ys, challenge, *coeff, inv_2)) + .collect::>(), + 2, + ) + } else { + // aggregate codeword with same length + let codeword_to_fold = (0..target_len) + .into_par_iter() + .map(|index| { + running_codeword_opt + .into_iter() + .chain(codewords_matched.iter().map(|m| m.as_view())) + .map(|codeword| codeword.values[index]) + .sum::() }) - .collect::>(), - 2, - ) + .collect::>(); + + RowMajorMatrix::new( + (0..target_len) + .into_par_iter() + .step_by(2) + .map(|index| { + let coeff = &folding_coeffs[index >> 1]; + + // 1st part folded with challenge then sum + let cur_same_pos_sum = codeword_fold_with_challenge( + &codeword_to_fold[index..index + 2], + challenge, + *coeff, + inv_2, + ); + // 2nd part: retrieve respective index then sum + let next_same_pos_sum = codewords_next_level_matched + .iter() + .map(|codeword| codeword.values[index >> 1]) + .sum::(); + cur_same_pos_sum + next_same_pos_sum + }) + .collect::>(), + 2, + ) + }; + + if cfg!(feature = "sanity-check") && is_last_round { + let (commitment, merkle_tree) = mmcs_ext.commit_matrix(folded_codeword.clone()); + commits.push(commitment); + trees.push(merkle_tree); + } + + // skip last round commitment as verifer need to derive encode(final_message) = final_codeword itself + if !is_last_round { + let (commitment, merkle_tree) = mmcs_ext.commit_matrix(folded_codeword); + write_digest_to_transcript(&commitment, transcript); + commits.push(commitment); + trees.push(merkle_tree); + } } +// do sumcheck interleaving with FRI step #[allow(clippy::too_many_arguments)] fn basefold_one_round>( pp: &>::ProverParameters, prover_states: &mut Vec>, challenge: Option>, - sumcheck_messages: &mut Vec>, - initial_oracle: &[E], + sumcheck_messages: &mut Vec>, + codewords: &mut VecDeque>, transcript: &mut impl Transcript, trees: &mut Vec>, commits: &mut Vec< as Mmcs>::Commitment>, @@ -264,12 +475,12 @@ where as Mmcs>::Commitment: IntoIterator + PartialEq, { - let prove_round_and_update_state_span = entered_span!("prove_round_and_update_state"); + // 1. sumcheck part + let sumcheck_round_span = entered_span!("basefold::sumcheck_one_round"); let prover_msgs = prover_states .par_iter_mut() .map(|prover_state| IOPProverState::prove_round_and_update_state(prover_state, &challenge)) .collect::>(); - exit_span!(prove_round_and_update_state_span); // for each round, we must collect #SIZE prover message let evaluations: AdditiveVec = @@ -279,42 +490,27 @@ where acc += AdditiveVec(prover_msg.evaluations); acc }); - transcript.append_field_element_exts(&evaluations.0); - sumcheck_messages.push(evaluations.0); + sumcheck_messages.push(IOPProverMessage { + evaluations: evaluations.0, + }); + exit_span!(sumcheck_round_span); let next_challenge = transcript.sample_and_append_challenge(b"commit round"); - // Fold the current oracle for FRI - let new_running_oracle = if trees.is_empty() { - basefold_one_round_by_interpolation_weights::( - pp, - log2_strict_usize(initial_oracle.len()) - 1, - initial_oracle, - next_challenge.elements, - ) - } else { - let values = &mmcs_ext.get_matrices(trees.last().unwrap())[0].values; - basefold_one_round_by_interpolation_weights::( - pp, - log2_strict_usize(values.len()) - 1, - values, - next_challenge.elements, - ) - }; - - if cfg!(feature = "sanity-check") && is_last_round { - let (commitment, merkle_tree) = mmcs_ext.commit_matrix(new_running_oracle.clone()); - commits.push(commitment); - trees.push(merkle_tree); - } - - if !is_last_round { - let (commitment, merkle_tree) = mmcs_ext.commit_matrix(new_running_oracle); - write_digest_to_transcript(&commitment, transcript); - commits.push(commitment); - trees.push(merkle_tree); - } + // 2. fri part + let fri_round_span = entered_span!("basefold::fri_one_round"); + basefold_fri_round::( + pp, + codewords, + trees, + commits, + mmcs_ext, + next_challenge.elements, + is_last_round, + transcript, + ); + exit_span!(fri_round_span); Some(next_challenge) } diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index bb119ce3f..05de77108 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -30,7 +30,10 @@ pub trait EncodingScheme: std::fmt::Debug + Clone { max_msg_size_log: usize, ) -> Result<(Self::ProverParameters, Self::VerifierParameters), Error>; - fn encode(pp: &Self::ProverParameters, rmm: RowMajorMatrix) -> Self::EncodedData; + fn encode( + pp: &Self::ProverParameters, + rmm: RowMajorMatrix, + ) -> Result; fn encode_slow_ext( rmm: p3::matrix::dense::RowMajorMatrix, diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 25d23ad86..208d4879e 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use super::{EncodingProverParameters, EncodingScheme}; -use crate::{Error, basefold::PolyEvalsCodeword}; +use crate::Error; use ff_ext::ExtensionField; use itertools::Itertools; use p3::{ @@ -100,7 +100,7 @@ where type VerifierParameters = RSCodeVerifierParameters; - type EncodedData = PolyEvalsCodeword; + type EncodedData = DenseMatrix; fn setup(_max_message_size_log: usize) -> Self::PublicParameters { RSCodeParameters { @@ -184,18 +184,14 @@ where )) } - fn encode(pp: &Self::ProverParameters, rmm: RowMajorMatrix) -> Self::EncodedData { + fn encode( + pp: &Self::ProverParameters, + rmm: RowMajorMatrix, + ) -> Result { let num_vars = rmm.num_vars(); let num_polys = rmm.width(); if num_vars > pp.get_max_message_size_log() { - return PolyEvalsCodeword::TooBig(num_vars); - } - - // In this case, the polynomial is so small that the opening is trivial. - // So we just build the Merkle tree over the polynomial evaluations. - // No codeword is needed. - if num_vars <= Spec::get_basecode_msg_size_log() { - return PolyEvalsCodeword::TooSmall(Box::new(rmm.into_default_padded_p3_rmm())); + return Err(Error::PolynomialTooLarge(num_vars)); } // here 2 resize happend. first is padding to next pow2 height, second is extend to 2^get_rate_log times size @@ -218,8 +214,7 @@ where // to make 2 consecutive position to be open together, we trickily "concat" 2 consecutive leafs // so both can be open under same row index let codeword = DenseMatrix::new(codeword, num_polys * 2); - - PolyEvalsCodeword::Normal(Box::new(codeword)) + Ok(codeword) } fn encode_small( @@ -296,13 +291,21 @@ where #[cfg(test)] mod tests { + use std::collections::VecDeque; + use ff_ext::GoldilocksExt2; use itertools::izip; - use p3::{goldilocks::Goldilocks, util::log2_strict_usize}; + use p3::{ + commit::{ExtensionMmcs, Mmcs}, + goldilocks::Goldilocks, + }; use rand::rngs::OsRng; + use transcript::BasicTranscript; - use crate::basefold::commit_phase::basefold_one_round_by_interpolation_weights; + use crate::{ + basefold::commit_phase::basefold_fri_round, util::merkle_tree::poseidon2_merkle_tree, + }; use super::*; @@ -314,15 +317,11 @@ mod tests { #[test] pub fn test_message_codeword_linearity() { let num_vars = 10; + let mmcs_ext = ExtensionMmcs::::new(poseidon2_merkle_tree::()); let rmm: RowMajorMatrix = RowMajorMatrix::rand(&mut OsRng, 1 << num_vars, 1); let pp = >::setup(num_vars); let (pp, vp) = Code::trim(pp, num_vars).unwrap(); - let codeword = Code::encode(&pp, rmm.clone()); - let codeword = match codeword { - PolyEvalsCodeword::Normal(dense_matrix) => dense_matrix, - PolyEvalsCodeword::TooSmall(_) => todo!(), - PolyEvalsCodeword::TooBig(_) => todo!(), - }; + let codeword = Code::encode(&pp, rmm.clone()).expect("encode error"); assert_eq!( codeword.values.len(), 1 << (num_vars + >::get_rate_log()) @@ -338,13 +337,21 @@ mod tests { izip!(&codeword.values, &codeword_ext.values).all(|(base, ext)| E::from(*base) == *ext) ); + let mut codeword_ext = VecDeque::from(vec![codeword_ext]); + let mut transcript = BasicTranscript::new(b"test"); + // test basefold.encode(raw_message.fold(1-r, r)) ?= codeword.fold(1-r, r) + let mut prove_data = vec![]; let r = E::from_u64(97); - let folded_codeword = basefold_one_round_by_interpolation_weights::( + basefold_fri_round::( &pp, - log2_strict_usize(codeword_ext.values.len()) - 1, - &codeword_ext.values, + &mut codeword_ext, + &mut prove_data, + &mut vec![], + &mmcs_ext, r, + false, + &mut transcript, ); // encoded folded raw message @@ -358,6 +365,9 @@ mod tests { 1, ), ); - assert_eq!(&folded_codeword.values, &codeword_from_folded_rmm.values); + assert_eq!( + &mmcs_ext.get_matrices(&prove_data[0])[0].values, + &codeword_from_folded_rmm.values + ); } } diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index 566df2248..27296e248 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -1,21 +1,24 @@ -use std::slice; +use std::{cmp::Reverse, slice}; use crate::{ - basefold::structure::MerkleTreeExt, - util::{arithmetic::inner_product, ext_to_usize, merkle_tree::poseidon2_merkle_tree}, + Point, + basefold::structure::{CircuitIndexMeta, MerkleTreeExt, QueryOpeningProof}, + util::{codeword_fold_with_challenge, merkle_tree::poseidon2_merkle_tree}, }; -use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use itertools::{Itertools, izip}; +use multilinear_extensions::virtual_poly::{build_eq_x_r_vec, eq_eval}; use p3::{ commit::{ExtensionMmcs, Mmcs}, - field::dot_product, + field::{Field, PrimeCharacteristicRing, dot_product}, + fri::{BatchOpening, CommitPhaseProofStep}, matrix::{Dimensions, dense::RowMajorMatrix}, util::log2_strict_usize, }; use serde::{Serialize, de::DeserializeOwned}; use sumcheck::{ macros::{entered_span, exit_span}, + structs::IOPProverMessage, util::interpolate_uni_poly, }; use transcript::Transcript; @@ -28,9 +31,10 @@ use super::{ structure::{BasefoldCommitment, BasefoldCommitmentWithWitness, BasefoldSpec}, }; -pub fn simple_batch_prover_query_phase( +pub fn batch_query_phase( transcript: &mut impl Transcript, - comm: &BasefoldCommitmentWithWitness, + fixed_comms: Option<&BasefoldCommitmentWithWitness>, + witin_comms: &BasefoldCommitmentWithWitness, trees: &[MerkleTreeExt], num_verifier_queries: usize, ) -> QueryOpeningProofs @@ -41,17 +45,16 @@ where let mmcs = poseidon2_merkle_tree::(); // Transform the challenge queries from field elements into integers - // TODO simplify with sample_bit - let queries: Vec<_> = transcript.sample_and_append_vec(b"query indices", num_verifier_queries); - let queries_usize: Vec = queries - .iter() - .map(|x_index| ext_to_usize(x_index) % comm.codeword_size()) - .collect_vec(); + let queries: Vec<_> = transcript.sample_bits_and_append_vec( + b"query indices", + num_verifier_queries, + witin_comms.log2_max_codeword_size, + ); - queries_usize + queries .iter() .map(|idx| { - let opening = { + let witin_base_proof = { // extract the even part of `idx` // --------------------------------- // the oracle values are committed in a row-bit-reversed format. @@ -61,173 +64,409 @@ where // however, since `p_d[j]` and `p_d[j + n_{d-1}]` are already concatenated in the same merkle leaf, // we can simply mask out the least significant bit (lsb) by performing a right shift by 1. let idx = idx >> 1; - let (mut values, proof) = mmcs.open_batch(idx, &comm.codeword); - let leafs = values.pop().unwrap(); - (leafs, proof) + let (opened_values, opening_proof) = mmcs.open_batch(idx, &witin_comms.codeword); + BatchOpening { + opened_values, + opening_proof, + } + }; + + let fixed_base_proof = if let Some(fixed_comms) = fixed_comms { + // follow same rule as witin base proof + let idx_shift = witin_comms.log2_max_codeword_size as i32 + - fixed_comms.log2_max_codeword_size as i32; + let idx = if idx_shift > 0 { + idx >> idx_shift + } else { + idx << -idx_shift + }; + let idx = idx >> 1; + let (opened_values, opening_proof) = mmcs.open_batch(idx, &fixed_comms.codeword); + Some(BatchOpening { + opened_values, + opening_proof, + }) + } else { + None }; // this is equivalent with "idx = idx % n_{d-1}" operation in non row bit reverse format let idx = idx >> 1; - let (_, opening_ext) = trees.iter().fold((idx, vec![]), |(idx, mut proofs), tree| { - // differentiate interpolate to left or right position at next layer - let is_interpolate_to_right_index = (idx & 1) == 1; - // mask the least significant bit (LSB) for the same reason as above: - // 1. we only need the even part of the index. - // 2. since even and odd parts are concatenated in the same leaf, - // the overall merkle tree height is effectively halved, - // so we divide by 2. - let (mut values, proof) = mmcs_ext.open_batch(idx >> 1, tree); - let leafs = values.pop().unwrap(); - debug_assert_eq!(leafs.len(), 2); - let sibling = leafs[(!is_interpolate_to_right_index) as usize]; - proofs.push((sibling, proof)); - (idx >> 1, proofs) - }); - (opening, opening_ext) + let (_, commit_phase_openings) = + trees + .iter() + .fold((idx, vec![]), |(idx, mut commit_phase_openings), tree| { + // differentiate interpolate to left or right position at next layer + let is_interpolate_to_right_index = (idx & 1) == 1; + // mask the least significant bit (LSB) for the same reason as above: + // 1. we only need the even part of the index. + // 2. since even and odd parts are concatenated in the same leaf, + // the overall merkle tree height is effectively halved, + // so we divide by 2. + let (mut values, opening_proof) = mmcs_ext.open_batch(idx >> 1, tree); + let leafs = values.pop().unwrap(); + debug_assert_eq!(leafs.len(), 2); + let sibling_value = leafs[(!is_interpolate_to_right_index) as usize]; + commit_phase_openings.push(CommitPhaseProofStep { + sibling_value, + opening_proof, + }); + (idx >> 1, commit_phase_openings) + }); + QueryOpeningProof { + witin_base_proof, + fixed_base_proof, + commit_phase_openings, + } }) .collect_vec() } #[allow(clippy::too_many_arguments)] -pub fn simple_batch_verifier_query_phase>( +pub fn batch_verifier_query_phase>( + max_num_var: usize, indices: &[usize], vp: &>::VerifierParameters, + final_message: &[Vec], + batch_coeffs: &[E], queries: &QueryOpeningProofs, - sum_check_messages: &[Vec], + fixed_comm: Option<&BasefoldCommitment>, + witin_comm: &BasefoldCommitment, + circuit_meta: &[CircuitIndexMeta], + commits: &[Digest], fold_challenges: &[E], - batch_coeffs: &[E], - _num_rounds: usize, - num_vars: usize, - final_message: &[E], - roots: &[Digest], - comm: &BasefoldCommitment, - partial_eq: &[E], - evals: &[E], + sumcheck_messages: &[IOPProverMessage], + point_evals: &[(Point, Vec)], ) where E::BaseField: Serialize + DeserializeOwned, { - let timer = start_timer!(|| "Verifier query phase"); - let num_polys = evals.len(); - - let encode_timer = start_timer!(|| "Encode final codeword"); + let inv_2 = E::BaseField::from_u64(2).inverse(); + debug_assert_eq!(point_evals.len(), circuit_meta.len()); + let encode_span = entered_span!("encode_final_codeword"); let final_codeword = >::encode_small( vp, - RowMajorMatrix::new(final_message.to_vec(), 1), + RowMajorMatrix::new( + (0..final_message[0].len()) + .map(|j| final_message.iter().map(|row| row[j]).sum()) + .collect_vec(), + 1, + ), ); - end_timer!(encode_timer); + exit_span!(encode_span); let mmcs_ext = ExtensionMmcs::::new(poseidon2_merkle_tree::()); let mmcs = poseidon2_merkle_tree::(); + let check_queries_span = entered_span!("check_queries"); + // can't use witin_comm.log2_max_codeword_size since it's untrusted + let log2_witin_max_codeword_size = + max_num_var + >::get_rate_log(); + + // an vector with same length as circuit_meta, which is sorted by num_var in descending order and keep its index + // for reverse lookup when retrieving next base codeword to involve into batching + let folding_sorted_order = circuit_meta + .iter() + .enumerate() + .sorted_by_key(|(_, CircuitIndexMeta { witin_num_vars, .. })| Reverse(witin_num_vars)) + .map(|(index, CircuitIndexMeta { witin_num_vars, .. })| (witin_num_vars, index)) + .collect_vec(); - let span = entered_span!("check queries"); - izip!(indices, queries).for_each(|(idx, ((commit_leafs, commit_proof), opening_ext))| { - // refer to prover documentation for the reason of right shift by 1 - let idx = idx >> 1; - mmcs.verify_batch( - &comm.pi_d_digest, - &[Dimensions { - // width size is double num_polys due to leaf + right leafs are concat - width: num_polys * 2, - height: 1 - << (num_vars + >::get_rate_log() - 1), - }], + indices.iter().zip_eq(queries).for_each( + |( idx, - slice::from_ref(commit_leafs), - commit_proof, - ) - .expect("verify batch failed"); - let (left, right) = commit_leafs.split_at(commit_leafs.len() / 2); - let (left, right): (E, E) = ( - dot_product(batch_coeffs.iter().copied(), left.iter().copied()), - dot_product(batch_coeffs.iter().copied(), right.iter().copied()), - ); - let r = fold_challenges.first().unwrap(); - let coeff = >::verifier_folding_coeffs_level( - vp, - num_vars + >::get_rate_log() - 1, - )[idx]; - let (lo, hi) = ((left + right).halve(), (left - right) * coeff); - let folded = lo + (hi - lo) * *r; - - let rounds = - num_vars - >::get_basecode_msg_size_log() - 1; - let n_d_next = - 1 << (num_vars + >::get_rate_log() - 1); - debug_assert_eq!(rounds, fold_challenges.len() - 1); - debug_assert_eq!(rounds, roots.len(),); - debug_assert_eq!(rounds, opening_ext.len(),); - let (final_idx, final_folded, _) = roots - .iter() - .zip_eq(fold_challenges.iter().skip(1)) - .zip_eq(opening_ext) - .fold( - (idx, folded, n_d_next), - |(idx, folded, n_d_i), ((pi_comm, r), (leaf, proof))| { - let is_interpolate_to_right_index = (idx & 1) == 1; - let mut leafs = vec![*leaf; 2]; - leafs[is_interpolate_to_right_index as usize] = folded; - - let idx = idx >> 1; - mmcs_ext - .verify_batch( - pi_comm, - &[Dimensions { - width: 2, - // width is 2, thus height divide by 2 via right shift - height: n_d_i >> 1, - }], - idx, - slice::from_ref(&leafs), - proof, - ) - .expect("verify failed"); + QueryOpeningProof { + witin_base_proof: + BatchOpening { + opened_values: witin_opened_values, + opening_proof: witin_opening_proof, + }, + fixed_base_proof: fixed_commit_option, + commit_phase_openings: opening_ext, + }, + )| { + // verify base oracle query proof + // refer to prover documentation for the reason of right shift by 1 + let mut idx = idx >> 1; + + let (witin_dimentions, fixed_dimentions) = + get_base_codeword_dimentions::(circuit_meta); + // verify witness + mmcs.verify_batch( + &witin_comm.commit, + &witin_dimentions, + idx, + witin_opened_values, + witin_opening_proof, + ) + .expect("verify witin commit batch failed"); + + // verify fixed + let fixed_commit_leafs = if let Some(fixed_comm) = fixed_comm { + let BatchOpening { + opened_values: fixed_opened_values, + opening_proof: fixed_opening_proof, + } = &fixed_commit_option.as_ref().unwrap(); + mmcs.verify_batch( + &fixed_comm.commit, + &fixed_dimentions, + { + let idx_shift = log2_witin_max_codeword_size as i32 + - fixed_comm.log2_max_codeword_size as i32; + if idx_shift > 0 { + idx >> idx_shift + } else { + idx << -idx_shift + } + }, + fixed_opened_values, + fixed_opening_proof, + ) + .expect("verify fixed commit batch failed"); + fixed_opened_values + } else { + &vec![] + }; + + let mut fixed_commit_leafs_iter = fixed_commit_leafs.iter(); + let mut batch_coeffs_iter = batch_coeffs.iter(); + + let base_codeword_lo_hi = circuit_meta + .iter() + .zip_eq(witin_opened_values) + .map( + |( + CircuitIndexMeta { + witin_num_polys, + fixed_num_vars, + fixed_num_polys, + .. + }, + witin_leafs, + )| { + let (lo, hi) = std::iter::once((witin_leafs, *witin_num_polys)) + .chain((*fixed_num_vars > 0).then(|| { + (fixed_commit_leafs_iter.next().unwrap(), *fixed_num_polys) + })) + .map(|(leafs, num_polys)| { + let batch_coeffs = batch_coeffs_iter + .by_ref() + .take(num_polys) + .copied() + .collect_vec(); + let (lo, hi): (&[E::BaseField], &[E::BaseField]) = + leafs.split_at(leafs.len() / 2); + ( + dot_product::( + batch_coeffs.iter().copied(), + lo.iter().copied(), + ), + dot_product::( + batch_coeffs.iter().copied(), + hi.iter().copied(), + ), + ) + }) + // fold witin/fixed lo, hi together because they share the same num_vars + .reduce(|(lo_wit, hi_wit), (lo_fixed, hi_fixed)| { + (lo_wit + lo_fixed, hi_wit + hi_fixed) + }) + .expect("unreachable"); + (lo, hi) + }, + ) + .collect_vec(); + debug_assert_eq!(folding_sorted_order.len(), base_codeword_lo_hi.len()); + debug_assert!(fixed_commit_leafs_iter.next().is_none()); + debug_assert!(batch_coeffs_iter.next().is_none()); + + // fold and query + let mut cur_num_var = max_num_var; + // -1 because for there are only #max_num_var-1 openings proof + let rounds = cur_num_var + - >::get_basecode_msg_size_log() + - 1; + let n_d_next = 1 + << (cur_num_var + >::get_rate_log() - 1); + debug_assert_eq!(rounds, fold_challenges.len() - 1); + debug_assert_eq!(rounds, commits.len(),); + debug_assert_eq!(rounds, opening_ext.len(),); + + // first folding challenge + let r = fold_challenges.first().unwrap(); + + let mut folding_sorted_order_iter = folding_sorted_order.iter(); + // take first batch which num_vars match max_num_var to initial fold value + let mut folded = folding_sorted_order_iter + .by_ref() + .peeking_take_while(|(num_vars, _)| **num_vars == cur_num_var) + .map(|(_, index)| { + let (lo, hi) = &base_codeword_lo_hi[*index]; let coeff = >::verifier_folding_coeffs_level( vp, - log2_strict_usize(n_d_i) - 1, + cur_num_var + + >::get_rate_log() + - 1, )[idx]; - debug_assert_eq!( - >::verifier_folding_coeffs_level( - vp, - log2_strict_usize(n_d_i) - 1, - ) - .len(), - n_d_i >> 1 - ); - let (left, right) = (leafs[0], leafs[1]); - let (lo, hi) = ((left + right).halve(), (left - right) * coeff); - (idx, lo + (hi - lo) * *r, n_d_i >> 1) + codeword_fold_with_challenge(&[*lo, *hi], *r, coeff, inv_2) + }) + .sum::(); + + let mut n_d_i = n_d_next; + for ( + (pi_comm, r), + CommitPhaseProofStep { + sibling_value: leaf, + opening_proof: proof, }, + ) in commits + .iter() + .zip_eq(fold_challenges.iter().skip(1)) + .zip_eq(opening_ext) + { + cur_num_var -= 1; + + let is_interpolate_to_right_index = (idx & 1) == 1; + let new_involved_codewords = folding_sorted_order_iter + .by_ref() + .peeking_take_while(|(num_vars, _)| **num_vars == cur_num_var) + .map(|(_, index)| { + let (lo, hi) = &base_codeword_lo_hi[*index]; + if is_interpolate_to_right_index { + *hi + } else { + *lo + } + }) + .sum::(); + + let mut leafs = vec![*leaf; 2]; + leafs[is_interpolate_to_right_index as usize] = folded + new_involved_codewords; + idx >>= 1; + mmcs_ext + .verify_batch( + pi_comm, + &[Dimensions { + width: 2, + // width is 2, thus height divide by 2 via right shift + height: n_d_i >> 1, + }], + idx, + slice::from_ref(&leafs), + proof, + ) + .expect("verify failed"); + let coeff = + >::verifier_folding_coeffs_level( + vp, + log2_strict_usize(n_d_i) - 1, + )[idx]; + debug_assert_eq!( + >::verifier_folding_coeffs_level( + vp, + log2_strict_usize(n_d_i) - 1, + ) + .len(), + n_d_i >> 1 + ); + folded = codeword_fold_with_challenge(&[leafs[0], leafs[1]], *r, coeff, inv_2); + n_d_i >>= 1; + } + debug_assert!(folding_sorted_order_iter.next().is_none()); + assert!( + final_codeword.values[idx] == folded, + "final_codeword.values[idx] value {:?} != folded {:?}", + final_codeword.values[idx], + folded ); - assert!( - final_codeword.values[final_idx] == final_folded, - "final_codeword.values[idx] value {:?} != folded {:?}", - final_codeword.values[final_idx], - final_folded - ); - }); - exit_span!(span); + }, + ); + exit_span!(check_queries_span); // 1. check initial claim match with first round sumcheck value assert_eq!( - dot_product::(batch_coeffs.iter().copied(), evals.iter().copied()), - { sum_check_messages[0][0] + sum_check_messages[0][1] } + // we need to scale up with scalar for witin_num_vars < max_num_var + dot_product::( + batch_coeffs.iter().copied(), + point_evals.iter().zip_eq(circuit_meta.iter()).flat_map( + |((_, evals), CircuitIndexMeta { witin_num_vars, .. })| { + evals.iter().copied().map(move |eval| { + eval * E::from_u64(1 << (max_num_var - witin_num_vars) as u64) + }) + } + ) + ), + { sumcheck_messages[0].evaluations[0] + sumcheck_messages[0].evaluations[1] } ); // 2. check every round of sumcheck match with prev claims for i in 0..fold_challenges.len() - 1 { assert_eq!( - interpolate_uni_poly(&sum_check_messages[i], fold_challenges[i]), - { sum_check_messages[i + 1][0] + sum_check_messages[i + 1][1] } + interpolate_uni_poly(&sumcheck_messages[i].evaluations, fold_challenges[i]), + { sumcheck_messages[i + 1].evaluations[0] + sumcheck_messages[i + 1].evaluations[1] } ); } // 3. check final evaluation are correct assert_eq!( interpolate_uni_poly( - &sum_check_messages[fold_challenges.len() - 1], + &sumcheck_messages[fold_challenges.len() - 1].evaluations, fold_challenges[fold_challenges.len() - 1] ), - inner_product(final_message, partial_eq) + izip!(final_message, point_evals.iter().map(|(point, _)| point)) + .map(|(final_message, point)| { + // coeff is the eq polynomial evaluated at the first challenge.len() variables + let num_vars_evaluated = point.len() + - >::get_basecode_msg_size_log(); + let coeff = eq_eval( + &point[..num_vars_evaluated], + &fold_challenges[fold_challenges.len() - num_vars_evaluated..], + ); + // Compute eq as the partially evaluated eq polynomial + let eq = build_eq_x_r_vec(&point[num_vars_evaluated..]); + dot_product( + final_message.iter().copied(), + eq.into_iter().map(|e| e * coeff), + ) + }) + .sum() ); +} - end_timer!(timer); +fn get_base_codeword_dimentions>( + circuit_meta_map: &[CircuitIndexMeta], +) -> (Vec, Vec) { + let (wit_dim, fixed_dim): (Vec<_>, Vec<_>) = circuit_meta_map + .iter() + .map( + |CircuitIndexMeta { + witin_num_vars, + witin_num_polys, + fixed_num_vars, + fixed_num_polys, + }| { + ( + Dimensions { + // width size is double num_polys due to leaf + right leafs are concat + width: witin_num_polys * 2, + height: 1 + << (witin_num_vars + + >::get_rate_log() + - 1), + }, + if *fixed_num_vars > 0 { + Some(Dimensions { + // width size is double num_polys due to leaf + right leafs are concat + width: fixed_num_polys * 2, + height: 1 + << (fixed_num_vars + + >::get_rate_log() + - 1), + }) + } else { + None + }, + ) + }, + ) + .unzip(); + let fixed_dim = fixed_dim.into_iter().flatten().collect_vec(); + (wit_dim, fixed_dim) } diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index dd6ae7810..081fa70ab 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -1,9 +1,13 @@ -use crate::util::merkle_tree::{Poseidon2ExtMerkleMmcs, poseidon2_merkle_tree}; +use crate::{ + basefold::log2_strict_usize, + util::merkle_tree::{Poseidon2ExtMerkleMmcs, poseidon2_merkle_tree}, +}; use core::fmt::Debug; use ff_ext::{ExtensionField, PoseidonField}; -use itertools::izip; +use itertools::{Itertools, izip}; use p3::{ - commit::Mmcs, + commit::{ExtensionMmcs, Mmcs}, + fri::{BatchOpening, CommitPhaseProofStep}, matrix::{Matrix, dense::DenseMatrix}, }; use serde::{Deserialize, Serialize, Serializer, de::DeserializeOwned}; @@ -11,7 +15,7 @@ use sumcheck::structs::IOPProverMessage; use multilinear_extensions::virtual_poly::ArcMultilinearExtension; -use std::marker::PhantomData; +use std::{collections::BTreeMap, marker::PhantomData}; pub type Digest = as Mmcs>::Commitment; pub type MerkleTree = <::MMCS as Mmcs>::ProverData>; @@ -61,52 +65,75 @@ pub struct BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, { - pub(crate) pi_d_digest: Digest, + pub(crate) commit: Digest, pub(crate) codeword: MerkleTree, - pub(crate) polynomials_bh_evals: Vec>, - pub(crate) num_vars: usize, - pub(crate) is_base: bool, - pub(crate) num_polys: usize, + + pub(crate) log2_max_codeword_size: usize, + // for small polynomials, the prover commits the entire polynomial as merkle leaves without encoding to codeword + // the verifier performs direct checks on these leaves without requiring a proximity test. + pub(crate) trivial_proofdata: BTreeMap, MerkleTree)>, + // poly groups w.r.t circuit index + pub(crate) polys: BTreeMap>>, + // keep codeword index w.r.t circuit index + pub circuit_codeword_index: BTreeMap, } impl BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, { + pub fn new( + commit: Digest, + codeword: MerkleTree, + polys: BTreeMap>>, + trivial_proofdata: BTreeMap, MerkleTree)>, + circuit_codeword_index: BTreeMap, + ) -> Self { + let mmcs = poseidon2_merkle_tree::(); + // size = height * 2 because we split codeword leafs into left/right, concat and commit under same row index + let log2_max_codeword_size = log2_strict_usize( + mmcs.get_matrices(&codeword) + .iter() + .map(|m| m.height() * 2) + .max() + .unwrap(), + ); + Self { + commit, + codeword, + polys, + trivial_proofdata, + circuit_codeword_index, + log2_max_codeword_size, + } + } + pub fn to_commitment(&self) -> BasefoldCommitment { BasefoldCommitment::new( - self.pi_d_digest.clone(), - self.num_vars, - self.is_base, - self.num_polys, + self.commit.clone(), + self.trivial_proofdata + .iter() + .map(|(_, (digest, _))| digest.clone()) + .collect_vec(), + self.log2_max_codeword_size, ) } - pub fn poly_size(&self) -> usize { - 1 << self.num_vars - } + // pub fn poly_size(&self) -> usize { + // 1 << self.num_vars + // } - pub fn is_base(&self) -> bool { - self.is_base - } + // pub fn trivial_num_vars>(num_vars: usize) -> bool { + // num_vars <= Spec::get_basecode_msg_size_log() + // } - pub fn trivial_num_vars>(num_vars: usize) -> bool { - num_vars <= Spec::get_basecode_msg_size_log() - } + // pub fn is_trivial>(&self) -> bool { + // Self::trivial_num_vars::(self.num_vars) + // } - pub fn is_trivial>(&self) -> bool { - Self::trivial_num_vars::(self.num_vars) - } - - pub fn codeword_size(&self) -> usize { + pub fn get_codewords(&self) -> Vec<&DenseMatrix> { let mmcs = poseidon2_merkle_tree::(); - // size = height * 2 because we concat pi[left]/pi[right] under same row index - mmcs.get_matrices(&self.codeword)[0].height() * 2 - } - - pub fn get_codewords(&self) -> &DenseMatrix { - let mmcs = poseidon2_merkle_tree::(); - mmcs.get_matrices(&self.codeword)[0] + mmcs.get_matrices(&self.codeword) } } @@ -116,35 +143,29 @@ pub struct BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - pub(super) pi_d_digest: Digest, - pub(super) num_vars: Option, - pub(super) is_base: bool, - pub(super) num_polys: Option, + pub(super) commit: Digest, + pub(crate) log2_max_codeword_size: usize, + pub(crate) trivial_commits: Vec>, } impl BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, { - pub fn new(pi_d_digest: Digest, num_vars: usize, is_base: bool, num_polys: usize) -> Self { + pub fn new( + commit: Digest, + trivial_commits: Vec>, + log2_max_codeword_size: usize, + ) -> Self { Self { - pi_d_digest, - num_vars: Some(num_vars), - is_base, - num_polys: Some(num_polys), + commit, + trivial_commits, + log2_max_codeword_size, } } - pub fn pi_d_digest(&self) -> Digest { - self.pi_d_digest.clone() - } - - pub fn num_vars(&self) -> Option { - self.num_vars - } - - pub fn is_base(&self) -> bool { - self.is_base + pub fn commit(&self) -> Digest { + self.commit.clone() } } @@ -153,10 +174,12 @@ where E::BaseField: Serialize + DeserializeOwned, { fn eq(&self, other: &Self) -> bool { - self.get_codewords().eq(other.get_codewords()) - && izip!(&self.polynomials_bh_evals, &other.polynomials_bh_evals).all( - |(bh_evals_a, bh_evals_b)| bh_evals_a.evaluations() == bh_evals_b.evaluations(), - ) + izip!(self.get_codewords(), other.get_codewords()) + .all(|(codeword_a, codeword_b)| codeword_a.eq(codeword_b)) + && izip!(self.polys.values(), other.polys.values()).all(|(evals_a, evals_b)| { + izip!(evals_a, evals_b) + .all(|(evals_a, evals_b)| evals_a.evaluations() == evals_b.evaluations()) + }) } } @@ -211,25 +234,35 @@ impl> Clone for Basefold { } } -pub type MKProofNTo1 = (Vec, P); -// for 2 to 1, leaf layer just need one value, as the other can be interpolated from previous layer -pub type MKProof2To1 = (F1, P); -pub type QueryOpeningProofs = Vec<( - MKProofNTo1< +pub type ExtMmcs = ExtensionMmcs< + ::BaseField, + E, + <::BaseField as PoseidonField>::MMCS, +>; + +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct QueryOpeningProof { + pub witin_base_proof: BatchOpening< ::BaseField, - <<::BaseField as PoseidonField>::MMCS as Mmcs< - ::BaseField, - >>::Proof, + <::BaseField as PoseidonField>::MMCS, >, - Vec< - MKProof2To1< - E, - <<::BaseField as PoseidonField>::MMCS as Mmcs< - ::BaseField, - >>::Proof, + pub fixed_base_proof: Option< + BatchOpening< + ::BaseField, + <::BaseField as PoseidonField>::MMCS, >, >, -)>; + #[allow(clippy::type_complexity)] + pub commit_phase_openings: Vec>>, +} + +pub type QueryOpeningProofs = Vec>; + +pub type TrivialProof = Vec::BaseField>>>; #[derive(Clone, Serialize, Deserialize)] #[serde(bound( @@ -240,32 +273,12 @@ pub struct BasefoldProof where E::BaseField: Serialize + DeserializeOwned, { - pub(crate) sumcheck_messages: Vec>, pub(crate) commits: Vec>, - pub(crate) final_message: Vec, + pub(crate) final_message: Vec>, pub(crate) query_opening_proof: QueryOpeningProofs, pub(crate) sumcheck_proof: Option>>, - pub(crate) trivial_proof: Option>, -} - -impl BasefoldProof -where - E::BaseField: Serialize + DeserializeOwned, -{ - pub fn trivial(evals: DenseMatrix) -> Self { - Self { - sumcheck_messages: vec![], - commits: vec![], - final_message: vec![], - query_opening_proof: Default::default(), - sumcheck_proof: None, - trivial_proof: Some(evals), - } - } - - pub fn is_trivial(&self) -> bool { - self.trivial_proof.is_some() - } + // vec![witness, fixed], where fixed is optional + pub(crate) trivial_proof: Option>, } #[derive(Clone, Serialize, Deserialize)] @@ -277,7 +290,15 @@ pub struct BasefoldCommitPhaseProof where E::BaseField: Serialize + DeserializeOwned, { - pub(crate) sumcheck_messages: Vec>, + pub(crate) sumcheck_messages: Vec>, pub(crate) commits: Vec>, - pub(crate) final_message: Vec, + pub(crate) final_message: Vec>, +} + +#[derive(Clone, Default)] +pub(crate) struct CircuitIndexMeta { + pub witin_num_vars: usize, + pub witin_num_polys: usize, + pub fixed_num_vars: usize, + pub fixed_num_polys: usize, } diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index 3b99db622..e54de81f8 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -1,7 +1,7 @@ #![deny(clippy::cargo)] use ff_ext::ExtensionField; use serde::{Serialize, de::DeserializeOwned}; -use std::fmt::Debug; +use std::{collections::BTreeMap, fmt::Debug}; use transcript::Transcript; use witness::RowMajorMatrix; @@ -16,6 +16,9 @@ pub type Param = >::Param; pub type ProverParam = >::ProverParam; pub type VerifierParam = >::VerifierParam; +/// A point is a vector of num_var length +pub type Point = Vec; + pub fn pcs_setup>( poly_size: usize, ) -> Result { @@ -38,9 +41,9 @@ pub fn pcs_commit>( pub fn pcs_batch_commit>( pp: &Pcs::ProverParam, - rmm: RowMajorMatrix<::BaseField>, + rmms: BTreeMap::BaseField>>, ) -> Result { - Pcs::batch_commit(pp, rmm) + Pcs::batch_commit(pp, rmms) } pub fn pcs_open>( @@ -54,15 +57,27 @@ pub fn pcs_open>( Pcs::open(pp, poly, comm, point, eval, transcript) } +#[allow(clippy::too_many_arguments)] pub fn pcs_batch_open>( pp: &Pcs::ProverParam, - polys: &[ArcMultilinearExtension], - comms: &[Pcs::CommitmentWithWitness], - points: &[Vec], - evals: &[Evaluation], + num_instances: &[(usize, usize)], + fixed_comms: Option<&Pcs::CommitmentWithWitness>, + witin_comms: &Pcs::CommitmentWithWitness, + points: &[Point], + evals: &[Vec], + circuit_num_polys: &[(usize, usize)], transcript: &mut impl Transcript, ) -> Result { - Pcs::batch_open(pp, polys, comms, points, evals, transcript) + Pcs::batch_open( + pp, + num_instances, + fixed_comms, + witin_comms, + points, + evals, + circuit_num_polys, + transcript, + ) } pub fn pcs_verify>( @@ -76,18 +91,32 @@ pub fn pcs_verify>( Pcs::verify(vp, comm, point, eval, proof, transcript) } +#[allow(clippy::too_many_arguments)] pub fn pcs_batch_verify<'a, E: ExtensionField, Pcs: PolynomialCommitmentScheme>( vp: &Pcs::VerifierParam, - comms: &[Pcs::Commitment], - points: &[Vec], - evals: &[Evaluation], + num_instances: &[(usize, usize)], + points: &[Point], + fixed_comms: Option<&Pcs::Commitment>, + witin_comms: &Pcs::Commitment, + evals: &[Vec], proof: &Pcs::Proof, + circuit_num_polys: &[(usize, usize)], transcript: &mut impl Transcript, ) -> Result<(), Error> where Pcs::Commitment: 'a, { - Pcs::batch_verify(vp, comms, points, evals, proof, transcript) + Pcs::batch_verify( + vp, + num_instances, + points, + fixed_comms, + witin_comms, + evals, + proof, + circuit_num_polys, + transcript, + ) } pub trait PolynomialCommitmentScheme: Clone { @@ -134,15 +163,15 @@ pub trait PolynomialCommitmentScheme: Clone { fn batch_commit( pp: &Self::ProverParam, - polys: RowMajorMatrix, + rmms: BTreeMap::BaseField>>, ) -> Result; fn batch_commit_and_write( pp: &Self::ProverParam, - rmm: RowMajorMatrix<::BaseField>, + rmms: BTreeMap::BaseField>>, transcript: &mut impl Transcript, ) -> Result { - let comm = Self::batch_commit(pp, rmm)?; + let comm = Self::batch_commit(pp, rmms)?; Self::write_commitment(&Self::get_pure_commitment(&comm), transcript)?; Ok(comm) } @@ -156,12 +185,15 @@ pub trait PolynomialCommitmentScheme: Clone { transcript: &mut impl Transcript, ) -> Result; + #[allow(clippy::too_many_arguments)] fn batch_open( pp: &Self::ProverParam, - polys: &[ArcMultilinearExtension], - comms: &[Self::CommitmentWithWitness], - points: &[Vec], - evals: &[Evaluation], + num_instances: &[(usize, usize)], + fixed_comms: Option<&Self::CommitmentWithWitness>, + witin_comms: &Self::CommitmentWithWitness, + points: &[Point], + evals: &[Vec], + circuit_num_polys: &[(usize, usize)], transcript: &mut impl Transcript, ) -> Result; @@ -187,12 +219,16 @@ pub trait PolynomialCommitmentScheme: Clone { transcript: &mut impl Transcript, ) -> Result<(), Error>; + #[allow(clippy::too_many_arguments)] fn batch_verify( vp: &Self::VerifierParam, - comms: &[Self::Commitment], - points: &[Vec], - evals: &[Evaluation], + num_instances: &[(usize, usize)], + points: &[Point], + fixed_comms: Option<&Self::Commitment>, + witin_comms: &Self::Commitment, + evals: &[Vec], proof: &Self::Proof, + circuit_num_polys: &[(usize, usize)], transcript: &mut impl Transcript, ) -> Result<(), Error>; @@ -247,6 +283,7 @@ pub enum Error { PolynomialTooLarge(usize), PolynomialSizesNotEqual, MerkleRootMismatch, + PointEvalMismatch(String), WhirError(whir_external::error::Error), } @@ -381,6 +418,8 @@ pub mod test_util { Pcs: PolynomialCommitmentScheme, Standard: Distribution, { + use std::collections::BTreeMap; + let mut rng = rand::thread_rng(); for num_vars in num_vars_start..num_vars_end { let (pp, vp) = setup_pcs::(num_vars); @@ -389,7 +428,9 @@ pub mod test_util { let mut transcript = BasicTranscript::new(b"BaseFold"); let rmm = RowMajorMatrix::::rand(&mut rng, 1 << num_vars, batch_size); let polys = rmm.to_mles(); - let comm = Pcs::batch_commit_and_write(&pp, rmm, &mut transcript).unwrap(); + let comm = + Pcs::batch_commit_and_write(&pp, BTreeMap::from([(0, rmm)]), &mut transcript) + .unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); transcript.append_field_element_exts(&evals); @@ -428,4 +469,85 @@ pub mod test_util { } } } + + #[cfg(test)] + pub(super) fn run_batch_commit_open_verify( + num_vars_start: usize, + num_vars_end: usize, + batch_size: usize, + ) where + E: ExtensionField, + Pcs: PolynomialCommitmentScheme, + Standard: Distribution, + { + use std::collections::BTreeMap; + + for num_vars in num_vars_start..num_vars_end { + let (pp, vp) = setup_pcs::(num_vars); + let num_instances = vec![(0, 1 << num_vars)]; + let circuit_num_polys = vec![(batch_size, 0)]; + + let (comm, evals, proof, challenge) = { + let mut transcript = BasicTranscript::new(b"BaseFold"); + let rmm = + RowMajorMatrix::::rand(&mut OsRng, 1 << num_vars, batch_size); + + let polys = rmm.to_mles(); + + let comm = + Pcs::batch_commit_and_write(&pp, BTreeMap::from([(0, rmm)]), &mut transcript) + .unwrap(); + let point = get_point_from_challenge(num_vars, &mut transcript); + let evals = polys.iter().map(|poly| poly.evaluate(&point)).collect_vec(); + transcript.append_field_element_exts(&evals); + + let proof = Pcs::batch_open( + &pp, + &num_instances, + None, + &comm, + &[point.clone()], + &[evals.clone()], + &circuit_num_polys, + &mut transcript, + ) + .unwrap(); + ( + Pcs::get_pure_commitment(&comm), + evals, + proof, + transcript.read_challenge(), + ) + }; + // Batch verify + { + let mut transcript = BasicTranscript::new(b"BaseFold"); + Pcs::write_commitment(&comm, &mut transcript).unwrap(); + + let point = get_point_from_challenge(num_vars, &mut transcript); + transcript.append_field_element_exts(&evals); + + Pcs::batch_verify( + &vp, + &num_instances, + &[point.clone()], + None, + &comm, + &[evals.clone()], + &proof, + &circuit_num_polys, + &mut transcript, + ) + .unwrap(); + + let v_challenge = transcript.read_challenge(); + assert_eq!(challenge, v_challenge); + + println!( + "Proof size for simple batch: {} bytes", + bincode::serialized_size(&proof).unwrap() + ); + } + } + } } diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index 5baf7680e..d0ec02060 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -2,6 +2,8 @@ pub mod arithmetic; pub mod expression; pub mod hash; pub mod parallel; +use std::collections::VecDeque; + use ff_ext::{ExtensionField, SmallField}; use itertools::{Either, Itertools, izip}; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; @@ -266,6 +268,93 @@ pub fn ext_try_into_base(x: &E) -> Result(input: &'a [T], sizes: &[usize]) -> Vec<&'a [T]> { + let total_size: usize = sizes.iter().sum(); + + if total_size != input.len() { + panic!( + "total size of chunks ({}) doesn't match input length ({})", + total_size, + input.len() + ); + } + + // `scan` keeps track of the current start index and produces each slice + sizes + .iter() + .scan(0, |start, &size| { + let end = *start + size; + let slice = &input[*start..end]; + *start = end; + Some(slice) + }) + .collect() +} + +/// removes and returns elements from the front of the deque +/// as long as they satisfy the given predicate. +/// +/// # arguments +/// * `deque` - the mutable VecDeque to operate on. +/// * `pred` - a predicate function that takes a reference to an element +/// and returns `true` if the element should be removed. +/// +/// # returns +/// a `Vec` containing all the elements that were removed. +pub fn pop_front_while(deque: &mut VecDeque, mut pred: F) -> Vec +where + F: FnMut(&T) -> bool, +{ + let mut result = Vec::new(); + while let Some(front) = deque.front() { + if pred(front) { + result.push(deque.pop_front().unwrap()); + } else { + break; + } + } + result +} + +#[inline(always)] +pub(crate) fn codeword_fold_with_challenge( + codeword: &[E], + challenge: E, + coeff: E::BaseField, + inv_2: E::BaseField, +) -> E { + let (left, right) = (codeword[0], codeword[1]); + // original (left, right) = (lo + hi*x, lo - hi*x), lo, hi are codeword, but after times x it's not codeword + // recover left & right codeword via (lo, hi) = ((left + right) / 2, (left - right) / 2x) + let (lo, hi) = ((left + right) * inv_2, (left - right) * coeff); // e.g. coeff = (2 * dit_butterfly)^(-1) in rs code + // we do fold on (lo, hi) to get folded = (1-r) * lo + r * hi (with lo, hi are two codewords), as it match perfectly with raw message in lagrange domain fixed variable + lo + challenge * (hi - lo) +} + #[cfg(any(test, feature = "benchmark"))] pub mod test { #[cfg(test)] diff --git a/mpcs/src/whir.rs b/mpcs/src/whir.rs index f9fa4cc10..7e7c9d5c8 100644 --- a/mpcs/src/whir.rs +++ b/mpcs/src/whir.rs @@ -1,6 +1,10 @@ mod spec; mod structure; +use std::collections::BTreeMap; + +use crate::Point; + use super::PolynomialCommitmentScheme; use ff_ext::{ExtensionField, PoseidonField}; use multilinear_extensions::{mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension}; @@ -10,6 +14,7 @@ pub use spec::WhirDefaultSpec; use spec::WhirSpec; use structure::WhirCommitment; pub use structure::{Whir, WhirDefault}; +use transcript::Transcript; use whir_external::{ crypto::{DigestExt, MerklePathBase, MerklePathExt, MerkleTreeBase, MerkleTreeExt}, parameters::MultivariateParameters, @@ -133,15 +138,15 @@ where fn batch_commit( _pp: &Self::ProverParam, - rmm: witness::RowMajorMatrix, + mut rmms: BTreeMap::BaseField>>, ) -> Result { let parameters = Spec::get_whir_parameters(true); let whir_config = WhirConfig::new( - MultivariateParameters::new(log2_strict_usize(rmm.num_instances())), + MultivariateParameters::new(log2_strict_usize(rmms[&0].num_instances())), parameters, ); let (witness, _commitment) = Committer::new(whir_config) - .batch_commit(rmm) + .batch_commit(rmms.remove(&0).unwrap()) .map_err(crate::Error::WhirError)?; Ok(witness) @@ -149,11 +154,13 @@ where fn batch_open( _pp: &Self::ProverParam, - _polys: &[ArcMultilinearExtension], - _comms: &[Self::CommitmentWithWitness], - _points: &[Vec], - _evals: &[crate::Evaluation], - _transcript: &mut impl transcript::Transcript, + _num_instances: &[(usize, usize)], + _fixed_comms: Option<&Self::CommitmentWithWitness>, + _witin_comms: &Self::CommitmentWithWitness, + _points: &[Point], + _evals: &[Vec], + _circuit_num_polys: &[(usize, usize)], + _transcript: &mut impl Transcript, ) -> Result { todo!() } @@ -177,11 +184,14 @@ where fn batch_verify( _vp: &Self::VerifierParam, - _comms: &[Self::Commitment], - _points: &[Vec], - _evals: &[crate::Evaluation], + _num_instances: &[(usize, usize)], + _points: &[Point], + _fixed_comms: Option<&Self::Commitment>, + _witin_comms: &Self::Commitment, + _evals: &[Vec], _proof: &Self::Proof, - _transcript: &mut impl transcript::Transcript, + _circuit_num_polys: &[(usize, usize)], + _transcript: &mut impl Transcript, ) -> Result<(), crate::Error> { todo!() } @@ -258,7 +268,7 @@ mod tests { } #[test] - fn whir_simple_batch_commit_open_verify_goldilocks() { + fn whir_batch_commit_open_verify_goldilocks() { { // Both challenge and poly are over base field run_simple_batch_commit_open_verify::(10, 16, 1); diff --git a/p3/Cargo.toml b/p3/Cargo.toml index 9a94cf1e6..676a363fc 100644 --- a/p3/Cargo.toml +++ b/p3/Cargo.toml @@ -15,6 +15,7 @@ p3-challenger.workspace = true p3-commit.workspace = true p3-dft.workspace = true p3-field.workspace = true +p3-fri.workspace = true p3-goldilocks.workspace = true p3-matrix.workspace = true p3-maybe-rayon = { workspace = true, features = ["parallel"] } diff --git a/p3/src/lib.rs b/p3/src/lib.rs index c3ea57fa4..525ae94dd 100644 --- a/p3/src/lib.rs +++ b/p3/src/lib.rs @@ -3,6 +3,7 @@ pub use p3_challenger as challenger; pub use p3_commit as commit; pub use p3_dft as dft; pub use p3_field as field; +pub use p3_fri as fri; pub use p3_goldilocks as goldilocks; pub use p3_matrix as matrix; pub use p3_maybe_rayon as maybe_rayon; diff --git a/poseidon/src/challenger.rs b/poseidon/src/challenger.rs index 26bd646fc..08791d3d5 100644 --- a/poseidon/src/challenger.rs +++ b/poseidon/src/challenger.rs @@ -56,8 +56,8 @@ impl CanSampleBits for DefaultChallenger where F: PoseidonField, { - fn sample_bits(&mut self, _bits: usize) -> usize { - todo!() + fn sample_bits(&mut self, bits: usize) -> usize { + self.inner.sample_bits(bits) } } diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 0207b825a..9e1355753 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -98,7 +98,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .iter() .map(|c| c.elements) .collect_vec(); - let poly = merge_sumcheck_prover_state(prover_states); + let poly = merge_sumcheck_prover_state(&prover_states); ( point, @@ -448,12 +448,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { } /// collect all mle evaluation (claim) after sumcheck - /// NOTE final evaluation size of each mle could be >= 1 - pub fn get_mle_final_evaluations(&self) -> Vec { + pub fn get_mle_final_evaluations(&self) -> Vec> { self.poly .flattened_ml_extensions .iter() - .flat_map(|mle| { + .map(|mle| { op_mle! { |mle| mle.to_vec(), |mle| mle.into_iter().map(E::from).collect_vec() @@ -462,6 +461,15 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .collect() } + /// collect all mle evaluation (claim) after sumcheck + /// NOTE final evaluation size of each mle could be >= 1 + pub fn get_mle_flatten_final_evaluations(&self) -> Vec { + self.get_mle_final_evaluations() + .into_iter() + .flatten() + .collect_vec() + } + pub fn expected_numvars_at_round(&self) -> usize { // first round start from 1 let num_vars = self.max_num_variables + 1 - self.round; diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index d31960daa..590ad38c6 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -287,10 +287,15 @@ pub fn merge_sumcheck_polys<'a, E: ExtensionField>( let blow_factor = 1 << (merged_num_vars - poly.num_vars()); DenseMultilinearExtension::from_evaluations_ext_vec( merged_num_vars, - poly.get_base_field_vec() - .iter() - .flat_map(|e| std::iter::repeat(E::from(*e)).take(blow_factor)) - .collect_vec(), + op_mle!( + poly, + |poly| { + poly.iter() + .flat_map(|e| std::iter::repeat(*e).take(blow_factor)) + .collect_vec() + }, + |base_poly| base_poly.iter().map(|e| E::from(*e)).collect_vec() + ), ) } }; @@ -300,9 +305,9 @@ pub fn merge_sumcheck_polys<'a, E: ExtensionField>( } /// retrieve virtual poly from sumcheck prover state to single virtual poly -pub fn merge_sumcheck_prover_state( - prover_states: Vec>, -) -> VirtualPolynomial<'_, E> { +pub fn merge_sumcheck_prover_state<'a, E: ExtensionField>( + prover_states: &[IOPProverState<'a, E>], +) -> VirtualPolynomial<'a, E> { merge_sumcheck_polys( prover_states.iter().map(|ps| &ps.poly).collect_vec(), Some(prover_states[0].poly_meta.clone()), diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index d333c334f..cc380f314 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -1,4 +1,5 @@ use ff_ext::{ExtensionField, FieldChallengerExt}; +use p3::challenger::CanSampleBits; use poseidon::challenger::{CanObserve, DefaultChallenger, FieldChallenger}; use crate::{Challenge, ForkableTranscript, Transcript}; @@ -53,6 +54,10 @@ impl Transcript for BasicTranscript { fn sample_vec(&mut self, n: usize) -> Vec { self.challenger.sample_ext_vec(n) } + + fn sample_bits(&mut self, bits: usize) -> usize { + self.challenger.sample_bits(bits) + } } impl ForkableTranscript for BasicTranscript {} diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index 57f763094..64bf56e86 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -63,6 +63,13 @@ pub trait Transcript { } } + /// Append a iterator of extension field elements to the transcript. + fn append_field_element_exts_iter<'a>(&mut self, element: impl Iterator) { + for e in element { + self.append_field_element_ext(e); + } + } + /// Append a challenge to the transcript. fn append_challenge(&mut self, challenge: Challenge) { self.append_field_element_ext(&challenge.elements) @@ -83,8 +90,24 @@ pub trait Transcript { self.sample_vec(n) } + fn sample_bits_and_append_vec( + &mut self, + label: &'static [u8], + n: usize, + bits: usize, + ) -> Vec { + self.append_message(label); + self.sample_vec_bits(n, bits) + } + fn sample_vec(&mut self, n: usize) -> Vec; + fn sample_bits(&mut self, bits: usize) -> usize; + + fn sample_vec_bits(&mut self, n: usize, bits: usize) -> Vec { + (0..n).map(|_| self.sample_bits(bits)).collect_vec() + } + /// derive one challenge from transcript and return all pows result fn sample_and_append_challenge_pows(&mut self, size: usize, label: &'static [u8]) -> Vec { let alpha = self.sample_and_append_challenge(label).elements; diff --git a/transcript/src/statistics.rs b/transcript/src/statistics.rs index 116532c12..c6e5fa6a2 100644 --- a/transcript/src/statistics.rs +++ b/transcript/src/statistics.rs @@ -58,6 +58,10 @@ impl Transcript for BasicTranscriptWithStat<'_, E> { fn sample_vec(&mut self, n: usize) -> Vec { self.inner.sample_vec(n) } + + fn sample_bits(&mut self, bits: usize) -> usize { + self.inner.sample_bits(bits) + } } impl ForkableTranscript for BasicTranscriptWithStat<'_, E> {} diff --git a/transcript/src/syncronized.rs b/transcript/src/syncronized.rs index a60b7b786..e22c58a09 100644 --- a/transcript/src/syncronized.rs +++ b/transcript/src/syncronized.rs @@ -54,6 +54,12 @@ impl Transcript for TranscriptSyncronized { .unwrap(); } + fn append_field_element_exts_iter<'a>(&mut self, element: impl Iterator) { + self.ef_append_tx[self.rolling_index] + .send(element.copied().collect::>()) + .unwrap(); + } + fn append_challenge(&mut self, _challenge: Challenge) { unimplemented!() } @@ -89,4 +95,8 @@ impl Transcript for TranscriptSyncronized { fn sample_vec(&mut self, _n: usize) -> Vec { unimplemented!() } + + fn sample_bits(&mut self, _bits: usize) -> usize { + unimplemented!() + } }