Skip to content

Commit 387d231

Browse files
committed
refactor sumcheck boilerplate code by fix_var
1 parent 7390c27 commit 387d231

File tree

4 files changed

+61
-152
lines changed

4 files changed

+61
-152
lines changed

multilinear_extensions/src/virtual_polys.rs

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ pub struct VirtualPolynomials<'a, E: ExtensionField> {
1818

1919
impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
2020
pub fn new(num_threads: usize, max_num_variables: usize) -> Self {
21-
println!("ceil_log2(num_threads) {}", ceil_log2(num_threads));
2221
VirtualPolynomials {
2322
num_threads,
2423
polys: (0..num_threads)

sumcheck/src/prover.rs

+58-125
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@ use crossbeam_channel::bounded;
55
use ff_ext::ExtensionField;
66
use itertools::Itertools;
77
use multilinear_extensions::{
8-
mle::{DenseMultilinearExtension, FieldType, MultilinearExtension},
9-
op_mle,
10-
util::largest_even_below,
11-
virtual_poly::VirtualPolynomial,
8+
mle::FieldType, op_mle, util::largest_even_below, virtual_poly::VirtualPolynomial,
129
};
1310
use rayon::{
1411
Scope,
@@ -114,16 +111,8 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
114111
if let Some(p) = challenge {
115112
prover_state.challenges.push(p);
116113
// fix last challenge to collect final evaluation
117-
prover_state
118-
.poly
119-
.flattened_ml_extensions
120-
.iter_mut()
121-
.for_each(|mle| {
122-
let mle = Arc::get_mut(mle).unwrap();
123-
if mle.num_vars() > 0 {
124-
mle.fix_variables_in_place(&[p.elements]);
125-
}
126-
});
114+
prover_state.fix_var(p.elements);
115+
127116
tx_prover_state
128117
.send(Some((thread_id, prover_state)))
129118
.unwrap();
@@ -183,29 +172,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
183172
if let Some(p) = challenge {
184173
prover_state.challenges.push(p);
185174
// fix last challenge to collect final evaluation
186-
prover_state
187-
.poly
188-
.flattened_ml_extensions
189-
.iter_mut()
190-
.for_each(|mle| {
191-
if num_variables == 1 {
192-
// first time fix variable should be create new instance
193-
if mle.num_vars() > 0 {
194-
*mle = mle.fix_variables(&[p.elements]).into();
195-
} else {
196-
*mle =
197-
Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart(
198-
0,
199-
mle.get_base_field_vec().to_vec(),
200-
))
201-
}
202-
} else {
203-
let mle = Arc::get_mut(mle).unwrap();
204-
if mle.num_vars() > 0 {
205-
mle.fix_variables_in_place(&[p.elements]);
206-
}
207-
}
208-
});
175+
prover_state.fix_var(p.elements);
209176
tx_prover_state
210177
.send(Some((main_thread_id, prover_state)))
211178
.unwrap();
@@ -280,21 +247,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
280247
if let Some(p) = challenge {
281248
prover_state.challenges.push(p);
282249
// fix last challenge to collect final evaluation
283-
prover_state
284-
.poly
285-
.flattened_ml_extensions
286-
.iter_mut()
287-
.for_each(
288-
|mle: &mut Arc<
289-
dyn MultilinearExtension<E, Output = DenseMultilinearExtension<E>>,
290-
>| {
291-
if mle.num_vars() > 0 {
292-
Arc::get_mut(mle)
293-
.unwrap()
294-
.fix_variables_in_place(&[p.elements]);
295-
}
296-
},
297-
);
250+
prover_state.fix_var(p.elements);
298251
};
299252
exit_span!(span);
300253

@@ -330,11 +283,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
330283

331284
let max_degree = polynomial.aux_info.max_degree;
332285
assert!(extrapolation_aux.len() == max_degree - 1);
286+
let num_polys = polynomial.flattened_ml_extensions.len();
333287
Self {
334288
challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables),
335289
round: 0,
336290
poly: polynomial,
337291
extrapolation_aux,
292+
poly_index_fixvar_in_place: vec![false; num_polys],
338293
}
339294
}
340295

@@ -381,30 +336,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
381336
self.challenges.push(chal);
382337
let r = self.challenges[self.round - 1];
383338

384-
if self.challenges.len() == 1 {
385-
self.poly.flattened_ml_extensions.iter_mut().for_each(|f| {
386-
if f.num_vars() > 0 {
387-
*f = Arc::new(f.fix_variables(&[r.elements]));
388-
} else {
389-
panic!("calling sumcheck on constant")
390-
}
391-
});
392-
} else {
393-
self.poly
394-
.flattened_ml_extensions
395-
.iter_mut()
396-
// benchmark result indicate make_mut achieve better performange than get_mut,
397-
// which can be +5% overhead rust docs doen't explain the
398-
// reason
399-
.map(Arc::get_mut)
400-
.for_each(|f| {
401-
if let Some(f) = f {
402-
if f.num_vars() > 0 {
403-
f.fix_variables_in_place(&[r.elements]);
404-
}
405-
}
406-
});
407-
}
339+
self.fix_var(r.elements);
408340
}
409341
exit_span!(span);
410342
// end_timer!(fix_argument);
@@ -485,6 +417,29 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
485417
})
486418
.collect()
487419
}
420+
421+
/// fix_var
422+
pub fn fix_var(&mut self, r: E) {
423+
self.poly_index_fixvar_in_place
424+
.iter_mut()
425+
.zip_eq(self.poly.flattened_ml_extensions.iter_mut())
426+
.for_each(|(has_fixvar_in_place, poly)| {
427+
if *has_fixvar_in_place {
428+
// in place
429+
let poly = Arc::get_mut(poly);
430+
if let Some(f) = poly {
431+
if f.num_vars() > 0 {
432+
f.fix_variables_in_place(&[r])
433+
}
434+
};
435+
} else if poly.num_vars() > 0 {
436+
*poly = Arc::new(poly.fix_variables(&[r]));
437+
*has_fixvar_in_place = true;
438+
} else {
439+
panic!("calling sumcheck on constant")
440+
}
441+
});
442+
}
488443
}
489444

