Skip to content

Commit bcb2cda

Browse files
author
cancai
committed
feat: Implement shuffle map side spill support (not flight shuffle)
1 parent 0956446 commit bcb2cda

File tree

22 files changed

+980
-252
lines changed

22 files changed

+980
-252
lines changed

daft/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ def set_execution_config(
186186
maintain_order: bool | None = None,
187187
enable_dynamic_batching: bool | None = None,
188188
dynamic_batching_strategy: str | None = None,
189+
shuffle_spill_threshold: int | None = None,
189190
) -> DaftContext:
190191
"""Globally sets various configuration parameters which control various aspects of Daft execution.
191192
@@ -229,6 +230,7 @@ def set_execution_config(
229230
maintain_order: Whether to maintain order during execution. Defaults to True. Some blocking sink operators (e.g. write_parquet) won't respect this flag and will always keep maintain_order as false, and propagate to child operators. It's useful to set this to False for running df.collect() when no ordering is required.
230231
enable_dynamic_batching: Whether to enable dynamic batching. Defaults to False.
231232
dynamic_batching_strategy: The strategy to use for dynamic batching. Defaults to 'auto'.
233+
shuffle_spill_threshold: Memory threshold in bytes for shuffle spill. Defaults to None (no spill).
232234
"""
233235
# Replace values in the DaftExecutionConfig with user-specified overrides
234236
ctx = get_context()
@@ -265,6 +267,7 @@ def set_execution_config(
265267
maintain_order=maintain_order,
266268
enable_dynamic_batching=enable_dynamic_batching,
267269
dynamic_batching_strategy=dynamic_batching_strategy,
270+
shuffle_spill_threshold=shuffle_spill_threshold,
268271
)
269272

270273
ctx._ctx._daft_execution_config = new_daft_execution_config

daft/daft/__init__.pyi

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,18 +2032,20 @@ class LocalPhysicalPlan:
20322032
def from_logical_plan_builder(builder: LogicalPlanBuilder) -> LocalPhysicalPlan: ...
20332033

20342034
class RayPartitionRef:
2035-
object_ref: ray.ObjectRef
2035+
object_refs: list[ray.ObjectRef]
20362036
num_rows: int
20372037
size_bytes: int
20382038

2039-
def __init__(self, object_ref: ray.ObjectRef, num_rows: int, size_bytes: int): ...
2039+
def __init__(self, object_refs: list[ray.ObjectRef], num_rows: int, size_bytes: int): ...
20402040

20412041
class RaySwordfishTask:
20422042
def name(self) -> str: ...
2043+
def num_partitions(self) -> int: ...
20432044
def plan(self) -> LocalPhysicalPlan: ...
20442045
def psets(self) -> dict[str, list[RayPartitionRef]]: ...
20452046
def config(self) -> PyDaftExecutionConfig: ...
20462047
def context(self) -> dict[str, str]: ...
2048+
def is_into_batches(self) -> bool: ...
20472049

20482050
class RayTaskResult:
20492051
@staticmethod
@@ -2136,6 +2138,7 @@ class PyDaftExecutionConfig:
21362138
maintain_order: bool | None = None,
21372139
enable_dynamic_batching: bool | None = None,
21382140
dynamic_batching_strategy: str | None = None,
2141+
shuffle_spill_threshold: int | None = None,
21392142
) -> PyDaftExecutionConfig: ...
21402143
@property
21412144
def enable_scan_task_split_and_merge(self) -> bool: ...

