@@ -5,10 +5,7 @@ use crossbeam_channel::bounded;
5
5
use ff_ext:: ExtensionField ;
6
6
use itertools:: Itertools ;
7
7
use multilinear_extensions:: {
8
- mle:: { DenseMultilinearExtension , FieldType , MultilinearExtension } ,
9
- op_mle,
10
- util:: largest_even_below,
11
- virtual_poly:: VirtualPolynomial ,
8
+ mle:: FieldType , op_mle, util:: largest_even_below, virtual_poly:: VirtualPolynomial ,
12
9
} ;
13
10
use rayon:: {
14
11
Scope ,
@@ -114,16 +111,8 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
114
111
if let Some ( p) = challenge {
115
112
prover_state. challenges . push ( p) ;
116
113
// fix last challenge to collect final evaluation
117
- prover_state
118
- . poly
119
- . flattened_ml_extensions
120
- . iter_mut ( )
121
- . for_each ( |mle| {
122
- let mle = Arc :: get_mut ( mle) . unwrap ( ) ;
123
- if mle. num_vars ( ) > 0 {
124
- mle. fix_variables_in_place ( & [ p. elements ] ) ;
125
- }
126
- } ) ;
114
+ prover_state. fix_var ( p. elements ) ;
115
+
127
116
tx_prover_state
128
117
. send ( Some ( ( thread_id, prover_state) ) )
129
118
. unwrap ( ) ;
@@ -183,29 +172,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
183
172
if let Some ( p) = challenge {
184
173
prover_state. challenges . push ( p) ;
185
174
// fix last challenge to collect final evaluation
186
- prover_state
187
- . poly
188
- . flattened_ml_extensions
189
- . iter_mut ( )
190
- . for_each ( |mle| {
191
- if num_variables == 1 {
192
- // first time fix variable should be create new instance
193
- if mle. num_vars ( ) > 0 {
194
- * mle = mle. fix_variables ( & [ p. elements ] ) . into ( ) ;
195
- } else {
196
- * mle =
197
- Arc :: new ( DenseMultilinearExtension :: from_evaluation_vec_smart (
198
- 0 ,
199
- mle. get_base_field_vec ( ) . to_vec ( ) ,
200
- ) )
201
- }
202
- } else {
203
- let mle = Arc :: get_mut ( mle) . unwrap ( ) ;
204
- if mle. num_vars ( ) > 0 {
205
- mle. fix_variables_in_place ( & [ p. elements ] ) ;
206
- }
207
- }
208
- } ) ;
175
+ prover_state. fix_var ( p. elements ) ;
209
176
tx_prover_state
210
177
. send ( Some ( ( main_thread_id, prover_state) ) )
211
178
. unwrap ( ) ;
@@ -280,21 +247,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
280
247
if let Some ( p) = challenge {
281
248
prover_state. challenges . push ( p) ;
282
249
// fix last challenge to collect final evaluation
283
- prover_state
284
- . poly
285
- . flattened_ml_extensions
286
- . iter_mut ( )
287
- . for_each (
288
- |mle : & mut Arc <
289
- dyn MultilinearExtension < E , Output = DenseMultilinearExtension < E > > ,
290
- > | {
291
- if mle. num_vars ( ) > 0 {
292
- Arc :: get_mut ( mle)
293
- . unwrap ( )
294
- . fix_variables_in_place ( & [ p. elements ] ) ;
295
- }
296
- } ,
297
- ) ;
250
+ prover_state. fix_var ( p. elements ) ;
298
251
} ;
299
252
exit_span ! ( span) ;
300
253
@@ -330,11 +283,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
330
283
331
284
let max_degree = polynomial. aux_info . max_degree ;
332
285
assert ! ( extrapolation_aux. len( ) == max_degree - 1 ) ;
286
+ let num_polys = polynomial. flattened_ml_extensions . len ( ) ;
333
287
Self {
334
288
challenges : Vec :: with_capacity ( polynomial. aux_info . max_num_variables ) ,
335
289
round : 0 ,
336
290
poly : polynomial,
337
291
extrapolation_aux,
292
+ poly_index_fixvar_in_place : vec ! [ false ; num_polys] ,
338
293
}
339
294
}
340
295
@@ -381,30 +336,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
381
336
self . challenges . push ( chal) ;
382
337
let r = self . challenges [ self . round - 1 ] ;
383
338
384
- if self . challenges . len ( ) == 1 {
385
- self . poly . flattened_ml_extensions . iter_mut ( ) . for_each ( |f| {
386
- if f. num_vars ( ) > 0 {
387
- * f = Arc :: new ( f. fix_variables ( & [ r. elements ] ) ) ;
388
- } else {
389
- panic ! ( "calling sumcheck on constant" )
390
- }
391
- } ) ;
392
- } else {
393
- self . poly
394
- . flattened_ml_extensions
395
- . iter_mut ( )
396
- // benchmark result indicate make_mut achieve better performange than get_mut,
397
- // which can be +5% overhead rust docs doen't explain the
398
- // reason
399
- . map ( Arc :: get_mut)
400
- . for_each ( |f| {
401
- if let Some ( f) = f {
402
- if f. num_vars ( ) > 0 {
403
- f. fix_variables_in_place ( & [ r. elements ] ) ;
404
- }
405
- }
406
- } ) ;
407
- }
339
+ self . fix_var ( r. elements ) ;
408
340
}
409
341
exit_span ! ( span) ;
410
342
// end_timer!(fix_argument);
@@ -485,6 +417,29 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
485
417
} )
486
418
. collect ( )
487
419
}
420
+
421
+ /// fix_var
422
+ pub fn fix_var ( & mut self , r : E ) {
423
+ self . poly_index_fixvar_in_place
424
+ . iter_mut ( )
425
+ . zip_eq ( self . poly . flattened_ml_extensions . iter_mut ( ) )
426
+ . for_each ( |( has_fixvar_in_place, poly) | {
427
+ if * has_fixvar_in_place {
428
+ // in place
429
+ let poly = Arc :: get_mut ( poly) ;
430
+ if let Some ( f) = poly {
431
+ if f. num_vars ( ) > 0 {
432
+ f. fix_variables_in_place ( & [ r] )
433
+ }
434
+ } ;
435
+ } else if poly. num_vars ( ) > 0 {
436
+ * poly = Arc :: new ( poly. fix_variables ( & [ r] ) ) ;
437
+ * has_fixvar_in_place = true ;
438
+ } else {
439
+ panic ! ( "calling sumcheck on constant" )
440
+ }
441
+ } ) ;
442
+ }
488
443
}
489
444
490
445
/// parallel version
@@ -538,28 +493,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
538
493
if let Some ( p) = challenge {
539
494
prover_state. challenges . push ( p) ;
540
495
// fix last challenge to collect final evaluation
541
- prover_state
542
- . poly
543
- . flattened_ml_extensions
544
- . par_iter_mut ( )
545
- . for_each ( |mle| {
546
- if num_variables == 1 {
547
- // first time fix variable should be create new instance
548
- if mle. num_vars ( ) > 0 {
549
- * mle = mle. fix_variables ( & [ p. elements ] ) . into ( ) ;
550
- } else {
551
- * mle = Arc :: new ( DenseMultilinearExtension :: from_evaluation_vec_smart (
552
- 0 ,
553
- mle. get_base_field_vec ( ) . to_vec ( ) ,
554
- ) )
555
- }
556
- } else {
557
- let mle = Arc :: get_mut ( mle) . unwrap ( ) ;
558
- if mle. num_vars ( ) > 0 {
559
- mle. fix_variables_in_place ( & [ p. elements ] ) ;
560
- }
561
- }
562
- } ) ;
496
+ prover_state. fix_var_parallel ( p. elements ) ;
563
497
} ;
564
498
exit_span ! ( span) ;
565
499
@@ -588,6 +522,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
588
522
) ;
589
523
590
524
let max_degree = polynomial. aux_info . max_degree ;
525
+ let num_polys = polynomial. flattened_ml_extensions . len ( ) ;
591
526
let prover_state = Self {
592
527
challenges : Vec :: with_capacity ( polynomial. aux_info . max_num_variables ) ,
593
528
round : 0 ,
@@ -599,6 +534,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
599
534
( points, weights)
600
535
} )
601
536
. collect ( ) ,
537
+ poly_index_fixvar_in_place : vec ! [ false ; num_polys] ,
602
538
} ;
603
539
604
540
end_timer ! ( start) ;
@@ -644,33 +580,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
644
580
self . challenges . push ( chal) ;
645
581
let r = self . challenges [ self . round - 1 ] ;
646
582
647
- if self . challenges . len ( ) == 1 {
648
- self . poly
649
- . flattened_ml_extensions
650
- . par_iter_mut ( )
651
- . for_each ( |f| {
652
- if f. num_vars ( ) > 0 {
653
- * f = Arc :: new ( f. fix_variables_parallel ( & [ r. elements ] ) ) ;
654
- } else {
655
- panic ! ( "calling sumcheck on constant" )
656
- }
657
- } ) ;
658
- } else {
659
- self . poly
660
- . flattened_ml_extensions
661
- . par_iter_mut ( )
662
- // benchmark result indicate make_mut achieve better performange than get_mut,
663
- // which can be +5% overhead rust docs doen't explain the
664
- // reason
665
- . map ( Arc :: get_mut)
666
- . for_each ( |f| {
667
- if let Some ( f) = f {
668
- if f. num_vars ( ) > 0 {
669
- f. fix_variables_in_place_parallel ( & [ r. elements ] )
670
- }
671
- }
672
- } ) ;
673
- }
583
+ self . fix_var ( r. elements ) ;
674
584
}
675
585
exit_span ! ( span) ;
676
586
// end_timer!(fix_argument);
@@ -728,4 +638,27 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
728
638
evaluations : products_sum,
729
639
}
730
640
}
641
+
642
+ /// fix_var
643
+ pub fn fix_var_parallel ( & mut self , r : E ) {
644
+ self . poly_index_fixvar_in_place
645
+ . par_iter_mut ( )
646
+ . zip_eq ( self . poly . flattened_ml_extensions . par_iter_mut ( ) )
647
+ . for_each ( |( has_fixvar_in_place, poly) | {
648
+ if * has_fixvar_in_place {
649
+ // in place
650
+ let poly = Arc :: get_mut ( poly) ;
651
+ if let Some ( f) = poly {
652
+ if f. num_vars ( ) > 0 {
653
+ f. fix_variables_in_place_parallel ( & [ r] )
654
+ }
655
+ } ;
656
+ } else if poly. num_vars ( ) > 0 {
657
+ * poly = Arc :: new ( poly. fix_variables_parallel ( & [ r] ) ) ;
658
+ * has_fixvar_in_place = true ;
659
+ } else {
660
+ panic ! ( "calling sumcheck on constant" )
661
+ }
662
+ } ) ;
663
+ }
731
664
}
0 commit comments