@@ -4,9 +4,9 @@ use ark_std::iterable::Iterable;
4
4
use ff_ext:: ExtensionField ;
5
5
use itertools:: Itertools ;
6
6
use multilinear_extensions:: {
7
- commutative_op_mle_pair_pool ,
7
+ commutative_op_mle_pair ,
8
8
mle:: { DenseMultilinearExtension , FieldType , IntoMLE } ,
9
- op_mle_xa_b_pool , op_mle3_range_pool ,
9
+ op_mle_xa_b , op_mle3_range ,
10
10
util:: { ceil_log2, max_usable_threads} ,
11
11
virtual_poly_v2:: ArcMultilinearExtension ,
12
12
} ;
@@ -238,28 +238,12 @@ pub(crate) fn infer_tower_product_witness<E: ExtensionField>(
238
238
wit_layers
239
239
}
240
240
241
- fn try_recycle_arcpoly < E : ExtensionField , PF1 : Fn ( ) -> Vec < E > , PF2 : Fn ( ) -> Vec < E :: BaseField > > (
241
+ fn optional_arcpoly_unwrap_pushback < E : ExtensionField > (
242
242
poly : Cow < ArcMultilinearExtension < ' _ , E > > ,
243
- pool_e : & mut SimpleVecPool < Vec < E > , PF1 > ,
244
- pool_b : & mut SimpleVecPool < Vec < E :: BaseField > , PF2 > ,
243
+ pool_e : & mut SimpleVecPool < Vec < E > , impl Fn ( ) -> Vec < E > > ,
244
+ pool_b : & mut SimpleVecPool < Vec < E :: BaseField > , impl Fn ( ) -> Vec < E :: BaseField > > ,
245
245
pool_expected_size_vec : usize ,
246
246
) {
247
- // fn downcast_arc<E: ExtensionField>(
248
- // arc: ArcMultilinearExtension<'_, E>,
249
- // ) -> DenseMultilinearExtension<E> {
250
- // unsafe {
251
- // // get the raw pointer from the Arc
252
- // assert_eq!(Arc::strong_count(&arc), 1);
253
- // let raw = Arc::into_raw(arc);
254
- // // cast the raw pointer to the desired concrete type
255
- // let typed_ptr = raw as *const DenseMultilinearExtension<E>;
256
- // // manually drop the Arc without dropping the value
257
- // Arc::decrement_strong_count(raw);
258
- // // reconstruct the Arc with the concrete type
259
- // // Move the value out
260
- // ptr::read(typed_ptr)
261
- // }
262
- // }
263
247
let len = poly. evaluations ( ) . len ( ) ;
264
248
if len == pool_expected_size_vec {
265
249
match poly {
@@ -348,7 +332,7 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
348
332
& |cow_a, cow_b, pool_e, pool_b| {
349
333
let ( a, b) = ( cow_a. as_ref ( ) , cow_b. as_ref ( ) ) ;
350
334
let poly =
351
- commutative_op_mle_pair_pool ! (
335
+ commutative_op_mle_pair ! (
352
336
|a, b, res| {
353
337
match ( a. len( ) , b. len( ) ) {
354
338
( 1 , 1 ) => {
@@ -401,14 +385,14 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
401
385
pool_e,
402
386
pool_b
403
387
) ;
404
- try_recycle_arcpoly ( cow_a, pool_e, pool_b, len) ;
405
- try_recycle_arcpoly ( cow_b, pool_e, pool_b, len) ;
388
+ optional_arcpoly_unwrap_pushback ( cow_a, pool_e, pool_b, len) ;
389
+ optional_arcpoly_unwrap_pushback ( cow_b, pool_e, pool_b, len) ;
406
390
poly
407
391
} ,
408
392
& |cow_a, cow_b, pool_e, pool_b| {
409
393
let ( a, b) = ( cow_a. as_ref ( ) , cow_b. as_ref ( ) ) ;
410
394
let poly =
411
- commutative_op_mle_pair_pool ! (
395
+ commutative_op_mle_pair ! (
412
396
|a, b, res| {
413
397
match ( a. len( ) , b. len( ) ) {
414
398
( 1 , 1 ) => {
@@ -464,13 +448,13 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
464
448
pool_e,
465
449
pool_b
466
450
) ;
467
- try_recycle_arcpoly ( cow_a, pool_e, pool_b, len) ;
468
- try_recycle_arcpoly ( cow_b, pool_e, pool_b, len) ;
451
+ optional_arcpoly_unwrap_pushback ( cow_a, pool_e, pool_b, len) ;
452
+ optional_arcpoly_unwrap_pushback ( cow_b, pool_e, pool_b, len) ;
469
453
poly
470
454
} ,
471
455
& |cow_x, cow_a, cow_b, pool_e, pool_b| {
472
456
let ( x, a, b) = ( cow_x. as_ref ( ) , cow_a. as_ref ( ) , cow_b. as_ref ( ) ) ;
473
- let poly = op_mle_xa_b_pool ! (
457
+ let poly = op_mle_xa_b ! (
474
458
|x, a, b, res| {
475
459
let res = SyncUnsafeCell :: new( res) ;
476
460
assert_eq!( a. len( ) , 1 ) ;
@@ -490,9 +474,9 @@ pub(crate) fn wit_infer_by_expr_pool<'a, E: ExtensionField, const N: usize>(
490
474
pool_e,
491
475
pool_b
492
476
) ;
493
- try_recycle_arcpoly ( cow_a, pool_e, pool_b, len) ;
494
- try_recycle_arcpoly ( cow_b, pool_e, pool_b, len) ;
495
- try_recycle_arcpoly ( cow_x, pool_e, pool_b, len) ;
477
+ optional_arcpoly_unwrap_pushback ( cow_a, pool_e, pool_b, len) ;
478
+ optional_arcpoly_unwrap_pushback ( cow_b, pool_e, pool_b, len) ;
479
+ optional_arcpoly_unwrap_pushback ( cow_x, pool_e, pool_b, len) ;
496
480
poly
497
481
} ,
498
482
pool_e,
0 commit comments