490445
/// parallel version
@@ -538,28 +493,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
538493
if let Some(p) = challenge {
539494
prover_state.challenges.push(p);
540495
// fix last challenge to collect final evaluation
541-
prover_state
542-
.poly
543-
.flattened_ml_extensions
544-
.par_iter_mut()
545-
.for_each(|mle| {
546-
if num_variables == 1 {
547-
// first time fix variable should be create new instance
548-
if mle.num_vars() > 0 {
549-
*mle = mle.fix_variables(&[p.elements]).into();
550-
} else {
551-
*mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart(
552-
0,
553-
mle.get_base_field_vec().to_vec(),
554-
))
555-
}
556-
} else {
557-
let mle = Arc::get_mut(mle).unwrap();
558-
if mle.num_vars() > 0 {
559-
mle.fix_variables_in_place(&[p.elements]);
560-
}
561-
}
562-
});
496+
prover_state.fix_var_parallel(p.elements);
563497
};
564498
exit_span!(span);
565499

@@ -588,6 +522,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
588522
);
589523

590524
let max_degree = polynomial.aux_info.max_degree;
525+
let num_polys = polynomial.flattened_ml_extensions.len();
591526
let prover_state = Self {
592527
challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables),
593528
round: 0,
@@ -599,6 +534,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
599534
(points, weights)
600535
})
601536
.collect(),
537+
poly_index_fixvar_in_place: vec![false; num_polys],
602538
};
603539

604540
end_timer!(start);
@@ -644,33 +580,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
644580
self.challenges.push(chal);
645581
let r = self.challenges[self.round - 1];
646582

