From 6bc67a064dfaa432763fcfa0ba3105bd372ea722 Mon Sep 17 00:00:00 2001 From: ljedrz Date: Tue, 27 Jan 2026 14:25:44 +0100 Subject: [PATCH 1/5] perf: use a singe committer_key and fft_precomputation for batch proving Signed-off-by: ljedrz --- algorithms/benches/snark/varuna.rs | 41 ++-- .../src/polycommit/kzg10/data_structures.rs | 147 +++++++++++- .../src/polycommit/kzg10/degree_info.rs | 63 +++++ algorithms/src/polycommit/kzg10/mod.rs | 3 + .../polycommit/sonic_pc/data_structures.rs | 84 +------ algorithms/src/polycommit/sonic_pc/mod.rs | 132 ++--------- algorithms/src/polycommit/test_templates.rs | 216 ++++++++---------- algorithms/src/snark/varuna/ahp/ahp.rs | 42 +--- .../src/snark/varuna/ahp/indexer/circuit.rs | 58 +---- .../snark/varuna/ahp/indexer/circuit_info.rs | 45 +++- .../src/snark/varuna/ahp/indexer/indexer.rs | 55 +---- .../ahp/prover/round_functions/first.rs | 26 ++- .../ahp/prover/round_functions/fourth.rs | 6 +- .../varuna/ahp/prover/round_functions/mod.rs | 5 +- .../prover/round_functions/prepare_third.rs | 14 +- .../ahp/prover/round_functions/second.rs | 26 ++- .../ahp/prover/round_functions/third.rs | 9 +- .../src/snark/varuna/ahp/prover/state.rs | 15 +- .../data_structures/circuit_proving_key.rs | 14 +- algorithms/src/snark/varuna/tests.rs | 58 +++-- algorithms/src/snark/varuna/varuna.rs | 83 +++---- algorithms/src/srs/universal_prover.rs | 13 +- circuit/environment/src/helpers/assignment.rs | 3 +- circuit/environment/src/helpers/converter.rs | 3 +- console/network/src/canary_v0.rs | 5 +- console/network/src/lib.rs | 3 +- console/network/src/mainnet_v0.rs | 14 +- console/network/src/testnet_v0.rs | 5 +- synthesizer/snark/src/certificate/mod.rs | 5 +- synthesizer/snark/src/lib.rs | 2 +- synthesizer/snark/src/proving_key/mod.rs | 20 +- 31 files changed, 592 insertions(+), 623 deletions(-) create mode 100644 algorithms/src/polycommit/kzg10/degree_info.rs diff --git a/algorithms/benches/snark/varuna.rs b/algorithms/benches/snark/varuna.rs index 8ccba91ff7..49a0147f8c 100644 --- a/algorithms/benches/snark/varuna.rs +++ b/algorithms/benches/snark/varuna.rs @@ -20,6 +20,7 @@ use snarkvm_algorithms::{ AlgebraicSponge, SNARK, crypto_hash::PoseidonSponge, + polycommit::kzg10::DegreeInfo, snark::varuna::{CircuitVerifyingKey, TestCircuit, VarunaHidingMode, VarunaSNARK, VarunaVersion, ahp::AHPForR1CS}, }; use snarkvm_curves::bls12_377::{Bls12_377, Fq, Fr}; @@ -67,25 +68,22 @@ fn snark_prove(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(1000, 1000, 1000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let fs_parameters = FS::sample_parameters(); let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let params = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let (pk, _) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let degree_info = pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); c.bench_function("snark_prove_v1", |b| { let varuna_version = VarunaVersion::V1; - b.iter(|| { - VarunaInst::prove(universal_prover, &fs_parameters, ¶ms.0, varuna_version, &circuit, rng).unwrap() - }) + b.iter(|| VarunaInst::prove(universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) }); c.bench_function("snark_prove_v2", |b| { let varuna_version = VarunaVersion::V2; - b.iter(|| { - VarunaInst::prove(universal_prover, &fs_parameters, ¶ms.0, varuna_version, &circuit, rng).unwrap() - }) + b.iter(|| VarunaInst::prove(universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) }); } @@ -99,7 +97,6 @@ fn snark_batch_prove(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(1000000, 1000000, 1000000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let fs_parameters = FS::sample_parameters(); let circuit_batch_size = 5; @@ -109,6 +106,8 @@ fn snark_batch_prove(c: &mut Criterion) { let mut all_circuits = Vec::with_capacity(circuit_batch_size); let mut keys_to_constraints = BTreeMap::new(); + let mut degree_info = None; + for i in 0..circuit_batch_size { let num_constraints = num_constraints_base + i; let num_variables = num_variables_base + i; @@ -120,13 +119,16 @@ fn snark_batch_prove(c: &mut Criterion) { circuits.push(circuit); } let (pk, _) = VarunaInst::circuit_setup(&universal_srs, &circuits[0]).unwrap(); - pks.push(pk); all_circuits.push(circuits); + let degree_info_i = pk.circuit.index_info.degree_info::().unwrap(); + degree_info = degree_info.map(|i: DegreeInfo| i.union(°ree_info_i)).or(Some(degree_info_i)); + pks.push(pk); } for i in 0..circuit_batch_size { keys_to_constraints.insert(&pks[i], all_circuits[i].as_slice()); } + let universal_prover = &universal_srs.to_universal_prover(degree_info.unwrap()).unwrap(); let varuna_version = VarunaVersion::V2; b.iter(|| { @@ -146,13 +148,14 @@ fn snark_verify(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(100, 100, 100).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); let (circuit, public_inputs) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let degree_info = pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); let varuna_version = VarunaVersion::V2; let proof = VarunaInst::prove(universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap(); @@ -180,13 +183,14 @@ fn snark_batch_verify(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(1000, 1000, 100).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); let circuit_batch_size = 5; let instance_batch_size = 5; + let mut degree_info = None; + let mut pks = Vec::with_capacity(circuit_batch_size); let mut vks = Vec::with_capacity(circuit_batch_size); let mut all_circuits = Vec::with_capacity(circuit_batch_size); @@ -205,16 +209,19 @@ fn snark_batch_verify(c: &mut Criterion) { inputs.push(public_inputs); } let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuits[0]).unwrap(); - pks.push(pk); - vks.push(vk); all_circuits.push(circuits); all_inputs.push(inputs); + let degree_info_i = pk.circuit.index_info.degree_info::().unwrap(); + degree_info = degree_info.map(|i: DegreeInfo| i.union(°ree_info_i)).or(Some(degree_info_i)); + pks.push(pk); + vks.push(vk); } for i in 0..circuit_batch_size { keys_to_constraints.insert(&pks[i], all_circuits[i].as_slice()); keys_to_inputs.insert(&vks[i], all_inputs[i].as_slice()); } + let universal_prover = &universal_srs.to_universal_prover(degree_info.unwrap()).unwrap(); let varuna_version = VarunaVersion::V2; let proof = @@ -300,7 +307,6 @@ fn snark_certificate_prove(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(100000, 100000, 100000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let fs_parameters = FS::sample_parameters(); let fs_p = &fs_parameters; @@ -310,6 +316,8 @@ fn snark_certificate_prove(c: &mut Criterion) { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let degree_info = pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); c.bench_function(&format!("snark_certificate_prove_{size}"), |b| { b.iter(|| VarunaInst::prove_vk(universal_prover, fs_p, &vk, &pk).unwrap()) @@ -322,7 +330,6 @@ fn snark_certificate_verify(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(100_000, 100_000, 100_000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); let fs_p = &fs_parameters; @@ -333,6 +340,8 @@ fn snark_certificate_verify(c: &mut Criterion) { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let degree_info = pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); let certificate = VarunaInst::prove_vk(universal_prover, fs_p, &vk, &pk).unwrap(); c.bench_function(&format!("snark_certificate_verify_{size}"), |b| { diff --git a/algorithms/src/polycommit/kzg10/data_structures.rs b/algorithms/src/polycommit/kzg10/data_structures.rs index bdd0488f99..6394fb7322 100644 --- a/algorithms/src/polycommit/kzg10/data_structures.rs +++ b/algorithms/src/polycommit/kzg10/data_structures.rs @@ -13,9 +13,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +use super::DegreeInfo; use crate::{ AlgebraicSponge, - fft::{DensePolynomial, EvaluationDomain}, + fft::{ + DensePolynomial, + EvaluationDomain, + domain::{FFTPrecomputation, IFFTPrecomputation}, + }, + polycommit::sonic_pc::CommitterKey, + prelude::PCError, + srs::{UniversalProver, UniversalVerifier}, }; use snarkvm_curves::{AffineCurve, PairingCurve, PairingEngine, ProjectiveCurve}; use snarkvm_fields::{ConstraintFieldError, ToConstraintField, Zero}; @@ -28,9 +36,9 @@ use snarkvm_utilities::{ serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate}, }; -use crate::srs::{UniversalProver, UniversalVerifier}; -use anyhow::Result; +use anyhow::{Result, bail}; use core::ops::{Add, AddAssign}; +use itertools::Itertools; use rand::RngCore; use std::{ borrow::Cow, @@ -98,8 +106,19 @@ impl UniversalParams { self.powers.max_num_powers() - 1 } - pub fn to_universal_prover(&self) -> Result> { - Ok(UniversalProver:: { max_degree: self.max_degree(), _unused: None }) + pub fn to_universal_prover(&self, degree_info: DegreeInfo) -> Result> { + // Construct the committer key. + let committer_key = self.to_committer_key(°ree_info)?; + // Construct the FFT and iFFT precomputation. + let (fft_precomputation, ifft_precomputation) = + Self::fft_precomputation(degree_info.max_fft_size).ok_or(SerializationError::InvalidData)?; + // Return the universal prover. + Ok(UniversalProver:: { + committer_key, + fft_precomputation, + ifft_precomputation, + max_degree: self.max_degree(), + }) } pub fn to_universal_verifier(&self) -> Result> { @@ -115,6 +134,124 @@ impl UniversalParams { prepared_negative_powers_of_beta_h: self.powers.prepared_negative_powers_of_beta_h(), }) } + + pub fn to_committer_key(&self, degree_info: &DegreeInfo) -> Result> { + let trim_time = start_timer!(|| "Trimming public parameters"); + + // Retrieve the supported parameters from the degree information. + let supported_degree = degree_info.max_degree; + let supported_lagrange_sizes = degree_info.lagrange_sizes.clone().map(|mut s| s.drain().collect_vec()); + let enforced_degree_bounds = degree_info.degree_bounds.clone().map(|mut s| s.drain().collect_vec()); + let supported_hiding_bound = degree_info.hiding_bound; + + // The maximum degree of the entire SRS. + let max_degree = self.max_degree(); + + // TODO (howardwu): This should never be happening... Use a BTreeSet. + let enforced_degree_bounds = enforced_degree_bounds.map(|bounds| { + let mut v = bounds.to_vec(); + v.sort_unstable(); + v.dedup(); + v + }); + + let (shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g) = if let Some(enforced_degree_bounds) = + enforced_degree_bounds.as_ref() + { + if enforced_degree_bounds.is_empty() { + (None, None) + } else { + let highest_enforced_degree_bound = *enforced_degree_bounds.last().unwrap(); + if highest_enforced_degree_bound > supported_degree { + bail!( + "The highest enforced degree bound {highest_enforced_degree_bound} is larger than the supported degree {supported_degree}" + ); + } + + let lowest_shift_degree = max_degree - highest_enforced_degree_bound; + + let shifted_ck_time = start_timer!(|| format!( + "Constructing `shifted_powers_of_beta_g` of size {}", + max_degree - lowest_shift_degree + 1 + )); + + let shifted_powers_of_beta_g = + self.powers_of_beta_g(lowest_shift_degree, self.max_degree() + 1)?.to_vec(); + let mut shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); + // Also add degree 0. + for degree_bound in enforced_degree_bounds { + let shift_degree = max_degree - degree_bound; + let mut powers_for_degree_bound = Vec::with_capacity((max_degree + 2).saturating_sub(shift_degree)); + for i in 0..=supported_hiding_bound + 1 { + // We have an additional degree in `powers_of_beta_times_gamma_g` beyond + // `powers_of_beta_g`. + if shift_degree + i < max_degree + 2 { + powers_for_degree_bound.push(self.powers_of_beta_times_gamma_g()[&(shift_degree + i)]); + } + } + shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, powers_for_degree_bound); + } + + end_timer!(shifted_ck_time); + + (Some(shifted_powers_of_beta_g), Some(shifted_powers_of_beta_times_gamma_g)) + } + } else { + (None, None) + }; + + let powers_of_beta_g = self.powers_of_beta_g(0, supported_degree + 1)?.to_vec(); + let powers_of_beta_times_gamma_g = (0..=(supported_hiding_bound + 1)) + .map(|i| { + self.powers_of_beta_times_gamma_g() + .get(&i) + .copied() + .ok_or(PCError::HidingBoundToolarge { hiding_poly_degree: supported_hiding_bound, num_powers: 0 }) + }) + .collect::, _>>()?; + + let mut lagrange_bases_at_beta_g = BTreeMap::new(); + if let Some(supported_lagrange_sizes) = supported_lagrange_sizes { + for size in supported_lagrange_sizes { + let lagrange_time = start_timer!(|| format!("Constructing `lagrange_bases` of size {size}")); + if !size.is_power_of_two() { + bail!("The Lagrange basis size ({size}) is not a power of two") + } + if size > self.max_degree() + 1 { + bail!( + "The Lagrange basis size ({size}) is larger than the supported degree ({})", + self.max_degree() + 1 + ) + } + let domain = crate::fft::EvaluationDomain::new(size).unwrap(); + let lagrange_basis_at_beta_g = self.lagrange_basis(domain)?; + assert!(lagrange_basis_at_beta_g.len().is_power_of_two()); + lagrange_bases_at_beta_g.insert(domain.size(), lagrange_basis_at_beta_g); + end_timer!(lagrange_time); + } + } + + let ck = CommitterKey { + powers_of_beta_g, + lagrange_bases_at_beta_g, + powers_of_beta_times_gamma_g, + shifted_powers_of_beta_g, + shifted_powers_of_beta_times_gamma_g, + enforced_degree_bounds, + }; + + end_timer!(trim_time); + Ok(ck) + } + + /// Returns the FFT and iFFT precomputation for the largest supported + /// domain. + fn fft_precomputation(max_domain_size: usize) -> Option<(FFTPrecomputation, IFFTPrecomputation)> { + let domain = EvaluationDomain::new(max_domain_size)?; + let fft_precomputation = domain.precompute_fft(); + let ifft_precomputation = fft_precomputation.to_ifft_precomputation(); + Some((fft_precomputation, ifft_precomputation)) + } } impl FromBytes for UniversalParams { diff --git a/algorithms/src/polycommit/kzg10/degree_info.rs b/algorithms/src/polycommit/kzg10/degree_info.rs new file mode 100644 index 0000000000..43ccabf713 --- /dev/null +++ b/algorithms/src/polycommit/kzg10/degree_info.rs @@ -0,0 +1,63 @@ +// Copyright (c) 2019-2026 Provable Inc. +// This file is part of the snarkVM library. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashSet; + +#[derive(Clone)] +pub struct DegreeInfo { + /// The maximum degree of the required SRS to commit to the polynomials. + pub max_degree: usize, + /// The maximum IOP poly degree used for (i)fft_precomputation. + pub max_fft_size: usize, + /// TODO (howardwu): This should never be happening... Use a BTreeSet. + /// The degree bounds on IOP polynomials. + pub degree_bounds: Option>, + /// The hiding bound for polynomial queries. + pub hiding_bound: usize, + /// The supported sizes for the lagrange-basis SRS. + pub lagrange_sizes: Option>, +} + +impl DegreeInfo { + /// Initializes a new degree info. + pub const fn new( + max_degree: usize, + max_fft_size: usize, + degree_bounds: Option>, + hiding_bound: usize, + lagrange_sizes: Option>, + ) -> Self { + Self { max_degree, max_fft_size, degree_bounds, hiding_bound, lagrange_sizes } + } +} + +impl DegreeInfo { + pub fn union(self, other: &Self) -> Self { + let max_degree = self.max_degree.max(other.max_degree); + let max_fft_size = self.max_fft_size.max(other.max_fft_size); + let degree_bounds = match (&self.degree_bounds, &other.degree_bounds) { + (Some(a), Some(b)) => Some(a | b), + (Some(a), None) | (None, Some(a)) => Some(a.clone()), + (None, None) => None, + }; + let hiding_bound = self.hiding_bound.max(other.hiding_bound); + let lagrange_sizes = match (&self.lagrange_sizes, &other.lagrange_sizes) { + (Some(a), Some(b)) => Some(a | b), + (Some(a), None) | (None, Some(a)) => Some(a.clone()), + (None, None) => None, + }; + Self::new(max_degree, max_fft_size, degree_bounds, hiding_bound, lagrange_sizes) + } +} diff --git a/algorithms/src/polycommit/kzg10/mod.rs b/algorithms/src/polycommit/kzg10/mod.rs index 8b21c3a9e7..a40cad2b4f 100644 --- a/algorithms/src/polycommit/kzg10/mod.rs +++ b/algorithms/src/polycommit/kzg10/mod.rs @@ -41,6 +41,9 @@ use rayon::prelude::*; mod data_structures; pub use data_structures::*; +mod degree_info; +pub use degree_info::*; + use super::sonic_pc::LabeledPolynomialWithBasis; #[derive(Debug, PartialEq, Eq)] diff --git a/algorithms/src/polycommit/sonic_pc/data_structures.rs b/algorithms/src/polycommit/sonic_pc/data_structures.rs index f2c09c876b..50b555ebf4 100644 --- a/algorithms/src/polycommit/sonic_pc/data_structures.rs +++ b/algorithms/src/polycommit/sonic_pc/data_structures.rs @@ -275,44 +275,11 @@ impl ToBytes for CommitterKey { } impl CommitterKey { - fn len(&self) -> usize { - if self.shifted_powers_of_beta_g.is_some() { self.shifted_powers_of_beta_g.as_ref().unwrap().len() } else { 0 } - } -} - -/// `CommitterUnionKey` is a union of `CommitterKey`s, useful for multi-circuit -/// batch proofs. -#[derive(Debug)] -pub struct CommitterUnionKey<'a, E: PairingEngine> { - /// The key used to commit to polynomials. - pub powers_of_beta_g: Option<&'a Vec>, - - /// The key used to commit to polynomials in Lagrange basis. - pub lagrange_bases_at_beta_g: BTreeMap>, - - /// The key used to commit to hiding polynomials. - pub powers_of_beta_times_gamma_g: Option<&'a Vec>, - - /// The powers used to commit to shifted polynomials. - /// This is `None` if `self` does not support enforcing any degree bounds. - pub shifted_powers_of_beta_g: Option<&'a Vec>, - - /// The powers used to commit to shifted hiding polynomials. - /// This is `None` if `self` does not support enforcing any degree bounds. - pub shifted_powers_of_beta_times_gamma_g: Option>>, - - /// The degree bounds that are supported by `self`. - /// Sorted in ascending order from smallest bound to largest bound. - /// This is `None` if `self` does not support enforcing any degree bounds. - pub enforced_degree_bounds: Option>, -} - -impl<'a, E: PairingEngine> CommitterUnionKey<'a, E> { /// Obtain powers for the underlying KZG10 construction pub fn powers(&self) -> kzg10::Powers { kzg10::Powers { - powers_of_beta_g: self.powers_of_beta_g.unwrap().as_slice().into(), - powers_of_beta_times_gamma_g: self.powers_of_beta_times_gamma_g.unwrap().as_slice().into(), + powers_of_beta_g: self.powers_of_beta_g.as_slice().into(), + powers_of_beta_times_gamma_g: self.powers_of_beta_times_gamma_g.as_slice().into(), } } @@ -345,55 +312,10 @@ impl<'a, E: PairingEngine> CommitterUnionKey<'a, E> { pub fn lagrange_basis(&self, domain: EvaluationDomain) -> Option> { self.lagrange_bases_at_beta_g.get(&domain.size()).map(|basis| kzg10::LagrangeBasis { lagrange_basis_at_beta_g: Cow::Borrowed(basis), - powers_of_beta_times_gamma_g: Cow::Borrowed(self.powers_of_beta_times_gamma_g.unwrap()), + powers_of_beta_times_gamma_g: Cow::Borrowed(&self.powers_of_beta_times_gamma_g), domain, }) } - - pub fn union>>(committer_keys: T) -> Self { - let mut ck_union = CommitterUnionKey:: { - powers_of_beta_g: None, - lagrange_bases_at_beta_g: BTreeMap::new(), - powers_of_beta_times_gamma_g: None, - shifted_powers_of_beta_g: None, - shifted_powers_of_beta_times_gamma_g: None, - enforced_degree_bounds: None, - }; - let mut enforced_degree_bounds = vec![]; - let mut biggest_ck: Option<&CommitterKey> = None; - let mut shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); - for ck in committer_keys { - if biggest_ck.is_none() || biggest_ck.unwrap().len() < ck.len() { - biggest_ck = Some(ck); - } - let lagrange_bases = &ck.lagrange_bases_at_beta_g; - for (bound_base, bases) in lagrange_bases.iter() { - ck_union.lagrange_bases_at_beta_g.entry(*bound_base).or_insert(bases); - } - if let Some(shifted_powers) = ck.shifted_powers_of_beta_times_gamma_g.as_ref() { - for (bound_power, powers) in shifted_powers.iter() { - shifted_powers_of_beta_times_gamma_g.entry(*bound_power).or_insert(powers); - } - } - if let Some(degree_bounds) = &ck.enforced_degree_bounds { - enforced_degree_bounds.append(&mut degree_bounds.clone()); - } - } - - let biggest_ck = biggest_ck.unwrap(); - ck_union.powers_of_beta_g = Some(&biggest_ck.powers_of_beta_g); - ck_union.powers_of_beta_times_gamma_g = Some(&biggest_ck.powers_of_beta_times_gamma_g); - ck_union.shifted_powers_of_beta_g = biggest_ck.shifted_powers_of_beta_g.as_ref(); - - if !enforced_degree_bounds.is_empty() { - enforced_degree_bounds.sort(); - enforced_degree_bounds.dedup(); - ck_union.enforced_degree_bounds = Some(enforced_degree_bounds); - ck_union.shifted_powers_of_beta_times_gamma_g = Some(shifted_powers_of_beta_times_gamma_g); - } - - ck_union - } } /// Evaluation proof at a query set. diff --git a/algorithms/src/polycommit/sonic_pc/mod.rs b/algorithms/src/polycommit/sonic_pc/mod.rs index 251ae6b4f3..0b6b71254b 100644 --- a/algorithms/src/polycommit/sonic_pc/mod.rs +++ b/algorithms/src/polycommit/sonic_pc/mod.rs @@ -61,109 +61,6 @@ impl> SonicKZG10 { kzg10::KZG10::load_srs(max_degree) } - pub fn trim( - pp: &UniversalParams, - supported_degree: usize, - supported_lagrange_sizes: impl IntoIterator, - supported_hiding_bound: usize, - enforced_degree_bounds: Option<&[usize]>, - ) -> Result<(CommitterKey, UniversalVerifier)> { - let trim_time = start_timer!(|| "Trimming public parameters"); - let max_degree = pp.max_degree(); - - let enforced_degree_bounds = enforced_degree_bounds.map(|bounds| { - let mut v = bounds.to_vec(); - v.sort_unstable(); - v.dedup(); - v - }); - - let (shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g) = if let Some(enforced_degree_bounds) = - enforced_degree_bounds.as_ref() - { - if enforced_degree_bounds.is_empty() { - (None, None) - } else { - let highest_enforced_degree_bound = *enforced_degree_bounds.last().unwrap(); - if highest_enforced_degree_bound > supported_degree { - bail!( - "The highest enforced degree bound {highest_enforced_degree_bound} is larger than the supported degree {supported_degree}" - ); - } - - let lowest_shift_degree = max_degree - highest_enforced_degree_bound; - - let shifted_ck_time = start_timer!(|| format!( - "Constructing `shifted_powers_of_beta_g` of size {}", - max_degree - lowest_shift_degree + 1 - )); - - let shifted_powers_of_beta_g = pp.powers_of_beta_g(lowest_shift_degree, pp.max_degree() + 1)?; - let mut shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); - // Also add degree 0. - for degree_bound in enforced_degree_bounds { - let shift_degree = max_degree - degree_bound; - // We have an additional degree in `powers_of_beta_times_gamma_g` beyond - // `powers_of_beta_g`. - let powers_for_degree_bound = pp - .powers_of_beta_times_gamma_g() - .range(shift_degree..max_degree.min(shift_degree + supported_hiding_bound) + 2) - .map(|(_k, v)| *v) - .collect(); - shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, powers_for_degree_bound); - } - - end_timer!(shifted_ck_time); - - (Some(shifted_powers_of_beta_g), Some(shifted_powers_of_beta_times_gamma_g)) - } - } else { - (None, None) - }; - - let powers_of_beta_g = pp.powers_of_beta_g(0, supported_degree + 1)?; - let powers_of_beta_times_gamma_g = pp - .powers_of_beta_times_gamma_g() - .range(0..=(supported_hiding_bound + 1)) - .map(|(_k, v)| *v) - .collect::>(); - if powers_of_beta_times_gamma_g.len() != supported_hiding_bound + 2 { - return Err( - PCError::HidingBoundToolarge { hiding_poly_degree: supported_hiding_bound, num_powers: 0 }.into() - ); - } - - let mut lagrange_bases_at_beta_g = BTreeMap::new(); - for size in supported_lagrange_sizes { - let lagrange_time = start_timer!(|| format!("Constructing `lagrange_bases` of size {size}")); - if !size.is_power_of_two() { - bail!("The Lagrange basis size ({size}) is not a power of two") - } - if size > pp.max_degree() + 1 { - bail!("The Lagrange basis size ({size}) is larger than the supported degree ({})", pp.max_degree() + 1) - } - let domain = crate::fft::EvaluationDomain::new(size).unwrap(); - let lagrange_basis_at_beta_g = pp.lagrange_basis(domain)?; - assert!(lagrange_basis_at_beta_g.len().is_power_of_two()); - lagrange_bases_at_beta_g.insert(domain.size(), lagrange_basis_at_beta_g); - end_timer!(lagrange_time); - } - - let ck = CommitterKey { - powers_of_beta_g, - lagrange_bases_at_beta_g, - powers_of_beta_times_gamma_g, - shifted_powers_of_beta_g, - shifted_powers_of_beta_times_gamma_g, - enforced_degree_bounds, - }; - - let vk = pp.to_universal_verifier()?; - - end_timer!(trim_time); - Ok((ck, vk)) - } - /// Outputs commitments to `polynomials`. /// /// If `polynomials[i].is_hiding()`, then the `i`-th commitment is hiding @@ -180,12 +77,12 @@ impl> SonicKZG10 { #[allow(clippy::format_push_string)] pub fn commit<'b>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, polynomials: impl IntoIterator>, rng: Option<&mut dyn RngCore>, ) -> Result<(Vec>>, Vec>), PCError> { let rng = &mut OptionalRng(rng); let commit_time = start_timer!(|| "Committing to polynomials"); + let ck = &universal_prover.committer_key; let mut pool = snarkvm_utilities::ExecutionPool::>::new(); for p in polynomials { @@ -262,7 +159,6 @@ impl> SonicKZG10 { pub fn combine_for_open<'a>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, labeled_polynomials: impl ExactSizeIterator>, rands: impl ExactSizeIterator>, fs_rng: &mut S, @@ -273,6 +169,7 @@ impl> SonicKZG10 { { ensure!(labeled_polynomials.len() == rands.len()); let mut to_combine = Vec::with_capacity(labeled_polynomials.len()); + let ck = &universal_prover.committer_key; for (p, r) in labeled_polynomials.zip_eq(rands) { let enforced_degree_bounds: Option<&[usize]> = ck.enforced_degree_bounds.as_deref(); @@ -290,7 +187,6 @@ impl> SonicKZG10 { /// set. pub fn batch_open<'a>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, labeled_polynomials: impl ExactSizeIterator>, query_set: &QuerySet, rands: impl ExactSizeIterator>, @@ -301,6 +197,7 @@ impl> SonicKZG10 { Commitment: 'a, { ensure!(labeled_polynomials.len() == rands.len()); + let ck = &universal_prover.committer_key; let poly_rand: HashMap<_, _> = labeled_polynomials.into_iter().zip_eq(rands).map(|(poly, r)| (poly.label(), (poly, r))).collect(); @@ -330,7 +227,7 @@ impl> SonicKZG10 { query_rands.push(*rand); } let (polynomial, rand) = - Self::combine_for_open(universal_prover, ck, query_polys.into_iter(), query_rands.into_iter(), fs_rng)?; + Self::combine_for_open(universal_prover, query_polys.into_iter(), query_rands.into_iter(), fs_rng)?; let _randomizer = fs_rng.squeeze_short_nonnative_field_element::(); pool.add_job(move || { @@ -417,7 +314,6 @@ impl> SonicKZG10 { pub fn open_combinations<'a>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, linear_combinations: impl IntoIterator>, polynomials: impl IntoIterator>, rands: impl IntoIterator>, @@ -444,7 +340,7 @@ impl> SonicKZG10 { let num_polys = lc.len(); // We filter out l.is_one() entries because those constants are not committed to - // and used directly by the verifier. + // but used directly by the verifier. for (coeff, label) in lc.iter().filter(|(_, l)| !l.is_one()) { let label: &String = label.try_into().expect("cannot be one!"); let (cur_poly, cur_rand) = @@ -472,8 +368,7 @@ impl> SonicKZG10 { lc_info.push((lc_label, degree_bound)); } - let proof = - Self::batch_open(universal_prover, ck, lc_polynomials.iter(), query_set, lc_randomness.iter(), fs_rng)?; + let proof = Self::batch_open(universal_prover, lc_polynomials.iter(), query_set, lc_randomness.iter(), fs_rng)?; Ok(BatchLCProof { proof }) } @@ -691,7 +586,10 @@ mod tests { #![allow(non_camel_case_types)] use super::{CommitterKey, SonicKZG10}; - use crate::{crypto_hash::PoseidonSponge, polycommit::test_templates::*}; + use crate::{ + crypto_hash::PoseidonSponge, + polycommit::{kzg10::DegreeInfo, test_templates::*}, + }; use snarkvm_curves::bls12_377::{Bls12_377, Fq}; use snarkvm_utilities::{FromBytes, ToBytes, rand::TestRng}; @@ -704,13 +602,15 @@ mod tests { fn test_committer_key_serialization() { let rng = &mut TestRng::default(); let max_degree = rand::distributions::Uniform::from(8..=64).sample(rng); - let supported_degree = rand::distributions::Uniform::from(1..=max_degree).sample(rng); + let pp = PC_Bls12_377::load_srs(max_degree).unwrap(); + let supported_degree = rand::distributions::Uniform::from(1..=max_degree).sample(rng); + let hiding_bound = 0; let lagrange_size = |d: usize| if d.is_power_of_two() { d } else { d.next_power_of_two() >> 1 }; + let lagrange_sizes = Some([lagrange_size(supported_degree)].into_iter().collect()); + let degree_info = DegreeInfo::new(supported_degree, supported_degree, None, hiding_bound, lagrange_sizes); - let pp = PC_Bls12_377::load_srs(max_degree).unwrap(); - - let (ck, _vk) = PC_Bls12_377::trim(&pp, supported_degree, [lagrange_size(supported_degree)], 0, None).unwrap(); + let ck = pp.to_committer_key(°ree_info).unwrap(); let ck_bytes = ck.to_bytes_le().unwrap(); let ck_recovered: CommitterKey = FromBytes::read_le(&ck_bytes[..]).unwrap(); diff --git a/algorithms/src/polycommit/test_templates.rs b/algorithms/src/polycommit/test_templates.rs index d47c1c30ba..0eac7b0354 100644 --- a/algorithms/src/polycommit/test_templates.rs +++ b/algorithms/src/polycommit/test_templates.rs @@ -17,7 +17,6 @@ use super::sonic_pc::{ BatchLCProof, BatchProof, Commitment, - CommitterUnionKey, Evaluations, LabeledCommitment, QuerySet, @@ -29,6 +28,7 @@ use crate::{ fft::DensePolynomial, polycommit::{ PCError, + kzg10::DegreeInfo, sonic_pc::{LabeledPolynomial, LabeledPolynomialWithBasis, LinearCombination}, }, srs::UniversalVerifier, @@ -42,7 +42,7 @@ use rand::{ Rng, distributions::{self, Distribution}, }; -use std::marker::PhantomData; +use std::{collections::HashSet, marker::PhantomData}; #[derive(Default)] struct TestInfo { @@ -56,7 +56,7 @@ struct TestInfo { } pub struct TestComponents> { - pub verification_key: UniversalVerifier, + pub universal_verifier: UniversalVerifier, pub commitments: Vec>>, pub query_set: QuerySet, pub evaluations: Evaluations, @@ -69,8 +69,9 @@ pub struct TestComponents> { pub fn bad_degree_bound_test>() -> Result<(), PCError> { let rng = &mut TestRng::default(); let max_degree = 100; + let hiding_bound = 1; let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); + let universal_verifier = pp.to_universal_verifier().unwrap(); for _ in 0..10 { let supported_degree = distributions::Uniform::from(1..=max_degree).sample(rng); @@ -78,7 +79,7 @@ pub fn bad_degree_bound_test>() - let mut labels = Vec::new(); let mut polynomials = Vec::new(); - let mut degree_bounds = Vec::new(); + let mut degree_bounds = HashSet::new(); for i in 0..10 { let label = format!("Test{i}"); @@ -86,22 +87,22 @@ pub fn bad_degree_bound_test>() - let poly = DensePolynomial::rand(supported_degree, rng); let degree_bound = 1usize; - let hiding_bound = Some(1); - degree_bounds.push(degree_bound); + degree_bounds.insert(degree_bound); - polynomials.push(LabeledPolynomial::new(label, poly, Some(degree_bound), hiding_bound)) + polynomials.push(LabeledPolynomial::new(label, poly, Some(degree_bound), Some(hiding_bound))) } - println!("supported degree: {supported_degree:?}"); - let (ck, vk) = - SonicKZG10::::trim(&pp, supported_degree, None, supported_degree, Some(degree_bounds.as_slice())) - .unwrap(); - println!("Trimmed"); - - let ck = CommitterUnionKey::union(std::iter::once(&ck)); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, + degree_bounds: Some(degree_bounds), + hiding_bound, + lagrange_sizes: None, + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, polynomials.iter().map(Into::into), Some(rng))?; + SonicKZG10::::commit(universal_prover, polynomials.iter().map(Into::into), Some(rng))?; let mut query_set = QuerySet::new(); let mut values = Evaluations::new(); @@ -116,14 +117,14 @@ pub fn bad_degree_bound_test>() - let mut sponge_for_open = S::new(); let proof = SonicKZG10::batch_open( universal_prover, - &ck, polynomials.iter(), &query_set, rands.iter(), &mut sponge_for_open, )?; let mut sponge_for_check = S::new(); - let result = SonicKZG10::batch_check(&vk, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; + let result = + SonicKZG10::batch_check(&universal_verifier, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; assert!(result, "proof was incorrect, Query set: {query_set:#?}"); } Ok(()) @@ -141,17 +142,17 @@ pub fn lagrange_test_template>() let rng = &mut TestRng::default(); let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); for _ in 0..num_iters { + let universal_verifier = pp.to_universal_verifier().unwrap(); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); let mut lagrange_polynomials = Vec::new(); - let mut supported_lagrange_sizes = Vec::new(); - let degree_bounds = None; + let mut supported_lagrange_sizes = HashSet::new(); + let degree_bounds = HashSet::new(); + let hiding_bound = 1; let mut labels = Vec::new(); - println!("Sampled supported degree"); // Generate polynomials let num_points_in_query_set = distributions::Uniform::from(1..=max_num_queries).sample(rng); @@ -166,34 +167,26 @@ pub fn lagrange_test_template>() let domain = crate::fft::EvaluationDomain::new(evals.len()).unwrap(); let evals = crate::fft::Evaluations::from_vec_and_domain(evals, domain); let poly = evals.interpolate_by_ref(); - supported_lagrange_sizes.push(domain.size()); + supported_lagrange_sizes.insert(domain.size()); assert_eq!(poly.evaluate_over_domain_by_ref(domain), evals); let degree_bound = None; - let hiding_bound = Some(1); - polynomials.push(LabeledPolynomial::new(label.clone(), poly, degree_bound, hiding_bound)); - lagrange_polynomials.push(LabeledPolynomialWithBasis::new_lagrange_basis(label, evals, hiding_bound)) + polynomials.push(LabeledPolynomial::new(label.clone(), poly, degree_bound, Some(hiding_bound))); + lagrange_polynomials.push(LabeledPolynomialWithBasis::new_lagrange_basis(label, evals, Some(hiding_bound))) } let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); - println!("supported degree: {supported_degree:?}"); - println!("supported hiding bound: {supported_hiding_bound:?}"); - println!("num_points_in_query_set: {num_points_in_query_set:?}"); - let (ck, vk_orig) = SonicKZG10::::trim( - &pp, - supported_degree, - supported_lagrange_sizes, - supported_hiding_bound, - degree_bounds, - ) - .unwrap(); - println!("Trimmed"); - - let ck = CommitterUnionKey::union(std::iter::once(&ck)); - let vk = pp.to_universal_verifier().unwrap(); - - let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, lagrange_polynomials, Some(rng)).unwrap(); + assert_eq!(supported_hiding_bound, 1); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, + degree_bounds: Some(degree_bounds), + hiding_bound, + lagrange_sizes: Some(supported_lagrange_sizes), + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); + + let (comms, rands) = SonicKZG10::::commit(universal_prover, lagrange_polynomials, Some(rng)).unwrap(); // Construct query set let mut query_set = QuerySet::new(); @@ -212,14 +205,14 @@ pub fn lagrange_test_template>() let mut sponge_for_open = S::new(); let proof = SonicKZG10::batch_open( universal_prover, - &ck, polynomials.iter(), &query_set, rands.iter(), &mut sponge_for_open, )?; let mut sponge_for_check = S::new(); - let result = SonicKZG10::batch_check(&vk, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; + let result = + SonicKZG10::batch_check(&universal_verifier, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; if !result { println!("Failed with {num_polynomials} polynomials, num_points_in_query_set: {num_points_in_query_set:?}"); println!("Degree of polynomials:"); @@ -230,7 +223,7 @@ pub fn lagrange_test_template>() assert!(result, "proof was incorrect, Query set: {query_set:#?}"); test_components.push(TestComponents { - verification_key: vk_orig, + universal_verifier, commitments: comms, query_set, evaluations: values, @@ -263,15 +256,16 @@ where let rng = &mut TestRng::default(); let max_degree = max_degree.unwrap_or_else(|| distributions::Uniform::from(8..=64).sample(rng)); let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); let supported_degree_bounds = [1 << 10, 1 << 15, 1 << 20, 1 << 25, 1 << 30]; + let hiding_bound = 1; for _ in 0..num_iters { + let universal_verifier = pp.to_universal_verifier().unwrap(); let supported_degree = supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree).sample(rng)); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); - let mut degree_bounds = if enforce_degree_bounds { Some(Vec::new()) } else { None }; + let mut degree_bounds = HashSet::new(); let mut labels = Vec::new(); println!("Sampled supported degree"); @@ -290,40 +284,31 @@ where .filter(|x| *x >= degree && *x < supported_degree) .collect::>(); - let degree_bound = if let Some(degree_bounds) = &mut degree_bounds { - if !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen() { - let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); - let idx = range.sample(rng); - - let degree_bound = supported_degree_bounds_after_trimmed[idx]; - degree_bounds.push(degree_bound); - Some(degree_bound) - } else { - None - } - } else { - None - }; - - let hiding_bound = Some(1); - println!("Hiding bound: {hiding_bound:?}"); - - polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, hiding_bound)) + let compute_degree_bound = + enforce_degree_bounds && !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen(); + let degree_bound = compute_degree_bound.then(|| { + let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); + let idx = range.sample(rng); + supported_degree_bounds_after_trimmed[idx] + }); + if let Some(degree_bound) = degree_bound { + degree_bounds.insert(degree_bound); + } + polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, Some(hiding_bound))) } let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); - println!("supported degree: {supported_degree:?}"); - println!("supported hiding bound: {supported_hiding_bound:?}"); - println!("num_points_in_query_set: {num_points_in_query_set:?}"); - let (ck, vk_orig) = - SonicKZG10::::trim(&pp, supported_degree, None, supported_hiding_bound, degree_bounds.as_deref()) - .unwrap(); - println!("Trimmed"); - - let ck = CommitterUnionKey::union(std::iter::once(&ck)); - let vk = pp.to_universal_verifier().unwrap(); + assert_eq!(supported_hiding_bound, 1); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, + degree_bounds: Some(degree_bounds), + hiding_bound, + lagrange_sizes: None, + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, polynomials.iter().map(Into::into), Some(rng))?; + SonicKZG10::::commit(universal_prover, polynomials.iter().map(Into::into), Some(rng))?; // Construct query set let mut query_set = QuerySet::new(); @@ -342,14 +327,14 @@ where let mut sponge_for_open = S::new(); let proof = SonicKZG10::batch_open( universal_prover, - &ck, polynomials.iter(), &query_set, rands.iter(), &mut sponge_for_open, )?; let mut sponge_for_check = S::new(); - let result = SonicKZG10::batch_check(&vk, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; + let result = + SonicKZG10::batch_check(&universal_verifier, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; if !result { println!("Failed with {num_polynomials} polynomials, num_points_in_query_set: {num_points_in_query_set:?}"); println!("Degree of polynomials:"); @@ -360,7 +345,7 @@ where assert!(result, "proof was incorrect, Query set: {query_set:#?}"); test_components.push(TestComponents { - verification_key: vk_orig, + universal_verifier, commitments: comms, query_set, evaluations: values, @@ -391,15 +376,16 @@ fn equation_test_template>( let rng = &mut TestRng::default(); let max_degree = max_degree.unwrap_or_else(|| distributions::Uniform::from(8..=64).sample(rng)); let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); let supported_degree_bounds = [1 << 10, 1 << 15, 1 << 20, 1 << 25, 1 << 30]; + let hiding_bound = 1; for _ in 0..num_iters { + let universal_verifier = pp.to_universal_verifier().unwrap(); let supported_degree = supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree).sample(rng)); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); - let mut degree_bounds = if enforce_degree_bounds { Some(Vec::new()) } else { None }; + let mut degree_bounds = HashSet::new(); let mut labels = Vec::new(); println!("Sampled supported degree"); @@ -418,44 +404,31 @@ fn equation_test_template>( .filter(|x| *x >= degree && *x < supported_degree) .collect::>(); - let degree_bound = if let Some(degree_bounds) = &mut degree_bounds { - if !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen() { - let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); - let idx = range.sample(rng); - - let degree_bound = supported_degree_bounds_after_trimmed[idx]; - degree_bounds.push(degree_bound); - Some(degree_bound) - } else { - None - } - } else { - None - }; - - let hiding_bound = Some(1); - println!("Hiding bound: {hiding_bound:?}"); - - polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, hiding_bound)) + let compute_degree_bound = + enforce_degree_bounds && !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen(); + let degree_bound = compute_degree_bound.then(|| { + let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); + let idx = range.sample(rng); + supported_degree_bounds_after_trimmed[idx] + }); + if let Some(degree_bound) = degree_bound { + degree_bounds.insert(degree_bound); + } + polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, Some(hiding_bound))) } let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); - println!("supported degree: {supported_degree:?}"); - println!("supported hiding bound: {supported_hiding_bound:?}"); - println!("num_points_in_query_set: {num_points_in_query_set:?}"); - println!("{degree_bounds:?}"); - println!("{num_polynomials}"); - println!("{enforce_degree_bounds}"); - - let (ck, vk_orig) = - SonicKZG10::::trim(&pp, supported_degree, None, supported_hiding_bound, degree_bounds.as_deref()) - .unwrap(); - println!("Trimmed"); - - let ck = CommitterUnionKey::union(std::iter::once(&ck)); - let vk = pp.to_universal_verifier().unwrap(); + assert_eq!(supported_hiding_bound, 1); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, + degree_bounds: Some(degree_bounds), + hiding_bound, + lagrange_sizes: None, + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, polynomials.iter().map(Into::into), Some(rng))?; + SonicKZG10::::commit(universal_prover, polynomials.iter().map(Into::into), Some(rng))?; // Let's construct our equations let mut linear_combinations = Vec::new(); @@ -497,13 +470,10 @@ fn equation_test_template>( if linear_combinations.is_empty() { continue; } - println!("Generated query set"); - println!("Linear combinations: {linear_combinations:?}"); let mut sponge_for_open = S::new(); let proof = SonicKZG10::open_combinations( universal_prover, - &ck, &linear_combinations, polynomials, &rands, @@ -513,7 +483,7 @@ fn equation_test_template>( println!("Generated proof"); let mut sponge_for_check = S::new(); let result = SonicKZG10::check_combinations( - &vk, + &universal_verifier, &linear_combinations, &comms, &query_set, @@ -527,7 +497,7 @@ fn equation_test_template>( assert!(result, "proof was incorrect, equations: {linear_combinations:#?}"); test_components.push(TestComponents { - verification_key: vk_orig, + universal_verifier, commitments: comms, query_set, evaluations: values, diff --git a/algorithms/src/snark/varuna/ahp/ahp.rs b/algorithms/src/snark/varuna/ahp/ahp.rs index ca368316be..d86dd613e1 100644 --- a/algorithms/src/snark/varuna/ahp/ahp.rs +++ b/algorithms/src/snark/varuna/ahp/ahp.rs @@ -14,10 +14,7 @@ // limitations under the License. use crate::{ - fft::{ - EvaluationDomain, - domain::{FFTPrecomputation, IFFTPrecomputation}, - }, + fft::EvaluationDomain, polycommit::sonic_pc::{LCTerm, LabeledPolynomial, LinearCombination}, r1cs::SynthesisError, snark::varuna::{ @@ -110,20 +107,6 @@ impl AHPForR1CS { .ok_or(anyhow!("Could not find max_degree")) } - /// Get all the strict degree bounds enforced in the AHP. - pub fn get_degree_bounds(info: &CircuitInfo) -> Result<[usize; 4]> { - let num_variables = info.num_public_and_private_variables; - let num_non_zero_a = info.num_non_zero_a; - let num_non_zero_b = info.num_non_zero_b; - let num_non_zero_c = info.num_non_zero_c; - Ok([ - EvaluationDomain::::compute_size_of_domain(num_variables).ok_or(SynthesisError::PolyTooLarge)? - 2, - EvaluationDomain::::compute_size_of_domain(num_non_zero_a).ok_or(SynthesisError::PolyTooLarge)? - 2, - EvaluationDomain::::compute_size_of_domain(num_non_zero_b).ok_or(SynthesisError::PolyTooLarge)? - 2, - EvaluationDomain::::compute_size_of_domain(num_non_zero_c).ok_or(SynthesisError::PolyTooLarge)? - 2, - ]) - } - pub(crate) fn cmp_non_zero_domains( info: &CircuitInfo, max_candidate: Option>, @@ -144,29 +127,6 @@ impl AHPForR1CS { Ok(NonZeroDomains { max_non_zero_domain, domain_a, domain_b, domain_c }) } - pub fn fft_precomputation( - constraint_domain_size: usize, - variable_domain_size: usize, - non_zero_a_domain_size: usize, - non_zero_b_domain_size: usize, - non_zero_c_domain_size: usize, - ) -> Option<(FFTPrecomputation, IFFTPrecomputation)> { - let largest_domain_size = [ - 2 * constraint_domain_size, - 2 * variable_domain_size, - 2 * non_zero_a_domain_size, - 2 * non_zero_b_domain_size, - 2 * non_zero_c_domain_size, - ] - .into_iter() - .max()?; - let largest_mul_domain = EvaluationDomain::new(largest_domain_size)?; - - let fft_precomputation = largest_mul_domain.precompute_fft(); - let ifft_precomputation = fft_precomputation.to_ifft_precomputation(); - Some((fft_precomputation, ifft_precomputation)) - } - /// Construct the linear combinations that are checked by the AHP. /// Public input should be unformatted. /// We construct the linear combinations as per section 5 of our protocol diff --git a/algorithms/src/snark/varuna/ahp/indexer/circuit.rs b/algorithms/src/snark/varuna/ahp/indexer/circuit.rs index bd7cded702..97a08c1c56 100644 --- a/algorithms/src/snark/varuna/ahp/indexer/circuit.rs +++ b/algorithms/src/snark/varuna/ahp/indexer/circuit.rs @@ -16,21 +16,10 @@ use core::marker::PhantomData; use crate::{ - fft::{ - EvaluationDomain, - domain::{FFTPrecomputation, IFFTPrecomputation}, - }, polycommit::sonic_pc::LabeledPolynomial, - snark::varuna::{ - AHPForR1CS, - CircuitInfo, - Matrix, - SNARKMode, - ahp::matrices::MatrixEvals, - matrices::MatrixArithmetization, - }, + snark::varuna::{CircuitInfo, Matrix, SNARKMode, ahp::matrices::MatrixEvals, matrices::MatrixArithmetization}, }; -use anyhow::{Result, anyhow}; +use anyhow::Result; use blake2::Digest; use hex::FromHex; use snarkvm_fields::PrimeField; @@ -81,8 +70,6 @@ pub struct Circuit { pub b_arith: MatrixEvals, pub c_arith: MatrixEvals, - pub fft_precomputation: FFTPrecomputation, - pub ifft_precomputation: IFFTPrecomputation, pub(crate) _mode: PhantomData, pub(crate) id: CircuitId, } @@ -121,25 +108,6 @@ impl Circuit { Ok(CircuitId(blake2.finalize().into())) } - /// The maximum degree required to represent polynomials of this index. - pub fn max_degree(&self) -> Result { - self.index_info.max_degree::() - } - - /// The size of the constraint (i. e. row) domain in this R1CS instance. - pub fn constraint_domain_size(&self) -> Result { - Ok(crate::fft::EvaluationDomain::::new(self.index_info.num_constraints) - .ok_or(anyhow!("Cannot create EvaluationDomain"))? - .size()) - } - - /// The size of the variable (i. e. column) domain in this R1CS instance. - pub fn variable_domain_size(&self) -> Result { - Ok(crate::fft::EvaluationDomain::::new(self.index_info.num_public_and_private_variables) - .ok_or(anyhow!("Cannot create EvaluationDomain"))? - .size()) - } - /// Compute the row, col, rowcol and rowcolval polynomials of the three /// matrices in this R1CS instance. pub fn interpolate_matrix_evals(&self) -> Result>> { @@ -200,26 +168,6 @@ impl CanonicalDeserialize for Circuit { validate: Validate, ) -> Result { let index_info: CircuitInfo = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; - let constraint_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_constraints) - .ok_or(SerializationError::InvalidData)?; - let variable_domain_size = - EvaluationDomain::::compute_size_of_domain(index_info.num_public_and_private_variables) - .ok_or(SerializationError::InvalidData)?; - let non_zero_a_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_non_zero_a) - .ok_or(SerializationError::InvalidData)?; - let non_zero_b_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_non_zero_b) - .ok_or(SerializationError::InvalidData)?; - let non_zero_c_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_non_zero_c) - .ok_or(SerializationError::InvalidData)?; - - let (fft_precomputation, ifft_precomputation) = AHPForR1CS::::fft_precomputation( - variable_domain_size, - constraint_domain_size, - non_zero_a_domain_size, - non_zero_b_domain_size, - non_zero_c_domain_size, - ) - .ok_or(SerializationError::InvalidData)?; let a = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; let b = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; let c = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; @@ -232,8 +180,6 @@ impl CanonicalDeserialize for Circuit { a_arith: CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, b_arith: CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, c_arith: CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, - fft_precomputation, - ifft_precomputation, _mode: PhantomData, id, }) diff --git a/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs b/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs index 7f40a0783b..51018aec16 100644 --- a/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs +++ b/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs @@ -13,7 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::snark::varuna::{SNARKMode, ahp::AHPForR1CS}; +use crate::{ + fft::EvaluationDomain, + polycommit::kzg10::DegreeInfo, + snark::varuna::{SNARKMode, ahp::AHPForR1CS}, +}; use anyhow::Result; use snarkvm_fields::PrimeField; use snarkvm_utilities::{ToBytes, serialize::*}; @@ -44,12 +48,51 @@ pub struct CircuitInfo { } impl CircuitInfo { + /// The max degrees and bounds required to represent polynomials of this + /// circuit. + pub fn degree_info(&self) -> Result { + let max_degree = self.max_degree::()?; + let max_fft_size = self.max_fft_size::(); + let degree_bounds = Some(self.degree_bounds::().into_iter().collect()); + let hiding_bound = AHPForR1CS::::zk_bound().unwrap_or(0); + Ok(DegreeInfo::new(max_degree, max_fft_size, degree_bounds, hiding_bound, None)) + } + /// The maximum degree of polynomial required to represent this index in the /// AHP. pub fn max_degree(&self) -> Result { let max_non_zero = self.num_non_zero_a.max(self.num_non_zero_b).max(self.num_non_zero_c); AHPForR1CS::::max_degree(self.num_constraints, self.num_public_and_private_variables, max_non_zero) } + + /// The maximum size of polynomials we need to do (i)fft for + pub fn max_fft_size(&self) -> usize { + let size = [ + 2 * self.num_constraints, // zerocheck poly degree + 2 * self.num_public_and_private_variables, // lineval sumcheck poly degree + 2 * self.num_non_zero_a, // matrix sumcheck poly degree + 2 * self.num_non_zero_b, // matrix sumcheck poly degree + 2 * self.num_non_zero_c, // matrix sumcheck poly degree + ] + .into_iter() + .max() + .unwrap(); + crate::fft::EvaluationDomain::::compute_size_of_domain(size).unwrap() + } + + /// Get all the strict degree bounds enforced in the AHP. + pub fn degree_bounds(&self) -> [usize; 4] { + [ + // The degree bound for g_1 + EvaluationDomain::::compute_size_of_domain(self.num_public_and_private_variables).unwrap() - 2, + // The degree bound for g_A + EvaluationDomain::::compute_size_of_domain(self.num_non_zero_a).unwrap() - 2, + // The degree bound for g_B + EvaluationDomain::::compute_size_of_domain(self.num_non_zero_b).unwrap() - 2, + // The degree bound for g_C + EvaluationDomain::::compute_size_of_domain(self.num_non_zero_c).unwrap() - 2, + ] + } } impl ToBytes for CircuitInfo { diff --git a/algorithms/src/snark/varuna/ahp/indexer/indexer.rs b/algorithms/src/snark/varuna/ahp/indexer/indexer.rs index ae7d4b9869..de9c44dc4a 100644 --- a/algorithms/src/snark/varuna/ahp/indexer/indexer.rs +++ b/algorithms/src/snark/varuna/ahp/indexer/indexer.rs @@ -43,51 +43,10 @@ use super::Matrix; impl AHPForR1CS { /// Generate the index polynomials for this constraint system. pub fn index>(c: &C) -> Result> { - let IndexerState { - constraint_domain, - variable_domain, + let IndexerState { a, a_arith, b, b_arith, c, c_arith, index_info, id, .. } = + Self::index_helper(c).map_err(|e| anyhow!("{e:?}"))?; - a, - non_zero_a_domain, - a_arith, - - b, - non_zero_b_domain, - b_arith, - - c, - non_zero_c_domain, - c_arith, - - index_info, - id, - } = Self::index_helper(c).map_err(|e| anyhow!("{e:?}"))?; - - let fft_precomp_time = start_timer!(|| format!("Precomputing roots of unity {id}")); - - let (fft_precomputation, ifft_precomputation) = Self::fft_precomputation( - constraint_domain.size(), - variable_domain.size(), - non_zero_a_domain.size(), - non_zero_b_domain.size(), - non_zero_c_domain.size(), - ) - .ok_or(anyhow!("The polynomial degree is too large"))?; - end_timer!(fft_precomp_time); - - Ok(Circuit { - index_info, - a, - b, - c, - a_arith, - b_arith, - c_arith, - fft_precomputation, - ifft_precomputation, - id, - _mode: PhantomData, - }) + Ok(Circuit { index_info, a, b, c, a_arith, b_arith, c_arith, id, _mode: PhantomData }) } pub fn index_polynomial_info<'a>( @@ -207,9 +166,6 @@ impl AHPForR1CS { let id = Circuit::::hash(&index_info, &a, &b, &c)?; let result = Ok(IndexerState { - constraint_domain, - variable_domain, - a, non_zero_a_domain, a_arith, @@ -264,11 +220,6 @@ impl AHPForR1CS { } pub(crate) struct IndexerState { - // R_i in the Varuna spec - constraint_domain: EvaluationDomain, - // C_i in the Varuna spec - variable_domain: EvaluationDomain, - a: Matrix, non_zero_a_domain: EvaluationDomain, a_arith: MatrixEvals, diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs index f26d3df5c9..e1ae7c8619 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs @@ -14,10 +14,15 @@ // limitations under the License. use crate::{ - fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain, SparsePolynomial}, + fft::{ + DensePolynomial, + EvaluationDomain, + Evaluations as EvaluationsOnDomain, + SparsePolynomial, + domain::{FFTPrecomputation, IFFTPrecomputation}, + }, polycommit::sonic_pc::{LabeledPolynomial, PolynomialInfo, PolynomialLabel}, snark::varuna::{ - Circuit, CircuitId, SNARKMode, ahp::{AHPError, AHPForR1CS}, @@ -63,6 +68,8 @@ impl AHPForR1CS { rng: &mut R, ) -> Result, AHPError> { let round_time = start_timer!(|| "AHP::Prover::FirstRound"); + let fft_precomp = state.fft_precomputation; + let ifft_precomp = state.ifft_precomputation; let mut job_pool = snarkvm_utilities::ExecutionPool::with_capacity(state.total_instances); for (circuit, circuit_state) in state.circuit_specific_states.iter_mut() { let batch_size = circuit_state.batch_size; @@ -77,7 +84,9 @@ impl AHPForR1CS { for (j, (private_vars, x_poly)) in itertools::izip!(private_variables, x_polys).enumerate() { let w_label = witness_label(circuit.id, "w", j); - job_pool.add_job(move || Self::calculate_w(w_label, private_vars, x_poly, v_domain, i_domain, circuit)); + job_pool.add_job(move || { + Self::calculate_w(w_label, private_vars, x_poly, v_domain, i_domain, fft_precomp, ifft_precomp) + }); } } let mut batches = @@ -125,13 +134,14 @@ impl AHPForR1CS { } // Compute the shifted witness \overline{w}(X) - fn calculate_w( + fn calculate_w<'a>( label: String, private_variables: Vec, x_poly: DensePolynomial, variable_domain: EvaluationDomain, input_domain: EvaluationDomain, - circuit: &Circuit, + fft_precomp: &'a FFTPrecomputation, + ifft_precomp: &'a IFFTPrecomputation, ) -> Witness { let mut w_extended = private_variables; let ratio = variable_domain.size() / input_domain.size(); @@ -141,7 +151,7 @@ impl AHPForR1CS { let x_evals = { let mut coeffs = x_poly.coeffs; coeffs.resize(variable_domain.size(), F::zero()); - variable_domain.in_order_fft_in_place_with_pc(&mut coeffs, &circuit.fft_precomputation); + variable_domain.in_order_fft_in_place_with_pc(&mut coeffs, fft_precomp); coeffs }; @@ -158,8 +168,8 @@ impl AHPForR1CS { // Interpolating \widetilde{z} - \widetilde{x} and dividing by the // vanishing polynomial over variable_domain. - let w_poly = EvaluationsOnDomain::from_vec_and_domain(w_poly_evals, variable_domain) - .interpolate_with_pc(&circuit.ifft_precomputation); + let w_poly = + EvaluationsOnDomain::from_vec_and_domain(w_poly_evals, variable_domain).interpolate_with_pc(ifft_precomp); let (w_poly, remainder) = w_poly.divide_by_vanishing_poly(input_domain).unwrap(); assert!(remainder.is_zero()); diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs index 673d8366fc..a337e47948 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs @@ -83,6 +83,8 @@ impl AHPForR1CS { let verifier::SecondMessage { alpha, .. } = second_message; let verifier::ThirdMessage { beta } = third_message; + let fft_precomp = state.fft_precomputation; + let ifft_precomp = state.ifft_precomputation; let mut pool = ExecutionPool::with_capacity(3 * state.circuit_specific_states.len()); @@ -109,8 +111,8 @@ impl AHPForR1CS { *beta, v_R_i_alpha_v_C_i_beta, max_non_zero_domain_size, - &circuit.fft_precomputation, - &circuit.ifft_precomputation, + fft_precomp, + ifft_precomp, ); (circuit, result) }); diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs index 57c6d8a2df..275ed94df3 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs @@ -14,6 +14,7 @@ // limitations under the License. use crate::{ + fft::domain::{FFTPrecomputation, IFFTPrecomputation}, r1cs::ConstraintSynthesizer, snark::varuna::{ SNARKMode, @@ -41,6 +42,8 @@ impl AHPForR1CS { /// Initialize the AHP prover. pub fn init_prover<'a, C: ConstraintSynthesizer, R: Rng + CryptoRng>( circuits_to_constraints: &BTreeMap<&'a Circuit, &[C]>, + fft_precomp: &'a FFTPrecomputation, + ifft_precomp: &'a IFFTPrecomputation, rng: &mut R, ) -> Result, AHPError> { let init_time = start_timer!(|| "AHP::Prover::Init"); @@ -162,7 +165,7 @@ impl AHPForR1CS { }) .collect::, Vec>>, AHPError>>()?; - let state = prover::State::initialize(indices_and_assignments)?; + let state = prover::State::initialize(indices_and_assignments, fft_precomp, ifft_precomp)?; end_timer!(init_time); Ok(state) diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs index a90fdb735b..f824c9bba7 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs @@ -82,11 +82,11 @@ impl AHPForR1CS { let total_instances = num_instances.iter().sum::(); let matrix_labels = ["a", "b", "c"]; - let fft_precomputations = state - .circuit_specific_states - .keys() - .map(|circuit| (circuit.fft_precomputation.clone(), circuit.ifft_precomputation.clone())) - .collect_vec(); + let fft_precomputation = &state.fft_precomputation; + let ifft_precomputation = &state.ifft_precomputation; + + let fft_precomputations = + state.circuit_specific_states.keys().map(|_| (fft_precomputation, ifft_precomputation)).collect_vec(); // Compute lineval sumcheck witnesses let mut job_pool = ExecutionPool::with_capacity(total_instances * 3); @@ -108,8 +108,8 @@ impl AHPForR1CS { label, &circuit_specific_state.constraint_domain, &circuit_specific_state.variable_domain, - &precomp.0, - &precomp.1, + precomp.0, + precomp.1, assignment, matrix_transpose, *alpha, diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs index 7d93446d94..b2d0284b29 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs @@ -16,10 +16,15 @@ use std::collections::BTreeMap; use crate::{ - fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain, polynomial::PolyMultiplier}, + fft::{ + DensePolynomial, + EvaluationDomain, + Evaluations as EvaluationsOnDomain, + domain::IFFTPrecomputation, + polynomial::PolyMultiplier, + }, polycommit::sonic_pc::{LabeledPolynomial, PolynomialInfo, PolynomialLabel}, snark::varuna::{ - Circuit, CircuitId, SNARKMode, ahp::{AHPForR1CS, verifier}, @@ -80,6 +85,9 @@ impl AHPForR1CS { let mut job_pool = ExecutionPool::with_capacity(state.circuit_specific_states.len()); let max_constraint_domain = state.max_constraint_domain; + let fft_precomp = state.fft_precomputation; + let ifft_precomp = state.ifft_precomputation; + for (circuit, circuit_specific_state) in state.circuit_specific_states.iter_mut() { let z_a = circuit_specific_state.z_a.take().unwrap(); let z_b = circuit_specific_state.z_b.take().unwrap(); @@ -88,8 +96,6 @@ impl AHPForR1CS { let circuit_combiner = batch_combiners[&circuit.id].circuit_combiner; let instance_combiners = batch_combiners[&circuit.id].instance_combiners.clone(); let constraint_domain = circuit_specific_state.constraint_domain; - let fft_precomputation = &circuit.fft_precomputation; - let ifft_precomputation = &circuit.ifft_precomputation; let _circuit_id = &circuit.id; // seems like a compiler bug marks this as unused @@ -101,11 +107,11 @@ impl AHPForR1CS { let za_label = witness_label(circuit.id, "z_a", j); let zb_label = witness_label(circuit.id, "z_b", j); let zc_label = witness_label(circuit.id, "z_c", j); - let z_a = Self::calculate_z_m(za_label, z_a, constraint_domain, circuit); - let z_b = Self::calculate_z_m(zb_label, z_b, constraint_domain, circuit); - let z_c = Self::calculate_z_m(zc_label, z_c, constraint_domain, circuit); + let z_a = Self::calculate_z_m(za_label, z_a, constraint_domain, ifft_precomp); + let z_b = Self::calculate_z_m(zb_label, z_b, constraint_domain, ifft_precomp); + let z_c = Self::calculate_z_m(zc_label, z_c, constraint_domain, ifft_precomp); let mut multiplier_2 = PolyMultiplier::new(); - multiplier_2.add_precomputation(fft_precomputation, ifft_precomputation); + multiplier_2.add_precomputation(fft_precomp, ifft_precomp); multiplier_2.add_polynomial(z_a, "z_a"); multiplier_2.add_polynomial(z_b, "z_b"); let mut rowcheck = multiplier_2.multiply().unwrap(); @@ -145,13 +151,13 @@ impl AHPForR1CS { label: impl ToString, evaluations: Vec, constraint_domain: EvaluationDomain, - circuit: &Circuit, + ifft_precomp: &IFFTPrecomputation, ) -> DensePolynomial { let label = label.to_string(); let poly_time = start_timer!(|| format!("Computing {label}")); let evals = EvaluationsOnDomain::from_vec_and_domain(evaluations, constraint_domain); - let poly = evals.interpolate_with_pc_by_ref(&circuit.ifft_precomputation); + let poly = evals.interpolate_with_pc_by_ref(ifft_precomp); debug_assert!( poly.evaluate_over_domain_by_ref(constraint_domain) diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs index cd4bdb7b14..9dc675a658 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs @@ -154,12 +154,15 @@ impl AHPForR1CS { let matrix_labels = ["a", "b", "c"]; let matrix_combiners = [F::one(), *eta_b, *eta_c]; + let fft_precomputation = &state.fft_precomputation; + let ifft_precomputation = &state.ifft_precomputation; + // Compute lineval sumcheck witnesses let mut job_pool = ExecutionPool::with_capacity(total_instances * 3); // Iterate for each circuit in the batch. - for ((((circuit, circuit_specific_state), batch_combiner), assignments_i), matrix_transposes_i) in state + for (((circuit_specific_state, batch_combiner), assignments_i), matrix_transposes_i) in state .circuit_specific_states - .iter_mut() + .values_mut() .zip_eq(third_round_batch_combiners.values()) .zip_eq(assignments.values()) .zip_eq(matrix_transposes.values()) @@ -168,8 +171,6 @@ impl AHPForR1CS { let instance_combiners = &batch_combiner.instance_combiners; let constraint_domain = &circuit_specific_state.constraint_domain; let variable_domain = &circuit_specific_state.variable_domain; - let fft_precomputation = &circuit.fft_precomputation; - let ifft_precomputation = &circuit.ifft_precomputation; // Iterate for each instance in the batch. for (instance_combiner, assignment) in itertools::izip!(instance_combiners, assignments_i) { diff --git a/algorithms/src/snark/varuna/ahp/prover/state.rs b/algorithms/src/snark/varuna/ahp/prover/state.rs index 4f3684854f..4428b15393 100644 --- a/algorithms/src/snark/varuna/ahp/prover/state.rs +++ b/algorithms/src/snark/varuna/ahp/prover/state.rs @@ -16,7 +16,12 @@ use std::collections::{BTreeMap, VecDeque}; use crate::{ - fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain}, + fft::{ + DensePolynomial, + EvaluationDomain, + Evaluations as EvaluationsOnDomain, + domain::{FFTPrecomputation, IFFTPrecomputation}, + }, polycommit::sonic_pc::LabeledPolynomial, r1cs::{SynthesisError, SynthesisResult}, snark::varuna::{AHPError, AHPForR1CS, Circuit, SNARKMode}, @@ -91,6 +96,10 @@ pub struct State<'a, F: PrimeField, SM: SNARKMode> { pub(in crate::snark) max_variable_domain: EvaluationDomain, /// The total number of instances we're proving in the batch. pub(in crate::snark) total_instances: usize, + /// The precomputed roots for calculating FFTs. + pub fft_precomputation: &'a FFTPrecomputation, + /// The precomputed inverse roots for calculating inverse FFTs. + pub ifft_precomputation: &'a IFFTPrecomputation, } /// The public inputs for a single instance. @@ -116,6 +125,8 @@ pub(super) struct Assignments( impl<'a, F: PrimeField, SM: SNARKMode> State<'a, F, SM> { pub(super) fn initialize( indices_and_assignments: BTreeMap<&'a Circuit, Vec>>, + fft_precomputation: &'a FFTPrecomputation, + ifft_precomputation: &'a IFFTPrecomputation, ) -> Result { let mut max_non_zero_domain: Option> = None; let mut max_num_constraints = 0; @@ -195,6 +206,8 @@ impl<'a, F: PrimeField, SM: SNARKMode> State<'a, F, SM> { circuit_specific_states, total_instances, first_round_oracles: None, + fft_precomputation, + ifft_precomputation, }) } diff --git a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs index eef12185ef..ffc8cbe51c 100644 --- a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs +++ b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs @@ -13,10 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::{ - polycommit::sonic_pc, - snark::varuna::{CircuitVerifyingKey, SNARKMode, ahp::indexer::*}, -}; +use crate::snark::varuna::{CircuitVerifyingKey, SNARKMode, ahp::indexer::*}; use snarkvm_curves::PairingEngine; use snarkvm_utilities::{FromBytes, ToBytes, serialize::*}; use std::io::{self, Read, Write}; @@ -31,16 +28,12 @@ pub struct CircuitProvingKey { // NOTE: The circuit verifying key's circuit_info and circuit id are also stored in Circuit for convenience. /// The circuit itself. pub circuit: Arc>, - /// The committer key for this index, trimmed from the universal SRS. - pub committer_key: Arc>, } impl ToBytes for CircuitProvingKey { fn write_le(&self, mut writer: W) -> io::Result<()> { CanonicalSerialize::serialize_compressed(&self.circuit_verifying_key, &mut writer)?; - CanonicalSerialize::serialize_compressed(&self.circuit, &mut writer)?; - - self.committer_key.write_le(&mut writer) + Ok(CanonicalSerialize::serialize_compressed(&self.circuit, &mut writer)?) } } @@ -49,9 +42,8 @@ impl FromBytes for CircuitProvingKey { fn read_le(mut reader: R) -> io::Result { let circuit_verifying_key = CanonicalDeserialize::deserialize_compressed(&mut reader)?; let circuit = CanonicalDeserialize::deserialize_compressed(&mut reader)?; - let committer_key = Arc::new(FromBytes::read_le(&mut reader)?); - Ok(Self { circuit_verifying_key, circuit, committer_key }) + Ok(Self { circuit_verifying_key, circuit }) } } diff --git a/algorithms/src/snark/varuna/tests.rs b/algorithms/src/snark/varuna/tests.rs index f1d564e284..20884b9171 100644 --- a/algorithms/src/snark/varuna/tests.rs +++ b/algorithms/src/snark/varuna/tests.rs @@ -16,6 +16,7 @@ #[cfg(any(test, feature = "test"))] mod varuna { use crate::{ + polycommit::kzg10::DegreeInfo, snark::varuna::{ AHPForR1CS, CircuitVerifyingKey, @@ -53,7 +54,6 @@ mod varuna { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = $snark_inst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); @@ -72,6 +72,9 @@ mod varuna { let (index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap(); println!("Called circuit setup"); + let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circ, &index_vk, &certificate).unwrap()); println!("verified vk"); @@ -97,6 +100,7 @@ mod varuna { println!("running test with circuit_batch_size: {circuit_batch_size} and instance_batch_size: {instance_batch_size}"); let mut constraints = BTreeMap::new(); let mut inputs = BTreeMap::new(); + let mut degree_info = None; for i in 0..circuit_batch_size { let (circuit_batch, input_batch): (Vec<_>, Vec<_>) = (0..instance_batch_size) @@ -120,6 +124,12 @@ mod varuna { let mut vks_to_inputs = BTreeMap::new(); for (index_pk, index_vk) in index_keys.iter() { + let degree_info_i = index_pk.circuit.index_info.degree_info::().unwrap(); + degree_info = degree_info.map(|i: DegreeInfo|i.union(°ree_info_i)).or(Some(degree_info_i.clone())); + let universal_prover = &universal_srs.to_universal_prover( + degree_info_i + ).unwrap(); + let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); let circuits = constraints[&index_pk.circuit.id].as_slice(); assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circuits[0], &index_vk, &certificate).unwrap()); @@ -128,6 +138,10 @@ mod varuna { } println!("verified vks"); + let universal_prover = &universal_srs.to_universal_prover( + degree_info.unwrap() + ).unwrap(); + let proof = $snark_inst::prove_batch(universal_prover, &fs_parameters, varuna_version, &pks_to_constraints, rng).unwrap(); println!("Called prover"); @@ -244,8 +258,8 @@ mod varuna { fn prove_and_verify_with_tall_matrix_big() { let num_constraints = 100; let num_variables = 25; - let pk_size_zk = 91971; - let pk_size_posw = 91633; + let pk_size_zk = 53767; + let pk_size_posw = 53623; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -265,8 +279,8 @@ mod varuna { fn prove_and_verify_with_tall_matrix_small() { let num_constraints = 26; let num_variables = 25; - let pk_size_zk = 25428; - let pk_size_posw = 25090; + let pk_size_zk = 15463; + let pk_size_posw = 15319; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -286,8 +300,8 @@ mod varuna { fn prove_and_verify_with_squat_matrix_big() { let num_constraints = 25; let num_variables = 100; - let pk_size_zk = 53523; - let pk_size_posw = 53185; + let pk_size_zk = 15319; + let pk_size_posw = 15175; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -307,8 +321,8 @@ mod varuna { fn prove_and_verify_with_squat_matrix_small() { let num_constraints = 25; let num_variables = 26; - let pk_size_zk = 25284; - let pk_size_posw = 24946; + let pk_size_zk = 15319; + let pk_size_posw = 15175; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -328,8 +342,8 @@ mod varuna { fn prove_and_verify_with_square_matrix() { let num_constraints = 25; let num_variables = 25; - let pk_size_zk = 25284; - let pk_size_posw = 24946; + let pk_size_zk = 15319; + let pk_size_posw = 15175; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -381,7 +395,6 @@ mod varuna_hiding { ) { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); @@ -399,6 +412,9 @@ mod varuna_hiding { let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); println!("Called circuit setup"); + let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let proof = VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); println!("Called prover"); @@ -568,7 +584,6 @@ mod varuna_hiding { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); for varuna_version in [VarunaVersion::V1, VarunaVersion::V2] { @@ -579,6 +594,9 @@ mod varuna_hiding { let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); println!("Called circuit setup"); + let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let proof = VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); println!("Called prover"); @@ -629,7 +647,6 @@ mod varuna_hiding { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); let varuna_version = VarunaVersion::V2; @@ -643,6 +660,9 @@ mod varuna_hiding { let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &circuit1).unwrap(); println!("Called circuit setup"); + let degree_info = pk1.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let proof1 = VarunaInst::prove(universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap(); println!("Called prover"); assert!( @@ -669,6 +689,9 @@ mod varuna_hiding { let (pk2, vk2) = VarunaInst::circuit_setup(&universal_srs, &circuit2).unwrap(); println!("Called circuit setup"); + let degree_info = pk2.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let proof2 = VarunaInst::prove(universal_prover, &fs_parameters, &pk2, varuna_version, &circuit2, rng).unwrap(); println!("Called prover"); assert!( @@ -824,7 +847,12 @@ mod varuna_test_vectors { keys_to_constraints.insert(index_pk.circuit.deref(), std::slice::from_ref(&circ)); // Begin the Varuna protocol execution. - let prover_state = AHPForR1CS::<_, MM>::init_prover(&keys_to_constraints, rng).unwrap(); + let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = universal_srs.to_universal_prover(degree_info).unwrap(); + let fft_precomp = &universal_prover.fft_precomputation; + let ifft_precomp = &universal_prover.ifft_precomputation; + let prover_state = + AHPForR1CS::<_, MM>::init_prover(&keys_to_constraints, fft_precomp, ifft_precomp, rng).unwrap(); let mut prover_state = AHPForR1CS::<_, MM>::prover_first_round(prover_state, rng).unwrap(); let first_round_oracles = Arc::new(prover_state.first_round_oracles.as_ref().unwrap()); diff --git a/algorithms/src/snark/varuna/varuna.rs b/algorithms/src/snark/varuna/varuna.rs index 24d554041a..625dabf805 100644 --- a/algorithms/src/snark/varuna/varuna.rs +++ b/algorithms/src/snark/varuna/varuna.rs @@ -19,14 +19,9 @@ use crate::{ SNARK, SNARKError, fft::EvaluationDomain, - polycommit::sonic_pc::{ - Commitment, - CommitterUnionKey, - Evaluations, - LabeledCommitment, - QuerySet, - Randomness, - SonicKZG10, + polycommit::{ + kzg10::DegreeInfo, + sonic_pc::{Commitment, Evaluations, LabeledCommitment, QuerySet, Randomness, SonicKZG10}, }, r1cs::{ConstraintSynthesizer, SynthesisError}, snark::varuna::{ @@ -67,56 +62,53 @@ impl, SM: SNARKMode> VarunaSNARK /// Used to personalize the Fiat-Shamir RNG. pub const PROTOCOL_NAME: &'static [u8] = b"VARUNA-2023"; - // TODO: implement optimizations resulting from batching - // (e.g. computing a common set of Lagrange powers, FFT precomputations, - // etc) + /// Creates a set of proving and verifying keys for a set of circuits + /// We commit to each circuit's index polys separately + /// Performing a single batch commitment can be a future optimization pub fn batch_circuit_setup>( universal_srs: &UniversalSRS, circuits: &[&C], ) -> Result, CircuitVerifyingKey)>> { - let index_time = start_timer!(|| "Varuna::CircuitSetup"); + ensure!(!circuits.is_empty()); - let universal_prover = &universal_srs.to_universal_prover()?; + let index_time = start_timer!(|| "Varuna::CircuitSetup"); - let mut circuit_keys = Vec::with_capacity(circuits.len()); + let mut indexed_circuits = Vec::with_capacity(circuits.len()); + let mut degree_info = None::; for circuit in circuits { - let mut indexed_circuit = AHPForR1CS::<_, SM>::index(*circuit)?; - // TODO: Add check that c is in the correct mode. - // Ensure the universal SRS supports the circuit size. - universal_srs.download_powers_for(0..indexed_circuit.max_degree()?).map_err(|e| { - anyhow!("Failed to download powers for degree {}: {e}", indexed_circuit.max_degree().unwrap()) - })?; - let coefficient_support = AHPForR1CS::::get_degree_bounds(&indexed_circuit.index_info)?; - - // Varuna only needs degree 2 random polynomials. - let supported_hiding_bound = 1; - let supported_lagrange_sizes = [].into_iter(); // TODO: consider removing lagrange_bases_at_beta_g from CommitterKey - let (committer_key, _) = SonicKZG10::::trim( - universal_srs, - indexed_circuit.max_degree()?, - supported_lagrange_sizes, - supported_hiding_bound, - Some(coefficient_support.as_slice()), - )?; + let indexed_circuit = AHPForR1CS::<_, SM>::index(*circuit)?; + let degree_info_i = indexed_circuit.index_info.degree_info::()?; + degree_info = match degree_info { + Some(degree_info) => Some(degree_info.union(°ree_info_i)), + None => Some(degree_info_i), + }; + indexed_circuits.push(indexed_circuit); + } + let degree_info = degree_info.unwrap(); + + universal_srs + .download_powers_for(0..degree_info.max_degree) + .map_err(|e| anyhow!("Failed to download powers for degree {}: {e}", degree_info.max_degree))?; - let ck = CommitterUnionKey::union(std::iter::once(&committer_key)); + let universal_prover = universal_srs.to_universal_prover(degree_info)?; + let mut circuit_keys = Vec::with_capacity(circuits.len()); + for mut indexed_circuit in indexed_circuits { let commit_time = start_timer!(|| format!("Commit to index polynomials for {}", indexed_circuit.id)); let setup_rng = None::<&mut dyn RngCore>; // We do not randomize the commitments let (mut circuit_commitments, commitment_randomnesses): (_, _) = SonicKZG10::::commit( - universal_prover, - &ck, + &universal_prover, indexed_circuit.interpolate_matrix_evals()?.map(Into::into), setup_rng, )?; + indexed_circuit.prune_row_col_evals(); // row_col were only needed during indexing let empty_randomness = Randomness::::empty(); ensure!(commitment_randomnesses.iter().all(|r| r == &empty_randomness)); end_timer!(commit_time); circuit_commitments.sort_by(|c1, c2| c1.label().cmp(c2.label())); let circuit_commitments = circuit_commitments.into_iter().map(|c| *c.commitment()).collect(); - indexed_circuit.prune_row_col_evals(); let circuit_verifying_key = CircuitVerifyingKey { circuit_info: indexed_circuit.index_info, circuit_commitments, @@ -125,7 +117,6 @@ impl, SM: SNARKMode> VarunaSNARK let circuit_proving_key = CircuitProvingKey { circuit_verifying_key: circuit_verifying_key.clone(), circuit: Arc::new(indexed_circuit), - committer_key: Arc::new(committer_key), }; circuit_keys.push((circuit_proving_key, circuit_verifying_key)); } @@ -267,14 +258,13 @@ where } let query_set = QuerySet::from_iter([("circuit_check".into(), ("challenge".into(), point))]); - let committer_key = CommitterUnionKey::union(std::iter::once(proving_key.committer_key.as_ref())); let empty_randomness = vec![Randomness::::empty(); 12]; + let index_polynomials = proving_key.circuit.interpolate_matrix_evals()?; let certificate = SonicKZG10::::open_combinations( universal_prover, - &committer_key, &[lc], - proving_key.circuit.interpolate_matrix_evals()?, + index_polynomials, &empty_randomness, &query_set, &mut sponge, @@ -364,7 +354,10 @@ where for (pk, constraints) in keys_to_constraints { circuits_to_constraints.insert(pk.circuit.deref(), *constraints); } - let prover_state = AHPForR1CS::<_, SM>::init_prover(&circuits_to_constraints, zk_rng)?; + let fft_precomp = &universal_prover.fft_precomputation; + let ifft_precomp = &universal_prover.ifft_precomputation; + let prover_state = + AHPForR1CS::<_, SM>::init_prover(&circuits_to_constraints, fft_precomp, ifft_precomp, zk_rng)?; // extract information from the prover key and state to consume in further // calculations @@ -391,8 +384,6 @@ where } ensure!(prover_state.total_instances == total_instances); - let committer_key = CommitterUnionKey::union(keys_to_constraints.keys().map(|pk| pk.committer_key.deref())); - let circuit_commitments = keys_to_constraints.keys().map(|pk| pk.circuit_verifying_key.circuit_commitments.as_slice()); @@ -408,7 +399,6 @@ where let first_round_oracles = prover_state.first_round_oracles.as_ref().unwrap(); SonicKZG10::::commit( universal_prover, - &committer_key, first_round_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )? @@ -436,7 +426,6 @@ where let second_round_comm_time = start_timer!(|| "Committing to second round polys"); let (second_commitments, second_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, second_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -496,7 +485,6 @@ where let third_round_comm_time = start_timer!(|| "Committing to third round polys"); let (third_commitments, third_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, third_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -544,7 +532,6 @@ where let fourth_round_comm_time = start_timer!(|| "Committing to fourth round polys"); let (fourth_commitments, fourth_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, fourth_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -570,7 +557,6 @@ where let fifth_round_comm_time = start_timer!(|| "Committing to fifth round polys"); let (fifth_commitments, fifth_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, fifth_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -672,7 +658,6 @@ where let pc_proof = SonicKZG10::::open_combinations( universal_prover, - &committer_key, lc_s.values(), polynomials, &commitment_randomnesses, diff --git a/algorithms/src/srs/universal_prover.rs b/algorithms/src/srs/universal_prover.rs index 3f0e1702c5..9d14a22157 100644 --- a/algorithms/src/srs/universal_prover.rs +++ b/algorithms/src/srs/universal_prover.rs @@ -13,13 +13,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::{ + fft::domain::{FFTPrecomputation, IFFTPrecomputation}, + polycommit::sonic_pc::CommitterKey, +}; use snarkvm_curves::PairingEngine; /// `UniversalProver` is used to compute evaluation proofs for a given /// commitment. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Debug)] pub struct UniversalProver { + /// The committer key for the underlying KZG10 scheme. + pub committer_key: CommitterKey, + /// The precomputed roots for calculating FFTs. + pub fft_precomputation: FFTPrecomputation, + /// The precomputed inverse roots for calculating inverse FFTs. + pub ifft_precomputation: IFFTPrecomputation, /// The maximum degree supported by the universal SRS. pub max_degree: usize, - pub _unused: Option, } diff --git a/circuit/environment/src/helpers/assignment.rs b/circuit/environment/src/helpers/assignment.rs index 3c42a4209d..d3f362322a 100644 --- a/circuit/environment/src/helpers/assignment.rs +++ b/circuit/environment/src/helpers/assignment.rs @@ -300,13 +300,14 @@ mod tests { let max_degree = AHPForR1CS::::max_degree(200, 200, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_pp = FS::sample_parameters(); let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &assignment).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); + let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); let proof = VarunaInst::prove(universal_prover, &fs_pp, &index_pk, varuna_version, &assignment, rng).unwrap(); println!("Called prover"); diff --git a/circuit/environment/src/helpers/converter.rs b/circuit/environment/src/helpers/converter.rs index 5d3ad59a0f..660642f23a 100644 --- a/circuit/environment/src/helpers/converter.rs +++ b/circuit/environment/src/helpers/converter.rs @@ -265,13 +265,14 @@ mod tests { let max_degree = AHPForR1CS::::max_degree(200, 200, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_pp = FS::sample_parameters(); let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &Circuit).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); + let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); let proof = VarunaInst::prove(universal_prover, &fs_pp, &index_pk, varuna_version, &Circuit, rng).unwrap(); println!("Called prover"); diff --git a/console/network/src/canary_v0.rs b/console/network/src/canary_v0.rs index c39b3bb73a..4afb3e11a7 100644 --- a/console/network/src/canary_v0.rs +++ b/console/network/src/canary_v0.rs @@ -15,6 +15,7 @@ use super::*; use crate::TRANSACTION_PREFIX; +use snarkvm_algorithms::polycommit::kzg10::DegreeInfo; use snarkvm_console_algorithms::{ BHP256, BHP512, @@ -346,8 +347,8 @@ impl Network for CanaryV0 { } /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver { - MainnetV0::varuna_universal_prover() + fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver { + MainnetV0::varuna_universal_prover(degree_info) } /// Returns the Varuna universal verifier. diff --git a/console/network/src/lib.rs b/console/network/src/lib.rs index 2d3f17a784..948915c35e 100644 --- a/console/network/src/lib.rs +++ b/console/network/src/lib.rs @@ -63,6 +63,7 @@ pub use crate::environment::prelude::*; use snarkvm_algorithms::{ AlgebraicSponge, crypto_hash::PoseidonSponge, + polycommit::kzg10::DegreeInfo, snark::varuna::{CircuitProvingKey, CircuitVerifyingKey, VarunaHidingMode}, srs::{UniversalProver, UniversalVerifier}, }; @@ -358,7 +359,7 @@ pub trait Network: fn g_scalar_multiply(scalar: &Scalar) -> Group; /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver; + fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver; /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier; diff --git a/console/network/src/mainnet_v0.rs b/console/network/src/mainnet_v0.rs index 819deed383..f8e0497677 100644 --- a/console/network/src/mainnet_v0.rs +++ b/console/network/src/mainnet_v0.rs @@ -14,6 +14,7 @@ // limitations under the License. use super::*; +use snarkvm_algorithms::polycommit::kzg10::DegreeInfo; use snarkvm_console_algorithms::{ BHP256, BHP512, @@ -351,14 +352,11 @@ impl Network for MainnetV0 { } /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver { - static INSTANCE: OnceLock::PairingCurve>> = OnceLock::new(); - INSTANCE.get_or_init(|| { - snarkvm_algorithms::polycommit::kzg10::UniversalParams::load() - .expect("Failed to load universal SRS (KZG10).") - .to_universal_prover() - .expect("Failed to convert universal SRS (KZG10) to the prover.") - }) + fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver { + snarkvm_algorithms::polycommit::kzg10::UniversalParams::load() + .expect("Failed to load universal SRS (KZG10).") + .to_universal_prover(degree_info) + .expect("Failed to convert universal SRS (KZG10) to the prover.") } /// Returns the Varuna universal verifier. diff --git a/console/network/src/testnet_v0.rs b/console/network/src/testnet_v0.rs index 0043ce7030..02dc3029b9 100644 --- a/console/network/src/testnet_v0.rs +++ b/console/network/src/testnet_v0.rs @@ -15,6 +15,7 @@ use super::*; use crate::TRANSACTION_PREFIX; +use snarkvm_algorithms::polycommit::kzg10::DegreeInfo; use snarkvm_console_algorithms::{ BHP256, BHP512, @@ -346,8 +347,8 @@ impl Network for TestnetV0 { } /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver { - MainnetV0::varuna_universal_prover() + fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver { + MainnetV0::varuna_universal_prover(degree_info) } /// Returns the Varuna universal verifier. diff --git a/synthesizer/snark/src/certificate/mod.rs b/synthesizer/snark/src/certificate/mod.rs index f963cc3629..0c2675085f 100644 --- a/synthesizer/snark/src/certificate/mod.rs +++ b/synthesizer/snark/src/certificate/mod.rs @@ -41,11 +41,12 @@ impl Certificate { let timer = std::time::Instant::now(); // Retrieve the proving parameters. - let universal_prover = N::varuna_universal_prover(); + let degree_info = proving_key.circuit.index_info.degree_info::()?; + let universal_prover = N::varuna_universal_prover(degree_info); let fiat_shamir = N::varuna_fs_parameters(); // Compute the certificate. - let certificate = Varuna::::prove_vk(universal_prover, fiat_shamir, verifying_key, proving_key)?; + let certificate = Varuna::::prove_vk(&universal_prover, fiat_shamir, verifying_key, proving_key)?; #[cfg(feature = "dev-print")] { diff --git a/synthesizer/snark/src/lib.rs b/synthesizer/snark/src/lib.rs index c01de3c343..1953a26624 100644 --- a/synthesizer/snark/src/lib.rs +++ b/synthesizer/snark/src/lib.rs @@ -21,7 +21,7 @@ extern crate snarkvm_circuit as circuit; extern crate snarkvm_console as console; use console::network::{FiatShamir, prelude::*}; -use snarkvm_algorithms::{snark::varuna, traits::SNARK}; +use snarkvm_algorithms::{polycommit::kzg10::DegreeInfo, snark::varuna, traits::SNARK}; use snarkvm_utilities::dev_println; use std::sync::{Arc, OnceLock}; diff --git a/synthesizer/snark/src/proving_key/mod.rs b/synthesizer/snark/src/proving_key/mod.rs index c95f7bb27d..5ece780a8f 100644 --- a/synthesizer/snark/src/proving_key/mod.rs +++ b/synthesizer/snark/src/proving_key/mod.rs @@ -45,12 +45,13 @@ impl ProvingKey { let timer = std::time::Instant::now(); // Retrieve the proving parameters. - let universal_prover = N::varuna_universal_prover(); + let degree_info = self.circuit.index_info.degree_info::()?; + let universal_prover = N::varuna_universal_prover(degree_info); let fiat_shamir = N::varuna_fs_parameters(); // Compute the proof. let proof = - Proof::new(Varuna::::prove(universal_prover, fiat_shamir, self, varuna_version, assignment, rng)?); + Proof::new(Varuna::::prove(&universal_prover, fiat_shamir, self, varuna_version, assignment, rng)?); #[cfg(feature = "dev-print")] { @@ -69,6 +70,9 @@ impl ProvingKey { assignments: &[(ProvingKey, Vec>)], rng: &mut R, ) -> Result> { + // Ensure that the assignments are not empty. + ensure!(!assignments.is_empty(), "Cannot prove an empty batch"); + #[cfg(feature = "dev-print")] let timer = std::time::Instant::now(); @@ -81,12 +85,20 @@ impl ProvingKey { ensure!(instances.len() == num_expected_instances, "Incorrect number of proving keys for batch proof"); // Retrieve the proving parameters. - let universal_prover = N::varuna_universal_prover(); + let mut degree_info: Option = None; + for (pk, _) in assignments { + let degree_info_i = pk.circuit.index_info.degree_info::()?; + degree_info = match degree_info { + Some(degree_info) => Some(degree_info.union(°ree_info_i)), + None => Some(degree_info_i), + }; + } + let universal_prover = N::varuna_universal_prover(degree_info.unwrap()); let fiat_shamir = N::varuna_fs_parameters(); // Compute the proof. let batch_proof = - Proof::new(Varuna::::prove_batch(universal_prover, fiat_shamir, varuna_version, &instances, rng)?); + Proof::new(Varuna::::prove_batch(&universal_prover, fiat_shamir, varuna_version, &instances, rng)?); #[cfg(feature = "dev-print")] { From ee03ab5c91dfe287bf8908c9c720856b3f5b00d1 Mon Sep 17 00:00:00 2001 From: ljedrz Date: Wed, 28 Jan 2026 13:43:38 +0100 Subject: [PATCH 2/5] perf: carry along an updateable universal_prover during synthesis and execution Signed-off-by: ljedrz --- algorithms/benches/snark/varuna.rs | 88 +++---- algorithms/src/errors.rs | 3 + .../src/polycommit/kzg10/data_structures.rs | 147 +---------- .../src/polycommit/kzg10/degree_info.rs | 13 +- .../polycommit/sonic_pc/data_structures.rs | 230 ++++++++++++++++-- algorithms/src/polycommit/sonic_pc/mod.rs | 2 +- algorithms/src/polycommit/test_templates.rs | 28 ++- algorithms/src/snark/varuna/ahp/ahp.rs | 12 +- .../src/snark/varuna/data_structures/mod.rs | 3 + algorithms/src/snark/varuna/tests.rs | 98 ++++---- algorithms/src/snark/varuna/varuna.rs | 27 +- algorithms/src/srs/universal_prover.rs | 56 ++++- algorithms/src/traits/snark.rs | 1 + circuit/environment/src/helpers/assignment.rs | 16 +- circuit/environment/src/helpers/converter.rs | 15 +- console/network/src/canary_v0.rs | 6 - console/network/src/lib.rs | 6 +- console/network/src/mainnet_v0.rs | 9 - console/network/src/testnet_v0.rs | 6 - ledger/puzzle/epoch/Cargo.toml | 3 +- parameters/examples/inclusion.rs | 20 +- synthesizer/Cargo.toml | 4 +- synthesizer/benches/kary_merkle_tree.rs | 24 +- synthesizer/process/Cargo.toml | 7 +- synthesizer/process/src/execute.rs | 2 +- synthesizer/process/src/lib.rs | 42 +++- synthesizer/process/src/stack/deploy.rs | 8 +- .../process/src/stack/helpers/initialize.rs | 1 + .../process/src/stack/helpers/synthesize.rs | 3 +- synthesizer/process/src/stack/mod.rs | 4 +- synthesizer/process/src/tests/test_execute.rs | 5 +- .../process/src/tests/test_serializers.rs | 2 +- synthesizer/process/src/trace/mod.rs | 30 ++- synthesizer/program/Cargo.toml | 4 + synthesizer/snark/Cargo.toml | 8 + synthesizer/snark/src/certificate/mod.rs | 12 +- synthesizer/snark/src/lib.rs | 42 +++- synthesizer/snark/src/proving_key/mod.rs | 35 ++- synthesizer/snark/src/universal_srs.rs | 53 +++- 39 files changed, 684 insertions(+), 391 deletions(-) diff --git a/algorithms/benches/snark/varuna.rs b/algorithms/benches/snark/varuna.rs index 49a0147f8c..2b7a20fdda 100644 --- a/algorithms/benches/snark/varuna.rs +++ b/algorithms/benches/snark/varuna.rs @@ -20,13 +20,21 @@ use snarkvm_algorithms::{ AlgebraicSponge, SNARK, crypto_hash::PoseidonSponge, - polycommit::kzg10::DegreeInfo, - snark::varuna::{CircuitVerifyingKey, TestCircuit, VarunaHidingMode, VarunaSNARK, VarunaVersion, ahp::AHPForR1CS}, + snark::varuna::{ + CircuitVerifyingKey, + TestCircuit, + UniversalProver, + VarunaHidingMode, + VarunaSNARK, + VarunaVersion, + ahp::AHPForR1CS, + }, }; use snarkvm_curves::bls12_377::{Bls12_377, Fq, Fr}; use snarkvm_utilities::{CanonicalDeserialize, CanonicalSerialize, TestRng}; use criterion::Criterion; +use itertools::Itertools; use std::{collections::BTreeMap, time::Duration}; type VarunaInst = VarunaSNARK; @@ -54,7 +62,8 @@ fn snark_circuit_setup(c: &mut Criterion) { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); c.bench_function(&format!("snark_circuit_setup_{size}"), |b| { - b.iter(|| VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap()) + let mut universal_prover = UniversalProver::default(); + b.iter(|| VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap()) }); } } @@ -72,18 +81,17 @@ fn snark_prove(c: &mut Criterion) { let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk, _) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); - let degree_info = pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, _) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); c.bench_function("snark_prove_v1", |b| { let varuna_version = VarunaVersion::V1; - b.iter(|| VarunaInst::prove(universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) + b.iter(|| VarunaInst::prove(&universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) }); c.bench_function("snark_prove_v2", |b| { let varuna_version = VarunaVersion::V2; - b.iter(|| VarunaInst::prove(universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) + b.iter(|| VarunaInst::prove(&universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) }); } @@ -102,11 +110,10 @@ fn snark_batch_prove(c: &mut Criterion) { let circuit_batch_size = 5; let instance_batch_size = 5; - let mut pks = Vec::with_capacity(circuit_batch_size); let mut all_circuits = Vec::with_capacity(circuit_batch_size); let mut keys_to_constraints = BTreeMap::new(); - let mut degree_info = None; + let mut universal_prover = UniversalProver::default(); for i in 0..circuit_batch_size { let num_constraints = num_constraints_base + i; @@ -118,21 +125,20 @@ fn snark_batch_prove(c: &mut Criterion) { let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); circuits.push(circuit); } - let (pk, _) = VarunaInst::circuit_setup(&universal_srs, &circuits[0]).unwrap(); all_circuits.push(circuits); - let degree_info_i = pk.circuit.index_info.degree_info::().unwrap(); - degree_info = degree_info.map(|i: DegreeInfo| i.union(°ree_info_i)).or(Some(degree_info_i)); - pks.push(pk); } + let unique_circuits = all_circuits.iter().map(|c| &c[0]).collect_vec(); + let circuit_keys = + VarunaInst::batch_circuit_setup(&universal_srs, &mut universal_prover, unique_circuits.as_slice()).unwrap(); + for i in 0..circuit_batch_size { - keys_to_constraints.insert(&pks[i], all_circuits[i].as_slice()); + keys_to_constraints.insert(&circuit_keys[i].0, all_circuits[i].as_slice()); } - let universal_prover = &universal_srs.to_universal_prover(degree_info.unwrap()).unwrap(); let varuna_version = VarunaVersion::V2; b.iter(|| { - VarunaInst::prove_batch(universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) + VarunaInst::prove_batch(&universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) .unwrap() }) }); @@ -153,12 +159,11 @@ fn snark_verify(c: &mut Criterion) { let (circuit, public_inputs) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); - let degree_info = pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); let varuna_version = VarunaVersion::V2; - let proof = VarunaInst::prove(universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap(); + let proof = VarunaInst::prove(&universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap(); b.iter(|| { let verification = VarunaInst::verify( universal_verifier, @@ -184,15 +189,12 @@ fn snark_batch_verify(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(1000, 1000, 100).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); + let mut universal_prover = UniversalProver::default(); let fs_parameters = FS::sample_parameters(); let circuit_batch_size = 5; let instance_batch_size = 5; - let mut degree_info = None; - - let mut pks = Vec::with_capacity(circuit_batch_size); - let mut vks = Vec::with_capacity(circuit_batch_size); let mut all_circuits = Vec::with_capacity(circuit_batch_size); let mut all_inputs = Vec::with_capacity(circuit_batch_size); let mut keys_to_constraints = BTreeMap::new(); @@ -208,24 +210,22 @@ fn snark_batch_verify(c: &mut Criterion) { circuits.push(circuit); inputs.push(public_inputs); } - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuits[0]).unwrap(); all_circuits.push(circuits); all_inputs.push(inputs); - let degree_info_i = pk.circuit.index_info.degree_info::().unwrap(); - degree_info = degree_info.map(|i: DegreeInfo| i.union(°ree_info_i)).or(Some(degree_info_i)); - pks.push(pk); - vks.push(vk); } + let unique_circuits = all_circuits.iter().map(|c| &c[0]).collect_vec(); + let circuit_keys = + VarunaInst::batch_circuit_setup(&universal_srs, &mut universal_prover, unique_circuits.as_slice()).unwrap(); + for i in 0..circuit_batch_size { - keys_to_constraints.insert(&pks[i], all_circuits[i].as_slice()); - keys_to_inputs.insert(&vks[i], all_inputs[i].as_slice()); + keys_to_constraints.insert(&circuit_keys[i].0, all_circuits[i].as_slice()); + keys_to_inputs.insert(&circuit_keys[i].1, all_inputs[i].as_slice()); } - let universal_prover = &universal_srs.to_universal_prover(degree_info.unwrap()).unwrap(); let varuna_version = VarunaVersion::V2; let proof = - VarunaInst::prove_batch(universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) + VarunaInst::prove_batch(&universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) .unwrap(); b.iter(|| { let verification = @@ -253,7 +253,8 @@ fn snark_vk_serialize(c: &mut Criterion) { let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); let mut bytes = Vec::with_capacity(10000); group.bench_function(name, |b| { b.iter(|| { @@ -288,7 +289,8 @@ fn snark_vk_deserialize(c: &mut Criterion) { let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); let mut bytes = Vec::with_capacity(10000); vk.serialize_with_mode(&mut bytes, compress).unwrap(); group.bench_function(name, |b| { @@ -315,12 +317,11 @@ fn snark_certificate_prove(c: &mut Criterion) { let num_variables = size; let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); - let degree_info = pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); c.bench_function(&format!("snark_certificate_prove_{size}"), |b| { - b.iter(|| VarunaInst::prove_vk(universal_prover, fs_p, &vk, &pk).unwrap()) + b.iter(|| VarunaInst::prove_vk(&universal_prover, fs_p, &vk, &pk).unwrap()) }); } } @@ -339,10 +340,9 @@ fn snark_certificate_verify(c: &mut Criterion) { let num_variables = size; let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); - let degree_info = pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - let certificate = VarunaInst::prove_vk(universal_prover, fs_p, &vk, &pk).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); + let certificate = VarunaInst::prove_vk(&universal_prover, fs_p, &vk, &pk).unwrap(); c.bench_function(&format!("snark_certificate_verify_{size}"), |b| { b.iter(|| VarunaInst::verify_vk(universal_verifier, fs_p, &circuit, &vk, &certificate).unwrap()) diff --git a/algorithms/src/errors.rs b/algorithms/src/errors.rs index edbf8faa31..f5d9d7acaf 100644 --- a/algorithms/src/errors.rs +++ b/algorithms/src/errors.rs @@ -44,6 +44,9 @@ pub enum SNARKError { #[error("Public input size was different from the circuit")] PublicInputSizeMismatch, + + #[error("FFT precomputation not found")] + FFTPrecompNotFound, } impl From for SNARKError { diff --git a/algorithms/src/polycommit/kzg10/data_structures.rs b/algorithms/src/polycommit/kzg10/data_structures.rs index 6394fb7322..67b9131707 100644 --- a/algorithms/src/polycommit/kzg10/data_structures.rs +++ b/algorithms/src/polycommit/kzg10/data_structures.rs @@ -16,14 +16,9 @@ use super::DegreeInfo; use crate::{ AlgebraicSponge, - fft::{ - DensePolynomial, - EvaluationDomain, - domain::{FFTPrecomputation, IFFTPrecomputation}, - }, - polycommit::sonic_pc::CommitterKey, - prelude::PCError, - srs::{UniversalProver, UniversalVerifier}, + fft::{DensePolynomial, EvaluationDomain}, + snark::varuna::UniversalProver, + srs::UniversalVerifier, }; use snarkvm_curves::{AffineCurve, PairingCurve, PairingEngine, ProjectiveCurve}; use snarkvm_fields::{ConstraintFieldError, ToConstraintField, Zero}; @@ -36,9 +31,8 @@ use snarkvm_utilities::{ serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate}, }; -use anyhow::{Result, bail}; +use anyhow::Result; use core::ops::{Add, AddAssign}; -use itertools::Itertools; use rand::RngCore; use std::{ borrow::Cow, @@ -107,18 +101,9 @@ impl UniversalParams { } pub fn to_universal_prover(&self, degree_info: DegreeInfo) -> Result> { - // Construct the committer key. - let committer_key = self.to_committer_key(°ree_info)?; - // Construct the FFT and iFFT precomputation. - let (fft_precomputation, ifft_precomputation) = - Self::fft_precomputation(degree_info.max_fft_size).ok_or(SerializationError::InvalidData)?; - // Return the universal prover. - Ok(UniversalProver:: { - committer_key, - fft_precomputation, - ifft_precomputation, - max_degree: self.max_degree(), - }) + let mut universal_prover = UniversalProver::default(); + universal_prover.update(self, degree_info)?; + Ok(universal_prover) } pub fn to_universal_verifier(&self) -> Result> { @@ -134,124 +119,6 @@ impl UniversalParams { prepared_negative_powers_of_beta_h: self.powers.prepared_negative_powers_of_beta_h(), }) } - - pub fn to_committer_key(&self, degree_info: &DegreeInfo) -> Result> { - let trim_time = start_timer!(|| "Trimming public parameters"); - - // Retrieve the supported parameters from the degree information. - let supported_degree = degree_info.max_degree; - let supported_lagrange_sizes = degree_info.lagrange_sizes.clone().map(|mut s| s.drain().collect_vec()); - let enforced_degree_bounds = degree_info.degree_bounds.clone().map(|mut s| s.drain().collect_vec()); - let supported_hiding_bound = degree_info.hiding_bound; - - // The maximum degree of the entire SRS. - let max_degree = self.max_degree(); - - // TODO (howardwu): This should never be happening... Use a BTreeSet. - let enforced_degree_bounds = enforced_degree_bounds.map(|bounds| { - let mut v = bounds.to_vec(); - v.sort_unstable(); - v.dedup(); - v - }); - - let (shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g) = if let Some(enforced_degree_bounds) = - enforced_degree_bounds.as_ref() - { - if enforced_degree_bounds.is_empty() { - (None, None) - } else { - let highest_enforced_degree_bound = *enforced_degree_bounds.last().unwrap(); - if highest_enforced_degree_bound > supported_degree { - bail!( - "The highest enforced degree bound {highest_enforced_degree_bound} is larger than the supported degree {supported_degree}" - ); - } - - let lowest_shift_degree = max_degree - highest_enforced_degree_bound; - - let shifted_ck_time = start_timer!(|| format!( - "Constructing `shifted_powers_of_beta_g` of size {}", - max_degree - lowest_shift_degree + 1 - )); - - let shifted_powers_of_beta_g = - self.powers_of_beta_g(lowest_shift_degree, self.max_degree() + 1)?.to_vec(); - let mut shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); - // Also add degree 0. - for degree_bound in enforced_degree_bounds { - let shift_degree = max_degree - degree_bound; - let mut powers_for_degree_bound = Vec::with_capacity((max_degree + 2).saturating_sub(shift_degree)); - for i in 0..=supported_hiding_bound + 1 { - // We have an additional degree in `powers_of_beta_times_gamma_g` beyond - // `powers_of_beta_g`. - if shift_degree + i < max_degree + 2 { - powers_for_degree_bound.push(self.powers_of_beta_times_gamma_g()[&(shift_degree + i)]); - } - } - shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, powers_for_degree_bound); - } - - end_timer!(shifted_ck_time); - - (Some(shifted_powers_of_beta_g), Some(shifted_powers_of_beta_times_gamma_g)) - } - } else { - (None, None) - }; - - let powers_of_beta_g = self.powers_of_beta_g(0, supported_degree + 1)?.to_vec(); - let powers_of_beta_times_gamma_g = (0..=(supported_hiding_bound + 1)) - .map(|i| { - self.powers_of_beta_times_gamma_g() - .get(&i) - .copied() - .ok_or(PCError::HidingBoundToolarge { hiding_poly_degree: supported_hiding_bound, num_powers: 0 }) - }) - .collect::, _>>()?; - - let mut lagrange_bases_at_beta_g = BTreeMap::new(); - if let Some(supported_lagrange_sizes) = supported_lagrange_sizes { - for size in supported_lagrange_sizes { - let lagrange_time = start_timer!(|| format!("Constructing `lagrange_bases` of size {size}")); - if !size.is_power_of_two() { - bail!("The Lagrange basis size ({size}) is not a power of two") - } - if size > self.max_degree() + 1 { - bail!( - "The Lagrange basis size ({size}) is larger than the supported degree ({})", - self.max_degree() + 1 - ) - } - let domain = crate::fft::EvaluationDomain::new(size).unwrap(); - let lagrange_basis_at_beta_g = self.lagrange_basis(domain)?; - assert!(lagrange_basis_at_beta_g.len().is_power_of_two()); - lagrange_bases_at_beta_g.insert(domain.size(), lagrange_basis_at_beta_g); - end_timer!(lagrange_time); - } - } - - let ck = CommitterKey { - powers_of_beta_g, - lagrange_bases_at_beta_g, - powers_of_beta_times_gamma_g, - shifted_powers_of_beta_g, - shifted_powers_of_beta_times_gamma_g, - enforced_degree_bounds, - }; - - end_timer!(trim_time); - Ok(ck) - } - - /// Returns the FFT and iFFT precomputation for the largest supported - /// domain. - fn fft_precomputation(max_domain_size: usize) -> Option<(FFTPrecomputation, IFFTPrecomputation)> { - let domain = EvaluationDomain::new(max_domain_size)?; - let fft_precomputation = domain.precompute_fft(); - let ifft_precomputation = fft_precomputation.to_ifft_precomputation(); - Some((fft_precomputation, ifft_precomputation)) - } } impl FromBytes for UniversalParams { diff --git a/algorithms/src/polycommit/kzg10/degree_info.rs b/algorithms/src/polycommit/kzg10/degree_info.rs index 43ccabf713..385c235a63 100644 --- a/algorithms/src/polycommit/kzg10/degree_info.rs +++ b/algorithms/src/polycommit/kzg10/degree_info.rs @@ -13,21 +13,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; +use std::collections::BTreeSet; -#[derive(Clone)] +#[derive(Clone, Debug, Default)] pub struct DegreeInfo { /// The maximum degree of the required SRS to commit to the polynomials. pub max_degree: usize, /// The maximum IOP poly degree used for (i)fft_precomputation. pub max_fft_size: usize, - /// TODO (howardwu): This should never be happening... Use a BTreeSet. /// The degree bounds on IOP polynomials. - pub degree_bounds: Option>, + pub degree_bounds: Option>, /// The hiding bound for polynomial queries. pub hiding_bound: usize, /// The supported sizes for the lagrange-basis SRS. - pub lagrange_sizes: Option>, + pub lagrange_sizes: Option>, } impl DegreeInfo { @@ -35,9 +34,9 @@ impl DegreeInfo { pub const fn new( max_degree: usize, max_fft_size: usize, - degree_bounds: Option>, + degree_bounds: Option>, hiding_bound: usize, - lagrange_sizes: Option>, + lagrange_sizes: Option>, ) -> Self { Self { max_degree, max_fft_size, degree_bounds, hiding_bound, lagrange_sizes } } diff --git a/algorithms/src/polycommit/sonic_pc/data_structures.rs b/algorithms/src/polycommit/sonic_pc/data_structures.rs index 50b555ebf4..a19be0b9be 100644 --- a/algorithms/src/polycommit/sonic_pc/data_structures.rs +++ b/algorithms/src/polycommit/sonic_pc/data_structures.rs @@ -14,15 +14,23 @@ // limitations under the License. use super::{LabeledPolynomial, PolynomialInfo}; -use crate::{crypto_hash::sha256::sha256, fft::EvaluationDomain, polycommit::kzg10}; +use crate::{ + crypto_hash::sha256::sha256, + fft::EvaluationDomain, + polycommit::kzg10::{self, DegreeInfo}, + prelude::PCError, +}; use snarkvm_curves::PairingEngine; use snarkvm_fields::{ConstraintFieldError, Field, PrimeField, ToConstraintField}; use snarkvm_parameters::mainnet::MAX_NUM_POWERS; use snarkvm_utilities::{FromBytes, ToBytes, error, into_io_error, serialize::*}; +use anyhow::{Result, anyhow, bail, ensure}; use hashbrown::HashMap; +use itertools::Itertools; use std::{ borrow::{Borrow, Cow}, + cmp::Ordering, collections::{BTreeMap, BTreeSet}, fmt, io, @@ -46,7 +54,8 @@ pub struct CommitterKey { pub powers_of_beta_g: Vec, /// The key used to commit to polynomials in Lagrange basis. - pub lagrange_bases_at_beta_g: BTreeMap>, + /// This is `None` if `self` does not support lagrange bases + pub lagrange_bases_at_beta_g: Option>>, /// The key used to commit to hiding polynomials. pub powers_of_beta_times_gamma_g: Vec, @@ -82,17 +91,24 @@ impl FromBytes for CommitterKey { } // Deserialize `lagrange_basis_at_beta`. - let lagrange_bases_at_beta_len: u32 = FromBytes::read_le(&mut reader)?; - let mut lagrange_bases_at_beta_g = BTreeMap::new(); - for _ in 0..lagrange_bases_at_beta_len { - let size: u32 = FromBytes::read_le(&mut reader)?; - let mut basis = Vec::with_capacity(size as usize); - for _ in 0..size { - let power: E::G1Affine = FromBytes::read_le(&mut reader)?; - basis.push(power); + let has_lagrange_bases: bool = FromBytes::read_le(&mut reader)?; + let lagrange_bases_at_beta_g = match has_lagrange_bases { + true => { + let lagrange_bases_at_beta_len: u32 = FromBytes::read_le(&mut reader)?; + let mut lagrange_bases_at_beta_g = BTreeMap::new(); + for _ in 0..lagrange_bases_at_beta_len { + let size: u32 = FromBytes::read_le(&mut reader)?; + let mut basis = Vec::with_capacity(size as usize); + for _ in 0..size { + let power: E::G1Affine = FromBytes::read_le(&mut reader)?; + basis.push(power); + } + lagrange_bases_at_beta_g.insert(size as usize, basis); + } + Some(lagrange_bases_at_beta_g) } - lagrange_bases_at_beta_g.insert(size as usize, basis); - } + false => None, + }; // Deserialize `powers_of_beta_times_gamma_g`. let powers_of_beta_times_gamma_g_len: u32 = FromBytes::read_le(&mut reader)?; @@ -198,18 +214,21 @@ impl FromBytes for CommitterKey { impl ToBytes for CommitterKey { fn write_le(&self, mut writer: W) -> io::Result<()> { - // Serialize `powers`. + // Serialize `powers_of_beta_g`. (self.powers_of_beta_g.len() as u32).write_le(&mut writer)?; for power in &self.powers_of_beta_g { power.write_le(&mut writer)?; } - // Serialize `powers`. - (self.lagrange_bases_at_beta_g.len() as u32).write_le(&mut writer)?; - for (size, powers) in &self.lagrange_bases_at_beta_g { - (*size as u32).write_le(&mut writer)?; - for power in powers { - power.write_le(&mut writer)?; + // Serialize `lagrange_bases_at_beta_g`. + self.lagrange_bases_at_beta_g.is_some().write_le(&mut writer)?; + if let Some(lagrange_bases_at_beta_g) = &self.lagrange_bases_at_beta_g { + (lagrange_bases_at_beta_g.len() as u32).write_le(&mut writer)?; + for (size, powers) in lagrange_bases_at_beta_g.iter() { + (*size as u32).write_le(&mut writer)?; + for power in powers { + power.write_le(&mut writer)?; + } } } @@ -275,6 +294,166 @@ impl ToBytes for CommitterKey { } impl CommitterKey { + /// Update the committer_key + /// The `powers_of_beta_g` and `shifted_powers_of_beta_g` may grow or shrink + /// as per the new supported_degree The `enforced_degree_bounds` and + /// `shifted_powers_of_beta_times_gamma_g` will adjust to the new + /// degree_bounds The `lagrange_bases_at_beta_g` will optionally adjust + /// to the new degree_bounds The `powers_of_beta_times_gamma_g` is only + /// dependent on the `hiding_bound` so doesn't change This only works if + /// the SRS max_degree is fixed across specializations Note that this + /// implementation is not atomic. If specialize fails halfway through, + pub fn update(&mut self, srs: &UniversalParams, degree_info: &DegreeInfo) -> Result<()> { + let trim_time = start_timer!(|| "Trimming public parameters"); + + // The maximum degree required by the circuit polynomials + let supported_degree = degree_info.max_degree; + // The maximum degree of the entire SRS + let max_degree = srs.max_degree(); + // The hiding bound for the circuit polynomials + let supported_hiding_bound = degree_info.hiding_bound; + + // Update shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g, + // enforced_degree_bounds + match degree_info.degree_bounds.as_ref() { + None => { + self.shifted_powers_of_beta_g = None; + self.shifted_powers_of_beta_times_gamma_g = None; + self.enforced_degree_bounds = None; + } + Some(enforced_degree_bounds) => { + // Retrieve new highest degree bound + let new_highest_degree_bound = + *enforced_degree_bounds.last().ok_or(anyhow!("No degree bound found"))?; + ensure!(new_highest_degree_bound < supported_degree); + + // Retrieve current degree bounds and shifted powers + let degree_bounds = self.enforced_degree_bounds.take().unwrap_or_default(); + let mut highest_degree_bound = degree_bounds.iter().copied().sorted().last().unwrap_or_default(); + let mut shifted_powers_of_beta_g = self.shifted_powers_of_beta_g.take().unwrap_or_default(); + let mut shifted_powers_of_beta_times_gamma_g = + self.shifted_powers_of_beta_times_gamma_g.take().unwrap_or_default(); + + // We add 1 to any existing upper bound, in congruence with `max_degree + 1` + // below. This is because the proof system assumes we need this + // extra degree. This can optionally be refactored to ensure the + // extra degree is already encoded in degree_bounds. + if highest_degree_bound > 0 { + highest_degree_bound = + highest_degree_bound.checked_add(1).ok_or(error("Overflow highest_degree_bound"))?; + } + + let shifted_ck_time = start_timer!(|| "Constructing `shifted_powers_of_beta_g`"); + match new_highest_degree_bound.cmp(&highest_degree_bound) { + Ordering::Greater => { + let lowest_shift_degree = max_degree - new_highest_degree_bound; + let highest_shift_degree = max_degree + 1 - highest_degree_bound; + let mut new_powers = srs.powers_of_beta_g(lowest_shift_degree, highest_shift_degree)?; + new_powers.extend_from_slice(&shifted_powers_of_beta_g); + shifted_powers_of_beta_g = new_powers; + } + Ordering::Equal => (), // No need to adjust shifted_powers_of_beta_g + Ordering::Less => { + let degree_diff = highest_degree_bound - new_highest_degree_bound; + shifted_powers_of_beta_g = shifted_powers_of_beta_g.drain(degree_diff - 1..).collect_vec(); + } + }; + self.shifted_powers_of_beta_g = Some(shifted_powers_of_beta_g); + end_timer!(shifted_ck_time); + + // because we might batch prove circuits of different sizes, we consider each + // required degree_bound individually + let shifted_ck_time = start_timer!(|| "Constructing `shifted_powers_of_beta_times_gamma_g`"); + let mut new_shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); + for degree_bound in enforced_degree_bounds { + match shifted_powers_of_beta_times_gamma_g.remove(degree_bound) { + Some(found_powers) => { + new_shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, found_powers); + } + None => { + let shift_degree = max_degree - *degree_bound; + let mut powers_for_degree_bound = + Vec::with_capacity((max_degree + 2).saturating_sub(shift_degree)); + for i in 0..=supported_hiding_bound + .checked_add(1) + .ok_or(anyhow!("Overflow supported_hiding_bound"))? + { + // We have an additional degree in `powers_of_beta_times_gamma_g` beyond + // `powers_of_beta_g`. + if shift_degree + i < max_degree + 2 { + powers_for_degree_bound + .push(srs.powers_of_beta_times_gamma_g()[&(shift_degree + i)]); + } + } + new_shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, powers_for_degree_bound); + } + } + } + self.enforced_degree_bounds = Some(enforced_degree_bounds.iter().copied().collect_vec()); + self.shifted_powers_of_beta_times_gamma_g = Some(new_shifted_powers_of_beta_times_gamma_g); + end_timer!(shifted_ck_time); + } + } + + // Update powers_of_beta_g + let old_supported_degree = self.powers_of_beta_g.len(); + match supported_degree.cmp(&old_supported_degree) { + Ordering::Greater => { + let mut new_powers = srs.powers_of_beta_g(old_supported_degree, supported_degree)?; + self.powers_of_beta_g.append(&mut new_powers); + } + Ordering::Equal => (), // No need to adjust powers_of_beta_g + Ordering::Less => { + let degree_difference = old_supported_degree - supported_degree; + self.powers_of_beta_g.truncate(old_supported_degree - degree_difference); + } + } + + // Set powers_of_beta_times_gamma_g + self.powers_of_beta_times_gamma_g = + (0..=(supported_hiding_bound.checked_add(1).ok_or(anyhow!("Overflow supported_hiding_bound"))?)) + .map(|i| { + srs.powers_of_beta_times_gamma_g().get(&i).copied().ok_or(PCError::HidingBoundToolarge { + hiding_poly_degree: supported_hiding_bound, + num_powers: 0, + }) + }) + .collect::, _>>()?; + + // Update lagrange_bases_at_beta_g + let mut lagrange_bases_at_beta_g = BTreeMap::new(); + if let Some(supported_lagrange_sizes) = degree_info.lagrange_sizes.as_ref() { + let mut old_lagrange_bases_at_beta_g = self.lagrange_bases_at_beta_g.take().unwrap_or_default(); + for size in supported_lagrange_sizes { + let domain = EvaluationDomain::new(*size).unwrap(); + match old_lagrange_bases_at_beta_g.remove(&domain.size()) { + Some(found_lagrange_bases) => { + lagrange_bases_at_beta_g.insert(domain.size(), found_lagrange_bases); + } + None => { + let lagrange_time = start_timer!(|| format!("Constructing `lagrange_bases` of size {size}")); + if !size.is_power_of_two() { + bail!("The Lagrange basis size ({size}) is not a power of two") + } + if *size > srs.max_degree() + 1 { + bail!( + "The Lagrange basis size ({size}) is larger than the supported degree ({})", + srs.max_degree() + 1 + ) + } + let lagrange_basis_at_beta_g = srs.lagrange_basis(domain)?; + assert!(lagrange_basis_at_beta_g.len().is_power_of_two()); + lagrange_bases_at_beta_g.insert(domain.size(), lagrange_basis_at_beta_g); + end_timer!(lagrange_time); + } + } + } + self.lagrange_bases_at_beta_g = Some(lagrange_bases_at_beta_g); + } + end_timer!(trim_time); + Ok(()) + } + /// Obtain powers for the underlying KZG10 construction pub fn powers(&self) -> kzg10::Powers { kzg10::Powers { @@ -310,11 +489,14 @@ impl CommitterKey { /// Obtain elements of the SRS in the lagrange basis powers, for use with /// the underlying KZG10 construction. pub fn lagrange_basis(&self, domain: EvaluationDomain) -> Option> { - self.lagrange_bases_at_beta_g.get(&domain.size()).map(|basis| kzg10::LagrangeBasis { - lagrange_basis_at_beta_g: Cow::Borrowed(basis), - powers_of_beta_times_gamma_g: Cow::Borrowed(&self.powers_of_beta_times_gamma_g), - domain, - }) + if let Some(lagrange_bases) = self.lagrange_bases_at_beta_g.as_ref() { + return lagrange_bases.get(&domain.size()).map(|basis| kzg10::LagrangeBasis { + lagrange_basis_at_beta_g: Cow::Borrowed(basis), + powers_of_beta_times_gamma_g: Cow::Borrowed(&self.powers_of_beta_times_gamma_g), + domain, + }); + } + None } } diff --git a/algorithms/src/polycommit/sonic_pc/mod.rs b/algorithms/src/polycommit/sonic_pc/mod.rs index 0b6b71254b..8f830a3dfb 100644 --- a/algorithms/src/polycommit/sonic_pc/mod.rs +++ b/algorithms/src/polycommit/sonic_pc/mod.rs @@ -610,7 +610,7 @@ mod tests { let lagrange_sizes = Some([lagrange_size(supported_degree)].into_iter().collect()); let degree_info = DegreeInfo::new(supported_degree, supported_degree, None, hiding_bound, lagrange_sizes); - let ck = pp.to_committer_key(°ree_info).unwrap(); + let ck = pp.to_universal_prover(degree_info).unwrap().committer_key; let ck_bytes = ck.to_bytes_le().unwrap(); let ck_recovered: CommitterKey = FromBytes::read_le(&ck_bytes[..]).unwrap(); diff --git a/algorithms/src/polycommit/test_templates.rs b/algorithms/src/polycommit/test_templates.rs index 0eac7b0354..e5ab111882 100644 --- a/algorithms/src/polycommit/test_templates.rs +++ b/algorithms/src/polycommit/test_templates.rs @@ -42,7 +42,7 @@ use rand::{ Rng, distributions::{self, Distribution}, }; -use std::{collections::HashSet, marker::PhantomData}; +use std::{collections::BTreeSet, marker::PhantomData}; #[derive(Default)] struct TestInfo { @@ -79,7 +79,7 @@ pub fn bad_degree_bound_test>() - let mut labels = Vec::new(); let mut polynomials = Vec::new(); - let mut degree_bounds = HashSet::new(); + let mut degree_bounds = BTreeSet::new(); for i in 0..10 { let label = format!("Test{i}"); @@ -91,11 +91,12 @@ pub fn bad_degree_bound_test>() - polynomials.push(LabeledPolynomial::new(label, poly, Some(degree_bound), Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); let degree_info = DegreeInfo { max_degree, max_fft_size: supported_degree, - degree_bounds: Some(degree_bounds), + degree_bounds, hiding_bound, lagrange_sizes: None, }; @@ -148,8 +149,8 @@ pub fn lagrange_test_template>() assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); let mut lagrange_polynomials = Vec::new(); - let mut supported_lagrange_sizes = HashSet::new(); - let degree_bounds = HashSet::new(); + let mut supported_lagrange_sizes = BTreeSet::new(); + let degree_bounds = BTreeSet::new(); let hiding_bound = 1; let mut labels = Vec::new(); @@ -175,12 +176,13 @@ pub fn lagrange_test_template>() polynomials.push(LabeledPolynomial::new(label.clone(), poly, degree_bound, Some(hiding_bound))); lagrange_polynomials.push(LabeledPolynomialWithBasis::new_lagrange_basis(label, evals, Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); assert_eq!(supported_hiding_bound, 1); let degree_info = DegreeInfo { max_degree, max_fft_size: supported_degree, - degree_bounds: Some(degree_bounds), + degree_bounds, hiding_bound, lagrange_sizes: Some(supported_lagrange_sizes), }; @@ -262,10 +264,10 @@ where for _ in 0..num_iters { let universal_verifier = pp.to_universal_verifier().unwrap(); let supported_degree = - supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree).sample(rng)); + supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree - 1).sample(rng)); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); - let mut degree_bounds = HashSet::new(); + let mut degree_bounds = BTreeSet::new(); let mut labels = Vec::new(); println!("Sampled supported degree"); @@ -296,12 +298,13 @@ where } polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); assert_eq!(supported_hiding_bound, 1); let degree_info = DegreeInfo { max_degree, max_fft_size: supported_degree, - degree_bounds: Some(degree_bounds), + degree_bounds, hiding_bound, lagrange_sizes: None, }; @@ -382,10 +385,10 @@ fn equation_test_template>( for _ in 0..num_iters { let universal_verifier = pp.to_universal_verifier().unwrap(); let supported_degree = - supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree).sample(rng)); + supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree - 1).sample(rng)); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); - let mut degree_bounds = HashSet::new(); + let mut degree_bounds = BTreeSet::new(); let mut labels = Vec::new(); println!("Sampled supported degree"); @@ -416,12 +419,13 @@ fn equation_test_template>( } polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); assert_eq!(supported_hiding_bound, 1); let degree_info = DegreeInfo { max_degree, max_fft_size: supported_degree, - degree_bounds: Some(degree_bounds), + degree_bounds, hiding_bound, lagrange_sizes: None, }; diff --git a/algorithms/src/snark/varuna/ahp/ahp.rs b/algorithms/src/snark/varuna/ahp/ahp.rs index d86dd613e1..5bf11bb829 100644 --- a/algorithms/src/snark/varuna/ahp/ahp.rs +++ b/algorithms/src/snark/varuna/ahp/ahp.rs @@ -79,8 +79,7 @@ impl AHPForR1CS { Self::num_formatted_public_inputs_is_admissible(input.len()) } - /// The maximum degree of polynomials produced by the indexer and prover - /// of this protocol. + /// The maximum degree of polynomials the indexer or prover commits to. /// The number of the variables must include the "one" variable. That is, it /// must be with respect to the number of formatted public inputs. pub fn max_degree(num_constraints: usize, num_variables: usize, num_non_zero: usize) -> Result { @@ -94,12 +93,11 @@ impl AHPForR1CS { // these should correspond with the bounds set in the .rs files [ - 2 * constraint_domain_size + 2 * zk_bound - 2, - 2 * variable_domain_size + 2 * zk_bound - 2, + 2 * constraint_domain_size + 2 * zk_bound - 2, // h_0 + 2 * variable_domain_size + 2 * zk_bound - 2, // h_1 if SM::ZK { variable_domain_size + 3 } else { 0 }, // mask_poly - variable_domain_size, - constraint_domain_size, - non_zero_domain_size - 1, // non-zero polynomials + variable_domain_size, // g_1 + non_zero_domain_size, // g_M, h_M ] .iter() .max() diff --git a/algorithms/src/snark/varuna/data_structures/mod.rs b/algorithms/src/snark/varuna/data_structures/mod.rs index a28f97cac7..a81c111007 100644 --- a/algorithms/src/snark/varuna/data_structures/mod.rs +++ b/algorithms/src/snark/varuna/data_structures/mod.rs @@ -44,3 +44,6 @@ pub mod test_exports { /// The Varuna universal SRS. pub type UniversalSRS = crate::polycommit::sonic_pc::UniversalParams; + +/// The Varuna universal prover. +pub type UniversalProver = crate::srs::UniversalProver; diff --git a/algorithms/src/snark/varuna/tests.rs b/algorithms/src/snark/varuna/tests.rs index 20884b9171..1fa4fad1f1 100644 --- a/algorithms/src/snark/varuna/tests.rs +++ b/algorithms/src/snark/varuna/tests.rs @@ -16,7 +16,6 @@ #[cfg(any(test, feature = "test"))] mod varuna { use crate::{ - polycommit::kzg10::DegreeInfo, snark::varuna::{ AHPForR1CS, CircuitVerifyingKey, @@ -28,6 +27,7 @@ mod varuna { proof::proof_size, test_circuit::TestCircuit, }, + srs::UniversalProver, traits::{AlgebraicSponge, SNARK}, }; @@ -54,6 +54,7 @@ mod varuna { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = $snark_inst::universal_setup(max_degree).unwrap(); + let mut universal_prover = UniversalProver::default(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); @@ -69,13 +70,10 @@ mod varuna { let mut fake_inputs = public_inputs.clone(); fake_inputs[public_inputs.len() - 1] = random; - let (index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap(); + let (index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); println!("Called circuit setup"); - let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - - let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); + let certificate = $snark_inst::prove_vk(&universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circ, &index_vk, &certificate).unwrap()); println!("verified vk"); @@ -84,7 +82,7 @@ mod varuna { } assert_eq!(664, index_vk.to_bytes_le().unwrap().len(), "Update me if serialization has changed"); - let proof = $snark_inst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circ, rng).unwrap(); + let proof = $snark_inst::prove(&universal_prover, &fs_parameters, &index_pk, varuna_version, &circ, rng).unwrap(); println!("Called prover"); assert!($snark_inst::verify(universal_verifier, &fs_parameters, &index_vk, varuna_version, public_inputs.clone(), &proof).unwrap()); @@ -100,7 +98,6 @@ mod varuna { println!("running test with circuit_batch_size: {circuit_batch_size} and instance_batch_size: {instance_batch_size}"); let mut constraints = BTreeMap::new(); let mut inputs = BTreeMap::new(); - let mut degree_info = None; for i in 0..circuit_batch_size { let (circuit_batch, input_batch): (Vec<_>, Vec<_>) = (0..instance_batch_size) @@ -117,20 +114,14 @@ mod varuna { let unique_instances = constraints.values().map(|instances| &instances[0]).collect::>(); let index_keys = - $snark_inst::batch_circuit_setup(&universal_srs, unique_instances.as_slice()).unwrap(); + $snark_inst::batch_circuit_setup(&universal_srs, &mut universal_prover, unique_instances.as_slice()).unwrap(); println!("Called circuit setup"); let mut pks_to_constraints = BTreeMap::new(); let mut vks_to_inputs = BTreeMap::new(); for (index_pk, index_vk) in index_keys.iter() { - let degree_info_i = index_pk.circuit.index_info.degree_info::().unwrap(); - degree_info = degree_info.map(|i: DegreeInfo|i.union(°ree_info_i)).or(Some(degree_info_i.clone())); - let universal_prover = &universal_srs.to_universal_prover( - degree_info_i - ).unwrap(); - - let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); + let certificate = $snark_inst::prove_vk(&universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); let circuits = constraints[&index_pk.circuit.id].as_slice(); assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circuits[0], &index_vk, &certificate).unwrap()); pks_to_constraints.insert(index_pk, circuits); @@ -138,12 +129,8 @@ mod varuna { } println!("verified vks"); - let universal_prover = &universal_srs.to_universal_prover( - degree_info.unwrap() - ).unwrap(); - let proof = - $snark_inst::prove_batch(universal_prover, &fs_parameters, varuna_version, &pks_to_constraints, rng).unwrap(); + $snark_inst::prove_batch(&universal_prover, &fs_parameters, varuna_version, &pks_to_constraints, rng).unwrap(); println!("Called prover"); if varuna_version == VarunaVersion::V2 { @@ -209,7 +196,8 @@ mod varuna { let mul_depth = 1; let (circ, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); println!("Called circuit setup"); // Serialize @@ -234,7 +222,8 @@ mod varuna { let mul_depth = 1; let (circ, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); println!("Called circuit setup"); // Serialize @@ -366,6 +355,7 @@ mod varuna_hiding { crypto_hash::PoseidonSponge, snark::varuna::{ CircuitVerifyingKey, + UniversalProver, VarunaHidingMode, VarunaSNARK, VarunaVersion, @@ -409,14 +399,13 @@ mod varuna_hiding { let mut fake_inputs = public_inputs.clone(); fake_inputs[public_inputs.len() - 1] = Fr::rand(rng); - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = + VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); - let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - let proof = - VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); + VarunaInst::prove(&universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); println!("Called prover"); assert!( @@ -469,7 +458,8 @@ mod varuna_hiding { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); // Serialize @@ -489,7 +479,8 @@ mod varuna_hiding { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); // Serialize @@ -591,18 +582,17 @@ mod varuna_hiding { VarunaVersion::V1 => VarunaVersion::V2, VarunaVersion::V2 => VarunaVersion::V1, }; - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = + VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); - let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - let proof = - VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); + VarunaInst::prove(&universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); println!("Called prover"); universal_srs.download_powers_for(0..2usize.pow(18)).unwrap(); - let (new_pk, new_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let (new_pk, new_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); assert_eq!(index_pk, new_pk); assert_eq!(index_vk, new_vk); assert!( @@ -657,13 +647,12 @@ mod varuna_hiding { let num_constraints = 2usize.pow(15) - 10; let num_variables = 2usize.pow(15) - 10; let (circuit1, public_inputs1) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &circuit1).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit1).unwrap(); println!("Called circuit setup"); - let degree_info = pk1.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - - let proof1 = VarunaInst::prove(universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap(); + let proof1 = + VarunaInst::prove(&universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap(); println!("Called prover"); assert!( VarunaInst::verify( @@ -686,13 +675,11 @@ mod varuna_hiding { let num_constraints = 2usize.pow(19) - 10; let num_variables = 2usize.pow(19) - 10; let (circuit2, public_inputs2) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk2, vk2) = VarunaInst::circuit_setup(&universal_srs, &circuit2).unwrap(); + let (pk2, vk2) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit2).unwrap(); println!("Called circuit setup"); - let degree_info = pk2.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - - let proof2 = VarunaInst::prove(universal_prover, &fs_parameters, &pk2, varuna_version, &circuit2, rng).unwrap(); + let proof2 = + VarunaInst::prove(&universal_prover, &fs_parameters, &pk2, varuna_version, &circuit2, rng).unwrap(); println!("Called prover"); assert!( VarunaInst::verify(universal_verifier, &fs_parameters, &vk2, varuna_version, public_inputs2, &proof2) @@ -700,6 +687,14 @@ mod varuna_hiding { ); /*************************************************************************** * * */ + // Indexing, proving, and verifying again for a circuit with 1 << 15 constraints + // and 1 << 15 variables. + let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit1).unwrap(); + println!("Called circuit setup"); + + let proof1 = + VarunaInst::prove(&universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap(); + println!("Called prover"); assert!( VarunaInst::verify(universal_verifier, &fs_parameters, &vk1, varuna_version, public_inputs1, &proof1) .unwrap() @@ -711,6 +706,7 @@ mod varuna_test_vectors { use crate::{ fft::EvaluationDomain, snark::varuna::{AHPForR1CS, TestCircuit, VarunaNonHidingMode, VarunaSNARK, VarunaVersion, ahp::verifier}, + srs::UniversalProver, traits::snark::SNARK, }; use snarkvm_curves::bls12_377::{Bls12_377, Fq, Fr}; @@ -842,17 +838,19 @@ mod varuna_test_vectors { let (circ, _) = TestCircuit::generate_circuit_with_fixed_witness(a, b, mul_depth, num_constraints, num_variables); println!("Circuit: {circ:?}"); - let (index_pk, _index_vk) = VarunaSonicInst::circuit_setup(&universal_srs, &circ).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, _index_vk) = + VarunaSonicInst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); let mut keys_to_constraints = BTreeMap::new(); keys_to_constraints.insert(index_pk.circuit.deref(), std::slice::from_ref(&circ)); // Begin the Varuna protocol execution. let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); let universal_prover = universal_srs.to_universal_prover(degree_info).unwrap(); - let fft_precomp = &universal_prover.fft_precomputation; - let ifft_precomp = &universal_prover.ifft_precomputation; + let fft_precomp = universal_prover.fft_precomputation.unwrap(); + let ifft_precomp = universal_prover.ifft_precomputation.unwrap(); let prover_state = - AHPForR1CS::<_, MM>::init_prover(&keys_to_constraints, fft_precomp, ifft_precomp, rng).unwrap(); + AHPForR1CS::<_, MM>::init_prover(&keys_to_constraints, &fft_precomp, &ifft_precomp, rng).unwrap(); let mut prover_state = AHPForR1CS::<_, MM>::prover_first_round(prover_state, rng).unwrap(); let first_round_oracles = Arc::new(prover_state.first_round_oracles.as_ref().unwrap()); diff --git a/algorithms/src/snark/varuna/varuna.rs b/algorithms/src/snark/varuna/varuna.rs index 625dabf805..9519691107 100644 --- a/algorithms/src/snark/varuna/varuna.rs +++ b/algorithms/src/snark/varuna/varuna.rs @@ -66,7 +66,8 @@ impl, SM: SNARKMode> VarunaSNARK /// We commit to each circuit's index polys separately /// Performing a single batch commitment can be a future optimization pub fn batch_circuit_setup>( - universal_srs: &UniversalSRS, + srs: &UniversalSRS, + universal_prover: &mut UniversalProver, circuits: &[&C], ) -> Result, CircuitVerifyingKey)>> { ensure!(!circuits.is_empty()); @@ -86,11 +87,7 @@ impl, SM: SNARKMode> VarunaSNARK } let degree_info = degree_info.unwrap(); - universal_srs - .download_powers_for(0..degree_info.max_degree) - .map_err(|e| anyhow!("Failed to download powers for degree {}: {e}", degree_info.max_degree))?; - - let universal_prover = universal_srs.to_universal_prover(degree_info)?; + universal_prover.update(srs, degree_info)?; let mut circuit_keys = Vec::with_capacity(circuits.len()); for mut indexed_circuit in indexed_circuits { @@ -98,7 +95,7 @@ impl, SM: SNARKMode> VarunaSNARK let setup_rng = None::<&mut dyn RngCore>; // We do not randomize the commitments let (mut circuit_commitments, commitment_randomnesses): (_, _) = SonicKZG10::::commit( - &universal_prover, + universal_prover, indexed_circuit.interpolate_matrix_evals()?.map(Into::into), setup_rng, )?; @@ -220,10 +217,11 @@ where /// Generates the circuit proving and verifying keys. /// This is a deterministic algorithm that anyone can rerun. fn circuit_setup>( - universal_srs: &Self::UniversalSRS, + srs: &Self::UniversalSRS, + universal_prover: &mut Self::UniversalProver, circuit: &C, ) -> Result<(Self::ProvingKey, Self::VerifyingKey)> { - let mut circuit_keys = Self::batch_circuit_setup::(universal_srs, &[circuit])?; + let mut circuit_keys = Self::batch_circuit_setup::(srs, universal_prover, &[circuit])?; ensure!(circuit_keys.len() == 1); Ok(circuit_keys.pop().unwrap()) } @@ -354,10 +352,13 @@ where for (pk, constraints) in keys_to_constraints { circuits_to_constraints.insert(pk.circuit.deref(), *constraints); } - let fft_precomp = &universal_prover.fft_precomputation; - let ifft_precomp = &universal_prover.ifft_precomputation; - let prover_state = - AHPForR1CS::<_, SM>::init_prover(&circuits_to_constraints, fft_precomp, ifft_precomp, zk_rng)?; + + let prover_state = match (&universal_prover.fft_precomputation, &universal_prover.ifft_precomputation) { + (Some(fft_precomp), Some(ifft_precomp)) => { + AHPForR1CS::<_, SM>::init_prover(&circuits_to_constraints, fft_precomp, ifft_precomp, zk_rng)? + } + _ => return Err(SNARKError::FFTPrecompNotFound.into()), + }; // extract information from the prover key and state to consume in further // calculations diff --git a/algorithms/src/srs/universal_prover.rs b/algorithms/src/srs/universal_prover.rs index 9d14a22157..80c4de7ac7 100644 --- a/algorithms/src/srs/universal_prover.rs +++ b/algorithms/src/srs/universal_prover.rs @@ -14,11 +14,17 @@ // limitations under the License. use crate::{ - fft::domain::{FFTPrecomputation, IFFTPrecomputation}, - polycommit::sonic_pc::CommitterKey, + fft::{ + EvaluationDomain, + domain::{FFTPrecomputation, IFFTPrecomputation}, + }, + polycommit::{kzg10::DegreeInfo, sonic_pc::CommitterKey}, + snark::varuna::UniversalSRS, }; use snarkvm_curves::PairingEngine; +use anyhow::{Result, anyhow}; + /// `UniversalProver` is used to compute evaluation proofs for a given /// commitment. #[derive(Debug)] @@ -26,9 +32,51 @@ pub struct UniversalProver { /// The committer key for the underlying KZG10 scheme. pub committer_key: CommitterKey, /// The precomputed roots for calculating FFTs. - pub fft_precomputation: FFTPrecomputation, + pub fft_precomputation: Option>, /// The precomputed inverse roots for calculating inverse FFTs. - pub ifft_precomputation: IFFTPrecomputation, + pub ifft_precomputation: Option>, /// The maximum degree supported by the universal SRS. pub max_degree: usize, } + +impl Default for UniversalProver { + fn default() -> Self { + let committer_key = CommitterKey { + powers_of_beta_g: Vec::new(), + lagrange_bases_at_beta_g: None, + powers_of_beta_times_gamma_g: Vec::new(), + shifted_powers_of_beta_g: None, + shifted_powers_of_beta_times_gamma_g: None, + enforced_degree_bounds: None, + }; + Self { committer_key, fft_precomputation: None, ifft_precomputation: None, max_degree: 0 } + } +} + +impl UniversalProver { + /// Update the UniversalProver: + /// - set a new max_degree + /// - grow or shrink the committer_key + /// - grow or shrink the FFT and iFFT precomputation + /// + /// Note that this implementation is not atomic. If specialize fails halfway + /// through, we might have an inconsistent UniversalProver. + pub fn update(&mut self, srs: &UniversalSRS, degree_info: DegreeInfo) -> Result<()> { + // Set the new max_degree + self.max_degree = degree_info.max_degree; + // Update the committer key. + self.committer_key.update(srs, °ree_info)?; + // Specialize the FFT and iFFT precomputation. + self.update_fft_precomputation(degree_info.max_fft_size)?; + Ok(()) + } + + /// Returns the FFT and iFFT precomputation for the largest supported + /// domain. + fn update_fft_precomputation(&mut self, max_domain_size: usize) -> Result<()> { + let domain = EvaluationDomain::new(max_domain_size).ok_or(anyhow!("Invalid domain size"))?; + self.fft_precomputation = Some(domain.precompute_fft()); + self.ifft_precomputation = Some(self.fft_precomputation.as_ref().unwrap().to_ifft_precomputation()); + Ok(()) + } +} diff --git a/algorithms/src/traits/snark.rs b/algorithms/src/traits/snark.rs index 33b6253d89..a2c9988dc1 100644 --- a/algorithms/src/traits/snark.rs +++ b/algorithms/src/traits/snark.rs @@ -60,6 +60,7 @@ pub trait SNARK { fn circuit_setup>( srs: &Self::UniversalSRS, + universal_prover: &mut Self::UniversalProver, circuit: &C, ) -> Result<(Self::ProvingKey, Self::VerifyingKey)>; diff --git a/circuit/environment/src/helpers/assignment.rs b/circuit/environment/src/helpers/assignment.rs index d3f362322a..c960f7b140 100644 --- a/circuit/environment/src/helpers/assignment.rs +++ b/circuit/environment/src/helpers/assignment.rs @@ -227,7 +227,12 @@ impl snarkvm_algorithms::r1cs::ConstraintSynthesizer for Assig #[cfg(test)] mod tests { - use snarkvm_algorithms::{AlgebraicSponge, SNARK, r1cs::ConstraintSynthesizer, snark::varuna::VarunaVersion}; + use snarkvm_algorithms::{ + AlgebraicSponge, + SNARK, + r1cs::ConstraintSynthesizer, + snark::varuna::{UniversalProver, VarunaVersion}, + }; use snarkvm_circuit::prelude::*; use snarkvm_curves::bls12_377::Fr; @@ -303,13 +308,12 @@ mod tests { let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_pp = FS::sample_parameters(); - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &assignment).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = + VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &assignment).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - - let proof = VarunaInst::prove(universal_prover, &fs_pp, &index_pk, varuna_version, &assignment, rng).unwrap(); + let proof = VarunaInst::prove(&universal_prover, &fs_pp, &index_pk, varuna_version, &assignment, rng).unwrap(); println!("Called prover"); let one = ::BaseField::one(); diff --git a/circuit/environment/src/helpers/converter.rs b/circuit/environment/src/helpers/converter.rs index 660642f23a..d63a9871c5 100644 --- a/circuit/environment/src/helpers/converter.rs +++ b/circuit/environment/src/helpers/converter.rs @@ -201,7 +201,12 @@ impl R1CS { #[cfg(test)] mod tests { - use snarkvm_algorithms::{AlgebraicSponge, SNARK, r1cs::ConstraintSynthesizer, snark::varuna::VarunaVersion}; + use snarkvm_algorithms::{ + AlgebraicSponge, + SNARK, + r1cs::ConstraintSynthesizer, + snark::varuna::{UniversalProver, VarunaVersion}, + }; use snarkvm_circuit::prelude::*; use snarkvm_curves::bls12_377::Fr; @@ -268,13 +273,11 @@ mod tests { let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_pp = FS::sample_parameters(); - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &Circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &Circuit).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); - let universal_prover = &universal_srs.to_universal_prover(degree_info).unwrap(); - - let proof = VarunaInst::prove(universal_prover, &fs_pp, &index_pk, varuna_version, &Circuit, rng).unwrap(); + let proof = VarunaInst::prove(&universal_prover, &fs_pp, &index_pk, varuna_version, &Circuit, rng).unwrap(); println!("Called prover"); assert!( diff --git a/console/network/src/canary_v0.rs b/console/network/src/canary_v0.rs index 4afb3e11a7..c2ef592318 100644 --- a/console/network/src/canary_v0.rs +++ b/console/network/src/canary_v0.rs @@ -15,7 +15,6 @@ use super::*; use crate::TRANSACTION_PREFIX; -use snarkvm_algorithms::polycommit::kzg10::DegreeInfo; use snarkvm_console_algorithms::{ BHP256, BHP512, @@ -346,11 +345,6 @@ impl Network for CanaryV0 { .sum() } - /// Returns the Varuna universal prover. - fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver { - MainnetV0::varuna_universal_prover(degree_info) - } - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier { MainnetV0::varuna_universal_verifier() diff --git a/console/network/src/lib.rs b/console/network/src/lib.rs index 948915c35e..977c69a7cf 100644 --- a/console/network/src/lib.rs +++ b/console/network/src/lib.rs @@ -63,9 +63,8 @@ pub use crate::environment::prelude::*; use snarkvm_algorithms::{ AlgebraicSponge, crypto_hash::PoseidonSponge, - polycommit::kzg10::DegreeInfo, snark::varuna::{CircuitProvingKey, CircuitVerifyingKey, VarunaHidingMode}, - srs::{UniversalProver, UniversalVerifier}, + srs::UniversalVerifier, }; use snarkvm_console_algorithms::{BHP512, BHP1024, Poseidon2, Poseidon4}; use snarkvm_console_collections::merkle_tree::{MerklePath, MerkleTree}; @@ -358,9 +357,6 @@ pub trait Network: /// Returns the scalar multiplication on the generator `G`. fn g_scalar_multiply(scalar: &Scalar) -> Group; - /// Returns the Varuna universal prover. - fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver; - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier; diff --git a/console/network/src/mainnet_v0.rs b/console/network/src/mainnet_v0.rs index f8e0497677..50599211c8 100644 --- a/console/network/src/mainnet_v0.rs +++ b/console/network/src/mainnet_v0.rs @@ -14,7 +14,6 @@ // limitations under the License. use super::*; -use snarkvm_algorithms::polycommit::kzg10::DegreeInfo; use snarkvm_console_algorithms::{ BHP256, BHP512, @@ -351,14 +350,6 @@ impl Network for MainnetV0 { .sum() } - /// Returns the Varuna universal prover. - fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver { - snarkvm_algorithms::polycommit::kzg10::UniversalParams::load() - .expect("Failed to load universal SRS (KZG10).") - .to_universal_prover(degree_info) - .expect("Failed to convert universal SRS (KZG10) to the prover.") - } - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier { static INSTANCE: OnceLock::PairingCurve>> = OnceLock::new(); diff --git a/console/network/src/testnet_v0.rs b/console/network/src/testnet_v0.rs index 02dc3029b9..9b7de9fcd3 100644 --- a/console/network/src/testnet_v0.rs +++ b/console/network/src/testnet_v0.rs @@ -15,7 +15,6 @@ use super::*; use crate::TRANSACTION_PREFIX; -use snarkvm_algorithms::polycommit::kzg10::DegreeInfo; use snarkvm_console_algorithms::{ BHP256, BHP512, @@ -346,11 +345,6 @@ impl Network for TestnetV0 { .sum() } - /// Returns the Varuna universal prover. - fn varuna_universal_prover(degree_info: DegreeInfo) -> UniversalProver { - MainnetV0::varuna_universal_prover(degree_info) - } - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier { MainnetV0::varuna_universal_verifier() diff --git a/ledger/puzzle/epoch/Cargo.toml b/ledger/puzzle/epoch/Cargo.toml index 5e60ebcd78..7c13d8dcd7 100644 --- a/ledger/puzzle/epoch/Cargo.toml +++ b/ledger/puzzle/epoch/Cargo.toml @@ -29,7 +29,8 @@ serial = [ "snarkvm-console/serial", "snarkvm-ledger-puzzle/serial" ] locktick = [ "dep:locktick", "snarkvm-ledger-puzzle/locktick", - "snarkvm-synthesizer-process/locktick" + "snarkvm-synthesizer-process/locktick", + "snarkvm-synthesizer-program/locktick" ] merkle = [ ] synthesis = [ diff --git a/parameters/examples/inclusion.rs b/parameters/examples/inclusion.rs index 131f35ecf0..2fc22b26b9 100644 --- a/parameters/examples/inclusion.rs +++ b/parameters/examples/inclusion.rs @@ -33,10 +33,14 @@ type LedgerType = snarkvm_ledger_store::helpers::rocksdb::ConsensusDB; use snarkvm_synthesizer::{ VM, process::{InclusionAssignment, InclusionV0Assignment}, - snark::UniversalSRS, + snark::{UniversalProver, UniversalSRS}, }; use anyhow::{Result, anyhow}; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; use rand::thread_rng; use serde_json::{Value, json}; use std::{ @@ -171,17 +175,27 @@ pub fn inclusion>() -> Result<()> { // Load the universal SRS. let universal_srs = UniversalSRS::::load()?; + // Load the universal prover. + let universal_prover = RwLock::new(UniversalProver::::load()?); + // Sample the assignment for the inclusion circuit. let (assignment, state_path, serial_number, is_record_block_height_reached, upgrade_block_height) = sample_assignment::()?; // Synthesize the proving and verifying key. let inclusion_function_name = N::INCLUSION_FUNCTION_NAME; - let (proving_key, verifying_key) = universal_srs.to_circuit_key(inclusion_function_name, &assignment)?; + let (proving_key, verifying_key) = universal_srs.to_circuit_key(&universal_prover, inclusion_function_name, &assignment)?; for varuna_version in [VarunaVersion::V1, VarunaVersion::V2] { // Ensure the proving key and verifying keys are valid. - let proof = proving_key.prove(inclusion_function_name, varuna_version, &assignment, &mut thread_rng())?; + let proof = proving_key.prove( + &universal_srs, + &universal_prover, + inclusion_function_name, + varuna_version, + &assignment, + &mut thread_rng(), + )?; assert!(verifying_key.verify( inclusion_function_name, varuna_version, diff --git a/synthesizer/Cargo.toml b/synthesizer/Cargo.toml index cd1e0812cb..b59ee6eb2d 100644 --- a/synthesizer/Cargo.toml +++ b/synthesizer/Cargo.toml @@ -29,7 +29,9 @@ locktick = [ "snarkvm-ledger-puzzle/locktick", "snarkvm-ledger-puzzle-epoch/locktick", "snarkvm-ledger-store/locktick", - "snarkvm-synthesizer-process/locktick" + "snarkvm-synthesizer-process/locktick", + "snarkvm-synthesizer-program/locktick", + "snarkvm-synthesizer-snark/locktick" ] async = [ "snarkvm-ledger-query/async", "snarkvm-synthesizer-process/async" ] cuda = [ "snarkvm-algorithms/cuda" ] diff --git a/synthesizer/benches/kary_merkle_tree.rs b/synthesizer/benches/kary_merkle_tree.rs index a0ace39661..ff84724e52 100644 --- a/synthesizer/benches/kary_merkle_tree.rs +++ b/synthesizer/benches/kary_merkle_tree.rs @@ -27,9 +27,13 @@ use snarkvm_console::{ }, types::Field, }; -use snarkvm_synthesizer_snark::{ProvingKey, UniversalSRS}; +use snarkvm_synthesizer_snark::{ProvingKey, UniversalProver, UniversalSRS}; use criterion::Criterion; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; type CurrentNetwork = MainnetV0; type CurrentAleo = AleoV0; @@ -125,8 +129,13 @@ fn batch_prove(c: &mut Criterion) { // Load the universal srs. let universal_srs = UniversalSRS::::load().unwrap(); + + // Load the universal prover. + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + // Construct the proving key. - let (proving_key, _) = universal_srs.to_circuit_key("KaryMerklePathVerification", &assignment).unwrap(); + let (proving_key, _) = + universal_srs.to_circuit_key(&universal_prover, "KaryMerklePathVerification", &assignment).unwrap(); // Log the current time elapsed. println!(" • Generated the proving key in: {} ms", timer.elapsed().as_millis()); @@ -149,8 +158,15 @@ fn batch_prove(c: &mut Criterion) { let varuna_version = VarunaVersion::V2; c.bench_function(&format!("KaryMerkleTree batch prove {num_assignments} assignments"), |b| { b.iter(|| { - let _proof = - ProvingKey::prove_batch("ProveKaryMerkleTree", varuna_version, &assignments, &mut rng).unwrap(); + let _proof = ProvingKey::prove_batch( + &universal_srs, + &universal_prover, + "ProveKaryMerkleTree", + varuna_version, + &assignments, + &mut rng, + ) + .unwrap(); }) }); } diff --git a/synthesizer/process/Cargo.toml b/synthesizer/process/Cargo.toml index ab5c5a09b4..4c035d68b7 100644 --- a/synthesizer/process/Cargo.toml +++ b/synthesizer/process/Cargo.toml @@ -26,7 +26,12 @@ edition = "2024" [features] default = [ "indexmap/rayon", "rayon" ] async = [ "snarkvm-ledger-query/async" ] -locktick = [ "dep:locktick", "snarkvm-ledger-store/locktick" ] +locktick = [ + "dep:locktick", + "snarkvm-ledger-store/locktick", + "snarkvm-synthesizer-program/locktick", + "snarkvm-synthesizer-snark/locktick" +] rocks = [ "snarkvm-ledger-store/rocks" ] serial = [ "snarkvm-console/serial", diff --git a/synthesizer/process/src/execute.rs b/synthesizer/process/src/execute.rs index 198cffc2af..a4624efb91 100644 --- a/synthesizer/process/src/execute.rs +++ b/synthesizer/process/src/execute.rs @@ -38,7 +38,7 @@ impl Process { // The root request does not have to pass on another request's root_tvk. let root_tvk = None; // Initialize the trace. - let trace = Arc::new(RwLock::new(Trace::new())); + let trace = Arc::new(RwLock::new(Trace::new(self.universal_srs.clone(), Arc::clone(&self.universal_prover)))); // Initialize the call stack. let call_stack = CallStack::execute(authorization, trace.clone())?; lap!(timer, "Initialize call stack"); diff --git a/synthesizer/process/src/lib.rs b/synthesizer/process/src/lib.rs index 0af7d846d4..9b1d7df668 100644 --- a/synthesizer/process/src/lib.rs +++ b/synthesizer/process/src/lib.rs @@ -76,7 +76,7 @@ use snarkvm_synthesizer_program::{ Program, StackTrait, }; -use snarkvm_synthesizer_snark::{ProvingKey, UniversalSRS, VerifyingKey}; +use snarkvm_synthesizer_snark::{ProvingKey, UniversalProver, UniversalSRS, VerifyingKey}; use snarkvm_utilities::{defer, dev_println}; use aleo_std::prelude::{finish, lap, timer}; @@ -91,6 +91,8 @@ use std::{collections::HashMap, sync::Arc}; pub struct Process { /// The universal SRS. universal_srs: UniversalSRS, + /// The universal prover. + universal_prover: Arc>>, /// The mapping of program IDs to stacks. stacks: Arc, Arc>>>>, /// The mapping of program IDs to old stacks. @@ -104,8 +106,12 @@ impl Process { let timer = timer!("Process:setup"); // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; lap!(timer, "Initialize process"); // Initialize the 'credits.aleo' program. @@ -231,8 +237,12 @@ impl Process { let timer = timer!("Process::load"); // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; lap!(timer, "Initialize process"); // Initialize the 'credits.aleo' program. @@ -271,8 +281,12 @@ impl Process { let timer = timer!("Process::load_v0"); // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; lap!(timer, "Initialize process"); // Initialize the 'credits.aleo' program. @@ -310,8 +324,12 @@ impl Process { #[cfg(feature = "wasm")] pub fn load_web() -> Result { // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; // Initialize the 'credits.aleo' program. let program = Program::credits()?; @@ -392,6 +410,12 @@ impl Process { &self.universal_srs } + /// Returns the universal prover. + #[inline] + pub const fn universal_prover(&self) -> &Arc>> { + &self.universal_prover + } + /// Returns `true` if the process contains the program with the given ID. #[inline] pub fn contains_program(&self, program_id: &ProgramID) -> bool { diff --git a/synthesizer/process/src/stack/deploy.rs b/synthesizer/process/src/stack/deploy.rs index 8e5a9f5d2b..15a29e840e 100644 --- a/synthesizer/process/src/stack/deploy.rs +++ b/synthesizer/process/src/stack/deploy.rs @@ -41,7 +41,13 @@ impl Stack { lap!(timer, "Retrieve the keys for {function_name}"); // Certify the circuit. - let certificate = Certificate::certify(&function_name.to_string(), &proving_key, &verifying_key)?; + let certificate = Certificate::certify( + &self.universal_srs, + &self.universal_prover, + &function_name.to_string(), + &proving_key, + &verifying_key, + )?; lap!(timer, "Certify the circuit"); // Add the verifying key and certificate to the bundle. diff --git a/synthesizer/process/src/stack/helpers/initialize.rs b/synthesizer/process/src/stack/helpers/initialize.rs index ca4486ef54..6364fc953f 100644 --- a/synthesizer/process/src/stack/helpers/initialize.rs +++ b/synthesizer/process/src/stack/helpers/initialize.rs @@ -53,6 +53,7 @@ impl Stack { register_types: Default::default(), finalize_types: Default::default(), universal_srs: process.universal_srs().clone(), + universal_prover: process.universal_prover().clone(), proving_keys: Default::default(), verifying_keys: Default::default(), program_address: program.id().to_address()?, diff --git a/synthesizer/process/src/stack/helpers/synthesize.rs b/synthesizer/process/src/stack/helpers/synthesize.rs index 0475c08647..993f93bd47 100644 --- a/synthesizer/process/src/stack/helpers/synthesize.rs +++ b/synthesizer/process/src/stack/helpers/synthesize.rs @@ -101,7 +101,8 @@ impl Stack { } // Synthesize the proving and verifying key. - let (proving_key, verifying_key) = self.universal_srs.to_circuit_key(&function_name.to_string(), assignment)?; + let (proving_key, verifying_key) = + self.universal_srs.to_circuit_key(&self.universal_prover, &function_name.to_string(), assignment)?; // Insert the proving key. self.insert_proving_key(function_name, proving_key)?; // Insert the verifying key. diff --git a/synthesizer/process/src/stack/mod.rs b/synthesizer/process/src/stack/mod.rs index 8ad0384c1d..ffaea305bd 100644 --- a/synthesizer/process/src/stack/mod.rs +++ b/synthesizer/process/src/stack/mod.rs @@ -79,7 +79,7 @@ use snarkvm_synthesizer_program::{ RegistersTrait, StackTrait, }; -use snarkvm_synthesizer_snark::{Certificate, ProvingKey, UniversalSRS, VerifyingKey}; +use snarkvm_synthesizer_snark::{Certificate, ProvingKey, UniversalProver, UniversalSRS, VerifyingKey}; use aleo_std::prelude::{finish, lap, timer}; use indexmap::IndexMap; @@ -221,6 +221,8 @@ pub struct Stack { finalize_types: Arc, FinalizeTypes>>>, /// The universal SRS. universal_srs: UniversalSRS, + /// The universal prover. + universal_prover: Arc>>, /// The mapping of function name to proving key. proving_keys: Arc, ProvingKey>>>, /// The mapping of function name to verifying key. diff --git a/synthesizer/process/src/tests/test_execute.rs b/synthesizer/process/src/tests/test_execute.rs index ef96c36e5f..59b52058d7 100644 --- a/synthesizer/process/src/tests/test_execute.rs +++ b/synthesizer/process/src/tests/test_execute.rs @@ -32,7 +32,7 @@ use snarkvm_ledger_store::{ helpers::memory::{BlockMemory, FinalizeMemory}, }; use snarkvm_synthesizer_program::{FinalizeGlobalState, FinalizeStoreTrait, Program, StackTrait}; -use snarkvm_synthesizer_snark::UniversalSRS; +use snarkvm_synthesizer_snark::{UniversalProver, UniversalSRS}; use aleo_std::StorageMode; #[cfg(feature = "locktick")] @@ -405,7 +405,7 @@ output r4 as field.private;", assert_eq!(authorization.len(), 1); // Re-run to ensure state continues to work. - let trace = Arc::new(RwLock::new(Trace::new())); + let trace = Arc::new(RwLock::new(Trace::new(process.universal_srs.clone(), process.universal_prover.clone()))); let call_stack = CallStack::execute(authorization, trace).unwrap(); let response = stack.execute_function::(call_stack, None, None, rng).unwrap(); let candidate = response.outputs(); @@ -2486,6 +2486,7 @@ fn test_process_deploy_credits_program() { // Initialize an empty process without the `credits` program. let empty_process = Process { universal_srs: UniversalSRS::::load().unwrap(), + universal_prover: Arc::new(RwLock::new(UniversalProver::::load().unwrap())), stacks: Default::default(), old_stacks: Default::default(), }; diff --git a/synthesizer/process/src/tests/test_serializers.rs b/synthesizer/process/src/tests/test_serializers.rs index ec63e60a32..3f8e1bcf25 100644 --- a/synthesizer/process/src/tests/test_serializers.rs +++ b/synthesizer/process/src/tests/test_serializers.rs @@ -158,7 +158,7 @@ function test_serde_equivalence: let eval_is_ok = res_eval.is_ok(); // Execute the function. - let trace = Trace::new(); + let trace = Trace::new(process.universal_srs.clone(), process.universal_prover.clone()); let res_exec = stack.execute_function::( CallStack::execute(authorization.replicate(), Arc::new(RwLock::new(trace))).unwrap(), None, diff --git a/synthesizer/process/src/trace/mod.rs b/synthesizer/process/src/trace/mod.rs index e37caa3ecd..66ab11cdf9 100644 --- a/synthesizer/process/src/trace/mod.rs +++ b/synthesizer/process/src/trace/mod.rs @@ -27,9 +27,16 @@ use console::{ use snarkvm_algorithms::snark::varuna::VarunaVersion; use snarkvm_ledger_block::{Execution, Fee, Transition}; use snarkvm_ledger_query::QueryTrait; -use snarkvm_synthesizer_snark::{Proof, ProvingKey, VerifyingKey}; - -use std::{collections::HashMap, sync::OnceLock}; +use snarkvm_synthesizer_snark::{Proof, ProvingKey, UniversalProver, UniversalSRS, VerifyingKey}; + +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; +use std::{ + collections::HashMap, + sync::{Arc, OnceLock}, +}; use crate::Authorization; @@ -48,11 +55,16 @@ pub struct Trace { inclusion_assignments: OnceLock>>, /// A tracker for the global state root. global_state_root: OnceLock, + + /// The Universal SRS + srs: UniversalSRS, + /// The Universal Prover + universal_prover: Arc>>, } impl Trace { /// Initializes a new trace. - pub fn new() -> Self { + pub fn new(srs: UniversalSRS, universal_prover: Arc>>) -> Self { Self { transitions: Vec::new(), transition_tasks: HashMap::new(), @@ -60,6 +72,8 @@ impl Trace { inclusion_assignments: OnceLock::new(), global_state_root: OnceLock::new(), call_metrics: Vec::new(), + srs, + universal_prover, } } @@ -179,6 +193,8 @@ impl Trace { let proving_tasks = self.transition_tasks.values().cloned().collect(); // Compute the proof. let (global_state_root, proof) = Self::prove_batch::( + &self.srs, + &self.universal_prover, locator, varuna_version, proving_tasks, @@ -217,6 +233,8 @@ impl Trace { let proving_tasks = self.transition_tasks.values().cloned().collect(); // Compute the proof. let (global_state_root, proof) = Self::prove_batch::( + &self.srs, + &self.universal_prover, "credits.aleo/fee (private or public)", varuna_version, proving_tasks, @@ -301,6 +319,8 @@ impl Trace { impl Trace { /// Returns the global state root and proof for the given assignments. fn prove_batch, R: Rng + CryptoRng>( + srs: &UniversalSRS, + universal_prover: &RwLock>, locator: &str, varuna_version: VarunaVersion, mut proving_tasks: Vec<(ProvingKey, Vec>)>, @@ -369,7 +389,7 @@ impl Trace { } // Compute the proof. - let proof = ProvingKey::prove_batch(locator, varuna_version, &proving_tasks, rng)?; + let proof = ProvingKey::prove_batch(srs, universal_prover, locator, varuna_version, &proving_tasks, rng)?; // Return the global state root and proof. Ok((global_state_root, proof)) } diff --git a/synthesizer/program/Cargo.toml b/synthesizer/program/Cargo.toml index eab2c8106e..4a384e6401 100644 --- a/synthesizer/program/Cargo.toml +++ b/synthesizer/program/Cargo.toml @@ -25,6 +25,10 @@ edition = "2024" [features] default = [ ] +locktick = [ + "snarkvm-synthesizer-process/locktick", + "snarkvm-synthesizer-snark/locktick", +] serial = [ "snarkvm-console/serial" ] wasm = [ "snarkvm-console/wasm" ] diff --git a/synthesizer/snark/Cargo.toml b/synthesizer/snark/Cargo.toml index 75503780a9..06baaab15e 100644 --- a/synthesizer/snark/Cargo.toml +++ b/synthesizer/snark/Cargo.toml @@ -27,6 +27,7 @@ edition = "2024" default = [ ] cuda = [ "snarkvm-algorithms/cuda" ] dev-print = [ "snarkvm-utilities/dev-print" ] +locktick = [ "dep:locktick" ] serial = [ "snarkvm-console/serial", "snarkvm-algorithms/serial" ] wasm = [ "snarkvm-console/wasm" ] @@ -47,6 +48,13 @@ workspace = true [dependencies.bincode] workspace = true +[dependencies.locktick] +workspace = true +optional = true + +[dependencies.parking_lot] +workspace = true + [dependencies.serde_json] workspace = true features = [ "preserve_order" ] diff --git a/synthesizer/snark/src/certificate/mod.rs b/synthesizer/snark/src/certificate/mod.rs index 0c2675085f..fe86b9deaa 100644 --- a/synthesizer/snark/src/certificate/mod.rs +++ b/synthesizer/snark/src/certificate/mod.rs @@ -19,6 +19,11 @@ mod bytes; mod parse; mod serialize; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; + #[derive(Clone, PartialEq, Eq)] pub struct Certificate { /// The certificate. @@ -33,6 +38,8 @@ impl Certificate { /// Returns the certificate from the proving and verifying key. pub fn certify( + srs: &UniversalSRS, + universal_prover: &RwLock>, _function_name: &str, proving_key: &ProvingKey, verifying_key: &VerifyingKey, @@ -42,11 +49,12 @@ impl Certificate { // Retrieve the proving parameters. let degree_info = proving_key.circuit.index_info.degree_info::()?; - let universal_prover = N::varuna_universal_prover(degree_info); let fiat_shamir = N::varuna_fs_parameters(); + universal_prover.write().update(srs, degree_info)?; + // Compute the certificate. - let certificate = Varuna::::prove_vk(&universal_prover, fiat_shamir, verifying_key, proving_key)?; + let certificate = Varuna::::prove_vk(&universal_prover.read(), fiat_shamir, verifying_key, proving_key)?; #[cfg(feature = "dev-print")] { diff --git a/synthesizer/snark/src/lib.rs b/synthesizer/snark/src/lib.rs index 1953a26624..36c4eaee54 100644 --- a/synthesizer/snark/src/lib.rs +++ b/synthesizer/snark/src/lib.rs @@ -38,7 +38,7 @@ mod proving_key; pub use proving_key::ProvingKey; mod universal_srs; -pub use universal_srs::UniversalSRS; +pub use universal_srs::{UniversalProver, UniversalSRS}; mod verifying_key; pub use verifying_key::VerifyingKey; @@ -46,6 +46,7 @@ pub use verifying_key::VerifyingKey; #[cfg(test)] pub(crate) mod test_helpers { use super::*; + use crate::universal_srs::UniversalProver; use circuit::{ environment::{Assignment, Circuit, Eject, Environment, Inject, Mode, One}, types::Field, @@ -53,6 +54,10 @@ pub(crate) mod test_helpers { use console::{network::MainnetV0, prelude::One as _}; use snarkvm_algorithms::snark::varuna::VarunaVersion; + #[cfg(feature = "locktick")] + use locktick::parking_lot::RwLock; + #[cfg(not(feature = "locktick"))] + use parking_lot::RwLock; use std::sync::OnceLock; type CurrentNetwork = MainnetV0; @@ -104,8 +109,8 @@ pub(crate) mod test_helpers { .get_or_init(|| { let assignment = sample_assignment(); let srs = UniversalSRS::load().unwrap(); - let (proving_key, verifying_key) = srs.to_circuit_key("test", &assignment).unwrap(); - (proving_key, verifying_key) + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + srs.to_circuit_key(&universal_prover, "test", &assignment).unwrap() }) .clone() } @@ -117,7 +122,12 @@ pub(crate) mod test_helpers { .get_or_init(|| { let assignment = sample_assignment(); let (proving_key, _) = sample_keys(); - proving_key.prove("test", VarunaVersion::V2, &assignment, &mut TestRng::default()).unwrap() + let srs = UniversalSRS::::load().unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + + proving_key + .prove(&srs, &universal_prover, "test", VarunaVersion::V2, &assignment, &mut TestRng::default()) + .unwrap() }) .clone() } @@ -128,8 +138,10 @@ pub(crate) mod test_helpers { INSTANCE .get_or_init(|| { let (proving_key, verifying_key) = sample_keys(); + let srs = UniversalSRS::::load().unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); // Return the certificate. - Certificate::certify("test", &proving_key, &verifying_key).unwrap() + Certificate::certify(&srs, &universal_prover, "test", &proving_key, &verifying_key).unwrap() }) .clone() } @@ -138,10 +150,16 @@ pub(crate) mod test_helpers { #[cfg(test)] mod test { use super::*; + use crate::universal_srs::UniversalProver; use circuit::environment::{Circuit, Environment}; use console::network::MainnetV0; use snarkvm_algorithms::snark::varuna::VarunaVersion; + #[cfg(feature = "locktick")] + use locktick::parking_lot::RwLock; + #[cfg(not(feature = "locktick"))] + use parking_lot::RwLock; + type CurrentNetwork = MainnetV0; #[test] @@ -150,11 +168,14 @@ mod test { // Varuna setup, prove, and verify. let srs = UniversalSRS::::load().unwrap(); - let (proving_key, verifying_key) = srs.to_circuit_key("test", &assignment).unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + let (proving_key, verifying_key) = srs.to_circuit_key(&universal_prover, "test", &assignment).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - let proof = proving_key.prove("test", varuna_version, &assignment, &mut TestRng::default()).unwrap(); + let proof = proving_key + .prove(&srs, &universal_prover, "test", varuna_version, &assignment, &mut TestRng::default()) + .unwrap(); println!("Called prover"); let one = ::BaseField::one(); @@ -194,11 +215,14 @@ mod test { assert_eq!(assignment.num_private(), 3); let srs = UniversalSRS::::load().unwrap(); - let (proving_key, verifying_key) = srs.to_circuit_key("test", &assignment).unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + let (proving_key, verifying_key) = srs.to_circuit_key(&universal_prover, "test", &assignment).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - let proof = proving_key.prove("test", varuna_version, &assignment, &mut TestRng::default()).unwrap(); + let proof = proving_key + .prove(&srs, &universal_prover, "test", varuna_version, &assignment, &mut TestRng::default()) + .unwrap(); println!("Called prover"); // Should pass. diff --git a/synthesizer/snark/src/proving_key/mod.rs b/synthesizer/snark/src/proving_key/mod.rs index 5ece780a8f..4f0a27c730 100644 --- a/synthesizer/snark/src/proving_key/mod.rs +++ b/synthesizer/snark/src/proving_key/mod.rs @@ -19,6 +19,10 @@ mod bytes; mod parse; mod serialize; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; use std::collections::BTreeMap; #[derive(Clone)] @@ -36,6 +40,8 @@ impl ProvingKey { /// Returns a proof for the given assignment on the circuit. pub fn prove( &self, + srs: &UniversalSRS, + universal_prover: &RwLock>, _function_name: &str, varuna_version: varuna::VarunaVersion, assignment: &circuit::Assignment, @@ -46,12 +52,20 @@ impl ProvingKey { // Retrieve the proving parameters. let degree_info = self.circuit.index_info.degree_info::()?; - let universal_prover = N::varuna_universal_prover(degree_info); let fiat_shamir = N::varuna_fs_parameters(); + // Update the universal_prover and prevent other threads from writing to it. + universal_prover.write().update(srs, degree_info)?; + // Compute the proof. - let proof = - Proof::new(Varuna::::prove(&universal_prover, fiat_shamir, self, varuna_version, assignment, rng)?); + let proof = Proof::new(Varuna::::prove( + &universal_prover.read(), + fiat_shamir, + self, + varuna_version, + assignment, + rng, + )?); #[cfg(feature = "dev-print")] { @@ -65,6 +79,8 @@ impl ProvingKey { /// Returns a proof for the given batch of proving keys and assignments. #[allow(clippy::type_complexity)] pub fn prove_batch( + srs: &UniversalSRS, + universal_prover: &RwLock>, _locator: &str, varuna_version: varuna::VarunaVersion, assignments: &[(ProvingKey, Vec>)], @@ -93,12 +109,19 @@ impl ProvingKey { None => Some(degree_info_i), }; } - let universal_prover = N::varuna_universal_prover(degree_info.unwrap()); let fiat_shamir = N::varuna_fs_parameters(); + // Update the universal_prover and prevent other threads from writing to it + universal_prover.write().update(srs, degree_info.unwrap())?; + // Compute the proof. - let batch_proof = - Proof::new(Varuna::::prove_batch(&universal_prover, fiat_shamir, varuna_version, &instances, rng)?); + let batch_proof = Proof::new(Varuna::::prove_batch( + &universal_prover.read(), + fiat_shamir, + varuna_version, + &instances, + rng, + )?); #[cfg(feature = "dev-print")] { diff --git a/synthesizer/snark/src/universal_srs.rs b/synthesizer/snark/src/universal_srs.rs index ae2c8d153f..8c9f2aea83 100644 --- a/synthesizer/snark/src/universal_srs.rs +++ b/synthesizer/snark/src/universal_srs.rs @@ -15,12 +15,23 @@ use super::*; -#[derive(Clone)] +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; + +#[derive(Clone, Debug)] pub struct UniversalSRS { - /// The universal SRS parameter. + /// The universal SRS parameters. srs: Arc>>, } +impl Default for UniversalSRS { + fn default() -> Self { + Self { srs: Arc::new(OnceLock::new()) } + } +} + impl UniversalSRS { /// Initializes the universal SRS. pub fn load() -> Result { @@ -30,13 +41,16 @@ impl UniversalSRS { /// Returns the circuit proving and verifying key. pub fn to_circuit_key( &self, + universal_prover: &RwLock>, _function_name: &str, assignment: &circuit::Assignment, ) -> Result<(ProvingKey, VerifyingKey)> { #[cfg(feature = "dev-print")] let timer = std::time::Instant::now(); - let (proving_key, verifying_key) = Varuna::::circuit_setup(self, assignment)?; + // Locks the universal_prover. + let (proving_key, verifying_key) = + Varuna::::circuit_setup(self.deref(), &mut universal_prover.write(), assignment)?; #[cfg(feature = "dev-print")] { @@ -89,3 +103,36 @@ impl Deref for UniversalSRS { }) } } + +#[derive(Debug)] +pub struct UniversalProver { + /// The universal prover - trimmed from the SRS + universal_prover: varuna::UniversalProver, +} + +impl Default for UniversalProver { + fn default() -> Self { + Self { universal_prover: varuna::UniversalProver::default() } + } +} + +impl UniversalProver { + /// Initializes the universal Prover. + pub fn load() -> Result { + Ok(Self { universal_prover: varuna::UniversalProver::default() }) + } +} + +impl DerefMut for UniversalProver { + fn deref_mut(&mut self) -> &mut varuna::UniversalProver { + &mut self.universal_prover + } +} + +impl Deref for UniversalProver { + type Target = varuna::UniversalProver; + + fn deref(&self) -> &Self::Target { + &self.universal_prover + } +} From 750e041dad38f8f34ebb394dceff09b7b515999a Mon Sep 17 00:00:00 2001 From: ljedrz Date: Wed, 28 Jan 2026 15:50:33 +0100 Subject: [PATCH 3/5] chore: adjust the lockfile Signed-off-by: ljedrz --- Cargo.lock | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 6c91398ef4..0bdef9aeba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3753,6 +3753,8 @@ name = "snarkvm-synthesizer-snark" version = "4.4.0" dependencies = [ "bincode", + "locktick", + "parking_lot", "serde_json", "snarkvm-algorithms", "snarkvm-circuit", From 793bb968fb40ae410940148598288acc662035c4 Mon Sep 17 00:00:00 2001 From: ljedrz Date: Thu, 29 Jan 2026 10:26:09 +0100 Subject: [PATCH 4/5] fix: make CommitterKey & CircuitProvingKey serialization backward compatible Signed-off-by: ljedrz --- .../polycommit/sonic_pc/data_structures.rs | 8 +++--- .../data_structures/circuit_proving_key.rs | 25 ++++++++++++++++--- algorithms/src/snark/varuna/varuna.rs | 1 + 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/algorithms/src/polycommit/sonic_pc/data_structures.rs b/algorithms/src/polycommit/sonic_pc/data_structures.rs index a19be0b9be..703873b744 100644 --- a/algorithms/src/polycommit/sonic_pc/data_structures.rs +++ b/algorithms/src/polycommit/sonic_pc/data_structures.rs @@ -91,9 +91,9 @@ impl FromBytes for CommitterKey { } // Deserialize `lagrange_basis_at_beta`. - let has_lagrange_bases: bool = FromBytes::read_le(&mut reader)?; + let has_lagrange_bases = bool::read_le(&mut reader); let lagrange_bases_at_beta_g = match has_lagrange_bases { - true => { + Ok(true) => { let lagrange_bases_at_beta_len: u32 = FromBytes::read_le(&mut reader)?; let mut lagrange_bases_at_beta_g = BTreeMap::new(); for _ in 0..lagrange_bases_at_beta_len { @@ -107,7 +107,9 @@ impl FromBytes for CommitterKey { } Some(lagrange_bases_at_beta_g) } - false => None, + Ok(false) => None, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None, + Err(e) => return Err(e), }; // Deserialize `powers_of_beta_times_gamma_g`. diff --git a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs index ffc8cbe51c..33c4e37814 100644 --- a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs +++ b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs @@ -13,7 +13,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::snark::varuna::{CircuitVerifyingKey, SNARKMode, ahp::indexer::*}; +use crate::{ + polycommit::sonic_pc, + snark::varuna::{CircuitVerifyingKey, SNARKMode, ahp::indexer::*}, +}; use snarkvm_curves::PairingEngine; use snarkvm_utilities::{FromBytes, ToBytes, serialize::*}; use std::io::{self, Read, Write}; @@ -28,12 +31,22 @@ pub struct CircuitProvingKey { // NOTE: The circuit verifying key's circuit_info and circuit id are also stored in Circuit for convenience. /// The circuit itself. pub circuit: Arc>, + /// The committer key for this index, trimmed from the universal SRS. + /// Note: no longer used, kept for backward compatibility. + pub committer_key: Option>>, } impl ToBytes for CircuitProvingKey { fn write_le(&self, mut writer: W) -> io::Result<()> { CanonicalSerialize::serialize_compressed(&self.circuit_verifying_key, &mut writer)?; - Ok(CanonicalSerialize::serialize_compressed(&self.circuit, &mut writer)?) + CanonicalSerialize::serialize_compressed(&self.circuit, &mut writer)?; + + self.committer_key.is_some().write_le(&mut writer)?; + if let Some(key) = &self.committer_key { + key.write_le(&mut writer)?; + } + + Ok(()) } } @@ -42,8 +55,14 @@ impl FromBytes for CircuitProvingKey { fn read_le(mut reader: R) -> io::Result { let circuit_verifying_key = CanonicalDeserialize::deserialize_compressed(&mut reader)?; let circuit = CanonicalDeserialize::deserialize_compressed(&mut reader)?; + let committer_key = match bool::read_le(&mut reader) { + Ok(true) => Some(Arc::new(FromBytes::read_le(&mut reader)?)), + Ok(false) => None, + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None, + Err(e) => return Err(e), + }; - Ok(Self { circuit_verifying_key, circuit }) + Ok(Self { circuit_verifying_key, circuit, committer_key }) } } diff --git a/algorithms/src/snark/varuna/varuna.rs b/algorithms/src/snark/varuna/varuna.rs index 9519691107..3ba7790c48 100644 --- a/algorithms/src/snark/varuna/varuna.rs +++ b/algorithms/src/snark/varuna/varuna.rs @@ -114,6 +114,7 @@ impl, SM: SNARKMode> VarunaSNARK let circuit_proving_key = CircuitProvingKey { circuit_verifying_key: circuit_verifying_key.clone(), circuit: Arc::new(indexed_circuit), + committer_key: None, }; circuit_keys.push((circuit_proving_key, circuit_verifying_key)); } From 8083480322b888fd91523a655c7608859016fda8 Mon Sep 17 00:00:00 2001 From: ljedrz Date: Thu, 29 Jan 2026 13:16:24 +0100 Subject: [PATCH 5/5] make CommitterKey & CircuitProvingKey objects backward compatible Signed-off-by: ljedrz --- .../polycommit/sonic_pc/data_structures.rs | 77 +++++++++---------- .../data_structures/circuit_proving_key.rs | 16 +--- algorithms/src/snark/varuna/varuna.rs | 2 +- algorithms/src/srs/universal_prover.rs | 2 +- 4 files changed, 43 insertions(+), 54 deletions(-) diff --git a/algorithms/src/polycommit/sonic_pc/data_structures.rs b/algorithms/src/polycommit/sonic_pc/data_structures.rs index 703873b744..96c31a5d0c 100644 --- a/algorithms/src/polycommit/sonic_pc/data_structures.rs +++ b/algorithms/src/polycommit/sonic_pc/data_structures.rs @@ -34,6 +34,7 @@ use std::{ collections::{BTreeMap, BTreeSet}, fmt, io, + mem, ops::{AddAssign, MulAssign, SubAssign}, }; @@ -54,8 +55,8 @@ pub struct CommitterKey { pub powers_of_beta_g: Vec, /// The key used to commit to polynomials in Lagrange basis. - /// This is `None` if `self` does not support lagrange bases - pub lagrange_bases_at_beta_g: Option>>, + /// This is empty if `self` does not support lagrange bases + pub lagrange_bases_at_beta_g: BTreeMap>, /// The key used to commit to hiding polynomials. pub powers_of_beta_times_gamma_g: Vec, @@ -74,6 +75,19 @@ pub struct CommitterKey { pub enforced_degree_bounds: Option>, } +impl Default for CommitterKey { + fn default() -> Self { + Self { + powers_of_beta_g: Vec::new(), + lagrange_bases_at_beta_g: Default::default(), + powers_of_beta_times_gamma_g: Vec::new(), + shifted_powers_of_beta_g: None, + shifted_powers_of_beta_times_gamma_g: None, + enforced_degree_bounds: None, + } + } +} + impl FromBytes for CommitterKey { fn read_le(mut reader: R) -> io::Result { // Deserialize `powers`. @@ -91,26 +105,17 @@ impl FromBytes for CommitterKey { } // Deserialize `lagrange_basis_at_beta`. - let has_lagrange_bases = bool::read_le(&mut reader); - let lagrange_bases_at_beta_g = match has_lagrange_bases { - Ok(true) => { - let lagrange_bases_at_beta_len: u32 = FromBytes::read_le(&mut reader)?; - let mut lagrange_bases_at_beta_g = BTreeMap::new(); - for _ in 0..lagrange_bases_at_beta_len { - let size: u32 = FromBytes::read_le(&mut reader)?; - let mut basis = Vec::with_capacity(size as usize); - for _ in 0..size { - let power: E::G1Affine = FromBytes::read_le(&mut reader)?; - basis.push(power); - } - lagrange_bases_at_beta_g.insert(size as usize, basis); - } - Some(lagrange_bases_at_beta_g) + let lagrange_bases_at_beta_len: u32 = FromBytes::read_le(&mut reader)?; + let mut lagrange_bases_at_beta_g = BTreeMap::new(); + for _ in 0..lagrange_bases_at_beta_len { + let size: u32 = FromBytes::read_le(&mut reader)?; + let mut basis = Vec::with_capacity(size as usize); + for _ in 0..size { + let power: E::G1Affine = FromBytes::read_le(&mut reader)?; + basis.push(power); } - Ok(false) => None, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None, - Err(e) => return Err(e), - }; + lagrange_bases_at_beta_g.insert(size as usize, basis); + } // Deserialize `powers_of_beta_times_gamma_g`. let powers_of_beta_times_gamma_g_len: u32 = FromBytes::read_le(&mut reader)?; @@ -223,14 +228,11 @@ impl ToBytes for CommitterKey { } // Serialize `lagrange_bases_at_beta_g`. - self.lagrange_bases_at_beta_g.is_some().write_le(&mut writer)?; - if let Some(lagrange_bases_at_beta_g) = &self.lagrange_bases_at_beta_g { - (lagrange_bases_at_beta_g.len() as u32).write_le(&mut writer)?; - for (size, powers) in lagrange_bases_at_beta_g.iter() { - (*size as u32).write_le(&mut writer)?; - for power in powers { - power.write_le(&mut writer)?; - } + (self.lagrange_bases_at_beta_g.len() as u32).write_le(&mut writer)?; + for (size, powers) in &self.lagrange_bases_at_beta_g { + (*size as u32).write_le(&mut writer)?; + for power in powers { + power.write_le(&mut writer)?; } } @@ -425,7 +427,7 @@ impl CommitterKey { // Update lagrange_bases_at_beta_g let mut lagrange_bases_at_beta_g = BTreeMap::new(); if let Some(supported_lagrange_sizes) = degree_info.lagrange_sizes.as_ref() { - let mut old_lagrange_bases_at_beta_g = self.lagrange_bases_at_beta_g.take().unwrap_or_default(); + let mut old_lagrange_bases_at_beta_g = mem::take(&mut self.lagrange_bases_at_beta_g); for size in supported_lagrange_sizes { let domain = EvaluationDomain::new(*size).unwrap(); match old_lagrange_bases_at_beta_g.remove(&domain.size()) { @@ -450,7 +452,7 @@ impl CommitterKey { } } } - self.lagrange_bases_at_beta_g = Some(lagrange_bases_at_beta_g); + self.lagrange_bases_at_beta_g = lagrange_bases_at_beta_g; } end_timer!(trim_time); Ok(()) @@ -491,14 +493,11 @@ impl CommitterKey { /// Obtain elements of the SRS in the lagrange basis powers, for use with /// the underlying KZG10 construction. pub fn lagrange_basis(&self, domain: EvaluationDomain) -> Option> { - if let Some(lagrange_bases) = self.lagrange_bases_at_beta_g.as_ref() { - return lagrange_bases.get(&domain.size()).map(|basis| kzg10::LagrangeBasis { - lagrange_basis_at_beta_g: Cow::Borrowed(basis), - powers_of_beta_times_gamma_g: Cow::Borrowed(&self.powers_of_beta_times_gamma_g), - domain, - }); - } - None + self.lagrange_bases_at_beta_g.get(&domain.size()).map(|basis| kzg10::LagrangeBasis { + lagrange_basis_at_beta_g: Cow::Borrowed(basis), + powers_of_beta_times_gamma_g: Cow::Borrowed(&self.powers_of_beta_times_gamma_g), + domain, + }) } } diff --git a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs index 33c4e37814..d19bc77924 100644 --- a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs +++ b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs @@ -33,7 +33,7 @@ pub struct CircuitProvingKey { pub circuit: Arc>, /// The committer key for this index, trimmed from the universal SRS. /// Note: no longer used, kept for backward compatibility. - pub committer_key: Option>>, + pub committer_key: Arc>, } impl ToBytes for CircuitProvingKey { @@ -41,12 +41,7 @@ impl ToBytes for CircuitProvingKey { CanonicalSerialize::serialize_compressed(&self.circuit_verifying_key, &mut writer)?; CanonicalSerialize::serialize_compressed(&self.circuit, &mut writer)?; - self.committer_key.is_some().write_le(&mut writer)?; - if let Some(key) = &self.committer_key { - key.write_le(&mut writer)?; - } - - Ok(()) + self.committer_key.write_le(&mut writer) } } @@ -55,12 +50,7 @@ impl FromBytes for CircuitProvingKey { fn read_le(mut reader: R) -> io::Result { let circuit_verifying_key = CanonicalDeserialize::deserialize_compressed(&mut reader)?; let circuit = CanonicalDeserialize::deserialize_compressed(&mut reader)?; - let committer_key = match bool::read_le(&mut reader) { - Ok(true) => Some(Arc::new(FromBytes::read_le(&mut reader)?)), - Ok(false) => None, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => None, - Err(e) => return Err(e), - }; + let committer_key = Arc::new(FromBytes::read_le(&mut reader)?); Ok(Self { circuit_verifying_key, circuit, committer_key }) } diff --git a/algorithms/src/snark/varuna/varuna.rs b/algorithms/src/snark/varuna/varuna.rs index 3ba7790c48..e90a383d25 100644 --- a/algorithms/src/snark/varuna/varuna.rs +++ b/algorithms/src/snark/varuna/varuna.rs @@ -114,7 +114,7 @@ impl, SM: SNARKMode> VarunaSNARK let circuit_proving_key = CircuitProvingKey { circuit_verifying_key: circuit_verifying_key.clone(), circuit: Arc::new(indexed_circuit), - committer_key: None, + committer_key: Default::default(), }; circuit_keys.push((circuit_proving_key, circuit_verifying_key)); } diff --git a/algorithms/src/srs/universal_prover.rs b/algorithms/src/srs/universal_prover.rs index 80c4de7ac7..922fb17801 100644 --- a/algorithms/src/srs/universal_prover.rs +++ b/algorithms/src/srs/universal_prover.rs @@ -43,7 +43,7 @@ impl Default for UniversalProver { fn default() -> Self { let committer_key = CommitterKey { powers_of_beta_g: Vec::new(), - lagrange_bases_at_beta_g: None, + lagrange_bases_at_beta_g: Default::default(), powers_of_beta_times_gamma_g: Vec::new(), shifted_powers_of_beta_g: None, shifted_powers_of_beta_times_gamma_g: None,