@@ -401,27 +401,28 @@ def optimize(
401401 keys : Hashable | list [Hashable ] | set [Hashable ],
402402 ** kwargs : Any ,
403403) -> Mapping :
404- if not isinstance (keys , (list , set )):
405- keys = [keys ]
406- keys = list (flatten (keys ))
404+ keys = tuple (flatten (keys ))
407405
408406 if not isinstance (dsk , HighLevelGraph ):
409- dsk = HighLevelGraph .from_collections (id (dsk ), dsk , dependencies = ())
407+ dsk = HighLevelGraph .from_collections (str ( id (dsk ) ), dsk , dependencies = ())
410408
409+ dsk = optimize_blockwise (dsk , keys = keys )
410+ dsk = fuse_roots (dsk , keys = keys ) # type: ignore
411+ dsk = dsk .cull (set (keys )) # type: ignore
412+ return dsk
413+
414+
415+ def _get_optimization_function ():
411416 # Here we try to run optimizations from dask-awkward (if we detect
412417 # that dask-awkward has been imported). There is no cost to
413418 # running this optimization even in cases where it's unncessary
414- # because if no AwkwardInputLayers from daks -awkward are not
419+ # because if no AwkwardInputLayers from dask -awkward are
415420 # detected then the original graph is returned unchanged.
416421 if dask .config .get ("awkward" , default = False ):
417- from dask_awkward .lib .optimize import optimize
422+ from dask_awkward .lib .optimize import all_optimizations
418423
419- dsk = optimize (dsk , keys = keys ) # type: ignore[arg-type]
420-
421- dsk = optimize_blockwise (dsk , keys = keys )
422- dsk = fuse_roots (dsk , keys = keys ) # type: ignore
423- dsk = dsk .cull (set (keys )) # type: ignore
424- return dsk
424+ return all_optimizations
425+ return optimize
425426
426427
427428class AggHistogram (DaskMethodsMixin ):
@@ -479,7 +480,7 @@ def __dask_postpersist__(self) -> Any:
479480 return self ._rebuild , ()
480481
481482 __dask_optimize__ = globalmethod (
482- optimize , key = "histogram_optimize" , falsey = dont_optimize
483+ _get_optimization_function () , key = "histogram_optimize" , falsey = dont_optimize
483484 )
484485
485486 __dask_scheduler__ = staticmethod (tget )
@@ -706,7 +707,7 @@ def _rebuild(self, dsk: Any, *, rename: Any = None) -> Any:
706707 return type (self )(dsk , name , self .npartitions , self .histref )
707708
708709 __dask_optimize__ = globalmethod (
709- optimize , key = "histogram_optimize" , falsey = dont_optimize
710+ _get_optimization_function () , key = "histogram_optimize" , falsey = dont_optimize
710711 )
711712
712713 __dask_scheduler__ = staticmethod (tget )
0 commit comments