daft/runners/flotilla.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ async def run_plan(
7575
from daft.daft import PyDaftContext
7676

7777
with profile():
78-
psets = {k: await asyncio.gather(*v) for k, v in psets.items()}
79-
psets_mp = {k: [v._micropartition for v in v] for k, v in psets.items()}
78+
psets_mp: dict[str, list[PyMicroPartition]] = {
79+
k: [ray.get(r)._micropartition for r in refs] for k, refs in psets.items()
80+
}
8081

8182
metas = []
8283
native_executor = NativeExecutor()
@@ -131,6 +132,8 @@ class RaySwordfishTaskHandle:
131132

132133
result_handle: ray.ObjectRef
133134
actor_handle: ray.actor.ActorHandle
135+
num_partitions: int
136+
is_into_batches: bool
134137
task: asyncio.Task[RayTaskResult] | None = None
135138

136139
async def _get_result(self) -> RayTaskResult:
@@ -142,11 +145,34 @@ async def _get_result(self) -> RayTaskResult:
142145
task_metadata: SwordfishTaskMetadata = await metadata_ref
143146
assert len(results) == len(task_metadata.partition_metadatas)
144147

148+
# Pack the results into partitions
149+
num_partitions = self.num_partitions
150+
partition_refs = []
151+
152+
# We rely on the task metadata for now because IntoBatches is the only operator
153+
# that dynamically generates partitions without a fixed mapping.
154+
is_into_batches = self.is_into_batches
155+
if is_into_batches:
156+
for res, meta in zip(results, task_metadata.partition_metadatas):
157+
partition_refs.append(RayPartitionRef([res], meta.num_rows, meta.size_bytes or 0))
158+
else:
159+
packed_results: list[list[ray.ObjectRef]] = [[] for _ in range(num_partitions)]
160+
packed_metadatas: list[list[PartitionMetadata]] = [[] for _ in range(num_partitions)]
161+
162+
for i, (res, meta) in enumerate(zip(results, task_metadata.partition_metadatas)):
163+
part_idx = i % num_partitions
164+
packed_results[part_idx].append(res)
165+
packed_metadatas[part_idx].append(meta)
166+
167+
for i in range(num_partitions):
168+
chunks = packed_results[i]
169+
metas = packed_metadatas[i]
170+
total_rows = sum(m.num_rows for m in metas)
171+
total_bytes = sum(m.size_bytes or 0 for m in metas)
172+
partition_refs.append(RayPartitionRef(chunks, total_rows, total_bytes))
173+
145174
return RayTaskResult.success(
146-
[
147-
RayPartitionRef(result, metadata.num_rows, metadata.size_bytes or 0)
148-
for result, metadata in zip(results, task_metadata.partition_metadatas)
149-
],
175+
partition_refs,
150176
task_metadata.stats,
151177
)
152178
except (ray.exceptions.ActorDiedError, ray.exceptions.ActorUnschedulableError):
@@ -179,13 +205,18 @@ def __init__(
179205
self.actor_handle = actor_handle
180206

181207
def submit_task(self, task: RaySwordfishTask) -> RaySwordfishTaskHandle:
182-
psets = {k: [v.object_ref for v in v] for k, v in task.psets().items()}
183-
result_handle = self.actor_handle.run_plan.options(name=task.name()).remote(
184-
task.plan(), task.config(), psets, task.context()
208+
psets = {k: [obj_ref for p in v for obj_ref in p.object_refs] for k, v in task.psets().items()}
209+
result_handle = self.actor_handle.run_plan.remote(
210+
task.plan(),
211+
task.config(),
212+
psets,
213+
task.context(),
185214
)
186215
return RaySwordfishTaskHandle(
187-
result_handle,
188-
self.actor_handle,
216+
result_handle=result_handle,
217+
actor_handle=self.actor_handle,
218+
num_partitions=task.num_partitions(),
219+
is_into_batches=task.is_into_batches(),
189220
)
190221

191222
def shutdown(self) -> None:
@@ -259,12 +290,17 @@ def __init__(self, dashboard_url: str | None = None) -> None:
259290
def run_plan(
260291
self,
261292
plan: DistributedPhysicalPlan,
262-
partition_sets: dict[str, PartitionSet[ray.ObjectRef]],
293+
partition_sets: dict[str, PartitionSet[list[ray.ObjectRef]]],
263294
) -> None:
264-
psets = {
265-
k: [RayPartitionRef(v.partition(), v.metadata().num_rows, v.metadata().size_bytes or 0) for v in v.values()]
266-
for k, v in partition_sets.items()
267-
}
295+
psets = {}
296+
for k, v in partition_sets.items():
297+
partition_refs = []
298+
for val in v.values():
299+
partition_refs.append(
300+
RayPartitionRef(val.partition(), val.metadata().num_rows, val.metadata().size_bytes or 0)
301+
)
302+
psets[k] = partition_refs
303+
268304
self.curr_plans[plan.idx()] = plan
269305
self.curr_result_gens[plan.idx()] = self.plan_runner.run_plan(plan, psets)
270306

@@ -289,7 +325,7 @@ async def get_next_partition(self, plan_id: str) -> RayMaterializedResult | Reco
289325
[PartitionMetadata(next_partition_ref.num_rows, next_partition_ref.size_bytes)]
290326
)
291327
materialized_result = RayMaterializedResult(
292-
partition=next_partition_ref.object_ref,
328+
partition=next_partition_ref.object_refs,
293329
metadatas=metadata_accessor,
294330
metadata_idx=0,
295331
)
@@ -373,7 +409,7 @@ def __init__(self) -> None:
373409
def stream_plan(
374410
self,
375411
plan: DistributedPhysicalPlan,
376-
partition_sets: dict[str, PartitionSet[RayMaterializedResult]],
412+
partition_sets: dict[str, PartitionSet[list[ray.ObjectRef]]],
377413
) -> Generator[RayMaterializedResult, None, RecordBatch]:
378414
plan_id = plan.idx()
379415
ray.get(self.runner.run_plan.remote(plan, partition_sets))

daft/runners/ray_runner.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,14 @@ def _to_pandas_ref(df: pd.DataFrame | ray.ObjectRef) -> ray.ObjectRef:
282282
raise ValueError(f"Expected a Ray object ref or a Pandas DataFrame, got {type(df)}")
283283

284284

285-
class RayPartitionSet(PartitionSet[ray.ObjectRef]):
285+
class RayPartitionSet(PartitionSet[list[ray.ObjectRef]]):
286286
_results: dict[PartID, RayMaterializedResult]
287287

288288
def __init__(self) -> None:
289289
super().__init__()
290290
self._results = {}
291291

292-
def items(self) -> list[tuple[PartID, MaterializedResult[ray.ObjectRef]]]:
292+
def items(self) -> list[tuple[PartID, MaterializedResult[list[ray.ObjectRef]]]]:
293293
return [(pid, result) for pid, result in sorted(self._results.items())]
294294

295295
def _get_merged_micropartition(self, schema: Schema) -> MicroPartition:
@@ -298,22 +298,30 @@ def _get_merged_micropartition(self, schema: Schema) -> MicroPartition:
298298
assert ids_and_partitions[0][0] == 0
299299
assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions)
300300

301-
all_partitions = ray.get([part.partition() for id, part in ids_and_partitions])
302-
return MicroPartition.concat_or_empty(all_partitions, schema)
301+
all_refs = []
302+
for _, part in ids_and_partitions:
303+
all_refs.extend(part.partition())
304+
305+
all_micropartitions = ray.get(all_refs)
306+
return MicroPartition.concat_or_empty(all_micropartitions, schema)
303307

304308
def _get_preview_micropartitions(self, num_rows: int) -> list[MicroPartition]:
305309
ids_and_partitions = self.items()
306310
preview_parts = []
307311
for _, mat_result in ids_and_partitions:
308-
ref: ray.ObjectRef = mat_result.partition()
309-
part: MicroPartition = ray.get(ref)
310-
part_len = len(part)
311-
if part_len >= num_rows: # if this part has enough rows, take what we need and break
312-
preview_parts.append(part.slice(0, num_rows))
312+
refs: list[ray.ObjectRef] = mat_result.partition()
313+
parts: list[MicroPartition] = ray.get(refs)
314+
for part in parts:
315+
part_len = len(part)
316+
if part_len >= num_rows:
317+
preview_parts.append(part.slice(0, num_rows))
318+
num_rows = 0
319+
break
320+
else:
321+
num_rows -= part_len
322+
preview_parts.append(part)
323+
if num_rows == 0:
313324
break
314-
else: # otherwise, take the whole part and keep going
315-
num_rows -= part_len
316-
preview_parts.append(part)
317325
return preview_parts
318326

319327
def to_ray_dataset(self) -> RayDataset:
@@ -350,9 +358,9 @@ def _make_dask_dataframe_partition_from_micropartition(partition: MicroPartition
350358
return cast("dd.DataFrame", dd.from_delayed(ddf_parts, meta=meta))
351359

352360
def get_partition(self, idx: PartID) -> RayMaterializedResult:
353-
return self._results[idx].partition()
361+
return self._results[idx]
354362

355-
def set_partition(self, idx: PartID, result: MaterializedResult[ray.ObjectRef]) -> None:
363+
def set_partition(self, idx: PartID, result: MaterializedResult[list[ray.ObjectRef]]) -> None:
356364
assert isinstance(result, RayMaterializedResult)
357365
self._results[idx] = result
358366

@@ -377,8 +385,11 @@ def num_partitions(self) -> int:
377385
return len(self._results)
378386

379387
def wait(self) -> None:
380-
deduped_object_refs = {r.partition() for r in self._results.values()}
381-
ray.wait(list(deduped_object_refs), fetch_local=False, num_returns=len(deduped_object_refs))
388+
all_refs = []
389+
for r in self._results.values():
390+
all_refs.extend(r.partition())
391+
deduped_object_refs = list(set(all_refs))
392+
ray.wait(deduped_object_refs, fetch_local=False, num_returns=len(deduped_object_refs))
382393

383394

384395
def _from_arrow_type_with_ray_data_extensions(arrow_type: pa.DataType) -> DataType:
@@ -444,7 +455,7 @@ def partition_set_from_ray_dataset(
444455
pset = RayPartitionSet()
445456

446457
for i, obj in enumerate(daft_micropartitions):
447-
pset.set_partition(i, RayMaterializedResult(obj))
458+
pset.set_partition(i, RayMaterializedResult([obj]))
448459
return (
449460
pset,
450461
daft_schema,
@@ -476,7 +487,7 @@ def partition_set_from_dask_dataframe(
476487
pset = RayPartitionSet()
477488

478489
for i, obj in enumerate(daft_micropartitions):
479-
pset.set_partition(i, RayMaterializedResult(obj))
490+
pset.set_partition(i, RayMaterializedResult([obj]))
480491
return (
481492
pset,
482493
schemas[0],
@@ -487,7 +498,7 @@ def partition_set_from_dask_dataframe(
487498

488499

489500
@ray.remote # type: ignore[untyped-decorator]
490-
def get_metas(*partitions: MicroPartition) -> list[PartitionMetadata]:
501+
def get_metas(partitions: list[MicroPartition]) -> list[PartitionMetadata]:
491502
return [PartitionMetadata.from_table(partition) for partition in partitions]
492503

493504

@@ -665,7 +676,7 @@ def run_iter_tables(
665676
self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None
666677
) -> Iterator[MicroPartition]:
667678
for result in self.run_iter(builder, results_buffer_size=results_buffer_size):
668-
yield ray.get(result.partition())
679+
yield result.micropartition()
669680

670681
def _collect_into_cache(
671682
self, results_iter: Generator[RayMaterializedResult, None, RecordBatch]
@@ -689,48 +700,62 @@ def run(self, builder: LogicalPlanBuilder) -> tuple[PartitionCacheEntry, RecordB
689700
results_iter = self.run_iter(builder)
690701
return self._collect_into_cache(results_iter)
691702

692-
def put_partition_set_into_cache(self, pset: PartitionSet[ray.ObjectRef]) -> PartitionCacheEntry:
703+
def put_partition_set_into_cache(self, pset: PartitionSet[list[ray.ObjectRef]]) -> PartitionCacheEntry:
693704
if isinstance(pset, LocalPartitionSet):
694705
new_pset = RayPartitionSet()
695706
metadata_accessor = PartitionMetadataAccessor.from_metadata_list([v.metadata() for v in pset.values()])
696707
for i, (pid, py_mat_result) in enumerate(pset.items()):
697-
new_pset.set_partition(
698-
pid, RayMaterializedResult(ray.put(py_mat_result.partition()), metadata_accessor, i)
699-
)
708+
part = py_mat_result.partition()
709+
new_pset.set_partition(pid, RayMaterializedResult([ray.put(part)], metadata_accessor, i))
700710
pset = new_pset
701711
return self._part_set_cache.put_partition_set(pset=pset)
702712

703713
def runner_io(self) -> RayRunnerIO:
704714
return RayRunnerIO()
705715

706716

707-
class RayMaterializedResult(MaterializedResult[ray.ObjectRef]):
717+
class RayMaterializedResult(MaterializedResult[list[ray.ObjectRef]]):
708718
def __init__(
709719
self,
710-
partition: ray.ObjectRef[Any],
720+
partition: list[ray.ObjectRef[Any]],
711721
metadatas: PartitionMetadataAccessor | None = None,
712722
metadata_idx: int | None = None,
713723
):
724+
assert isinstance(partition, list)
714725
self._partition = partition
715726
if metadatas is None:
716727
assert metadata_idx is None
717728
metadatas = PartitionMetadataAccessor(get_metas.remote(self._partition))
718-
metadata_idx = 0
729+
719730
self._metadatas = metadatas
720731
self._metadata_idx = metadata_idx
721732

722-
def partition(self) -> ray.ObjectRef:
733+
def partition(self) -> list[ray.ObjectRef]:
723734
return self._partition
724735

725736
def micropartition(self) -> MicroPartition:
726-
return ray.get(self._partition)
737+
parts = ray.get(self._partition)
738+
return MicroPartition.concat(parts)
727739

728740
def metadata(self) -> PartitionMetadata:
729-
assert self._metadata_idx is not None
730-
return self._metadatas.get_index(self._metadata_idx)
741+
all_metas = self._metadatas._get_metadatas()
742+
743+
if self._metadata_idx is not None:
744+
return all_metas[self._metadata_idx]
745+
746+
total_rows = sum(m.num_rows for m in all_metas)
747+
total_bytes = 0
748+
for m in all_metas:
749+
if m.size_bytes is not None:
750+
total_bytes += m.size_bytes
751+
return PartitionMetadata(
752+
num_rows=total_rows,
753+
size_bytes=total_bytes if total_bytes > 0 else None,
754+
)
731755

732756
def cancel(self) -> None:
733-
return ray.cancel(self._partition)
757+
for p in self._partition:
758+
ray.cancel(p)
734759

735760
def _noop(self, _: ray.ObjectRef) -> None:
736761
return None

0 commit comments

Comments
 (0)