diff --git a/Cargo.lock b/Cargo.lock index 6c91398ef4..0bdef9aeba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3753,6 +3753,8 @@ name = "snarkvm-synthesizer-snark" version = "4.4.0" dependencies = [ "bincode", + "locktick", + "parking_lot", "serde_json", "snarkvm-algorithms", "snarkvm-circuit", diff --git a/algorithms/benches/snark/varuna.rs b/algorithms/benches/snark/varuna.rs index 8ccba91ff7..2b7a20fdda 100644 --- a/algorithms/benches/snark/varuna.rs +++ b/algorithms/benches/snark/varuna.rs @@ -20,12 +20,21 @@ use snarkvm_algorithms::{ AlgebraicSponge, SNARK, crypto_hash::PoseidonSponge, - snark::varuna::{CircuitVerifyingKey, TestCircuit, VarunaHidingMode, VarunaSNARK, VarunaVersion, ahp::AHPForR1CS}, + snark::varuna::{ + CircuitVerifyingKey, + TestCircuit, + UniversalProver, + VarunaHidingMode, + VarunaSNARK, + VarunaVersion, + ahp::AHPForR1CS, + }, }; use snarkvm_curves::bls12_377::{Bls12_377, Fq, Fr}; use snarkvm_utilities::{CanonicalDeserialize, CanonicalSerialize, TestRng}; use criterion::Criterion; +use itertools::Itertools; use std::{collections::BTreeMap, time::Duration}; type VarunaInst = VarunaSNARK; @@ -53,7 +62,8 @@ fn snark_circuit_setup(c: &mut Criterion) { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); c.bench_function(&format!("snark_circuit_setup_{size}"), |b| { - b.iter(|| VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap()) + let mut universal_prover = UniversalProver::default(); + b.iter(|| VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap()) }); } } @@ -67,25 +77,21 @@ fn snark_prove(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(1000, 1000, 1000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let fs_parameters = FS::sample_parameters(); let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let params = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, _) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); c.bench_function("snark_prove_v1", |b| { let varuna_version = VarunaVersion::V1; - b.iter(|| { - VarunaInst::prove(universal_prover, &fs_parameters, ¶ms.0, varuna_version, &circuit, rng).unwrap() - }) + b.iter(|| VarunaInst::prove(&universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) }); c.bench_function("snark_prove_v2", |b| { let varuna_version = VarunaVersion::V2; - b.iter(|| { - VarunaInst::prove(universal_prover, &fs_parameters, ¶ms.0, varuna_version, &circuit, rng).unwrap() - }) + b.iter(|| VarunaInst::prove(&universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap()) }); } @@ -99,16 +105,16 @@ fn snark_batch_prove(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(1000000, 1000000, 1000000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let fs_parameters = FS::sample_parameters(); let circuit_batch_size = 5; let instance_batch_size = 5; - let mut pks = Vec::with_capacity(circuit_batch_size); let mut all_circuits = Vec::with_capacity(circuit_batch_size); let mut keys_to_constraints = BTreeMap::new(); + let mut universal_prover = UniversalProver::default(); + for i in 0..circuit_batch_size { let num_constraints = num_constraints_base + i; let num_variables = num_variables_base + i; @@ -119,18 +125,20 @@ fn snark_batch_prove(c: &mut Criterion) { let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); circuits.push(circuit); } - let (pk, _) = VarunaInst::circuit_setup(&universal_srs, &circuits[0]).unwrap(); - pks.push(pk); all_circuits.push(circuits); } + let unique_circuits = all_circuits.iter().map(|c| &c[0]).collect_vec(); + let circuit_keys = + VarunaInst::batch_circuit_setup(&universal_srs, &mut universal_prover, unique_circuits.as_slice()).unwrap(); + for i in 0..circuit_batch_size { - keys_to_constraints.insert(&pks[i], all_circuits[i].as_slice()); + keys_to_constraints.insert(&circuit_keys[i].0, all_circuits[i].as_slice()); } let varuna_version = VarunaVersion::V2; b.iter(|| { - VarunaInst::prove_batch(universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) + VarunaInst::prove_batch(&universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) .unwrap() }) }); @@ -146,16 +154,16 @@ fn snark_verify(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(100, 100, 100).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); let (circuit, public_inputs) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); let varuna_version = VarunaVersion::V2; - let proof = VarunaInst::prove(universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap(); + let proof = VarunaInst::prove(&universal_prover, &fs_parameters, &pk, varuna_version, &circuit, rng).unwrap(); b.iter(|| { let verification = VarunaInst::verify( universal_verifier, @@ -180,15 +188,13 @@ fn snark_batch_verify(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(1000, 1000, 100).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); + let mut universal_prover = UniversalProver::default(); let fs_parameters = FS::sample_parameters(); let circuit_batch_size = 5; let instance_batch_size = 5; - let mut pks = Vec::with_capacity(circuit_batch_size); - let mut vks = Vec::with_capacity(circuit_batch_size); let mut all_circuits = Vec::with_capacity(circuit_batch_size); let mut all_inputs = Vec::with_capacity(circuit_batch_size); let mut keys_to_constraints = BTreeMap::new(); @@ -204,21 +210,22 @@ fn snark_batch_verify(c: &mut Criterion) { circuits.push(circuit); inputs.push(public_inputs); } - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuits[0]).unwrap(); - pks.push(pk); - vks.push(vk); all_circuits.push(circuits); all_inputs.push(inputs); } + let unique_circuits = all_circuits.iter().map(|c| &c[0]).collect_vec(); + let circuit_keys = + VarunaInst::batch_circuit_setup(&universal_srs, &mut universal_prover, unique_circuits.as_slice()).unwrap(); + for i in 0..circuit_batch_size { - keys_to_constraints.insert(&pks[i], all_circuits[i].as_slice()); - keys_to_inputs.insert(&vks[i], all_inputs[i].as_slice()); + keys_to_constraints.insert(&circuit_keys[i].0, all_circuits[i].as_slice()); + keys_to_inputs.insert(&circuit_keys[i].1, all_inputs[i].as_slice()); } let varuna_version = VarunaVersion::V2; let proof = - VarunaInst::prove_batch(universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) + VarunaInst::prove_batch(&universal_prover, &fs_parameters, varuna_version, &keys_to_constraints, rng) .unwrap(); b.iter(|| { let verification = @@ -246,7 +253,8 @@ fn snark_vk_serialize(c: &mut Criterion) { let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); let mut bytes = Vec::with_capacity(10000); group.bench_function(name, |b| { b.iter(|| { @@ -281,7 +289,8 @@ fn snark_vk_deserialize(c: &mut Criterion) { let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); let mut bytes = Vec::with_capacity(10000); vk.serialize_with_mode(&mut bytes, compress).unwrap(); group.bench_function(name, |b| { @@ -300,7 +309,6 @@ fn snark_certificate_prove(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(100000, 100000, 100000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let fs_parameters = FS::sample_parameters(); let fs_p = &fs_parameters; @@ -309,10 +317,11 @@ fn snark_certificate_prove(c: &mut Criterion) { let num_variables = size; let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); c.bench_function(&format!("snark_certificate_prove_{size}"), |b| { - b.iter(|| VarunaInst::prove_vk(universal_prover, fs_p, &vk, &pk).unwrap()) + b.iter(|| VarunaInst::prove_vk(&universal_prover, fs_p, &vk, &pk).unwrap()) }); } } @@ -322,7 +331,6 @@ fn snark_certificate_verify(c: &mut Criterion) { let max_degree = AHPForR1CS::::max_degree(100_000, 100_000, 100_000).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); let fs_p = &fs_parameters; @@ -332,8 +340,9 @@ fn snark_certificate_verify(c: &mut Criterion) { let num_variables = size; let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); - let certificate = VarunaInst::prove_vk(universal_prover, fs_p, &vk, &pk).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk, vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); + let certificate = VarunaInst::prove_vk(&universal_prover, fs_p, &vk, &pk).unwrap(); c.bench_function(&format!("snark_certificate_verify_{size}"), |b| { b.iter(|| VarunaInst::verify_vk(universal_verifier, fs_p, &circuit, &vk, &certificate).unwrap()) diff --git a/algorithms/src/errors.rs b/algorithms/src/errors.rs index edbf8faa31..f5d9d7acaf 100644 --- a/algorithms/src/errors.rs +++ b/algorithms/src/errors.rs @@ -44,6 +44,9 @@ pub enum SNARKError { #[error("Public input size was different from the circuit")] PublicInputSizeMismatch, + + #[error("FFT precomputation not found")] + FFTPrecompNotFound, } impl From for SNARKError { diff --git a/algorithms/src/polycommit/kzg10/data_structures.rs b/algorithms/src/polycommit/kzg10/data_structures.rs index bdd0488f99..67b9131707 100644 --- a/algorithms/src/polycommit/kzg10/data_structures.rs +++ b/algorithms/src/polycommit/kzg10/data_structures.rs @@ -13,9 +13,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use super::DegreeInfo; use crate::{ AlgebraicSponge, fft::{DensePolynomial, EvaluationDomain}, + snark::varuna::UniversalProver, + srs::UniversalVerifier, }; use snarkvm_curves::{AffineCurve, PairingCurve, PairingEngine, ProjectiveCurve}; use snarkvm_fields::{ConstraintFieldError, ToConstraintField, Zero}; @@ -28,7 +31,6 @@ use snarkvm_utilities::{ serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate}, }; -use crate::srs::{UniversalProver, UniversalVerifier}; use anyhow::Result; use core::ops::{Add, AddAssign}; use rand::RngCore; @@ -98,8 +100,10 @@ impl UniversalParams { self.powers.max_num_powers() - 1 } - pub fn to_universal_prover(&self) -> Result> { - Ok(UniversalProver:: { max_degree: self.max_degree(), _unused: None }) + pub fn to_universal_prover(&self, degree_info: DegreeInfo) -> Result> { + let mut universal_prover = UniversalProver::default(); + universal_prover.update(self, degree_info)?; + Ok(universal_prover) } pub fn to_universal_verifier(&self) -> Result> { diff --git a/algorithms/src/polycommit/kzg10/degree_info.rs b/algorithms/src/polycommit/kzg10/degree_info.rs new file mode 100644 index 0000000000..385c235a63 --- /dev/null +++ b/algorithms/src/polycommit/kzg10/degree_info.rs @@ -0,0 +1,62 @@ +// Copyright (c) 2019-2026 Provable Inc. +// This file is part of the snarkVM library. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeSet; + +#[derive(Clone, Debug, Default)] +pub struct DegreeInfo { + /// The maximum degree of the required SRS to commit to the polynomials. + pub max_degree: usize, + /// The maximum IOP poly degree used for (i)fft_precomputation. + pub max_fft_size: usize, + /// The degree bounds on IOP polynomials. + pub degree_bounds: Option>, + /// The hiding bound for polynomial queries. + pub hiding_bound: usize, + /// The supported sizes for the lagrange-basis SRS. + pub lagrange_sizes: Option>, +} + +impl DegreeInfo { + /// Initializes a new degree info. + pub const fn new( + max_degree: usize, + max_fft_size: usize, + degree_bounds: Option>, + hiding_bound: usize, + lagrange_sizes: Option>, + ) -> Self { + Self { max_degree, max_fft_size, degree_bounds, hiding_bound, lagrange_sizes } + } +} + +impl DegreeInfo { + pub fn union(self, other: &Self) -> Self { + let max_degree = self.max_degree.max(other.max_degree); + let max_fft_size = self.max_fft_size.max(other.max_fft_size); + let degree_bounds = match (&self.degree_bounds, &other.degree_bounds) { + (Some(a), Some(b)) => Some(a | b), + (Some(a), None) | (None, Some(a)) => Some(a.clone()), + (None, None) => None, + }; + let hiding_bound = self.hiding_bound.max(other.hiding_bound); + let lagrange_sizes = match (&self.lagrange_sizes, &other.lagrange_sizes) { + (Some(a), Some(b)) => Some(a | b), + (Some(a), None) | (None, Some(a)) => Some(a.clone()), + (None, None) => None, + }; + Self::new(max_degree, max_fft_size, degree_bounds, hiding_bound, lagrange_sizes) + } +} diff --git a/algorithms/src/polycommit/kzg10/mod.rs b/algorithms/src/polycommit/kzg10/mod.rs index 8b21c3a9e7..a40cad2b4f 100644 --- a/algorithms/src/polycommit/kzg10/mod.rs +++ b/algorithms/src/polycommit/kzg10/mod.rs @@ -41,6 +41,9 @@ use rayon::prelude::*; mod data_structures; pub use data_structures::*; +mod degree_info; +pub use degree_info::*; + use super::sonic_pc::LabeledPolynomialWithBasis; #[derive(Debug, PartialEq, Eq)] diff --git a/algorithms/src/polycommit/sonic_pc/data_structures.rs b/algorithms/src/polycommit/sonic_pc/data_structures.rs index f2c09c876b..96c31a5d0c 100644 --- a/algorithms/src/polycommit/sonic_pc/data_structures.rs +++ b/algorithms/src/polycommit/sonic_pc/data_structures.rs @@ -14,18 +14,27 @@ // limitations under the License. use super::{LabeledPolynomial, PolynomialInfo}; -use crate::{crypto_hash::sha256::sha256, fft::EvaluationDomain, polycommit::kzg10}; +use crate::{ + crypto_hash::sha256::sha256, + fft::EvaluationDomain, + polycommit::kzg10::{self, DegreeInfo}, + prelude::PCError, +}; use snarkvm_curves::PairingEngine; use snarkvm_fields::{ConstraintFieldError, Field, PrimeField, ToConstraintField}; use snarkvm_parameters::mainnet::MAX_NUM_POWERS; use snarkvm_utilities::{FromBytes, ToBytes, error, into_io_error, serialize::*}; +use anyhow::{Result, anyhow, bail, ensure}; use hashbrown::HashMap; +use itertools::Itertools; use std::{ borrow::{Borrow, Cow}, + cmp::Ordering, collections::{BTreeMap, BTreeSet}, fmt, io, + mem, ops::{AddAssign, MulAssign, SubAssign}, }; @@ -46,6 +55,7 @@ pub struct CommitterKey { pub powers_of_beta_g: Vec, /// The key used to commit to polynomials in Lagrange basis. + /// This is empty if `self` does not support lagrange bases pub lagrange_bases_at_beta_g: BTreeMap>, /// The key used to commit to hiding polynomials. @@ -65,6 +75,19 @@ pub struct CommitterKey { pub enforced_degree_bounds: Option>, } +impl Default for CommitterKey { + fn default() -> Self { + Self { + powers_of_beta_g: Vec::new(), + lagrange_bases_at_beta_g: Default::default(), + powers_of_beta_times_gamma_g: Vec::new(), + shifted_powers_of_beta_g: None, + shifted_powers_of_beta_times_gamma_g: None, + enforced_degree_bounds: None, + } + } +} + impl FromBytes for CommitterKey { fn read_le(mut reader: R) -> io::Result { // Deserialize `powers`. @@ -198,13 +221,13 @@ impl FromBytes for CommitterKey { impl ToBytes for CommitterKey { fn write_le(&self, mut writer: W) -> io::Result<()> { - // Serialize `powers`. + // Serialize `powers_of_beta_g`. (self.powers_of_beta_g.len() as u32).write_le(&mut writer)?; for power in &self.powers_of_beta_g { power.write_le(&mut writer)?; } - // Serialize `powers`. + // Serialize `lagrange_bases_at_beta_g`. (self.lagrange_bases_at_beta_g.len() as u32).write_le(&mut writer)?; for (size, powers) in &self.lagrange_bases_at_beta_g { (*size as u32).write_le(&mut writer)?; @@ -275,44 +298,171 @@ impl ToBytes for CommitterKey { } impl CommitterKey { - fn len(&self) -> usize { - if self.shifted_powers_of_beta_g.is_some() { self.shifted_powers_of_beta_g.as_ref().unwrap().len() } else { 0 } - } -} - -/// `CommitterUnionKey` is a union of `CommitterKey`s, useful for multi-circuit -/// batch proofs. -#[derive(Debug)] -pub struct CommitterUnionKey<'a, E: PairingEngine> { - /// The key used to commit to polynomials. - pub powers_of_beta_g: Option<&'a Vec>, - - /// The key used to commit to polynomials in Lagrange basis. - pub lagrange_bases_at_beta_g: BTreeMap>, - - /// The key used to commit to hiding polynomials. - pub powers_of_beta_times_gamma_g: Option<&'a Vec>, + /// Update the committer_key + /// The `powers_of_beta_g` and `shifted_powers_of_beta_g` may grow or shrink + /// as per the new supported_degree The `enforced_degree_bounds` and + /// `shifted_powers_of_beta_times_gamma_g` will adjust to the new + /// degree_bounds The `lagrange_bases_at_beta_g` will optionally adjust + /// to the new degree_bounds The `powers_of_beta_times_gamma_g` is only + /// dependent on the `hiding_bound` so doesn't change This only works if + /// the SRS max_degree is fixed across specializations Note that this + /// implementation is not atomic. If specialize fails halfway through, + pub fn update(&mut self, srs: &UniversalParams, degree_info: &DegreeInfo) -> Result<()> { + let trim_time = start_timer!(|| "Trimming public parameters"); + + // The maximum degree required by the circuit polynomials + let supported_degree = degree_info.max_degree; + // The maximum degree of the entire SRS + let max_degree = srs.max_degree(); + // The hiding bound for the circuit polynomials + let supported_hiding_bound = degree_info.hiding_bound; + + // Update shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g, + // enforced_degree_bounds + match degree_info.degree_bounds.as_ref() { + None => { + self.shifted_powers_of_beta_g = None; + self.shifted_powers_of_beta_times_gamma_g = None; + self.enforced_degree_bounds = None; + } + Some(enforced_degree_bounds) => { + // Retrieve new highest degree bound + let new_highest_degree_bound = + *enforced_degree_bounds.last().ok_or(anyhow!("No degree bound found"))?; + ensure!(new_highest_degree_bound < supported_degree); + + // Retrieve current degree bounds and shifted powers + let degree_bounds = self.enforced_degree_bounds.take().unwrap_or_default(); + let mut highest_degree_bound = degree_bounds.iter().copied().sorted().last().unwrap_or_default(); + let mut shifted_powers_of_beta_g = self.shifted_powers_of_beta_g.take().unwrap_or_default(); + let mut shifted_powers_of_beta_times_gamma_g = + self.shifted_powers_of_beta_times_gamma_g.take().unwrap_or_default(); + + // We add 1 to any existing upper bound, in congruence with `max_degree + 1` + // below. This is because the proof system assumes we need this + // extra degree. This can optionally be refactored to ensure the + // extra degree is already encoded in degree_bounds. + if highest_degree_bound > 0 { + highest_degree_bound = + highest_degree_bound.checked_add(1).ok_or(error("Overflow highest_degree_bound"))?; + } - /// The powers used to commit to shifted polynomials. - /// This is `None` if `self` does not support enforcing any degree bounds. - pub shifted_powers_of_beta_g: Option<&'a Vec>, + let shifted_ck_time = start_timer!(|| "Constructing `shifted_powers_of_beta_g`"); + match new_highest_degree_bound.cmp(&highest_degree_bound) { + Ordering::Greater => { + let lowest_shift_degree = max_degree - new_highest_degree_bound; + let highest_shift_degree = max_degree + 1 - highest_degree_bound; + let mut new_powers = srs.powers_of_beta_g(lowest_shift_degree, highest_shift_degree)?; + new_powers.extend_from_slice(&shifted_powers_of_beta_g); + shifted_powers_of_beta_g = new_powers; + } + Ordering::Equal => (), // No need to adjust shifted_powers_of_beta_g + Ordering::Less => { + let degree_diff = highest_degree_bound - new_highest_degree_bound; + shifted_powers_of_beta_g = shifted_powers_of_beta_g.drain(degree_diff - 1..).collect_vec(); + } + }; + self.shifted_powers_of_beta_g = Some(shifted_powers_of_beta_g); + end_timer!(shifted_ck_time); + + // because we might batch prove circuits of different sizes, we consider each + // required degree_bound individually + let shifted_ck_time = start_timer!(|| "Constructing `shifted_powers_of_beta_times_gamma_g`"); + let mut new_shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); + for degree_bound in enforced_degree_bounds { + match shifted_powers_of_beta_times_gamma_g.remove(degree_bound) { + Some(found_powers) => { + new_shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, found_powers); + } + None => { + let shift_degree = max_degree - *degree_bound; + let mut powers_for_degree_bound = + Vec::with_capacity((max_degree + 2).saturating_sub(shift_degree)); + for i in 0..=supported_hiding_bound + .checked_add(1) + .ok_or(anyhow!("Overflow supported_hiding_bound"))? + { + // We have an additional degree in `powers_of_beta_times_gamma_g` beyond + // `powers_of_beta_g`. + if shift_degree + i < max_degree + 2 { + powers_for_degree_bound + .push(srs.powers_of_beta_times_gamma_g()[&(shift_degree + i)]); + } + } + new_shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, powers_for_degree_bound); + } + } + } + self.enforced_degree_bounds = Some(enforced_degree_bounds.iter().copied().collect_vec()); + self.shifted_powers_of_beta_times_gamma_g = Some(new_shifted_powers_of_beta_times_gamma_g); + end_timer!(shifted_ck_time); + } + } - /// The powers used to commit to shifted hiding polynomials. - /// This is `None` if `self` does not support enforcing any degree bounds. - pub shifted_powers_of_beta_times_gamma_g: Option>>, + // Update powers_of_beta_g + let old_supported_degree = self.powers_of_beta_g.len(); + match supported_degree.cmp(&old_supported_degree) { + Ordering::Greater => { + let mut new_powers = srs.powers_of_beta_g(old_supported_degree, supported_degree)?; + self.powers_of_beta_g.append(&mut new_powers); + } + Ordering::Equal => (), // No need to adjust powers_of_beta_g + Ordering::Less => { + let degree_difference = old_supported_degree - supported_degree; + self.powers_of_beta_g.truncate(old_supported_degree - degree_difference); + } + } - /// The degree bounds that are supported by `self`. - /// Sorted in ascending order from smallest bound to largest bound. - /// This is `None` if `self` does not support enforcing any degree bounds. - pub enforced_degree_bounds: Option>, -} + // Set powers_of_beta_times_gamma_g + self.powers_of_beta_times_gamma_g = + (0..=(supported_hiding_bound.checked_add(1).ok_or(anyhow!("Overflow supported_hiding_bound"))?)) + .map(|i| { + srs.powers_of_beta_times_gamma_g().get(&i).copied().ok_or(PCError::HidingBoundToolarge { + hiding_poly_degree: supported_hiding_bound, + num_powers: 0, + }) + }) + .collect::, _>>()?; + + // Update lagrange_bases_at_beta_g + let mut lagrange_bases_at_beta_g = BTreeMap::new(); + if let Some(supported_lagrange_sizes) = degree_info.lagrange_sizes.as_ref() { + let mut old_lagrange_bases_at_beta_g = mem::take(&mut self.lagrange_bases_at_beta_g); + for size in supported_lagrange_sizes { + let domain = EvaluationDomain::new(*size).unwrap(); + match old_lagrange_bases_at_beta_g.remove(&domain.size()) { + Some(found_lagrange_bases) => { + lagrange_bases_at_beta_g.insert(domain.size(), found_lagrange_bases); + } + None => { + let lagrange_time = start_timer!(|| format!("Constructing `lagrange_bases` of size {size}")); + if !size.is_power_of_two() { + bail!("The Lagrange basis size ({size}) is not a power of two") + } + if *size > srs.max_degree() + 1 { + bail!( + "The Lagrange basis size ({size}) is larger than the supported degree ({})", + srs.max_degree() + 1 + ) + } + let lagrange_basis_at_beta_g = srs.lagrange_basis(domain)?; + assert!(lagrange_basis_at_beta_g.len().is_power_of_two()); + lagrange_bases_at_beta_g.insert(domain.size(), lagrange_basis_at_beta_g); + end_timer!(lagrange_time); + } + } + } + self.lagrange_bases_at_beta_g = lagrange_bases_at_beta_g; + } + end_timer!(trim_time); + Ok(()) + } -impl<'a, E: PairingEngine> CommitterUnionKey<'a, E> { /// Obtain powers for the underlying KZG10 construction pub fn powers(&self) -> kzg10::Powers { kzg10::Powers { - powers_of_beta_g: self.powers_of_beta_g.unwrap().as_slice().into(), - powers_of_beta_times_gamma_g: self.powers_of_beta_times_gamma_g.unwrap().as_slice().into(), + powers_of_beta_g: self.powers_of_beta_g.as_slice().into(), + powers_of_beta_times_gamma_g: self.powers_of_beta_times_gamma_g.as_slice().into(), } } @@ -345,55 +495,10 @@ impl<'a, E: PairingEngine> CommitterUnionKey<'a, E> { pub fn lagrange_basis(&self, domain: EvaluationDomain) -> Option> { self.lagrange_bases_at_beta_g.get(&domain.size()).map(|basis| kzg10::LagrangeBasis { lagrange_basis_at_beta_g: Cow::Borrowed(basis), - powers_of_beta_times_gamma_g: Cow::Borrowed(self.powers_of_beta_times_gamma_g.unwrap()), + powers_of_beta_times_gamma_g: Cow::Borrowed(&self.powers_of_beta_times_gamma_g), domain, }) } - - pub fn union>>(committer_keys: T) -> Self { - let mut ck_union = CommitterUnionKey:: { - powers_of_beta_g: None, - lagrange_bases_at_beta_g: BTreeMap::new(), - powers_of_beta_times_gamma_g: None, - shifted_powers_of_beta_g: None, - shifted_powers_of_beta_times_gamma_g: None, - enforced_degree_bounds: None, - }; - let mut enforced_degree_bounds = vec![]; - let mut biggest_ck: Option<&CommitterKey> = None; - let mut shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); - for ck in committer_keys { - if biggest_ck.is_none() || biggest_ck.unwrap().len() < ck.len() { - biggest_ck = Some(ck); - } - let lagrange_bases = &ck.lagrange_bases_at_beta_g; - for (bound_base, bases) in lagrange_bases.iter() { - ck_union.lagrange_bases_at_beta_g.entry(*bound_base).or_insert(bases); - } - if let Some(shifted_powers) = ck.shifted_powers_of_beta_times_gamma_g.as_ref() { - for (bound_power, powers) in shifted_powers.iter() { - shifted_powers_of_beta_times_gamma_g.entry(*bound_power).or_insert(powers); - } - } - if let Some(degree_bounds) = &ck.enforced_degree_bounds { - enforced_degree_bounds.append(&mut degree_bounds.clone()); - } - } - - let biggest_ck = biggest_ck.unwrap(); - ck_union.powers_of_beta_g = Some(&biggest_ck.powers_of_beta_g); - ck_union.powers_of_beta_times_gamma_g = Some(&biggest_ck.powers_of_beta_times_gamma_g); - ck_union.shifted_powers_of_beta_g = biggest_ck.shifted_powers_of_beta_g.as_ref(); - - if !enforced_degree_bounds.is_empty() { - enforced_degree_bounds.sort(); - enforced_degree_bounds.dedup(); - ck_union.enforced_degree_bounds = Some(enforced_degree_bounds); - ck_union.shifted_powers_of_beta_times_gamma_g = Some(shifted_powers_of_beta_times_gamma_g); - } - - ck_union - } } /// Evaluation proof at a query set. diff --git a/algorithms/src/polycommit/sonic_pc/mod.rs b/algorithms/src/polycommit/sonic_pc/mod.rs index 251ae6b4f3..8f830a3dfb 100644 --- a/algorithms/src/polycommit/sonic_pc/mod.rs +++ b/algorithms/src/polycommit/sonic_pc/mod.rs @@ -61,109 +61,6 @@ impl> SonicKZG10 { kzg10::KZG10::load_srs(max_degree) } - pub fn trim( - pp: &UniversalParams, - supported_degree: usize, - supported_lagrange_sizes: impl IntoIterator, - supported_hiding_bound: usize, - enforced_degree_bounds: Option<&[usize]>, - ) -> Result<(CommitterKey, UniversalVerifier)> { - let trim_time = start_timer!(|| "Trimming public parameters"); - let max_degree = pp.max_degree(); - - let enforced_degree_bounds = enforced_degree_bounds.map(|bounds| { - let mut v = bounds.to_vec(); - v.sort_unstable(); - v.dedup(); - v - }); - - let (shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g) = if let Some(enforced_degree_bounds) = - enforced_degree_bounds.as_ref() - { - if enforced_degree_bounds.is_empty() { - (None, None) - } else { - let highest_enforced_degree_bound = *enforced_degree_bounds.last().unwrap(); - if highest_enforced_degree_bound > supported_degree { - bail!( - "The highest enforced degree bound {highest_enforced_degree_bound} is larger than the supported degree {supported_degree}" - ); - } - - let lowest_shift_degree = max_degree - highest_enforced_degree_bound; - - let shifted_ck_time = start_timer!(|| format!( - "Constructing `shifted_powers_of_beta_g` of size {}", - max_degree - lowest_shift_degree + 1 - )); - - let shifted_powers_of_beta_g = pp.powers_of_beta_g(lowest_shift_degree, pp.max_degree() + 1)?; - let mut shifted_powers_of_beta_times_gamma_g = BTreeMap::new(); - // Also add degree 0. - for degree_bound in enforced_degree_bounds { - let shift_degree = max_degree - degree_bound; - // We have an additional degree in `powers_of_beta_times_gamma_g` beyond - // `powers_of_beta_g`. - let powers_for_degree_bound = pp - .powers_of_beta_times_gamma_g() - .range(shift_degree..max_degree.min(shift_degree + supported_hiding_bound) + 2) - .map(|(_k, v)| *v) - .collect(); - shifted_powers_of_beta_times_gamma_g.insert(*degree_bound, powers_for_degree_bound); - } - - end_timer!(shifted_ck_time); - - (Some(shifted_powers_of_beta_g), Some(shifted_powers_of_beta_times_gamma_g)) - } - } else { - (None, None) - }; - - let powers_of_beta_g = pp.powers_of_beta_g(0, supported_degree + 1)?; - let powers_of_beta_times_gamma_g = pp - .powers_of_beta_times_gamma_g() - .range(0..=(supported_hiding_bound + 1)) - .map(|(_k, v)| *v) - .collect::>(); - if powers_of_beta_times_gamma_g.len() != supported_hiding_bound + 2 { - return Err( - PCError::HidingBoundToolarge { hiding_poly_degree: supported_hiding_bound, num_powers: 0 }.into() - ); - } - - let mut lagrange_bases_at_beta_g = BTreeMap::new(); - for size in supported_lagrange_sizes { - let lagrange_time = start_timer!(|| format!("Constructing `lagrange_bases` of size {size}")); - if !size.is_power_of_two() { - bail!("The Lagrange basis size ({size}) is not a power of two") - } - if size > pp.max_degree() + 1 { - bail!("The Lagrange basis size ({size}) is larger than the supported degree ({})", pp.max_degree() + 1) - } - let domain = crate::fft::EvaluationDomain::new(size).unwrap(); - let lagrange_basis_at_beta_g = pp.lagrange_basis(domain)?; - assert!(lagrange_basis_at_beta_g.len().is_power_of_two()); - lagrange_bases_at_beta_g.insert(domain.size(), lagrange_basis_at_beta_g); - end_timer!(lagrange_time); - } - - let ck = CommitterKey { - powers_of_beta_g, - lagrange_bases_at_beta_g, - powers_of_beta_times_gamma_g, - shifted_powers_of_beta_g, - shifted_powers_of_beta_times_gamma_g, - enforced_degree_bounds, - }; - - let vk = pp.to_universal_verifier()?; - - end_timer!(trim_time); - Ok((ck, vk)) - } - /// Outputs commitments to `polynomials`. /// /// If `polynomials[i].is_hiding()`, then the `i`-th commitment is hiding @@ -180,12 +77,12 @@ impl> SonicKZG10 { #[allow(clippy::format_push_string)] pub fn commit<'b>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, polynomials: impl IntoIterator>, rng: Option<&mut dyn RngCore>, ) -> Result<(Vec>>, Vec>), PCError> { let rng = &mut OptionalRng(rng); let commit_time = start_timer!(|| "Committing to polynomials"); + let ck = &universal_prover.committer_key; let mut pool = snarkvm_utilities::ExecutionPool::>::new(); for p in polynomials { @@ -262,7 +159,6 @@ impl> SonicKZG10 { pub fn combine_for_open<'a>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, labeled_polynomials: impl ExactSizeIterator>, rands: impl ExactSizeIterator>, fs_rng: &mut S, @@ -273,6 +169,7 @@ impl> SonicKZG10 { { ensure!(labeled_polynomials.len() == rands.len()); let mut to_combine = Vec::with_capacity(labeled_polynomials.len()); + let ck = &universal_prover.committer_key; for (p, r) in labeled_polynomials.zip_eq(rands) { let enforced_degree_bounds: Option<&[usize]> = ck.enforced_degree_bounds.as_deref(); @@ -290,7 +187,6 @@ impl> SonicKZG10 { /// set. pub fn batch_open<'a>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, labeled_polynomials: impl ExactSizeIterator>, query_set: &QuerySet, rands: impl ExactSizeIterator>, @@ -301,6 +197,7 @@ impl> SonicKZG10 { Commitment: 'a, { ensure!(labeled_polynomials.len() == rands.len()); + let ck = &universal_prover.committer_key; let poly_rand: HashMap<_, _> = labeled_polynomials.into_iter().zip_eq(rands).map(|(poly, r)| (poly.label(), (poly, r))).collect(); @@ -330,7 +227,7 @@ impl> SonicKZG10 { query_rands.push(*rand); } let (polynomial, rand) = - Self::combine_for_open(universal_prover, ck, query_polys.into_iter(), query_rands.into_iter(), fs_rng)?; + Self::combine_for_open(universal_prover, query_polys.into_iter(), query_rands.into_iter(), fs_rng)?; let _randomizer = fs_rng.squeeze_short_nonnative_field_element::(); pool.add_job(move || { @@ -417,7 +314,6 @@ impl> SonicKZG10 { pub fn open_combinations<'a>( universal_prover: &UniversalProver, - ck: &CommitterUnionKey, linear_combinations: impl IntoIterator>, polynomials: impl IntoIterator>, rands: impl IntoIterator>, @@ -444,7 +340,7 @@ impl> SonicKZG10 { let num_polys = lc.len(); // We filter out l.is_one() entries because those constants are not committed to - // and used directly by the verifier. + // but used directly by the verifier. for (coeff, label) in lc.iter().filter(|(_, l)| !l.is_one()) { let label: &String = label.try_into().expect("cannot be one!"); let (cur_poly, cur_rand) = @@ -472,8 +368,7 @@ impl> SonicKZG10 { lc_info.push((lc_label, degree_bound)); } - let proof = - Self::batch_open(universal_prover, ck, lc_polynomials.iter(), query_set, lc_randomness.iter(), fs_rng)?; + let proof = Self::batch_open(universal_prover, lc_polynomials.iter(), query_set, lc_randomness.iter(), fs_rng)?; Ok(BatchLCProof { proof }) } @@ -691,7 +586,10 @@ mod tests { #![allow(non_camel_case_types)] use super::{CommitterKey, SonicKZG10}; - use crate::{crypto_hash::PoseidonSponge, polycommit::test_templates::*}; + use crate::{ + crypto_hash::PoseidonSponge, + polycommit::{kzg10::DegreeInfo, test_templates::*}, + }; use snarkvm_curves::bls12_377::{Bls12_377, Fq}; use snarkvm_utilities::{FromBytes, ToBytes, rand::TestRng}; @@ -704,13 +602,15 @@ mod tests { fn test_committer_key_serialization() { let rng = &mut TestRng::default(); let max_degree = rand::distributions::Uniform::from(8..=64).sample(rng); - let supported_degree = rand::distributions::Uniform::from(1..=max_degree).sample(rng); + let pp = PC_Bls12_377::load_srs(max_degree).unwrap(); + let supported_degree = rand::distributions::Uniform::from(1..=max_degree).sample(rng); + let hiding_bound = 0; let lagrange_size = |d: usize| if d.is_power_of_two() { d } else { d.next_power_of_two() >> 1 }; + let lagrange_sizes = Some([lagrange_size(supported_degree)].into_iter().collect()); + let degree_info = DegreeInfo::new(supported_degree, supported_degree, None, hiding_bound, lagrange_sizes); - let pp = PC_Bls12_377::load_srs(max_degree).unwrap(); - - let (ck, _vk) = PC_Bls12_377::trim(&pp, supported_degree, [lagrange_size(supported_degree)], 0, None).unwrap(); + let ck = pp.to_universal_prover(degree_info).unwrap().committer_key; let ck_bytes = ck.to_bytes_le().unwrap(); let ck_recovered: CommitterKey = FromBytes::read_le(&ck_bytes[..]).unwrap(); diff --git a/algorithms/src/polycommit/test_templates.rs b/algorithms/src/polycommit/test_templates.rs index d47c1c30ba..e5ab111882 100644 --- a/algorithms/src/polycommit/test_templates.rs +++ b/algorithms/src/polycommit/test_templates.rs @@ -17,7 +17,6 @@ use super::sonic_pc::{ BatchLCProof, BatchProof, Commitment, - CommitterUnionKey, Evaluations, LabeledCommitment, QuerySet, @@ -29,6 +28,7 @@ use crate::{ fft::DensePolynomial, polycommit::{ PCError, + kzg10::DegreeInfo, sonic_pc::{LabeledPolynomial, LabeledPolynomialWithBasis, LinearCombination}, }, srs::UniversalVerifier, @@ -42,7 +42,7 @@ use rand::{ Rng, distributions::{self, Distribution}, }; -use std::marker::PhantomData; +use std::{collections::BTreeSet, marker::PhantomData}; #[derive(Default)] struct TestInfo { @@ -56,7 +56,7 @@ struct TestInfo { } pub struct TestComponents> { - pub verification_key: UniversalVerifier, + pub universal_verifier: UniversalVerifier, pub commitments: Vec>>, pub query_set: QuerySet, pub evaluations: Evaluations, @@ -69,8 +69,9 @@ pub struct TestComponents> { pub fn bad_degree_bound_test>() -> Result<(), PCError> { let rng = &mut TestRng::default(); let max_degree = 100; + let hiding_bound = 1; let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); + let universal_verifier = pp.to_universal_verifier().unwrap(); for _ in 0..10 { let supported_degree = distributions::Uniform::from(1..=max_degree).sample(rng); @@ -78,7 +79,7 @@ pub fn bad_degree_bound_test>() - let mut labels = Vec::new(); let mut polynomials = Vec::new(); - let mut degree_bounds = Vec::new(); + let mut degree_bounds = BTreeSet::new(); for i in 0..10 { let label = format!("Test{i}"); @@ -86,22 +87,23 @@ pub fn bad_degree_bound_test>() - let poly = DensePolynomial::rand(supported_degree, rng); let degree_bound = 1usize; - let hiding_bound = Some(1); - degree_bounds.push(degree_bound); + degree_bounds.insert(degree_bound); - polynomials.push(LabeledPolynomial::new(label, poly, Some(degree_bound), hiding_bound)) + polynomials.push(LabeledPolynomial::new(label, poly, Some(degree_bound), Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); - println!("supported degree: {supported_degree:?}"); - let (ck, vk) = - SonicKZG10::::trim(&pp, supported_degree, None, supported_degree, Some(degree_bounds.as_slice())) - .unwrap(); - println!("Trimmed"); - - let ck = CommitterUnionKey::union(std::iter::once(&ck)); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, + degree_bounds, + hiding_bound, + lagrange_sizes: None, + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, polynomials.iter().map(Into::into), Some(rng))?; + SonicKZG10::::commit(universal_prover, polynomials.iter().map(Into::into), Some(rng))?; let mut query_set = QuerySet::new(); let mut values = Evaluations::new(); @@ -116,14 +118,14 @@ pub fn bad_degree_bound_test>() - let mut sponge_for_open = S::new(); let proof = SonicKZG10::batch_open( universal_prover, - &ck, polynomials.iter(), &query_set, rands.iter(), &mut sponge_for_open, )?; let mut sponge_for_check = S::new(); - let result = SonicKZG10::batch_check(&vk, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; + let result = + SonicKZG10::batch_check(&universal_verifier, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; assert!(result, "proof was incorrect, Query set: {query_set:#?}"); } Ok(()) @@ -141,17 +143,17 @@ pub fn lagrange_test_template>() let rng = &mut TestRng::default(); let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); for _ in 0..num_iters { + let universal_verifier = pp.to_universal_verifier().unwrap(); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); let mut lagrange_polynomials = Vec::new(); - let mut supported_lagrange_sizes = Vec::new(); - let degree_bounds = None; + let mut supported_lagrange_sizes = BTreeSet::new(); + let degree_bounds = BTreeSet::new(); + let hiding_bound = 1; let mut labels = Vec::new(); - println!("Sampled supported degree"); // Generate polynomials let num_points_in_query_set = distributions::Uniform::from(1..=max_num_queries).sample(rng); @@ -166,34 +168,27 @@ pub fn lagrange_test_template>() let domain = crate::fft::EvaluationDomain::new(evals.len()).unwrap(); let evals = crate::fft::Evaluations::from_vec_and_domain(evals, domain); let poly = evals.interpolate_by_ref(); - supported_lagrange_sizes.push(domain.size()); + supported_lagrange_sizes.insert(domain.size()); assert_eq!(poly.evaluate_over_domain_by_ref(domain), evals); let degree_bound = None; - let hiding_bound = Some(1); - polynomials.push(LabeledPolynomial::new(label.clone(), poly, degree_bound, hiding_bound)); - lagrange_polynomials.push(LabeledPolynomialWithBasis::new_lagrange_basis(label, evals, hiding_bound)) + polynomials.push(LabeledPolynomial::new(label.clone(), poly, degree_bound, Some(hiding_bound))); + lagrange_polynomials.push(LabeledPolynomialWithBasis::new_lagrange_basis(label, evals, Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); - println!("supported degree: {supported_degree:?}"); - println!("supported hiding bound: {supported_hiding_bound:?}"); - println!("num_points_in_query_set: {num_points_in_query_set:?}"); - let (ck, vk_orig) = SonicKZG10::::trim( - &pp, - supported_degree, - supported_lagrange_sizes, - supported_hiding_bound, + assert_eq!(supported_hiding_bound, 1); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, degree_bounds, - ) - .unwrap(); - println!("Trimmed"); + hiding_bound, + lagrange_sizes: Some(supported_lagrange_sizes), + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); - let ck = CommitterUnionKey::union(std::iter::once(&ck)); - let vk = pp.to_universal_verifier().unwrap(); - - let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, lagrange_polynomials, Some(rng)).unwrap(); + let (comms, rands) = SonicKZG10::::commit(universal_prover, lagrange_polynomials, Some(rng)).unwrap(); // Construct query set let mut query_set = QuerySet::new(); @@ -212,14 +207,14 @@ pub fn lagrange_test_template>() let mut sponge_for_open = S::new(); let proof = SonicKZG10::batch_open( universal_prover, - &ck, polynomials.iter(), &query_set, rands.iter(), &mut sponge_for_open, )?; let mut sponge_for_check = S::new(); - let result = SonicKZG10::batch_check(&vk, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; + let result = + SonicKZG10::batch_check(&universal_verifier, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; if !result { println!("Failed with {num_polynomials} polynomials, num_points_in_query_set: {num_points_in_query_set:?}"); println!("Degree of polynomials:"); @@ -230,7 +225,7 @@ pub fn lagrange_test_template>() assert!(result, "proof was incorrect, Query set: {query_set:#?}"); test_components.push(TestComponents { - verification_key: vk_orig, + universal_verifier, commitments: comms, query_set, evaluations: values, @@ -263,15 +258,16 @@ where let rng = &mut TestRng::default(); let max_degree = max_degree.unwrap_or_else(|| distributions::Uniform::from(8..=64).sample(rng)); let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); let supported_degree_bounds = [1 << 10, 1 << 15, 1 << 20, 1 << 25, 1 << 30]; + let hiding_bound = 1; for _ in 0..num_iters { + let universal_verifier = pp.to_universal_verifier().unwrap(); let supported_degree = - supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree).sample(rng)); + supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree - 1).sample(rng)); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); - let mut degree_bounds = if enforce_degree_bounds { Some(Vec::new()) } else { None }; + let mut degree_bounds = BTreeSet::new(); let mut labels = Vec::new(); println!("Sampled supported degree"); @@ -290,40 +286,32 @@ where .filter(|x| *x >= degree && *x < supported_degree) .collect::>(); - let degree_bound = if let Some(degree_bounds) = &mut degree_bounds { - if !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen() { - let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); - let idx = range.sample(rng); - - let degree_bound = supported_degree_bounds_after_trimmed[idx]; - degree_bounds.push(degree_bound); - Some(degree_bound) - } else { - None - } - } else { - None - }; - - let hiding_bound = Some(1); - println!("Hiding bound: {hiding_bound:?}"); - - polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, hiding_bound)) + let compute_degree_bound = + enforce_degree_bounds && !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen(); + let degree_bound = compute_degree_bound.then(|| { + let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); + let idx = range.sample(rng); + supported_degree_bounds_after_trimmed[idx] + }); + if let Some(degree_bound) = degree_bound { + degree_bounds.insert(degree_bound); + } + polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); - println!("supported degree: {supported_degree:?}"); - println!("supported hiding bound: {supported_hiding_bound:?}"); - println!("num_points_in_query_set: {num_points_in_query_set:?}"); - let (ck, vk_orig) = - SonicKZG10::::trim(&pp, supported_degree, None, supported_hiding_bound, degree_bounds.as_deref()) - .unwrap(); - println!("Trimmed"); - - let ck = CommitterUnionKey::union(std::iter::once(&ck)); - let vk = pp.to_universal_verifier().unwrap(); + assert_eq!(supported_hiding_bound, 1); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, + degree_bounds, + hiding_bound, + lagrange_sizes: None, + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, polynomials.iter().map(Into::into), Some(rng))?; + SonicKZG10::::commit(universal_prover, polynomials.iter().map(Into::into), Some(rng))?; // Construct query set let mut query_set = QuerySet::new(); @@ -342,14 +330,14 @@ where let mut sponge_for_open = S::new(); let proof = SonicKZG10::batch_open( universal_prover, - &ck, polynomials.iter(), &query_set, rands.iter(), &mut sponge_for_open, )?; let mut sponge_for_check = S::new(); - let result = SonicKZG10::batch_check(&vk, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; + let result = + SonicKZG10::batch_check(&universal_verifier, &comms, &query_set, &values, &proof, &mut sponge_for_check)?; if !result { println!("Failed with {num_polynomials} polynomials, num_points_in_query_set: {num_points_in_query_set:?}"); println!("Degree of polynomials:"); @@ -360,7 +348,7 @@ where assert!(result, "proof was incorrect, Query set: {query_set:#?}"); test_components.push(TestComponents { - verification_key: vk_orig, + universal_verifier, commitments: comms, query_set, evaluations: values, @@ -391,15 +379,16 @@ fn equation_test_template>( let rng = &mut TestRng::default(); let max_degree = max_degree.unwrap_or_else(|| distributions::Uniform::from(8..=64).sample(rng)); let pp = SonicKZG10::::load_srs(max_degree)?; - let universal_prover = &pp.to_universal_prover().unwrap(); let supported_degree_bounds = [1 << 10, 1 << 15, 1 << 20, 1 << 25, 1 << 30]; + let hiding_bound = 1; for _ in 0..num_iters { + let universal_verifier = pp.to_universal_verifier().unwrap(); let supported_degree = - supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree).sample(rng)); + supported_degree.unwrap_or_else(|| distributions::Uniform::from(4..=max_degree - 1).sample(rng)); assert!(max_degree >= supported_degree, "max_degree < supported_degree"); let mut polynomials = Vec::new(); - let mut degree_bounds = if enforce_degree_bounds { Some(Vec::new()) } else { None }; + let mut degree_bounds = BTreeSet::new(); let mut labels = Vec::new(); println!("Sampled supported degree"); @@ -418,44 +407,32 @@ fn equation_test_template>( .filter(|x| *x >= degree && *x < supported_degree) .collect::>(); - let degree_bound = if let Some(degree_bounds) = &mut degree_bounds { - if !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen() { - let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); - let idx = range.sample(rng); - - let degree_bound = supported_degree_bounds_after_trimmed[idx]; - degree_bounds.push(degree_bound); - Some(degree_bound) - } else { - None - } - } else { - None - }; - - let hiding_bound = Some(1); - println!("Hiding bound: {hiding_bound:?}"); - - polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, hiding_bound)) + let compute_degree_bound = + enforce_degree_bounds && !supported_degree_bounds_after_trimmed.is_empty() && rng.r#gen(); + let degree_bound = compute_degree_bound.then(|| { + let range = distributions::Uniform::from(0..supported_degree_bounds_after_trimmed.len()); + let idx = range.sample(rng); + supported_degree_bounds_after_trimmed[idx] + }); + if let Some(degree_bound) = degree_bound { + degree_bounds.insert(degree_bound); + } + polynomials.push(LabeledPolynomial::new(label, poly, degree_bound, Some(hiding_bound))) } + let degree_bounds = (!degree_bounds.is_empty()).then_some(degree_bounds); let supported_hiding_bound = polynomials.iter().map(|p| p.hiding_bound().unwrap_or(0)).max().unwrap_or(0); - println!("supported degree: {supported_degree:?}"); - println!("supported hiding bound: {supported_hiding_bound:?}"); - println!("num_points_in_query_set: {num_points_in_query_set:?}"); - println!("{degree_bounds:?}"); - println!("{num_polynomials}"); - println!("{enforce_degree_bounds}"); - - let (ck, vk_orig) = - SonicKZG10::::trim(&pp, supported_degree, None, supported_hiding_bound, degree_bounds.as_deref()) - .unwrap(); - println!("Trimmed"); - - let ck = CommitterUnionKey::union(std::iter::once(&ck)); - let vk = pp.to_universal_verifier().unwrap(); + assert_eq!(supported_hiding_bound, 1); + let degree_info = DegreeInfo { + max_degree, + max_fft_size: supported_degree, + degree_bounds, + hiding_bound, + lagrange_sizes: None, + }; + let universal_prover = &pp.to_universal_prover(degree_info).unwrap(); let (comms, rands) = - SonicKZG10::::commit(universal_prover, &ck, polynomials.iter().map(Into::into), Some(rng))?; + SonicKZG10::::commit(universal_prover, polynomials.iter().map(Into::into), Some(rng))?; // Let's construct our equations let mut linear_combinations = Vec::new(); @@ -497,13 +474,10 @@ fn equation_test_template>( if linear_combinations.is_empty() { continue; } - println!("Generated query set"); - println!("Linear combinations: {linear_combinations:?}"); let mut sponge_for_open = S::new(); let proof = SonicKZG10::open_combinations( universal_prover, - &ck, &linear_combinations, polynomials, &rands, @@ -513,7 +487,7 @@ fn equation_test_template>( println!("Generated proof"); let mut sponge_for_check = S::new(); let result = SonicKZG10::check_combinations( - &vk, + &universal_verifier, &linear_combinations, &comms, &query_set, @@ -527,7 +501,7 @@ fn equation_test_template>( assert!(result, "proof was incorrect, equations: {linear_combinations:#?}"); test_components.push(TestComponents { - verification_key: vk_orig, + universal_verifier, commitments: comms, query_set, evaluations: values, diff --git a/algorithms/src/snark/varuna/ahp/ahp.rs b/algorithms/src/snark/varuna/ahp/ahp.rs index ca368316be..5bf11bb829 100644 --- a/algorithms/src/snark/varuna/ahp/ahp.rs +++ b/algorithms/src/snark/varuna/ahp/ahp.rs @@ -14,10 +14,7 @@ // limitations under the License. use crate::{ - fft::{ - EvaluationDomain, - domain::{FFTPrecomputation, IFFTPrecomputation}, - }, + fft::EvaluationDomain, polycommit::sonic_pc::{LCTerm, LabeledPolynomial, LinearCombination}, r1cs::SynthesisError, snark::varuna::{ @@ -82,8 +79,7 @@ impl AHPForR1CS { Self::num_formatted_public_inputs_is_admissible(input.len()) } - /// The maximum degree of polynomials produced by the indexer and prover - /// of this protocol. + /// The maximum degree of polynomials the indexer or prover commits to. /// The number of the variables must include the "one" variable. That is, it /// must be with respect to the number of formatted public inputs. pub fn max_degree(num_constraints: usize, num_variables: usize, num_non_zero: usize) -> Result { @@ -97,12 +93,11 @@ impl AHPForR1CS { // these should correspond with the bounds set in the .rs files [ - 2 * constraint_domain_size + 2 * zk_bound - 2, - 2 * variable_domain_size + 2 * zk_bound - 2, + 2 * constraint_domain_size + 2 * zk_bound - 2, // h_0 + 2 * variable_domain_size + 2 * zk_bound - 2, // h_1 if SM::ZK { variable_domain_size + 3 } else { 0 }, // mask_poly - variable_domain_size, - constraint_domain_size, - non_zero_domain_size - 1, // non-zero polynomials + variable_domain_size, // g_1 + non_zero_domain_size, // g_M, h_M ] .iter() .max() @@ -110,20 +105,6 @@ impl AHPForR1CS { .ok_or(anyhow!("Could not find max_degree")) } - /// Get all the strict degree bounds enforced in the AHP. - pub fn get_degree_bounds(info: &CircuitInfo) -> Result<[usize; 4]> { - let num_variables = info.num_public_and_private_variables; - let num_non_zero_a = info.num_non_zero_a; - let num_non_zero_b = info.num_non_zero_b; - let num_non_zero_c = info.num_non_zero_c; - Ok([ - EvaluationDomain::::compute_size_of_domain(num_variables).ok_or(SynthesisError::PolyTooLarge)? - 2, - EvaluationDomain::::compute_size_of_domain(num_non_zero_a).ok_or(SynthesisError::PolyTooLarge)? - 2, - EvaluationDomain::::compute_size_of_domain(num_non_zero_b).ok_or(SynthesisError::PolyTooLarge)? - 2, - EvaluationDomain::::compute_size_of_domain(num_non_zero_c).ok_or(SynthesisError::PolyTooLarge)? - 2, - ]) - } - pub(crate) fn cmp_non_zero_domains( info: &CircuitInfo, max_candidate: Option>, @@ -144,29 +125,6 @@ impl AHPForR1CS { Ok(NonZeroDomains { max_non_zero_domain, domain_a, domain_b, domain_c }) } - pub fn fft_precomputation( - constraint_domain_size: usize, - variable_domain_size: usize, - non_zero_a_domain_size: usize, - non_zero_b_domain_size: usize, - non_zero_c_domain_size: usize, - ) -> Option<(FFTPrecomputation, IFFTPrecomputation)> { - let largest_domain_size = [ - 2 * constraint_domain_size, - 2 * variable_domain_size, - 2 * non_zero_a_domain_size, - 2 * non_zero_b_domain_size, - 2 * non_zero_c_domain_size, - ] - .into_iter() - .max()?; - let largest_mul_domain = EvaluationDomain::new(largest_domain_size)?; - - let fft_precomputation = largest_mul_domain.precompute_fft(); - let ifft_precomputation = fft_precomputation.to_ifft_precomputation(); - Some((fft_precomputation, ifft_precomputation)) - } - /// Construct the linear combinations that are checked by the AHP. /// Public input should be unformatted. /// We construct the linear combinations as per section 5 of our protocol diff --git a/algorithms/src/snark/varuna/ahp/indexer/circuit.rs b/algorithms/src/snark/varuna/ahp/indexer/circuit.rs index bd7cded702..97a08c1c56 100644 --- a/algorithms/src/snark/varuna/ahp/indexer/circuit.rs +++ b/algorithms/src/snark/varuna/ahp/indexer/circuit.rs @@ -16,21 +16,10 @@ use core::marker::PhantomData; use crate::{ - fft::{ - EvaluationDomain, - domain::{FFTPrecomputation, IFFTPrecomputation}, - }, polycommit::sonic_pc::LabeledPolynomial, - snark::varuna::{ - AHPForR1CS, - CircuitInfo, - Matrix, - SNARKMode, - ahp::matrices::MatrixEvals, - matrices::MatrixArithmetization, - }, + snark::varuna::{CircuitInfo, Matrix, SNARKMode, ahp::matrices::MatrixEvals, matrices::MatrixArithmetization}, }; -use anyhow::{Result, anyhow}; +use anyhow::Result; use blake2::Digest; use hex::FromHex; use snarkvm_fields::PrimeField; @@ -81,8 +70,6 @@ pub struct Circuit { pub b_arith: MatrixEvals, pub c_arith: MatrixEvals, - pub fft_precomputation: FFTPrecomputation, - pub ifft_precomputation: IFFTPrecomputation, pub(crate) _mode: PhantomData, pub(crate) id: CircuitId, } @@ -121,25 +108,6 @@ impl Circuit { Ok(CircuitId(blake2.finalize().into())) } - /// The maximum degree required to represent polynomials of this index. - pub fn max_degree(&self) -> Result { - self.index_info.max_degree::() - } - - /// The size of the constraint (i. e. row) domain in this R1CS instance. - pub fn constraint_domain_size(&self) -> Result { - Ok(crate::fft::EvaluationDomain::::new(self.index_info.num_constraints) - .ok_or(anyhow!("Cannot create EvaluationDomain"))? - .size()) - } - - /// The size of the variable (i. e. column) domain in this R1CS instance. - pub fn variable_domain_size(&self) -> Result { - Ok(crate::fft::EvaluationDomain::::new(self.index_info.num_public_and_private_variables) - .ok_or(anyhow!("Cannot create EvaluationDomain"))? - .size()) - } - /// Compute the row, col, rowcol and rowcolval polynomials of the three /// matrices in this R1CS instance. pub fn interpolate_matrix_evals(&self) -> Result>> { @@ -200,26 +168,6 @@ impl CanonicalDeserialize for Circuit { validate: Validate, ) -> Result { let index_info: CircuitInfo = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; - let constraint_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_constraints) - .ok_or(SerializationError::InvalidData)?; - let variable_domain_size = - EvaluationDomain::::compute_size_of_domain(index_info.num_public_and_private_variables) - .ok_or(SerializationError::InvalidData)?; - let non_zero_a_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_non_zero_a) - .ok_or(SerializationError::InvalidData)?; - let non_zero_b_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_non_zero_b) - .ok_or(SerializationError::InvalidData)?; - let non_zero_c_domain_size = EvaluationDomain::::compute_size_of_domain(index_info.num_non_zero_c) - .ok_or(SerializationError::InvalidData)?; - - let (fft_precomputation, ifft_precomputation) = AHPForR1CS::::fft_precomputation( - variable_domain_size, - constraint_domain_size, - non_zero_a_domain_size, - non_zero_b_domain_size, - non_zero_c_domain_size, - ) - .ok_or(SerializationError::InvalidData)?; let a = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; let b = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; let c = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; @@ -232,8 +180,6 @@ impl CanonicalDeserialize for Circuit { a_arith: CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, b_arith: CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, c_arith: CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, - fft_precomputation, - ifft_precomputation, _mode: PhantomData, id, }) diff --git a/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs b/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs index 7f40a0783b..51018aec16 100644 --- a/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs +++ b/algorithms/src/snark/varuna/ahp/indexer/circuit_info.rs @@ -13,7 +13,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::snark::varuna::{SNARKMode, ahp::AHPForR1CS}; +use crate::{ + fft::EvaluationDomain, + polycommit::kzg10::DegreeInfo, + snark::varuna::{SNARKMode, ahp::AHPForR1CS}, +}; use anyhow::Result; use snarkvm_fields::PrimeField; use snarkvm_utilities::{ToBytes, serialize::*}; @@ -44,12 +48,51 @@ pub struct CircuitInfo { } impl CircuitInfo { + /// The max degrees and bounds required to represent polynomials of this + /// circuit. + pub fn degree_info(&self) -> Result { + let max_degree = self.max_degree::()?; + let max_fft_size = self.max_fft_size::(); + let degree_bounds = Some(self.degree_bounds::().into_iter().collect()); + let hiding_bound = AHPForR1CS::::zk_bound().unwrap_or(0); + Ok(DegreeInfo::new(max_degree, max_fft_size, degree_bounds, hiding_bound, None)) + } + /// The maximum degree of polynomial required to represent this index in the /// AHP. pub fn max_degree(&self) -> Result { let max_non_zero = self.num_non_zero_a.max(self.num_non_zero_b).max(self.num_non_zero_c); AHPForR1CS::::max_degree(self.num_constraints, self.num_public_and_private_variables, max_non_zero) } + + /// The maximum size of polynomials we need to do (i)fft for + pub fn max_fft_size(&self) -> usize { + let size = [ + 2 * self.num_constraints, // zerocheck poly degree + 2 * self.num_public_and_private_variables, // lineval sumcheck poly degree + 2 * self.num_non_zero_a, // matrix sumcheck poly degree + 2 * self.num_non_zero_b, // matrix sumcheck poly degree + 2 * self.num_non_zero_c, // matrix sumcheck poly degree + ] + .into_iter() + .max() + .unwrap(); + crate::fft::EvaluationDomain::::compute_size_of_domain(size).unwrap() + } + + /// Get all the strict degree bounds enforced in the AHP. + pub fn degree_bounds(&self) -> [usize; 4] { + [ + // The degree bound for g_1 + EvaluationDomain::::compute_size_of_domain(self.num_public_and_private_variables).unwrap() - 2, + // The degree bound for g_A + EvaluationDomain::::compute_size_of_domain(self.num_non_zero_a).unwrap() - 2, + // The degree bound for g_B + EvaluationDomain::::compute_size_of_domain(self.num_non_zero_b).unwrap() - 2, + // The degree bound for g_C + EvaluationDomain::::compute_size_of_domain(self.num_non_zero_c).unwrap() - 2, + ] + } } impl ToBytes for CircuitInfo { diff --git a/algorithms/src/snark/varuna/ahp/indexer/indexer.rs b/algorithms/src/snark/varuna/ahp/indexer/indexer.rs index ae7d4b9869..de9c44dc4a 100644 --- a/algorithms/src/snark/varuna/ahp/indexer/indexer.rs +++ b/algorithms/src/snark/varuna/ahp/indexer/indexer.rs @@ -43,51 +43,10 @@ use super::Matrix; impl AHPForR1CS { /// Generate the index polynomials for this constraint system. pub fn index>(c: &C) -> Result> { - let IndexerState { - constraint_domain, - variable_domain, + let IndexerState { a, a_arith, b, b_arith, c, c_arith, index_info, id, .. } = + Self::index_helper(c).map_err(|e| anyhow!("{e:?}"))?; - a, - non_zero_a_domain, - a_arith, - - b, - non_zero_b_domain, - b_arith, - - c, - non_zero_c_domain, - c_arith, - - index_info, - id, - } = Self::index_helper(c).map_err(|e| anyhow!("{e:?}"))?; - - let fft_precomp_time = start_timer!(|| format!("Precomputing roots of unity {id}")); - - let (fft_precomputation, ifft_precomputation) = Self::fft_precomputation( - constraint_domain.size(), - variable_domain.size(), - non_zero_a_domain.size(), - non_zero_b_domain.size(), - non_zero_c_domain.size(), - ) - .ok_or(anyhow!("The polynomial degree is too large"))?; - end_timer!(fft_precomp_time); - - Ok(Circuit { - index_info, - a, - b, - c, - a_arith, - b_arith, - c_arith, - fft_precomputation, - ifft_precomputation, - id, - _mode: PhantomData, - }) + Ok(Circuit { index_info, a, b, c, a_arith, b_arith, c_arith, id, _mode: PhantomData }) } pub fn index_polynomial_info<'a>( @@ -207,9 +166,6 @@ impl AHPForR1CS { let id = Circuit::::hash(&index_info, &a, &b, &c)?; let result = Ok(IndexerState { - constraint_domain, - variable_domain, - a, non_zero_a_domain, a_arith, @@ -264,11 +220,6 @@ impl AHPForR1CS { } pub(crate) struct IndexerState { - // R_i in the Varuna spec - constraint_domain: EvaluationDomain, - // C_i in the Varuna spec - variable_domain: EvaluationDomain, - a: Matrix, non_zero_a_domain: EvaluationDomain, a_arith: MatrixEvals, diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs index f26d3df5c9..e1ae7c8619 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/first.rs @@ -14,10 +14,15 @@ // limitations under the License. use crate::{ - fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain, SparsePolynomial}, + fft::{ + DensePolynomial, + EvaluationDomain, + Evaluations as EvaluationsOnDomain, + SparsePolynomial, + domain::{FFTPrecomputation, IFFTPrecomputation}, + }, polycommit::sonic_pc::{LabeledPolynomial, PolynomialInfo, PolynomialLabel}, snark::varuna::{ - Circuit, CircuitId, SNARKMode, ahp::{AHPError, AHPForR1CS}, @@ -63,6 +68,8 @@ impl AHPForR1CS { rng: &mut R, ) -> Result, AHPError> { let round_time = start_timer!(|| "AHP::Prover::FirstRound"); + let fft_precomp = state.fft_precomputation; + let ifft_precomp = state.ifft_precomputation; let mut job_pool = snarkvm_utilities::ExecutionPool::with_capacity(state.total_instances); for (circuit, circuit_state) in state.circuit_specific_states.iter_mut() { let batch_size = circuit_state.batch_size; @@ -77,7 +84,9 @@ impl AHPForR1CS { for (j, (private_vars, x_poly)) in itertools::izip!(private_variables, x_polys).enumerate() { let w_label = witness_label(circuit.id, "w", j); - job_pool.add_job(move || Self::calculate_w(w_label, private_vars, x_poly, v_domain, i_domain, circuit)); + job_pool.add_job(move || { + Self::calculate_w(w_label, private_vars, x_poly, v_domain, i_domain, fft_precomp, ifft_precomp) + }); } } let mut batches = @@ -125,13 +134,14 @@ impl AHPForR1CS { } // Compute the shifted witness \overline{w}(X) - fn calculate_w( + fn calculate_w<'a>( label: String, private_variables: Vec, x_poly: DensePolynomial, variable_domain: EvaluationDomain, input_domain: EvaluationDomain, - circuit: &Circuit, + fft_precomp: &'a FFTPrecomputation, + ifft_precomp: &'a IFFTPrecomputation, ) -> Witness { let mut w_extended = private_variables; let ratio = variable_domain.size() / input_domain.size(); @@ -141,7 +151,7 @@ impl AHPForR1CS { let x_evals = { let mut coeffs = x_poly.coeffs; coeffs.resize(variable_domain.size(), F::zero()); - variable_domain.in_order_fft_in_place_with_pc(&mut coeffs, &circuit.fft_precomputation); + variable_domain.in_order_fft_in_place_with_pc(&mut coeffs, fft_precomp); coeffs }; @@ -158,8 +168,8 @@ impl AHPForR1CS { // Interpolating \widetilde{z} - \widetilde{x} and dividing by the // vanishing polynomial over variable_domain. - let w_poly = EvaluationsOnDomain::from_vec_and_domain(w_poly_evals, variable_domain) - .interpolate_with_pc(&circuit.ifft_precomputation); + let w_poly = + EvaluationsOnDomain::from_vec_and_domain(w_poly_evals, variable_domain).interpolate_with_pc(ifft_precomp); let (w_poly, remainder) = w_poly.divide_by_vanishing_poly(input_domain).unwrap(); assert!(remainder.is_zero()); diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs index 673d8366fc..a337e47948 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/fourth.rs @@ -83,6 +83,8 @@ impl AHPForR1CS { let verifier::SecondMessage { alpha, .. } = second_message; let verifier::ThirdMessage { beta } = third_message; + let fft_precomp = state.fft_precomputation; + let ifft_precomp = state.ifft_precomputation; let mut pool = ExecutionPool::with_capacity(3 * state.circuit_specific_states.len()); @@ -109,8 +111,8 @@ impl AHPForR1CS { *beta, v_R_i_alpha_v_C_i_beta, max_non_zero_domain_size, - &circuit.fft_precomputation, - &circuit.ifft_precomputation, + fft_precomp, + ifft_precomp, ); (circuit, result) }); diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs index 57c6d8a2df..275ed94df3 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/mod.rs @@ -14,6 +14,7 @@ // limitations under the License. use crate::{ + fft::domain::{FFTPrecomputation, IFFTPrecomputation}, r1cs::ConstraintSynthesizer, snark::varuna::{ SNARKMode, @@ -41,6 +42,8 @@ impl AHPForR1CS { /// Initialize the AHP prover. pub fn init_prover<'a, C: ConstraintSynthesizer, R: Rng + CryptoRng>( circuits_to_constraints: &BTreeMap<&'a Circuit, &[C]>, + fft_precomp: &'a FFTPrecomputation, + ifft_precomp: &'a IFFTPrecomputation, rng: &mut R, ) -> Result, AHPError> { let init_time = start_timer!(|| "AHP::Prover::Init"); @@ -162,7 +165,7 @@ impl AHPForR1CS { }) .collect::, Vec>>, AHPError>>()?; - let state = prover::State::initialize(indices_and_assignments)?; + let state = prover::State::initialize(indices_and_assignments, fft_precomp, ifft_precomp)?; end_timer!(init_time); Ok(state) diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs index a90fdb735b..f824c9bba7 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/prepare_third.rs @@ -82,11 +82,11 @@ impl AHPForR1CS { let total_instances = num_instances.iter().sum::(); let matrix_labels = ["a", "b", "c"]; - let fft_precomputations = state - .circuit_specific_states - .keys() - .map(|circuit| (circuit.fft_precomputation.clone(), circuit.ifft_precomputation.clone())) - .collect_vec(); + let fft_precomputation = &state.fft_precomputation; + let ifft_precomputation = &state.ifft_precomputation; + + let fft_precomputations = + state.circuit_specific_states.keys().map(|_| (fft_precomputation, ifft_precomputation)).collect_vec(); // Compute lineval sumcheck witnesses let mut job_pool = ExecutionPool::with_capacity(total_instances * 3); @@ -108,8 +108,8 @@ impl AHPForR1CS { label, &circuit_specific_state.constraint_domain, &circuit_specific_state.variable_domain, - &precomp.0, - &precomp.1, + precomp.0, + precomp.1, assignment, matrix_transpose, *alpha, diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs index 7d93446d94..b2d0284b29 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/second.rs @@ -16,10 +16,15 @@ use std::collections::BTreeMap; use crate::{ - fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain, polynomial::PolyMultiplier}, + fft::{ + DensePolynomial, + EvaluationDomain, + Evaluations as EvaluationsOnDomain, + domain::IFFTPrecomputation, + polynomial::PolyMultiplier, + }, polycommit::sonic_pc::{LabeledPolynomial, PolynomialInfo, PolynomialLabel}, snark::varuna::{ - Circuit, CircuitId, SNARKMode, ahp::{AHPForR1CS, verifier}, @@ -80,6 +85,9 @@ impl AHPForR1CS { let mut job_pool = ExecutionPool::with_capacity(state.circuit_specific_states.len()); let max_constraint_domain = state.max_constraint_domain; + let fft_precomp = state.fft_precomputation; + let ifft_precomp = state.ifft_precomputation; + for (circuit, circuit_specific_state) in state.circuit_specific_states.iter_mut() { let z_a = circuit_specific_state.z_a.take().unwrap(); let z_b = circuit_specific_state.z_b.take().unwrap(); @@ -88,8 +96,6 @@ impl AHPForR1CS { let circuit_combiner = batch_combiners[&circuit.id].circuit_combiner; let instance_combiners = batch_combiners[&circuit.id].instance_combiners.clone(); let constraint_domain = circuit_specific_state.constraint_domain; - let fft_precomputation = &circuit.fft_precomputation; - let ifft_precomputation = &circuit.ifft_precomputation; let _circuit_id = &circuit.id; // seems like a compiler bug marks this as unused @@ -101,11 +107,11 @@ impl AHPForR1CS { let za_label = witness_label(circuit.id, "z_a", j); let zb_label = witness_label(circuit.id, "z_b", j); let zc_label = witness_label(circuit.id, "z_c", j); - let z_a = Self::calculate_z_m(za_label, z_a, constraint_domain, circuit); - let z_b = Self::calculate_z_m(zb_label, z_b, constraint_domain, circuit); - let z_c = Self::calculate_z_m(zc_label, z_c, constraint_domain, circuit); + let z_a = Self::calculate_z_m(za_label, z_a, constraint_domain, ifft_precomp); + let z_b = Self::calculate_z_m(zb_label, z_b, constraint_domain, ifft_precomp); + let z_c = Self::calculate_z_m(zc_label, z_c, constraint_domain, ifft_precomp); let mut multiplier_2 = PolyMultiplier::new(); - multiplier_2.add_precomputation(fft_precomputation, ifft_precomputation); + multiplier_2.add_precomputation(fft_precomp, ifft_precomp); multiplier_2.add_polynomial(z_a, "z_a"); multiplier_2.add_polynomial(z_b, "z_b"); let mut rowcheck = multiplier_2.multiply().unwrap(); @@ -145,13 +151,13 @@ impl AHPForR1CS { label: impl ToString, evaluations: Vec, constraint_domain: EvaluationDomain, - circuit: &Circuit, + ifft_precomp: &IFFTPrecomputation, ) -> DensePolynomial { let label = label.to_string(); let poly_time = start_timer!(|| format!("Computing {label}")); let evals = EvaluationsOnDomain::from_vec_and_domain(evaluations, constraint_domain); - let poly = evals.interpolate_with_pc_by_ref(&circuit.ifft_precomputation); + let poly = evals.interpolate_with_pc_by_ref(ifft_precomp); debug_assert!( poly.evaluate_over_domain_by_ref(constraint_domain) diff --git a/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs b/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs index cd4bdb7b14..9dc675a658 100644 --- a/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs +++ b/algorithms/src/snark/varuna/ahp/prover/round_functions/third.rs @@ -154,12 +154,15 @@ impl AHPForR1CS { let matrix_labels = ["a", "b", "c"]; let matrix_combiners = [F::one(), *eta_b, *eta_c]; + let fft_precomputation = &state.fft_precomputation; + let ifft_precomputation = &state.ifft_precomputation; + // Compute lineval sumcheck witnesses let mut job_pool = ExecutionPool::with_capacity(total_instances * 3); // Iterate for each circuit in the batch. - for ((((circuit, circuit_specific_state), batch_combiner), assignments_i), matrix_transposes_i) in state + for (((circuit_specific_state, batch_combiner), assignments_i), matrix_transposes_i) in state .circuit_specific_states - .iter_mut() + .values_mut() .zip_eq(third_round_batch_combiners.values()) .zip_eq(assignments.values()) .zip_eq(matrix_transposes.values()) @@ -168,8 +171,6 @@ impl AHPForR1CS { let instance_combiners = &batch_combiner.instance_combiners; let constraint_domain = &circuit_specific_state.constraint_domain; let variable_domain = &circuit_specific_state.variable_domain; - let fft_precomputation = &circuit.fft_precomputation; - let ifft_precomputation = &circuit.ifft_precomputation; // Iterate for each instance in the batch. for (instance_combiner, assignment) in itertools::izip!(instance_combiners, assignments_i) { diff --git a/algorithms/src/snark/varuna/ahp/prover/state.rs b/algorithms/src/snark/varuna/ahp/prover/state.rs index 4f3684854f..4428b15393 100644 --- a/algorithms/src/snark/varuna/ahp/prover/state.rs +++ b/algorithms/src/snark/varuna/ahp/prover/state.rs @@ -16,7 +16,12 @@ use std::collections::{BTreeMap, VecDeque}; use crate::{ - fft::{DensePolynomial, EvaluationDomain, Evaluations as EvaluationsOnDomain}, + fft::{ + DensePolynomial, + EvaluationDomain, + Evaluations as EvaluationsOnDomain, + domain::{FFTPrecomputation, IFFTPrecomputation}, + }, polycommit::sonic_pc::LabeledPolynomial, r1cs::{SynthesisError, SynthesisResult}, snark::varuna::{AHPError, AHPForR1CS, Circuit, SNARKMode}, @@ -91,6 +96,10 @@ pub struct State<'a, F: PrimeField, SM: SNARKMode> { pub(in crate::snark) max_variable_domain: EvaluationDomain, /// The total number of instances we're proving in the batch. pub(in crate::snark) total_instances: usize, + /// The precomputed roots for calculating FFTs. + pub fft_precomputation: &'a FFTPrecomputation, + /// The precomputed inverse roots for calculating inverse FFTs. + pub ifft_precomputation: &'a IFFTPrecomputation, } /// The public inputs for a single instance. @@ -116,6 +125,8 @@ pub(super) struct Assignments( impl<'a, F: PrimeField, SM: SNARKMode> State<'a, F, SM> { pub(super) fn initialize( indices_and_assignments: BTreeMap<&'a Circuit, Vec>>, + fft_precomputation: &'a FFTPrecomputation, + ifft_precomputation: &'a IFFTPrecomputation, ) -> Result { let mut max_non_zero_domain: Option> = None; let mut max_num_constraints = 0; @@ -195,6 +206,8 @@ impl<'a, F: PrimeField, SM: SNARKMode> State<'a, F, SM> { circuit_specific_states, total_instances, first_round_oracles: None, + fft_precomputation, + ifft_precomputation, }) } diff --git a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs index eef12185ef..d19bc77924 100644 --- a/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs +++ b/algorithms/src/snark/varuna/data_structures/circuit_proving_key.rs @@ -32,6 +32,7 @@ pub struct CircuitProvingKey { /// The circuit itself. pub circuit: Arc>, /// The committer key for this index, trimmed from the universal SRS. + /// Note: no longer used, kept for backward compatibility. pub committer_key: Arc>, } diff --git a/algorithms/src/snark/varuna/data_structures/mod.rs b/algorithms/src/snark/varuna/data_structures/mod.rs index a28f97cac7..a81c111007 100644 --- a/algorithms/src/snark/varuna/data_structures/mod.rs +++ b/algorithms/src/snark/varuna/data_structures/mod.rs @@ -44,3 +44,6 @@ pub mod test_exports { /// The Varuna universal SRS. pub type UniversalSRS = crate::polycommit::sonic_pc::UniversalParams; + +/// The Varuna universal prover. +pub type UniversalProver = crate::srs::UniversalProver; diff --git a/algorithms/src/snark/varuna/tests.rs b/algorithms/src/snark/varuna/tests.rs index f1d564e284..1fa4fad1f1 100644 --- a/algorithms/src/snark/varuna/tests.rs +++ b/algorithms/src/snark/varuna/tests.rs @@ -27,6 +27,7 @@ mod varuna { proof::proof_size, test_circuit::TestCircuit, }, + srs::UniversalProver, traits::{AlgebraicSponge, SNARK}, }; @@ -53,7 +54,7 @@ mod varuna { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = $snark_inst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); + let mut universal_prover = UniversalProver::default(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); @@ -69,10 +70,10 @@ mod varuna { let mut fake_inputs = public_inputs.clone(); fake_inputs[public_inputs.len() - 1] = random; - let (index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap(); + let (index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); println!("Called circuit setup"); - let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); + let certificate = $snark_inst::prove_vk(&universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circ, &index_vk, &certificate).unwrap()); println!("verified vk"); @@ -81,7 +82,7 @@ mod varuna { } assert_eq!(664, index_vk.to_bytes_le().unwrap().len(), "Update me if serialization has changed"); - let proof = $snark_inst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circ, rng).unwrap(); + let proof = $snark_inst::prove(&universal_prover, &fs_parameters, &index_pk, varuna_version, &circ, rng).unwrap(); println!("Called prover"); assert!($snark_inst::verify(universal_verifier, &fs_parameters, &index_vk, varuna_version, public_inputs.clone(), &proof).unwrap()); @@ -113,14 +114,14 @@ mod varuna { let unique_instances = constraints.values().map(|instances| &instances[0]).collect::>(); let index_keys = - $snark_inst::batch_circuit_setup(&universal_srs, unique_instances.as_slice()).unwrap(); + $snark_inst::batch_circuit_setup(&universal_srs, &mut universal_prover, unique_instances.as_slice()).unwrap(); println!("Called circuit setup"); let mut pks_to_constraints = BTreeMap::new(); let mut vks_to_inputs = BTreeMap::new(); for (index_pk, index_vk) in index_keys.iter() { - let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); + let certificate = $snark_inst::prove_vk(&universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap(); let circuits = constraints[&index_pk.circuit.id].as_slice(); assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circuits[0], &index_vk, &certificate).unwrap()); pks_to_constraints.insert(index_pk, circuits); @@ -129,7 +130,7 @@ mod varuna { println!("verified vks"); let proof = - $snark_inst::prove_batch(universal_prover, &fs_parameters, varuna_version, &pks_to_constraints, rng).unwrap(); + $snark_inst::prove_batch(&universal_prover, &fs_parameters, varuna_version, &pks_to_constraints, rng).unwrap(); println!("Called prover"); if varuna_version == VarunaVersion::V2 { @@ -195,7 +196,8 @@ mod varuna { let mul_depth = 1; let (circ, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); println!("Called circuit setup"); // Serialize @@ -220,7 +222,8 @@ mod varuna { let mul_depth = 1; let (circ, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); println!("Called circuit setup"); // Serialize @@ -244,8 +247,8 @@ mod varuna { fn prove_and_verify_with_tall_matrix_big() { let num_constraints = 100; let num_variables = 25; - let pk_size_zk = 91971; - let pk_size_posw = 91633; + let pk_size_zk = 53767; + let pk_size_posw = 53623; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -265,8 +268,8 @@ mod varuna { fn prove_and_verify_with_tall_matrix_small() { let num_constraints = 26; let num_variables = 25; - let pk_size_zk = 25428; - let pk_size_posw = 25090; + let pk_size_zk = 15463; + let pk_size_posw = 15319; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -286,8 +289,8 @@ mod varuna { fn prove_and_verify_with_squat_matrix_big() { let num_constraints = 25; let num_variables = 100; - let pk_size_zk = 53523; - let pk_size_posw = 53185; + let pk_size_zk = 15319; + let pk_size_posw = 15175; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -307,8 +310,8 @@ mod varuna { fn prove_and_verify_with_squat_matrix_small() { let num_constraints = 25; let num_variables = 26; - let pk_size_zk = 25284; - let pk_size_posw = 24946; + let pk_size_zk = 15319; + let pk_size_posw = 15175; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -328,8 +331,8 @@ mod varuna { fn prove_and_verify_with_square_matrix() { let num_constraints = 25; let num_variables = 25; - let pk_size_zk = 25284; - let pk_size_posw = 24946; + let pk_size_zk = 15319; + let pk_size_posw = 15175; let mut rng = TestRng::default(); SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng); @@ -352,6 +355,7 @@ mod varuna_hiding { crypto_hash::PoseidonSponge, snark::varuna::{ CircuitVerifyingKey, + UniversalProver, VarunaHidingMode, VarunaSNARK, VarunaVersion, @@ -381,7 +385,6 @@ mod varuna_hiding { ) { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); @@ -396,11 +399,13 @@ mod varuna_hiding { let mut fake_inputs = public_inputs.clone(); fake_inputs[public_inputs.len() - 1] = Fr::rand(rng); - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = + VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); let proof = - VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); + VarunaInst::prove(&universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); println!("Called prover"); assert!( @@ -453,7 +458,8 @@ mod varuna_hiding { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); // Serialize @@ -473,7 +479,8 @@ mod varuna_hiding { let mul_depth = 1; let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); // Serialize @@ -568,7 +575,6 @@ mod varuna_hiding { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); for varuna_version in [VarunaVersion::V1, VarunaVersion::V2] { @@ -576,15 +582,17 @@ mod varuna_hiding { VarunaVersion::V1 => VarunaVersion::V2, VarunaVersion::V2 => VarunaVersion::V1, }; - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = + VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); println!("Called circuit setup"); let proof = - VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); + VarunaInst::prove(&universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap(); println!("Called prover"); universal_srs.download_powers_for(0..2usize.pow(18)).unwrap(); - let (new_pk, new_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap(); + let (new_pk, new_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit).unwrap(); assert_eq!(index_pk, new_pk); assert_eq!(index_vk, new_vk); assert!( @@ -629,7 +637,6 @@ mod varuna_hiding { let max_degree = AHPForR1CS::::max_degree(100, 25, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_parameters = FS::sample_parameters(); let varuna_version = VarunaVersion::V2; @@ -640,10 +647,12 @@ mod varuna_hiding { let num_constraints = 2usize.pow(15) - 10; let num_variables = 2usize.pow(15) - 10; let (circuit1, public_inputs1) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &circuit1).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit1).unwrap(); println!("Called circuit setup"); - let proof1 = VarunaInst::prove(universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap(); + let proof1 = + VarunaInst::prove(&universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap(); println!("Called prover"); assert!( VarunaInst::verify( @@ -666,10 +675,11 @@ mod varuna_hiding { let num_constraints = 2usize.pow(19) - 10; let num_variables = 2usize.pow(19) - 10; let (circuit2, public_inputs2) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng); - let (pk2, vk2) = VarunaInst::circuit_setup(&universal_srs, &circuit2).unwrap(); + let (pk2, vk2) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit2).unwrap(); println!("Called circuit setup"); - let proof2 = VarunaInst::prove(universal_prover, &fs_parameters, &pk2, varuna_version, &circuit2, rng).unwrap(); + let proof2 = + VarunaInst::prove(&universal_prover, &fs_parameters, &pk2, varuna_version, &circuit2, rng).unwrap(); println!("Called prover"); assert!( VarunaInst::verify(universal_verifier, &fs_parameters, &vk2, varuna_version, public_inputs2, &proof2) @@ -677,6 +687,14 @@ mod varuna_hiding { ); /*************************************************************************** * * */ + // Indexing, proving, and verifying again for a circuit with 1 << 15 constraints + // and 1 << 15 variables. + let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &circuit1).unwrap(); + println!("Called circuit setup"); + + let proof1 = + VarunaInst::prove(&universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap(); + println!("Called prover"); assert!( VarunaInst::verify(universal_verifier, &fs_parameters, &vk1, varuna_version, public_inputs1, &proof1) .unwrap() @@ -688,6 +706,7 @@ mod varuna_test_vectors { use crate::{ fft::EvaluationDomain, snark::varuna::{AHPForR1CS, TestCircuit, VarunaNonHidingMode, VarunaSNARK, VarunaVersion, ahp::verifier}, + srs::UniversalProver, traits::snark::SNARK, }; use snarkvm_curves::bls12_377::{Bls12_377, Fq, Fr}; @@ -819,12 +838,19 @@ mod varuna_test_vectors { let (circ, _) = TestCircuit::generate_circuit_with_fixed_witness(a, b, mul_depth, num_constraints, num_variables); println!("Circuit: {circ:?}"); - let (index_pk, _index_vk) = VarunaSonicInst::circuit_setup(&universal_srs, &circ).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, _index_vk) = + VarunaSonicInst::circuit_setup(&universal_srs, &mut universal_prover, &circ).unwrap(); let mut keys_to_constraints = BTreeMap::new(); keys_to_constraints.insert(index_pk.circuit.deref(), std::slice::from_ref(&circ)); // Begin the Varuna protocol execution. - let prover_state = AHPForR1CS::<_, MM>::init_prover(&keys_to_constraints, rng).unwrap(); + let degree_info = index_pk.circuit.index_info.degree_info::().unwrap(); + let universal_prover = universal_srs.to_universal_prover(degree_info).unwrap(); + let fft_precomp = universal_prover.fft_precomputation.unwrap(); + let ifft_precomp = universal_prover.ifft_precomputation.unwrap(); + let prover_state = + AHPForR1CS::<_, MM>::init_prover(&keys_to_constraints, &fft_precomp, &ifft_precomp, rng).unwrap(); let mut prover_state = AHPForR1CS::<_, MM>::prover_first_round(prover_state, rng).unwrap(); let first_round_oracles = Arc::new(prover_state.first_round_oracles.as_ref().unwrap()); diff --git a/algorithms/src/snark/varuna/varuna.rs b/algorithms/src/snark/varuna/varuna.rs index 24d554041a..e90a383d25 100644 --- a/algorithms/src/snark/varuna/varuna.rs +++ b/algorithms/src/snark/varuna/varuna.rs @@ -19,14 +19,9 @@ use crate::{ SNARK, SNARKError, fft::EvaluationDomain, - polycommit::sonic_pc::{ - Commitment, - CommitterUnionKey, - Evaluations, - LabeledCommitment, - QuerySet, - Randomness, - SonicKZG10, + polycommit::{ + kzg10::DegreeInfo, + sonic_pc::{Commitment, Evaluations, LabeledCommitment, QuerySet, Randomness, SonicKZG10}, }, r1cs::{ConstraintSynthesizer, SynthesisError}, snark::varuna::{ @@ -67,56 +62,50 @@ impl, SM: SNARKMode> VarunaSNARK /// Used to personalize the Fiat-Shamir RNG. pub const PROTOCOL_NAME: &'static [u8] = b"VARUNA-2023"; - // TODO: implement optimizations resulting from batching - // (e.g. computing a common set of Lagrange powers, FFT precomputations, - // etc) + /// Creates a set of proving and verifying keys for a set of circuits + /// We commit to each circuit's index polys separately + /// Performing a single batch commitment can be a future optimization pub fn batch_circuit_setup>( - universal_srs: &UniversalSRS, + srs: &UniversalSRS, + universal_prover: &mut UniversalProver, circuits: &[&C], ) -> Result, CircuitVerifyingKey)>> { - let index_time = start_timer!(|| "Varuna::CircuitSetup"); + ensure!(!circuits.is_empty()); - let universal_prover = &universal_srs.to_universal_prover()?; + let index_time = start_timer!(|| "Varuna::CircuitSetup"); - let mut circuit_keys = Vec::with_capacity(circuits.len()); + let mut indexed_circuits = Vec::with_capacity(circuits.len()); + let mut degree_info = None::; for circuit in circuits { - let mut indexed_circuit = AHPForR1CS::<_, SM>::index(*circuit)?; - // TODO: Add check that c is in the correct mode. - // Ensure the universal SRS supports the circuit size. - universal_srs.download_powers_for(0..indexed_circuit.max_degree()?).map_err(|e| { - anyhow!("Failed to download powers for degree {}: {e}", indexed_circuit.max_degree().unwrap()) - })?; - let coefficient_support = AHPForR1CS::::get_degree_bounds(&indexed_circuit.index_info)?; - - // Varuna only needs degree 2 random polynomials. - let supported_hiding_bound = 1; - let supported_lagrange_sizes = [].into_iter(); // TODO: consider removing lagrange_bases_at_beta_g from CommitterKey - let (committer_key, _) = SonicKZG10::::trim( - universal_srs, - indexed_circuit.max_degree()?, - supported_lagrange_sizes, - supported_hiding_bound, - Some(coefficient_support.as_slice()), - )?; + let indexed_circuit = AHPForR1CS::<_, SM>::index(*circuit)?; + let degree_info_i = indexed_circuit.index_info.degree_info::()?; + degree_info = match degree_info { + Some(degree_info) => Some(degree_info.union(°ree_info_i)), + None => Some(degree_info_i), + }; + indexed_circuits.push(indexed_circuit); + } + let degree_info = degree_info.unwrap(); - let ck = CommitterUnionKey::union(std::iter::once(&committer_key)); + universal_prover.update(srs, degree_info)?; + let mut circuit_keys = Vec::with_capacity(circuits.len()); + for mut indexed_circuit in indexed_circuits { let commit_time = start_timer!(|| format!("Commit to index polynomials for {}", indexed_circuit.id)); let setup_rng = None::<&mut dyn RngCore>; // We do not randomize the commitments let (mut circuit_commitments, commitment_randomnesses): (_, _) = SonicKZG10::::commit( universal_prover, - &ck, indexed_circuit.interpolate_matrix_evals()?.map(Into::into), setup_rng, )?; + indexed_circuit.prune_row_col_evals(); // row_col were only needed during indexing let empty_randomness = Randomness::::empty(); ensure!(commitment_randomnesses.iter().all(|r| r == &empty_randomness)); end_timer!(commit_time); circuit_commitments.sort_by(|c1, c2| c1.label().cmp(c2.label())); let circuit_commitments = circuit_commitments.into_iter().map(|c| *c.commitment()).collect(); - indexed_circuit.prune_row_col_evals(); let circuit_verifying_key = CircuitVerifyingKey { circuit_info: indexed_circuit.index_info, circuit_commitments, @@ -125,7 +114,7 @@ impl, SM: SNARKMode> VarunaSNARK let circuit_proving_key = CircuitProvingKey { circuit_verifying_key: circuit_verifying_key.clone(), circuit: Arc::new(indexed_circuit), - committer_key: Arc::new(committer_key), + committer_key: Default::default(), }; circuit_keys.push((circuit_proving_key, circuit_verifying_key)); } @@ -229,10 +218,11 @@ where /// Generates the circuit proving and verifying keys. /// This is a deterministic algorithm that anyone can rerun. fn circuit_setup>( - universal_srs: &Self::UniversalSRS, + srs: &Self::UniversalSRS, + universal_prover: &mut Self::UniversalProver, circuit: &C, ) -> Result<(Self::ProvingKey, Self::VerifyingKey)> { - let mut circuit_keys = Self::batch_circuit_setup::(universal_srs, &[circuit])?; + let mut circuit_keys = Self::batch_circuit_setup::(srs, universal_prover, &[circuit])?; ensure!(circuit_keys.len() == 1); Ok(circuit_keys.pop().unwrap()) } @@ -267,14 +257,13 @@ where } let query_set = QuerySet::from_iter([("circuit_check".into(), ("challenge".into(), point))]); - let committer_key = CommitterUnionKey::union(std::iter::once(proving_key.committer_key.as_ref())); let empty_randomness = vec![Randomness::::empty(); 12]; + let index_polynomials = proving_key.circuit.interpolate_matrix_evals()?; let certificate = SonicKZG10::::open_combinations( universal_prover, - &committer_key, &[lc], - proving_key.circuit.interpolate_matrix_evals()?, + index_polynomials, &empty_randomness, &query_set, &mut sponge, @@ -364,7 +353,13 @@ where for (pk, constraints) in keys_to_constraints { circuits_to_constraints.insert(pk.circuit.deref(), *constraints); } - let prover_state = AHPForR1CS::<_, SM>::init_prover(&circuits_to_constraints, zk_rng)?; + + let prover_state = match (&universal_prover.fft_precomputation, &universal_prover.ifft_precomputation) { + (Some(fft_precomp), Some(ifft_precomp)) => { + AHPForR1CS::<_, SM>::init_prover(&circuits_to_constraints, fft_precomp, ifft_precomp, zk_rng)? + } + _ => return Err(SNARKError::FFTPrecompNotFound.into()), + }; // extract information from the prover key and state to consume in further // calculations @@ -391,8 +386,6 @@ where } ensure!(prover_state.total_instances == total_instances); - let committer_key = CommitterUnionKey::union(keys_to_constraints.keys().map(|pk| pk.committer_key.deref())); - let circuit_commitments = keys_to_constraints.keys().map(|pk| pk.circuit_verifying_key.circuit_commitments.as_slice()); @@ -408,7 +401,6 @@ where let first_round_oracles = prover_state.first_round_oracles.as_ref().unwrap(); SonicKZG10::::commit( universal_prover, - &committer_key, first_round_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )? @@ -436,7 +428,6 @@ where let second_round_comm_time = start_timer!(|| "Committing to second round polys"); let (second_commitments, second_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, second_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -496,7 +487,6 @@ where let third_round_comm_time = start_timer!(|| "Committing to third round polys"); let (third_commitments, third_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, third_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -544,7 +534,6 @@ where let fourth_round_comm_time = start_timer!(|| "Committing to fourth round polys"); let (fourth_commitments, fourth_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, fourth_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -570,7 +559,6 @@ where let fifth_round_comm_time = start_timer!(|| "Committing to fifth round polys"); let (fifth_commitments, fifth_commitment_randomnesses) = SonicKZG10::::commit( universal_prover, - &committer_key, fifth_oracles.iter().map(Into::into), SM::ZK.then_some(zk_rng), )?; @@ -672,7 +660,6 @@ where let pc_proof = SonicKZG10::::open_combinations( universal_prover, - &committer_key, lc_s.values(), polynomials, &commitment_randomnesses, diff --git a/algorithms/src/srs/universal_prover.rs b/algorithms/src/srs/universal_prover.rs index 3f0e1702c5..922fb17801 100644 --- a/algorithms/src/srs/universal_prover.rs +++ b/algorithms/src/srs/universal_prover.rs @@ -13,13 +13,70 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::{ + fft::{ + EvaluationDomain, + domain::{FFTPrecomputation, IFFTPrecomputation}, + }, + polycommit::{kzg10::DegreeInfo, sonic_pc::CommitterKey}, + snark::varuna::UniversalSRS, +}; use snarkvm_curves::PairingEngine; +use anyhow::{Result, anyhow}; + /// `UniversalProver` is used to compute evaluation proofs for a given /// commitment. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Debug)] pub struct UniversalProver { + /// The committer key for the underlying KZG10 scheme. + pub committer_key: CommitterKey, + /// The precomputed roots for calculating FFTs. + pub fft_precomputation: Option>, + /// The precomputed inverse roots for calculating inverse FFTs. + pub ifft_precomputation: Option>, /// The maximum degree supported by the universal SRS. pub max_degree: usize, - pub _unused: Option, +} + +impl Default for UniversalProver { + fn default() -> Self { + let committer_key = CommitterKey { + powers_of_beta_g: Vec::new(), + lagrange_bases_at_beta_g: Default::default(), + powers_of_beta_times_gamma_g: Vec::new(), + shifted_powers_of_beta_g: None, + shifted_powers_of_beta_times_gamma_g: None, + enforced_degree_bounds: None, + }; + Self { committer_key, fft_precomputation: None, ifft_precomputation: None, max_degree: 0 } + } +} + +impl UniversalProver { + /// Update the UniversalProver: + /// - set a new max_degree + /// - grow or shrink the committer_key + /// - grow or shrink the FFT and iFFT precomputation + /// + /// Note that this implementation is not atomic. If specialize fails halfway + /// through, we might have an inconsistent UniversalProver. + pub fn update(&mut self, srs: &UniversalSRS, degree_info: DegreeInfo) -> Result<()> { + // Set the new max_degree + self.max_degree = degree_info.max_degree; + // Update the committer key. + self.committer_key.update(srs, °ree_info)?; + // Specialize the FFT and iFFT precomputation. + self.update_fft_precomputation(degree_info.max_fft_size)?; + Ok(()) + } + + /// Returns the FFT and iFFT precomputation for the largest supported + /// domain. + fn update_fft_precomputation(&mut self, max_domain_size: usize) -> Result<()> { + let domain = EvaluationDomain::new(max_domain_size).ok_or(anyhow!("Invalid domain size"))?; + self.fft_precomputation = Some(domain.precompute_fft()); + self.ifft_precomputation = Some(self.fft_precomputation.as_ref().unwrap().to_ifft_precomputation()); + Ok(()) + } } diff --git a/algorithms/src/traits/snark.rs b/algorithms/src/traits/snark.rs index 33b6253d89..a2c9988dc1 100644 --- a/algorithms/src/traits/snark.rs +++ b/algorithms/src/traits/snark.rs @@ -60,6 +60,7 @@ pub trait SNARK { fn circuit_setup>( srs: &Self::UniversalSRS, + universal_prover: &mut Self::UniversalProver, circuit: &C, ) -> Result<(Self::ProvingKey, Self::VerifyingKey)>; diff --git a/circuit/environment/src/helpers/assignment.rs b/circuit/environment/src/helpers/assignment.rs index 3c42a4209d..c960f7b140 100644 --- a/circuit/environment/src/helpers/assignment.rs +++ b/circuit/environment/src/helpers/assignment.rs @@ -227,7 +227,12 @@ impl snarkvm_algorithms::r1cs::ConstraintSynthesizer for Assig #[cfg(test)] mod tests { - use snarkvm_algorithms::{AlgebraicSponge, SNARK, r1cs::ConstraintSynthesizer, snark::varuna::VarunaVersion}; + use snarkvm_algorithms::{ + AlgebraicSponge, + SNARK, + r1cs::ConstraintSynthesizer, + snark::varuna::{UniversalProver, VarunaVersion}, + }; use snarkvm_circuit::prelude::*; use snarkvm_curves::bls12_377::Fr; @@ -300,15 +305,15 @@ mod tests { let max_degree = AHPForR1CS::::max_degree(200, 200, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_pp = FS::sample_parameters(); - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &assignment).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = + VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &assignment).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - - let proof = VarunaInst::prove(universal_prover, &fs_pp, &index_pk, varuna_version, &assignment, rng).unwrap(); + let proof = VarunaInst::prove(&universal_prover, &fs_pp, &index_pk, varuna_version, &assignment, rng).unwrap(); println!("Called prover"); let one = ::BaseField::one(); diff --git a/circuit/environment/src/helpers/converter.rs b/circuit/environment/src/helpers/converter.rs index 5d3ad59a0f..d63a9871c5 100644 --- a/circuit/environment/src/helpers/converter.rs +++ b/circuit/environment/src/helpers/converter.rs @@ -201,7 +201,12 @@ impl R1CS { #[cfg(test)] mod tests { - use snarkvm_algorithms::{AlgebraicSponge, SNARK, r1cs::ConstraintSynthesizer, snark::varuna::VarunaVersion}; + use snarkvm_algorithms::{ + AlgebraicSponge, + SNARK, + r1cs::ConstraintSynthesizer, + snark::varuna::{UniversalProver, VarunaVersion}, + }; use snarkvm_circuit::prelude::*; use snarkvm_curves::bls12_377::Fr; @@ -265,15 +270,14 @@ mod tests { let max_degree = AHPForR1CS::::max_degree(200, 200, 300).unwrap(); let universal_srs = VarunaInst::universal_setup(max_degree).unwrap(); - let universal_prover = &universal_srs.to_universal_prover().unwrap(); let universal_verifier = &universal_srs.to_universal_verifier().unwrap(); let fs_pp = FS::sample_parameters(); - let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &Circuit).unwrap(); + let mut universal_prover = UniversalProver::default(); + let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &mut universal_prover, &Circuit).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - - let proof = VarunaInst::prove(universal_prover, &fs_pp, &index_pk, varuna_version, &Circuit, rng).unwrap(); + let proof = VarunaInst::prove(&universal_prover, &fs_pp, &index_pk, varuna_version, &Circuit, rng).unwrap(); println!("Called prover"); assert!( diff --git a/console/network/src/canary_v0.rs b/console/network/src/canary_v0.rs index c39b3bb73a..c2ef592318 100644 --- a/console/network/src/canary_v0.rs +++ b/console/network/src/canary_v0.rs @@ -345,11 +345,6 @@ impl Network for CanaryV0 { .sum() } - /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver { - MainnetV0::varuna_universal_prover() - } - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier { MainnetV0::varuna_universal_verifier() diff --git a/console/network/src/lib.rs b/console/network/src/lib.rs index 2d3f17a784..977c69a7cf 100644 --- a/console/network/src/lib.rs +++ b/console/network/src/lib.rs @@ -64,7 +64,7 @@ use snarkvm_algorithms::{ AlgebraicSponge, crypto_hash::PoseidonSponge, snark::varuna::{CircuitProvingKey, CircuitVerifyingKey, VarunaHidingMode}, - srs::{UniversalProver, UniversalVerifier}, + srs::UniversalVerifier, }; use snarkvm_console_algorithms::{BHP512, BHP1024, Poseidon2, Poseidon4}; use snarkvm_console_collections::merkle_tree::{MerklePath, MerkleTree}; @@ -357,9 +357,6 @@ pub trait Network: /// Returns the scalar multiplication on the generator `G`. fn g_scalar_multiply(scalar: &Scalar) -> Group; - /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver; - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier; diff --git a/console/network/src/mainnet_v0.rs b/console/network/src/mainnet_v0.rs index 819deed383..50599211c8 100644 --- a/console/network/src/mainnet_v0.rs +++ b/console/network/src/mainnet_v0.rs @@ -350,17 +350,6 @@ impl Network for MainnetV0 { .sum() } - /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver { - static INSTANCE: OnceLock::PairingCurve>> = OnceLock::new(); - INSTANCE.get_or_init(|| { - snarkvm_algorithms::polycommit::kzg10::UniversalParams::load() - .expect("Failed to load universal SRS (KZG10).") - .to_universal_prover() - .expect("Failed to convert universal SRS (KZG10) to the prover.") - }) - } - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier { static INSTANCE: OnceLock::PairingCurve>> = OnceLock::new(); diff --git a/console/network/src/testnet_v0.rs b/console/network/src/testnet_v0.rs index 0043ce7030..9b7de9fcd3 100644 --- a/console/network/src/testnet_v0.rs +++ b/console/network/src/testnet_v0.rs @@ -345,11 +345,6 @@ impl Network for TestnetV0 { .sum() } - /// Returns the Varuna universal prover. - fn varuna_universal_prover() -> &'static UniversalProver { - MainnetV0::varuna_universal_prover() - } - /// Returns the Varuna universal verifier. fn varuna_universal_verifier() -> &'static UniversalVerifier { MainnetV0::varuna_universal_verifier() diff --git a/ledger/puzzle/epoch/Cargo.toml b/ledger/puzzle/epoch/Cargo.toml index 5e60ebcd78..7c13d8dcd7 100644 --- a/ledger/puzzle/epoch/Cargo.toml +++ b/ledger/puzzle/epoch/Cargo.toml @@ -29,7 +29,8 @@ serial = [ "snarkvm-console/serial", "snarkvm-ledger-puzzle/serial" ] locktick = [ "dep:locktick", "snarkvm-ledger-puzzle/locktick", - "snarkvm-synthesizer-process/locktick" + "snarkvm-synthesizer-process/locktick", + "snarkvm-synthesizer-program/locktick" ] merkle = [ ] synthesis = [ diff --git a/parameters/examples/inclusion.rs b/parameters/examples/inclusion.rs index 131f35ecf0..2fc22b26b9 100644 --- a/parameters/examples/inclusion.rs +++ b/parameters/examples/inclusion.rs @@ -33,10 +33,14 @@ type LedgerType = snarkvm_ledger_store::helpers::rocksdb::ConsensusDB; use snarkvm_synthesizer::{ VM, process::{InclusionAssignment, InclusionV0Assignment}, - snark::UniversalSRS, + snark::{UniversalProver, UniversalSRS}, }; use anyhow::{Result, anyhow}; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; use rand::thread_rng; use serde_json::{Value, json}; use std::{ @@ -171,17 +175,27 @@ pub fn inclusion>() -> Result<()> { // Load the universal SRS. let universal_srs = UniversalSRS::::load()?; + // Load the universal prover. + let universal_prover = RwLock::new(UniversalProver::::load()?); + // Sample the assignment for the inclusion circuit. let (assignment, state_path, serial_number, is_record_block_height_reached, upgrade_block_height) = sample_assignment::()?; // Synthesize the proving and verifying key. let inclusion_function_name = N::INCLUSION_FUNCTION_NAME; - let (proving_key, verifying_key) = universal_srs.to_circuit_key(inclusion_function_name, &assignment)?; + let (proving_key, verifying_key) = universal_srs.to_circuit_key(&universal_prover, inclusion_function_name, &assignment)?; for varuna_version in [VarunaVersion::V1, VarunaVersion::V2] { // Ensure the proving key and verifying keys are valid. - let proof = proving_key.prove(inclusion_function_name, varuna_version, &assignment, &mut thread_rng())?; + let proof = proving_key.prove( + &universal_srs, + &universal_prover, + inclusion_function_name, + varuna_version, + &assignment, + &mut thread_rng(), + )?; assert!(verifying_key.verify( inclusion_function_name, varuna_version, diff --git a/synthesizer/Cargo.toml b/synthesizer/Cargo.toml index cd1e0812cb..b59ee6eb2d 100644 --- a/synthesizer/Cargo.toml +++ b/synthesizer/Cargo.toml @@ -29,7 +29,9 @@ locktick = [ "snarkvm-ledger-puzzle/locktick", "snarkvm-ledger-puzzle-epoch/locktick", "snarkvm-ledger-store/locktick", - "snarkvm-synthesizer-process/locktick" + "snarkvm-synthesizer-process/locktick", + "snarkvm-synthesizer-program/locktick", + "snarkvm-synthesizer-snark/locktick" ] async = [ "snarkvm-ledger-query/async", "snarkvm-synthesizer-process/async" ] cuda = [ "snarkvm-algorithms/cuda" ] diff --git a/synthesizer/benches/kary_merkle_tree.rs b/synthesizer/benches/kary_merkle_tree.rs index a0ace39661..ff84724e52 100644 --- a/synthesizer/benches/kary_merkle_tree.rs +++ b/synthesizer/benches/kary_merkle_tree.rs @@ -27,9 +27,13 @@ use snarkvm_console::{ }, types::Field, }; -use snarkvm_synthesizer_snark::{ProvingKey, UniversalSRS}; +use snarkvm_synthesizer_snark::{ProvingKey, UniversalProver, UniversalSRS}; use criterion::Criterion; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; type CurrentNetwork = MainnetV0; type CurrentAleo = AleoV0; @@ -125,8 +129,13 @@ fn batch_prove(c: &mut Criterion) { // Load the universal srs. let universal_srs = UniversalSRS::::load().unwrap(); + + // Load the universal prover. + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + // Construct the proving key. - let (proving_key, _) = universal_srs.to_circuit_key("KaryMerklePathVerification", &assignment).unwrap(); + let (proving_key, _) = + universal_srs.to_circuit_key(&universal_prover, "KaryMerklePathVerification", &assignment).unwrap(); // Log the current time elapsed. println!(" • Generated the proving key in: {} ms", timer.elapsed().as_millis()); @@ -149,8 +158,15 @@ fn batch_prove(c: &mut Criterion) { let varuna_version = VarunaVersion::V2; c.bench_function(&format!("KaryMerkleTree batch prove {num_assignments} assignments"), |b| { b.iter(|| { - let _proof = - ProvingKey::prove_batch("ProveKaryMerkleTree", varuna_version, &assignments, &mut rng).unwrap(); + let _proof = ProvingKey::prove_batch( + &universal_srs, + &universal_prover, + "ProveKaryMerkleTree", + varuna_version, + &assignments, + &mut rng, + ) + .unwrap(); }) }); } diff --git a/synthesizer/process/Cargo.toml b/synthesizer/process/Cargo.toml index ab5c5a09b4..4c035d68b7 100644 --- a/synthesizer/process/Cargo.toml +++ b/synthesizer/process/Cargo.toml @@ -26,7 +26,12 @@ edition = "2024" [features] default = [ "indexmap/rayon", "rayon" ] async = [ "snarkvm-ledger-query/async" ] -locktick = [ "dep:locktick", "snarkvm-ledger-store/locktick" ] +locktick = [ + "dep:locktick", + "snarkvm-ledger-store/locktick", + "snarkvm-synthesizer-program/locktick", + "snarkvm-synthesizer-snark/locktick" +] rocks = [ "snarkvm-ledger-store/rocks" ] serial = [ "snarkvm-console/serial", diff --git a/synthesizer/process/src/execute.rs b/synthesizer/process/src/execute.rs index 198cffc2af..a4624efb91 100644 --- a/synthesizer/process/src/execute.rs +++ b/synthesizer/process/src/execute.rs @@ -38,7 +38,7 @@ impl Process { // The root request does not have to pass on another request's root_tvk. let root_tvk = None; // Initialize the trace. - let trace = Arc::new(RwLock::new(Trace::new())); + let trace = Arc::new(RwLock::new(Trace::new(self.universal_srs.clone(), Arc::clone(&self.universal_prover)))); // Initialize the call stack. let call_stack = CallStack::execute(authorization, trace.clone())?; lap!(timer, "Initialize call stack"); diff --git a/synthesizer/process/src/lib.rs b/synthesizer/process/src/lib.rs index 0af7d846d4..9b1d7df668 100644 --- a/synthesizer/process/src/lib.rs +++ b/synthesizer/process/src/lib.rs @@ -76,7 +76,7 @@ use snarkvm_synthesizer_program::{ Program, StackTrait, }; -use snarkvm_synthesizer_snark::{ProvingKey, UniversalSRS, VerifyingKey}; +use snarkvm_synthesizer_snark::{ProvingKey, UniversalProver, UniversalSRS, VerifyingKey}; use snarkvm_utilities::{defer, dev_println}; use aleo_std::prelude::{finish, lap, timer}; @@ -91,6 +91,8 @@ use std::{collections::HashMap, sync::Arc}; pub struct Process { /// The universal SRS. universal_srs: UniversalSRS, + /// The universal prover. + universal_prover: Arc>>, /// The mapping of program IDs to stacks. stacks: Arc, Arc>>>>, /// The mapping of program IDs to old stacks. @@ -104,8 +106,12 @@ impl Process { let timer = timer!("Process:setup"); // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; lap!(timer, "Initialize process"); // Initialize the 'credits.aleo' program. @@ -231,8 +237,12 @@ impl Process { let timer = timer!("Process::load"); // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; lap!(timer, "Initialize process"); // Initialize the 'credits.aleo' program. @@ -271,8 +281,12 @@ impl Process { let timer = timer!("Process::load_v0"); // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; lap!(timer, "Initialize process"); // Initialize the 'credits.aleo' program. @@ -310,8 +324,12 @@ impl Process { #[cfg(feature = "wasm")] pub fn load_web() -> Result { // Initialize the process. - let mut process = - Self { universal_srs: UniversalSRS::load()?, stacks: Default::default(), old_stacks: Default::default() }; + let mut process = Self { + universal_srs: UniversalSRS::load()?, + universal_prover: Arc::new(RwLock::new(UniversalProver::load()?)), + stacks: Default::default(), + old_stacks: Default::default(), + }; // Initialize the 'credits.aleo' program. let program = Program::credits()?; @@ -392,6 +410,12 @@ impl Process { &self.universal_srs } + /// Returns the universal prover. + #[inline] + pub const fn universal_prover(&self) -> &Arc>> { + &self.universal_prover + } + /// Returns `true` if the process contains the program with the given ID. #[inline] pub fn contains_program(&self, program_id: &ProgramID) -> bool { diff --git a/synthesizer/process/src/stack/deploy.rs b/synthesizer/process/src/stack/deploy.rs index 8e5a9f5d2b..15a29e840e 100644 --- a/synthesizer/process/src/stack/deploy.rs +++ b/synthesizer/process/src/stack/deploy.rs @@ -41,7 +41,13 @@ impl Stack { lap!(timer, "Retrieve the keys for {function_name}"); // Certify the circuit. - let certificate = Certificate::certify(&function_name.to_string(), &proving_key, &verifying_key)?; + let certificate = Certificate::certify( + &self.universal_srs, + &self.universal_prover, + &function_name.to_string(), + &proving_key, + &verifying_key, + )?; lap!(timer, "Certify the circuit"); // Add the verifying key and certificate to the bundle. diff --git a/synthesizer/process/src/stack/helpers/initialize.rs b/synthesizer/process/src/stack/helpers/initialize.rs index ca4486ef54..6364fc953f 100644 --- a/synthesizer/process/src/stack/helpers/initialize.rs +++ b/synthesizer/process/src/stack/helpers/initialize.rs @@ -53,6 +53,7 @@ impl Stack { register_types: Default::default(), finalize_types: Default::default(), universal_srs: process.universal_srs().clone(), + universal_prover: process.universal_prover().clone(), proving_keys: Default::default(), verifying_keys: Default::default(), program_address: program.id().to_address()?, diff --git a/synthesizer/process/src/stack/helpers/synthesize.rs b/synthesizer/process/src/stack/helpers/synthesize.rs index 0475c08647..993f93bd47 100644 --- a/synthesizer/process/src/stack/helpers/synthesize.rs +++ b/synthesizer/process/src/stack/helpers/synthesize.rs @@ -101,7 +101,8 @@ impl Stack { } // Synthesize the proving and verifying key. - let (proving_key, verifying_key) = self.universal_srs.to_circuit_key(&function_name.to_string(), assignment)?; + let (proving_key, verifying_key) = + self.universal_srs.to_circuit_key(&self.universal_prover, &function_name.to_string(), assignment)?; // Insert the proving key. self.insert_proving_key(function_name, proving_key)?; // Insert the verifying key. diff --git a/synthesizer/process/src/stack/mod.rs b/synthesizer/process/src/stack/mod.rs index 8ad0384c1d..ffaea305bd 100644 --- a/synthesizer/process/src/stack/mod.rs +++ b/synthesizer/process/src/stack/mod.rs @@ -79,7 +79,7 @@ use snarkvm_synthesizer_program::{ RegistersTrait, StackTrait, }; -use snarkvm_synthesizer_snark::{Certificate, ProvingKey, UniversalSRS, VerifyingKey}; +use snarkvm_synthesizer_snark::{Certificate, ProvingKey, UniversalProver, UniversalSRS, VerifyingKey}; use aleo_std::prelude::{finish, lap, timer}; use indexmap::IndexMap; @@ -221,6 +221,8 @@ pub struct Stack { finalize_types: Arc, FinalizeTypes>>>, /// The universal SRS. universal_srs: UniversalSRS, + /// The universal prover. + universal_prover: Arc>>, /// The mapping of function name to proving key. proving_keys: Arc, ProvingKey>>>, /// The mapping of function name to verifying key. diff --git a/synthesizer/process/src/tests/test_execute.rs b/synthesizer/process/src/tests/test_execute.rs index ef96c36e5f..59b52058d7 100644 --- a/synthesizer/process/src/tests/test_execute.rs +++ b/synthesizer/process/src/tests/test_execute.rs @@ -32,7 +32,7 @@ use snarkvm_ledger_store::{ helpers::memory::{BlockMemory, FinalizeMemory}, }; use snarkvm_synthesizer_program::{FinalizeGlobalState, FinalizeStoreTrait, Program, StackTrait}; -use snarkvm_synthesizer_snark::UniversalSRS; +use snarkvm_synthesizer_snark::{UniversalProver, UniversalSRS}; use aleo_std::StorageMode; #[cfg(feature = "locktick")] @@ -405,7 +405,7 @@ output r4 as field.private;", assert_eq!(authorization.len(), 1); // Re-run to ensure state continues to work. - let trace = Arc::new(RwLock::new(Trace::new())); + let trace = Arc::new(RwLock::new(Trace::new(process.universal_srs.clone(), process.universal_prover.clone()))); let call_stack = CallStack::execute(authorization, trace).unwrap(); let response = stack.execute_function::(call_stack, None, None, rng).unwrap(); let candidate = response.outputs(); @@ -2486,6 +2486,7 @@ fn test_process_deploy_credits_program() { // Initialize an empty process without the `credits` program. let empty_process = Process { universal_srs: UniversalSRS::::load().unwrap(), + universal_prover: Arc::new(RwLock::new(UniversalProver::::load().unwrap())), stacks: Default::default(), old_stacks: Default::default(), }; diff --git a/synthesizer/process/src/tests/test_serializers.rs b/synthesizer/process/src/tests/test_serializers.rs index ec63e60a32..3f8e1bcf25 100644 --- a/synthesizer/process/src/tests/test_serializers.rs +++ b/synthesizer/process/src/tests/test_serializers.rs @@ -158,7 +158,7 @@ function test_serde_equivalence: let eval_is_ok = res_eval.is_ok(); // Execute the function. - let trace = Trace::new(); + let trace = Trace::new(process.universal_srs.clone(), process.universal_prover.clone()); let res_exec = stack.execute_function::( CallStack::execute(authorization.replicate(), Arc::new(RwLock::new(trace))).unwrap(), None, diff --git a/synthesizer/process/src/trace/mod.rs b/synthesizer/process/src/trace/mod.rs index e37caa3ecd..66ab11cdf9 100644 --- a/synthesizer/process/src/trace/mod.rs +++ b/synthesizer/process/src/trace/mod.rs @@ -27,9 +27,16 @@ use console::{ use snarkvm_algorithms::snark::varuna::VarunaVersion; use snarkvm_ledger_block::{Execution, Fee, Transition}; use snarkvm_ledger_query::QueryTrait; -use snarkvm_synthesizer_snark::{Proof, ProvingKey, VerifyingKey}; - -use std::{collections::HashMap, sync::OnceLock}; +use snarkvm_synthesizer_snark::{Proof, ProvingKey, UniversalProver, UniversalSRS, VerifyingKey}; + +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; +use std::{ + collections::HashMap, + sync::{Arc, OnceLock}, +}; use crate::Authorization; @@ -48,11 +55,16 @@ pub struct Trace { inclusion_assignments: OnceLock>>, /// A tracker for the global state root. global_state_root: OnceLock, + + /// The Universal SRS + srs: UniversalSRS, + /// The Universal Prover + universal_prover: Arc>>, } impl Trace { /// Initializes a new trace. - pub fn new() -> Self { + pub fn new(srs: UniversalSRS, universal_prover: Arc>>) -> Self { Self { transitions: Vec::new(), transition_tasks: HashMap::new(), @@ -60,6 +72,8 @@ impl Trace { inclusion_assignments: OnceLock::new(), global_state_root: OnceLock::new(), call_metrics: Vec::new(), + srs, + universal_prover, } } @@ -179,6 +193,8 @@ impl Trace { let proving_tasks = self.transition_tasks.values().cloned().collect(); // Compute the proof. let (global_state_root, proof) = Self::prove_batch::( + &self.srs, + &self.universal_prover, locator, varuna_version, proving_tasks, @@ -217,6 +233,8 @@ impl Trace { let proving_tasks = self.transition_tasks.values().cloned().collect(); // Compute the proof. let (global_state_root, proof) = Self::prove_batch::( + &self.srs, + &self.universal_prover, "credits.aleo/fee (private or public)", varuna_version, proving_tasks, @@ -301,6 +319,8 @@ impl Trace { impl Trace { /// Returns the global state root and proof for the given assignments. fn prove_batch, R: Rng + CryptoRng>( + srs: &UniversalSRS, + universal_prover: &RwLock>, locator: &str, varuna_version: VarunaVersion, mut proving_tasks: Vec<(ProvingKey, Vec>)>, @@ -369,7 +389,7 @@ impl Trace { } // Compute the proof. - let proof = ProvingKey::prove_batch(locator, varuna_version, &proving_tasks, rng)?; + let proof = ProvingKey::prove_batch(srs, universal_prover, locator, varuna_version, &proving_tasks, rng)?; // Return the global state root and proof. Ok((global_state_root, proof)) } diff --git a/synthesizer/program/Cargo.toml b/synthesizer/program/Cargo.toml index eab2c8106e..4a384e6401 100644 --- a/synthesizer/program/Cargo.toml +++ b/synthesizer/program/Cargo.toml @@ -25,6 +25,10 @@ edition = "2024" [features] default = [ ] +locktick = [ + "snarkvm-synthesizer-process/locktick", + "snarkvm-synthesizer-snark/locktick", +] serial = [ "snarkvm-console/serial" ] wasm = [ "snarkvm-console/wasm" ] diff --git a/synthesizer/snark/Cargo.toml b/synthesizer/snark/Cargo.toml index 75503780a9..06baaab15e 100644 --- a/synthesizer/snark/Cargo.toml +++ b/synthesizer/snark/Cargo.toml @@ -27,6 +27,7 @@ edition = "2024" default = [ ] cuda = [ "snarkvm-algorithms/cuda" ] dev-print = [ "snarkvm-utilities/dev-print" ] +locktick = [ "dep:locktick" ] serial = [ "snarkvm-console/serial", "snarkvm-algorithms/serial" ] wasm = [ "snarkvm-console/wasm" ] @@ -47,6 +48,13 @@ workspace = true [dependencies.bincode] workspace = true +[dependencies.locktick] +workspace = true +optional = true + +[dependencies.parking_lot] +workspace = true + [dependencies.serde_json] workspace = true features = [ "preserve_order" ] diff --git a/synthesizer/snark/src/certificate/mod.rs b/synthesizer/snark/src/certificate/mod.rs index f963cc3629..fe86b9deaa 100644 --- a/synthesizer/snark/src/certificate/mod.rs +++ b/synthesizer/snark/src/certificate/mod.rs @@ -19,6 +19,11 @@ mod bytes; mod parse; mod serialize; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; + #[derive(Clone, PartialEq, Eq)] pub struct Certificate { /// The certificate. @@ -33,6 +38,8 @@ impl Certificate { /// Returns the certificate from the proving and verifying key. pub fn certify( + srs: &UniversalSRS, + universal_prover: &RwLock>, _function_name: &str, proving_key: &ProvingKey, verifying_key: &VerifyingKey, @@ -41,11 +48,13 @@ impl Certificate { let timer = std::time::Instant::now(); // Retrieve the proving parameters. - let universal_prover = N::varuna_universal_prover(); + let degree_info = proving_key.circuit.index_info.degree_info::()?; let fiat_shamir = N::varuna_fs_parameters(); + universal_prover.write().update(srs, degree_info)?; + // Compute the certificate. - let certificate = Varuna::::prove_vk(universal_prover, fiat_shamir, verifying_key, proving_key)?; + let certificate = Varuna::::prove_vk(&universal_prover.read(), fiat_shamir, verifying_key, proving_key)?; #[cfg(feature = "dev-print")] { diff --git a/synthesizer/snark/src/lib.rs b/synthesizer/snark/src/lib.rs index c01de3c343..36c4eaee54 100644 --- a/synthesizer/snark/src/lib.rs +++ b/synthesizer/snark/src/lib.rs @@ -21,7 +21,7 @@ extern crate snarkvm_circuit as circuit; extern crate snarkvm_console as console; use console::network::{FiatShamir, prelude::*}; -use snarkvm_algorithms::{snark::varuna, traits::SNARK}; +use snarkvm_algorithms::{polycommit::kzg10::DegreeInfo, snark::varuna, traits::SNARK}; use snarkvm_utilities::dev_println; use std::sync::{Arc, OnceLock}; @@ -38,7 +38,7 @@ mod proving_key; pub use proving_key::ProvingKey; mod universal_srs; -pub use universal_srs::UniversalSRS; +pub use universal_srs::{UniversalProver, UniversalSRS}; mod verifying_key; pub use verifying_key::VerifyingKey; @@ -46,6 +46,7 @@ pub use verifying_key::VerifyingKey; #[cfg(test)] pub(crate) mod test_helpers { use super::*; + use crate::universal_srs::UniversalProver; use circuit::{ environment::{Assignment, Circuit, Eject, Environment, Inject, Mode, One}, types::Field, @@ -53,6 +54,10 @@ pub(crate) mod test_helpers { use console::{network::MainnetV0, prelude::One as _}; use snarkvm_algorithms::snark::varuna::VarunaVersion; + #[cfg(feature = "locktick")] + use locktick::parking_lot::RwLock; + #[cfg(not(feature = "locktick"))] + use parking_lot::RwLock; use std::sync::OnceLock; type CurrentNetwork = MainnetV0; @@ -104,8 +109,8 @@ pub(crate) mod test_helpers { .get_or_init(|| { let assignment = sample_assignment(); let srs = UniversalSRS::load().unwrap(); - let (proving_key, verifying_key) = srs.to_circuit_key("test", &assignment).unwrap(); - (proving_key, verifying_key) + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + srs.to_circuit_key(&universal_prover, "test", &assignment).unwrap() }) .clone() } @@ -117,7 +122,12 @@ pub(crate) mod test_helpers { .get_or_init(|| { let assignment = sample_assignment(); let (proving_key, _) = sample_keys(); - proving_key.prove("test", VarunaVersion::V2, &assignment, &mut TestRng::default()).unwrap() + let srs = UniversalSRS::::load().unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + + proving_key + .prove(&srs, &universal_prover, "test", VarunaVersion::V2, &assignment, &mut TestRng::default()) + .unwrap() }) .clone() } @@ -128,8 +138,10 @@ pub(crate) mod test_helpers { INSTANCE .get_or_init(|| { let (proving_key, verifying_key) = sample_keys(); + let srs = UniversalSRS::::load().unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); // Return the certificate. - Certificate::certify("test", &proving_key, &verifying_key).unwrap() + Certificate::certify(&srs, &universal_prover, "test", &proving_key, &verifying_key).unwrap() }) .clone() } @@ -138,10 +150,16 @@ pub(crate) mod test_helpers { #[cfg(test)] mod test { use super::*; + use crate::universal_srs::UniversalProver; use circuit::environment::{Circuit, Environment}; use console::network::MainnetV0; use snarkvm_algorithms::snark::varuna::VarunaVersion; + #[cfg(feature = "locktick")] + use locktick::parking_lot::RwLock; + #[cfg(not(feature = "locktick"))] + use parking_lot::RwLock; + type CurrentNetwork = MainnetV0; #[test] @@ -150,11 +168,14 @@ mod test { // Varuna setup, prove, and verify. let srs = UniversalSRS::::load().unwrap(); - let (proving_key, verifying_key) = srs.to_circuit_key("test", &assignment).unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + let (proving_key, verifying_key) = srs.to_circuit_key(&universal_prover, "test", &assignment).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - let proof = proving_key.prove("test", varuna_version, &assignment, &mut TestRng::default()).unwrap(); + let proof = proving_key + .prove(&srs, &universal_prover, "test", varuna_version, &assignment, &mut TestRng::default()) + .unwrap(); println!("Called prover"); let one = ::BaseField::one(); @@ -194,11 +215,14 @@ mod test { assert_eq!(assignment.num_private(), 3); let srs = UniversalSRS::::load().unwrap(); - let (proving_key, verifying_key) = srs.to_circuit_key("test", &assignment).unwrap(); + let universal_prover = RwLock::new(UniversalProver::::load().unwrap()); + let (proving_key, verifying_key) = srs.to_circuit_key(&universal_prover, "test", &assignment).unwrap(); let varuna_version = VarunaVersion::V2; println!("Called circuit setup"); - let proof = proving_key.prove("test", varuna_version, &assignment, &mut TestRng::default()).unwrap(); + let proof = proving_key + .prove(&srs, &universal_prover, "test", varuna_version, &assignment, &mut TestRng::default()) + .unwrap(); println!("Called prover"); // Should pass. diff --git a/synthesizer/snark/src/proving_key/mod.rs b/synthesizer/snark/src/proving_key/mod.rs index c95f7bb27d..4f0a27c730 100644 --- a/synthesizer/snark/src/proving_key/mod.rs +++ b/synthesizer/snark/src/proving_key/mod.rs @@ -19,6 +19,10 @@ mod bytes; mod parse; mod serialize; +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; use std::collections::BTreeMap; #[derive(Clone)] @@ -36,6 +40,8 @@ impl ProvingKey { /// Returns a proof for the given assignment on the circuit. pub fn prove( &self, + srs: &UniversalSRS, + universal_prover: &RwLock>, _function_name: &str, varuna_version: varuna::VarunaVersion, assignment: &circuit::Assignment, @@ -45,12 +51,21 @@ impl ProvingKey { let timer = std::time::Instant::now(); // Retrieve the proving parameters. - let universal_prover = N::varuna_universal_prover(); + let degree_info = self.circuit.index_info.degree_info::()?; let fiat_shamir = N::varuna_fs_parameters(); + // Update the universal_prover and prevent other threads from writing to it. + universal_prover.write().update(srs, degree_info)?; + // Compute the proof. - let proof = - Proof::new(Varuna::::prove(universal_prover, fiat_shamir, self, varuna_version, assignment, rng)?); + let proof = Proof::new(Varuna::::prove( + &universal_prover.read(), + fiat_shamir, + self, + varuna_version, + assignment, + rng, + )?); #[cfg(feature = "dev-print")] { @@ -64,11 +79,16 @@ impl ProvingKey { /// Returns a proof for the given batch of proving keys and assignments. #[allow(clippy::type_complexity)] pub fn prove_batch( + srs: &UniversalSRS, + universal_prover: &RwLock>, _locator: &str, varuna_version: varuna::VarunaVersion, assignments: &[(ProvingKey, Vec>)], rng: &mut R, ) -> Result> { + // Ensure that the assignments are not empty. + ensure!(!assignments.is_empty(), "Cannot prove an empty batch"); + #[cfg(feature = "dev-print")] let timer = std::time::Instant::now(); @@ -81,12 +101,27 @@ impl ProvingKey { ensure!(instances.len() == num_expected_instances, "Incorrect number of proving keys for batch proof"); // Retrieve the proving parameters. - let universal_prover = N::varuna_universal_prover(); + let mut degree_info: Option = None; + for (pk, _) in assignments { + let degree_info_i = pk.circuit.index_info.degree_info::()?; + degree_info = match degree_info { + Some(degree_info) => Some(degree_info.union(°ree_info_i)), + None => Some(degree_info_i), + }; + } let fiat_shamir = N::varuna_fs_parameters(); + // Update the universal_prover and prevent other threads from writing to it + universal_prover.write().update(srs, degree_info.unwrap())?; + // Compute the proof. - let batch_proof = - Proof::new(Varuna::::prove_batch(universal_prover, fiat_shamir, varuna_version, &instances, rng)?); + let batch_proof = Proof::new(Varuna::::prove_batch( + &universal_prover.read(), + fiat_shamir, + varuna_version, + &instances, + rng, + )?); #[cfg(feature = "dev-print")] { diff --git a/synthesizer/snark/src/universal_srs.rs b/synthesizer/snark/src/universal_srs.rs index ae2c8d153f..8c9f2aea83 100644 --- a/synthesizer/snark/src/universal_srs.rs +++ b/synthesizer/snark/src/universal_srs.rs @@ -15,12 +15,23 @@ use super::*; -#[derive(Clone)] +#[cfg(feature = "locktick")] +use locktick::parking_lot::RwLock; +#[cfg(not(feature = "locktick"))] +use parking_lot::RwLock; + +#[derive(Clone, Debug)] pub struct UniversalSRS { - /// The universal SRS parameter. + /// The universal SRS parameters. srs: Arc>>, } +impl Default for UniversalSRS { + fn default() -> Self { + Self { srs: Arc::new(OnceLock::new()) } + } +} + impl UniversalSRS { /// Initializes the universal SRS. pub fn load() -> Result { @@ -30,13 +41,16 @@ impl UniversalSRS { /// Returns the circuit proving and verifying key. pub fn to_circuit_key( &self, + universal_prover: &RwLock>, _function_name: &str, assignment: &circuit::Assignment, ) -> Result<(ProvingKey, VerifyingKey)> { #[cfg(feature = "dev-print")] let timer = std::time::Instant::now(); - let (proving_key, verifying_key) = Varuna::::circuit_setup(self, assignment)?; + // Locks the universal_prover. + let (proving_key, verifying_key) = + Varuna::::circuit_setup(self.deref(), &mut universal_prover.write(), assignment)?; #[cfg(feature = "dev-print")] { @@ -89,3 +103,36 @@ impl Deref for UniversalSRS { }) } } + +#[derive(Debug)] +pub struct UniversalProver { + /// The universal prover - trimmed from the SRS + universal_prover: varuna::UniversalProver, +} + +impl Default for UniversalProver { + fn default() -> Self { + Self { universal_prover: varuna::UniversalProver::default() } + } +} + +impl UniversalProver { + /// Initializes the universal Prover. + pub fn load() -> Result { + Ok(Self { universal_prover: varuna::UniversalProver::default() }) + } +} + +impl DerefMut for UniversalProver { + fn deref_mut(&mut self) -> &mut varuna::UniversalProver { + &mut self.universal_prover + } +} + +impl Deref for UniversalProver { + type Target = varuna::UniversalProver; + + fn deref(&self) -> &Self::Target { + &self.universal_prover + } +}