diff --git a/Cargo.lock b/Cargo.lock index 47fcbec86..380bb0cdd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2701,6 +2701,7 @@ dependencies = [ "p3", "proc-macro2", "quote", + "rand", "sumcheck", "syn 2.0.98", ] diff --git a/multilinear_extensions/src/virtual_polys.rs b/multilinear_extensions/src/virtual_polys.rs index 46d6dad20..13cd70e06 100644 --- a/multilinear_extensions/src/virtual_polys.rs +++ b/multilinear_extensions/src/virtual_polys.rs @@ -18,6 +18,7 @@ pub struct VirtualPolynomials<'a, E: ExtensionField> { impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> { pub fn new(num_threads: usize, max_num_variables: usize) -> Self { + debug_assert!(num_threads > 0); VirtualPolynomials { num_threads, polys: (0..num_threads) diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 30b07a0cd..0518fc016 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -285,6 +285,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { assert!(extrapolation_aux.len() == max_degree - 1); let num_polys = polynomial.flattened_ml_extensions.len(); Self { + max_num_variables: polynomial.aux_info.max_num_variables, challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, @@ -335,7 +336,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let chal = challenge.unwrap(); self.challenges.push(chal); let r = self.challenges[self.round - 1]; - self.fix_var(r.elements); } exit_span!(span); @@ -345,22 +345,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) - // - // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars, - // for it evaluation value we need to times 2^(max_num_vars - num_vars) - // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n - // For i round univariate poly, f^i(x) - // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds - // = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n' - // = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b) - // same applied on f^i[1] - // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value let span = entered_span!("products_sum"); let AdditiveVec(products_sum) = self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), |mut products_sum, (coefficient, products)| { let span = entered_span!("sum"); - let f = &self.poly.flattened_ml_extensions; let mut sum: Vec = match products.len() { 1 => sumcheck_code_gen!(1, false, |i| &f[products[i]]).to_vec(), @@ -418,12 +407,22 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .collect() } + pub fn expected_numvars_at_round(&self) -> usize { + // first round start from 1 + let num_vars = self.max_num_variables + 1 - self.round; + debug_assert!(num_vars > 0, "make sumcheck work on constant"); + num_vars + } + /// fix_var pub fn fix_var(&mut self, r: E) { + let expected_numvars_at_round = self.expected_numvars_at_round(); self.poly_index_fixvar_in_place .iter_mut() .zip_eq(self.poly.flattened_ml_extensions.iter_mut()) .for_each(|(can_fixvar_in_place, poly)| { + debug_assert!(poly.num_vars() <= expected_numvars_at_round); + debug_assert!(poly.num_vars() > 0); if *can_fixvar_in_place { // in place let poly = Arc::get_mut(poly); @@ -433,8 +432,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { } }; } else if poly.num_vars() > 0 { - *poly = Arc::new(poly.fix_variables(&[r])); - *can_fixvar_in_place = true; + if expected_numvars_at_round == poly.num_vars() { + *poly = Arc::new(poly.fix_variables(&[r])); + *can_fixvar_in_place = true; + } } else { panic!("calling sumcheck on constant") } @@ -524,6 +525,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let max_degree = polynomial.aux_info.max_degree; let num_polys = polynomial.flattened_ml_extensions.len(); let prover_state = Self { + max_num_variables: polynomial.aux_info.max_num_variables, challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, @@ -579,7 +581,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let chal = challenge.unwrap(); self.challenges.push(chal); let r = self.challenges[self.round - 1]; - self.fix_var(r.elements); } exit_span!(span); @@ -641,6 +642,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { /// fix_var pub fn fix_var_parallel(&mut self, r: E) { + let expected_numvars_at_round = self.expected_numvars_at_round(); self.poly_index_fixvar_in_place .par_iter_mut() .zip_eq(self.poly.flattened_ml_extensions.par_iter_mut()) @@ -654,8 +656,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { } }; } else if poly.num_vars() > 0 { - *poly = Arc::new(poly.fix_variables_parallel(&[r])); - *can_fixvar_in_place = true; + if expected_numvars_at_round == poly.num_vars() { + *poly = Arc::new(poly.fix_variables_parallel(&[r])); + *can_fixvar_in_place = true; + } } else { panic!("calling sumcheck on constant") } diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index 959eb2766..306726957 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -44,6 +44,7 @@ pub struct IOPProverState<'a, E: ExtensionField> { /// points with precomputed barycentric weights for extrapolating smaller /// degree uni-polys to `max_degree + 1` evaluations. pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, + pub(crate) max_num_variables: usize, /// record poly should fix variable in place or not pub(crate) poly_index_fixvar_in_place: Vec, } diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index f3dd5afe4..4e6cbfac1 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -5,6 +5,7 @@ use crate::{ use ark_std::{rand::RngCore, test_rng}; use ff_ext::{ExtensionField, FromUniformBytes, GoldilocksExt2}; use multilinear_extensions::{ + util::max_usable_threads, virtual_poly::{VPAuxInfo, VirtualPolynomial}, virtual_polys::VirtualPolynomials, }; @@ -13,17 +14,19 @@ use transcript::{BasicTranscript, Transcript}; #[test] fn test_sumcheck_with_different_degree() { - let nv = vec![4, 5]; // test polynomial mixed with different num_var - test_sumcheck_with_different_degree_helper::(nv); + // test polynomial mixed with different num_var + let nv = vec![3, 4, 5]; + let num_polys = nv.len(); + for num_threads in 1..num_polys.min(max_usable_threads()) { + test_sumcheck_with_different_degree_helper::(num_threads, &nv); + } } -fn test_sumcheck_with_different_degree_helper(nv: Vec) { +fn test_sumcheck_with_different_degree_helper(num_threads: usize, nv: &[usize]) { let mut rng = test_rng(); let degree = 2; let num_multiplicands_range = (degree, degree + 1); let num_products = 1; - // TODO investigate error when num_threads > 1 - let num_threads = 1; let mut transcript = BasicTranscript::::new(b"test"); let max_num_variables = *nv.iter().max().unwrap(); @@ -69,10 +72,11 @@ fn test_sumcheck_with_different_degree_helper(nv: Vec) .map(|c| c.elements) .collect::>(); assert_eq!(r.len(), max_num_variables); + // r are right alignment assert!( input_polys .iter() - .map(|(poly, _)| { poly.evaluate(&r[..poly.aux_info.max_num_variables]) }) + .map(|(poly, _)| { poly.evaluate(&r[r.len() - poly.aux_info.max_num_variables..]) }) .sum::() == subclaim.expected_evaluation, "wrong subclaim" diff --git a/sumcheck_macro/Cargo.toml b/sumcheck_macro/Cargo.toml index e1f842693..925ebe9a4 100644 --- a/sumcheck_macro/Cargo.toml +++ b/sumcheck_macro/Cargo.toml @@ -17,6 +17,7 @@ itertools.workspace = true p3 = { path = "../p3" } proc-macro2 = "1.0.92" quote = "1.0" +rand.workspace = true syn = { version = "2.0", features = ["full"] } [dev-dependencies] diff --git a/sumcheck_macro/examples/expand.rs b/sumcheck_macro/examples/expand.rs index 97e3412a3..80641be5e 100644 --- a/sumcheck_macro/examples/expand.rs +++ b/sumcheck_macro/examples/expand.rs @@ -2,22 +2,21 @@ /// ```sh /// cargo expand --example expand /// ``` -use ff_ext::ExtensionField; -use ff_ext::GoldilocksExt2; +use ff_ext::{ExtensionField, GoldilocksExt2}; use multilinear_extensions::{ mle::FieldType, util::largest_even_below, virtual_poly::VirtualPolynomial, }; use p3::field::PrimeCharacteristicRing; +use rand::rngs::OsRng; use sumcheck::util::{AdditiveArray, ceil_log2}; #[derive(Default)] struct Container<'a, E: ExtensionField> { poly: VirtualPolynomial<'a, E>, - round: usize, } fn main() { - let c = Container::::default(); + let c = Container::::new(); c.run(); } @@ -26,4 +25,14 @@ impl Container<'_, E> { let _result: AdditiveArray<_, 4> = sumcheck_macro::sumcheck_code_gen!(3, false, |_| &self.poly.flattened_ml_extensions[0]); } + + pub fn expected_numvars_at_round(&self) -> usize { + 1 + } + + pub fn new() -> Self { + Self { + poly: VirtualPolynomial::random(3, (4, 5), 2, &mut OsRng).0, + } + } } diff --git a/sumcheck_macro/src/lib.rs b/sumcheck_macro/src/lib.rs index 428500261..df77f9b37 100644 --- a/sumcheck_macro/src/lib.rs +++ b/sumcheck_macro/src/lib.rs @@ -219,33 +219,68 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr }; let iter = if parallalize { - quote! {.into_par_iter().step_by(2).with_min_len(64)} + quote! {.into_par_iter().step_by(2).rev().with_min_len(64)} } else { quote! {.step_by(2).rev()} }; // Generate the final AdditiveArray expression. + + // special case: generate product for polynomial num_var less than current expected num_var + // which happened when we batching sumcheck with different num_vars + let product = mul_exprs( + (1..=degree) + .map(|j: u32| { + let v = ident(format!("v{j}")); + quote! {#v[b]} + }) + .collect(), + ); + let degree_plus_one = (degree + 1) as usize; quote! { - let res = (0..largest_even_below(v1.len())) - #iter - .map(|b| { - #additive_array_items - }) - .sum::>(); - let res = if v1.len() == 1 { - let b = 0; - AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one]) - } else { - res - }; - let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(v1.len()).max(1) + self.round - 1); - if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from_u64(1 << num_vars_multiplicity))) + // To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars + // we actually need to have a full sum, times 2^(bh_num_vars - num_vars) to accumulate into univariate computation + // E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n + // For i < n - n', to compute univariate poly, f^i(x), b is i-th round boolean hypercube + // f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} challenge get from prev rounds + // = \sum_b f_1(b) + f_2(r, 0, b) + // = 2^(|b| - |b1|) * \sum_b1 f_1(b1) + \sum_b f_2(r, 0, b) + // b1 is suffix alignment with b + // same applied on f^i[1], f^i[2], ... f^i[degree + 1] + // It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value + + // NOTE: current method work in suffix alignment order + let num_var = ceil_log2(v1.len()); + let expected_numvars_at_round = self.expected_numvars_at_round(); + if num_var < expected_numvars_at_round { + // TODO optimize by caching computed result for later round reuse + // need to figure out how to cache in one place to support base/extension field + let mut sum = (0..largest_even_below(v1.len())).map( + |b| { + #product + }, + ).sum(); + // calculate multiplicity term + // minus one because when expected num of var is n_i, the boolean hypercube dimension only n_i-1 + let num_vars_multiplicity = self.expected_numvars_at_round().saturating_sub(1).saturating_sub(num_var); + if num_vars_multiplicity > 0 { + sum *= E::BaseField::from_u64(1 << num_vars_multiplicity); + } + AdditiveArray::<_, #degree_plus_one>([sum; #degree_plus_one]) } else { - res + if v1.len() == 1 { + let b = 0; + AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one]) + } else { + (0..largest_even_below(v1.len())) + #iter + .map(|b| { + #additive_array_items + }) + .sum::>() + } } - } }; @@ -314,7 +349,7 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr // Generate the second match statement that maps f vars to AdditiveArray. out = quote! { { - #out + #out match (#match_input) { #match_arms _ => unreachable!(),