647-
if self.challenges.len() == 1 {
648-
self.poly
649-
.flattened_ml_extensions
650-
.par_iter_mut()
651-
.for_each(|f| {
652-
if f.num_vars() > 0 {
653-
*f = Arc::new(f.fix_variables_parallel(&[r.elements]));
654-
} else {
655-
panic!("calling sumcheck on constant")
656-
}
657-
});
658-
} else {
659-
self.poly
660-
.flattened_ml_extensions
661-
.par_iter_mut()
662-
// benchmark result indicate make_mut achieve better performange than get_mut,
663-
// which can be +5% overhead rust docs doen't explain the
664-
// reason
665-
.map(Arc::get_mut)
666-
.for_each(|f| {
667-
if let Some(f) = f {
668-
if f.num_vars() > 0 {
669-
f.fix_variables_in_place_parallel(&[r.elements])
670-
}
671-
}
672-
});
673-
}
583+
self.fix_var(r.elements);
674584
}
675585
exit_span!(span);
676586
// end_timer!(fix_argument);
@@ -728,4 +638,27 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
728638
evaluations: products_sum,
729639
}
730640
}
641+
642+
/// fix_var
643+
pub fn fix_var_parallel(&mut self, r: E) {
644+
self.poly_index_fixvar_in_place
645+
.par_iter_mut()
646+
.zip_eq(self.poly.flattened_ml_extensions.par_iter_mut())
647+
.for_each(|(has_fixvar_in_place, poly)| {
648+
if *has_fixvar_in_place {
649+
// in place
650+
let poly = Arc::get_mut(poly);
651+
if let Some(f) = poly {
652+
if f.num_vars() > 0 {
653+
f.fix_variables_in_place_parallel(&[r])
654+
}
655+
};
656+
} else if poly.num_vars() > 0 {
657+
*poly = Arc::new(poly.fix_variables_parallel(&[r]));
658+
*has_fixvar_in_place = true;
659+
} else {
660+
panic!("calling sumcheck on constant")
661+
}
662+
});
663+
}
731664
}

sumcheck/src/structs.rs

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ pub struct IOPProverState<'a, E: ExtensionField> {
4444
/// points with precomputed barycentric weights for extrapolating smaller
4545
/// degree uni-polys to `max_degree + 1` evaluations.
4646
pub(crate) extrapolation_aux: Vec<(Vec<E>, Vec<E>)>,
47+
/// record poly should fix variable in place or not
48+
pub(crate) poly_index_fixvar_in_place: Vec<bool>,
4749
}
4850

4951
/// Prover State of a PolyIOP

sumcheck/src/test.rs

+1-26
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
use std::sync::Arc;
2-
31
use crate::{
42
structs::{IOPProverState, IOPVerifierState},
53
util::interpolate_uni_poly,
64
};
75
use ark_std::{rand::RngCore, test_rng};
86
use ff_ext::{ExtensionField, FromUniformBytes, GoldilocksExt2};
97
use multilinear_extensions::{
10-
mle::DenseMultilinearExtension,
118
virtual_poly::{VPAuxInfo, VirtualPolynomial},
129
virtual_polys::VirtualPolynomials,
1310
};
1411
use p3_field::PrimeCharacteristicRing;
15-
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
1612
use transcript::{BasicTranscript, Transcript};
1713

1814
#[test]
@@ -144,28 +140,7 @@ fn test_sumcheck_internal<E: ExtensionField>(
144140
if let Some(p) = challenge {
145141
prover_state.challenges.push(p);
146142
// fix last challenge to collect final evaluation
147-
prover_state
148-
.poly
149-
.flattened_ml_extensions
150-
.par_iter_mut()
151-
.for_each(|mle| {
152-
if num_variables == 1 {
153-
// first time fix variable should be create new instance
154-
if mle.num_vars() > 0 {
155-
*mle = mle.fix_variables(&[p.elements]).into();
156-
} else {
157-
*mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart(
158-
0,
159-
mle.get_base_field_vec().to_vec(),
160-
))
161-
}
162-
} else {
163-
let mle = Arc::get_mut(mle).unwrap();
164-
if mle.num_vars() > 0 {
165-
mle.fix_variables_in_place(&[p.elements]);
166-
}
167-
}
168-
});
143+
prover_state.fix_var(p.elements);
169144
};
170145
let subclaim = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &asserted_sum);
171146
assert!(

0 commit comments

Comments
 (0)