Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,14 +1811,10 @@ def _zero_getitem(arr: ak.Array, zeroth: slice, rest: tuple[slice, ...]) -> ak.A


def compute_typetracer(dsk: HighLevelGraph, name: str) -> ak.Array:
from dask_awkward.lib.optimize import get_sync

key = (name, 0)
return typetracer_array(
Delayed(
key,
dsk.cull({key}),
layer=name,
).compute()
)
return typetracer_array(get_sync(dsk, [key])[0])


def new_array_object(
Expand Down
61 changes: 60 additions & 1 deletion src/dask_awkward/lib/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import awkward as ak
import dask.config
from awkward.typetracer import touch_data
from dask._task_spec import Alias, DataNode, GraphNode, List, Tuple
from dask.blockwise import Blockwise, fuse_roots, optimize_blockwise
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph
from dask.local import get_sync

from dask_awkward.layers import (
AwkwardBlockwiseLayer,
Expand Down Expand Up @@ -83,6 +83,65 @@ def optimize(dsk: HighLevelGraph, keys: Sequence[Key], **_: Any) -> Mapping:
return dsk


get_cache = {}


class NoKey: ...


def _unwind(llg, arg):
if isinstance(arg, Task):
func = arg.func
args = arg.args
kwargs = arg.kwargs
args = [_unwind(llg, arg) for arg in args]
args = [_ for _ in args if _ is not NoKey]
kwargs = {k: _unwind(llg, v) for k, v in kwargs.items()}
kwargs = {k: v for k, v in kwargs.items() if v is not NoKey}
return func(*args, **kwargs)
if isinstance(arg, Alias):
return _get_sync(llg, arg.target)
if isinstance(arg, List):
out = [_unwind(llg, _) for _ in arg.args]
return [_ for _ in out if _ is not NoKey]
if isinstance(arg, Tuple):
out = tuple(_unwind(llg, _) for _ in arg.args)
return tuple(_ for _ in out if _ is not NoKey)
if isinstance(arg, DataNode):
return arg.value
if isinstance(arg, GraphNode):
# other types to implement
raise ValueError

return arg


def _get_sync(llg, key):
if key not in get_cache:
if key in llg:
task = llg[key]
else:
return NoKey
out = _unwind(llg, task)
try:
if out in llg:
# some things just return another key
out = _get_sync(llg, out)
except TypeError:
pass
get_cache[key] = out

return get_cache[key]


def get_sync(hlg, keys):
get_cache.clear()
llg = dict(hlg)
out = [_get_sync(llg, key) for key in keys]
get_cache.clear()
return out


def _prepare_buffer_projection(
dsk: HighLevelGraph, keys: Sequence[Key]
) -> tuple[dict[str, TypeTracerReport], dict[str, Any]] | None:
Expand Down
Loading