Skip to content

batched sumcheck suffix alignment #870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions multilinear_extensions/src/virtual_polys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct VirtualPolynomials<'a, E: ExtensionField> {

impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
pub fn new(num_threads: usize, max_num_variables: usize) -> Self {
debug_assert!(num_threads > 0);
VirtualPolynomials {
num_threads,
polys: (0..num_threads)
Expand Down
38 changes: 21 additions & 17 deletions sumcheck/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
assert!(extrapolation_aux.len() == max_degree - 1);
let num_polys = polynomial.flattened_ml_extensions.len();
Self {
max_num_variables: polynomial.aux_info.max_num_variables,
challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables),
round: 0,
poly: polynomial,
Expand Down Expand Up @@ -335,7 +336,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
let chal = challenge.unwrap();
self.challenges.push(chal);
let r = self.challenges[self.round - 1];

self.fix_var(r.elements);
}
exit_span!(span);
Expand All @@ -345,22 +345,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {

// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
//
// To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars,
// for it evaluation value we need to times 2^(max_num_vars - num_vars)
// E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n
// For i round univariate poly, f^i(x)
// f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} chanllenge get from prev rounds
// = \sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n - n'
// = 2^(|b| - |b1|) * \sum_b1 f_1(r, 0, b1) + \sum_b f_2(r, 0, b)
// same applied on f^i[1]
// It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value
let span = entered_span!("products_sum");
let AdditiveVec(products_sum) = self.poly.products.iter().fold(
AdditiveVec::new(self.poly.aux_info.max_degree + 1),
|mut products_sum, (coefficient, products)| {
let span = entered_span!("sum");

let f = &self.poly.flattened_ml_extensions;
let mut sum: Vec<E> = match products.len() {
1 => sumcheck_code_gen!(1, false, |i| &f[products[i]]).to_vec(),
Expand Down Expand Up @@ -418,12 +407,22 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
.collect()
}

pub fn expected_numvars_at_round(&self) -> usize {
// first round start from 1
let num_vars = self.max_num_variables + 1 - self.round;
debug_assert!(num_vars > 0, "make sumcheck work on constant");
num_vars
}

/// fix_var
pub fn fix_var(&mut self, r: E) {
let expected_numvars_at_round = self.expected_numvars_at_round();
self.poly_index_fixvar_in_place
.iter_mut()
.zip_eq(self.poly.flattened_ml_extensions.iter_mut())
.for_each(|(can_fixvar_in_place, poly)| {
debug_assert!(poly.num_vars() <= expected_numvars_at_round);
debug_assert!(poly.num_vars() > 0);
if *can_fixvar_in_place {
// in place
let poly = Arc::get_mut(poly);
Expand All @@ -433,8 +432,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
}
};
} else if poly.num_vars() > 0 {
*poly = Arc::new(poly.fix_variables(&[r]));
*can_fixvar_in_place = true;
if expected_numvars_at_round == poly.num_vars() {
*poly = Arc::new(poly.fix_variables(&[r]));
*can_fixvar_in_place = true;
}
} else {
panic!("calling sumcheck on constant")
}
Expand Down Expand Up @@ -524,6 +525,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
let max_degree = polynomial.aux_info.max_degree;
let num_polys = polynomial.flattened_ml_extensions.len();
let prover_state = Self {
max_num_variables: polynomial.aux_info.max_num_variables,
challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables),
round: 0,
poly: polynomial,
Expand Down Expand Up @@ -579,7 +581,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
let chal = challenge.unwrap();
self.challenges.push(chal);
let r = self.challenges[self.round - 1];

self.fix_var(r.elements);
}
exit_span!(span);
Expand Down Expand Up @@ -641,6 +642,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {

/// fix_var
pub fn fix_var_parallel(&mut self, r: E) {
let expected_numvars_at_round = self.expected_numvars_at_round();
self.poly_index_fixvar_in_place
.par_iter_mut()
.zip_eq(self.poly.flattened_ml_extensions.par_iter_mut())
Expand All @@ -654,8 +656,10 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
}
};
} else if poly.num_vars() > 0 {
*poly = Arc::new(poly.fix_variables_parallel(&[r]));
*can_fixvar_in_place = true;
if expected_numvars_at_round == poly.num_vars() {
*poly = Arc::new(poly.fix_variables_parallel(&[r]));
*can_fixvar_in_place = true;
}
} else {
panic!("calling sumcheck on constant")
}
Expand Down
1 change: 1 addition & 0 deletions sumcheck/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub struct IOPProverState<'a, E: ExtensionField> {
/// points with precomputed barycentric weights for extrapolating smaller
/// degree uni-polys to `max_degree + 1` evaluations.
pub(crate) extrapolation_aux: Vec<(Vec<E>, Vec<E>)>,
pub(crate) max_num_variables: usize,
/// record poly should fix variable in place or not
pub(crate) poly_index_fixvar_in_place: Vec<bool>,
}
Expand Down
16 changes: 10 additions & 6 deletions sumcheck/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
use ark_std::{rand::RngCore, test_rng};
use ff_ext::{ExtensionField, FromUniformBytes, GoldilocksExt2};
use multilinear_extensions::{
util::max_usable_threads,
virtual_poly::{VPAuxInfo, VirtualPolynomial},
virtual_polys::VirtualPolynomials,
};
Expand All @@ -13,17 +14,19 @@ use transcript::{BasicTranscript, Transcript};

#[test]
fn test_sumcheck_with_different_degree() {
let nv = vec![4, 5]; // test polynomial mixed with different num_var
test_sumcheck_with_different_degree_helper::<GoldilocksExt2>(nv);
// test polynomial mixed with different num_var
let nv = vec![3, 4, 5];
let num_polys = nv.len();
for num_threads in 1..num_polys.min(max_usable_threads()) {
test_sumcheck_with_different_degree_helper::<GoldilocksExt2>(num_threads, &nv);
}
}

fn test_sumcheck_with_different_degree_helper<E: ExtensionField>(nv: Vec<usize>) {
fn test_sumcheck_with_different_degree_helper<E: ExtensionField>(num_threads: usize, nv: &[usize]) {
let mut rng = test_rng();
let degree = 2;
let num_multiplicands_range = (degree, degree + 1);
let num_products = 1;
// TODO investigate error when num_threads > 1
let num_threads = 1;
let mut transcript = BasicTranscript::<E>::new(b"test");

let max_num_variables = *nv.iter().max().unwrap();
Expand Down Expand Up @@ -69,10 +72,11 @@ fn test_sumcheck_with_different_degree_helper<E: ExtensionField>(nv: Vec<usize>)
.map(|c| c.elements)
.collect::<Vec<_>>();
assert_eq!(r.len(), max_num_variables);
// r are right alignment
assert!(
input_polys
.iter()
.map(|(poly, _)| { poly.evaluate(&r[..poly.aux_info.max_num_variables]) })
.map(|(poly, _)| { poly.evaluate(&r[r.len() - poly.aux_info.max_num_variables..]) })
.sum::<E>()
== subclaim.expected_evaluation,
"wrong subclaim"
Expand Down
1 change: 1 addition & 0 deletions sumcheck_macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ itertools.workspace = true
p3 = { path = "../p3" }
proc-macro2 = "1.0.92"
quote = "1.0"
rand.workspace = true
syn = { version = "2.0", features = ["full"] }

[dev-dependencies]
Expand Down
17 changes: 13 additions & 4 deletions sumcheck_macro/examples/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
/// ```sh
/// cargo expand --example expand
/// ```
use ff_ext::ExtensionField;
use ff_ext::GoldilocksExt2;
use ff_ext::{ExtensionField, GoldilocksExt2};
use multilinear_extensions::{
mle::FieldType, util::largest_even_below, virtual_poly::VirtualPolynomial,
};
use p3::field::PrimeCharacteristicRing;
use rand::rngs::OsRng;
use sumcheck::util::{AdditiveArray, ceil_log2};

#[derive(Default)]
struct Container<'a, E: ExtensionField> {
poly: VirtualPolynomial<'a, E>,
round: usize,
}

fn main() {
let c = Container::<GoldilocksExt2>::default();
let c = Container::<GoldilocksExt2>::new();
c.run();
}

Expand All @@ -26,4 +25,14 @@ impl<E: ExtensionField> Container<'_, E> {
let _result: AdditiveArray<_, 4> =
sumcheck_macro::sumcheck_code_gen!(3, false, |_| &self.poly.flattened_ml_extensions[0]);
}

pub fn expected_numvars_at_round(&self) -> usize {
1
}

pub fn new() -> Self {
Self {
poly: VirtualPolynomial::random(3, (4, 5), 2, &mut OsRng).0,
}
}
}
73 changes: 54 additions & 19 deletions sumcheck_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,33 +219,68 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr
};

let iter = if parallalize {
quote! {.into_par_iter().step_by(2).with_min_len(64)}
quote! {.into_par_iter().step_by(2).rev().with_min_len(64)}
} else {
quote! {.step_by(2).rev()}
};

// Generate the final AdditiveArray expression.

// special case: generate product for polynomial num_var less than current expected num_var
// which happened when we batching sumcheck with different num_vars
let product = mul_exprs(
(1..=degree)
.map(|j: u32| {
let v = ident(format!("v{j}"));
quote! {#v[b]}
})
.collect(),
);

let degree_plus_one = (degree + 1) as usize;
quote! {
let res = (0..largest_even_below(v1.len()))
#iter
.map(|b| {
#additive_array_items
})
.sum::<AdditiveArray<_, #degree_plus_one>>();
let res = if v1.len() == 1 {
let b = 0;
AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one])
} else {
res
};
let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(v1.len()).max(1) + self.round - 1);
if num_vars_multiplicity > 0 {
AdditiveArray(res.0.map(|e| e * E::BaseField::from_u64(1 << num_vars_multiplicity)))
// To deal with different num_vars, we exploit a fact that for each product which num_vars < max_num_vars
// we actually need to have a full sum, times 2^(bh_num_vars - num_vars) to accumulate into univariate computation
// E.g. Giving multivariate poly f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'}, X \in {F}^{n}, |X1| := n', |X| = n, n' <= n
// For i < n - n', to compute univariate poly, f^i(x), b is i-th round boolean hypercube
// f^i[0] = \sum_b f(r, 0, b), b \in {0, 1}^{n-i-1}, r \in {F}^{n-i-1} challenge get from prev rounds
// = \sum_b f_1(b) + f_2(r, 0, b)
// = 2^(|b| - |b1|) * \sum_b1 f_1(b1) + \sum_b f_2(r, 0, b)
// b1 is suffix alignment with b
// same applied on f^i[1], f^i[2], ... f^i[degree + 1]
// It imply that, for every evals in f_1, to compute univariate poly, we just need to times a factor 2^(|b| - |b1|) for it evaluation value

// NOTE: current method work in suffix alignment order
let num_var = ceil_log2(v1.len());
let expected_numvars_at_round = self.expected_numvars_at_round();
if num_var < expected_numvars_at_round {
// TODO optimize by caching computed result for later round reuse
// need to figure out how to cache in one place to support base/extension field
let mut sum = (0..largest_even_below(v1.len())).map(
|b| {
#product
},
).sum();
// calculate multiplicity term
// minus one because when expected num of var is n_i, the boolean hypercube dimension only n_i-1
let num_vars_multiplicity = self.expected_numvars_at_round().saturating_sub(1).saturating_sub(num_var);
if num_vars_multiplicity > 0 {
sum *= E::BaseField::from_u64(1 << num_vars_multiplicity);
}
AdditiveArray::<_, #degree_plus_one>([sum; #degree_plus_one])
} else {
res
if v1.len() == 1 {
let b = 0;
AdditiveArray::<_, #degree_plus_one>([#additive_array_first_item ; #degree_plus_one])
} else {
(0..largest_even_below(v1.len()))
#iter
.map(|b| {
#additive_array_items
})
.sum::<AdditiveArray<_, #degree_plus_one>>()
}
}

}
};

Expand Down Expand Up @@ -314,7 +349,7 @@ pub fn sumcheck_code_gen(input: proc_macro::TokenStream) -> proc_macro::TokenStr
// Generate the second match statement that maps f vars to AdditiveArray.
out = quote! {
{
#out
#out
match (#match_input) {
#match_arms
_ => unreachable!(),
Expand Down