@@ -282,14 +282,14 @@ def _to_pandas_ref(df: pd.DataFrame | ray.ObjectRef) -> ray.ObjectRef:
282282 raise ValueError (f"Expected a Ray object ref or a Pandas DataFrame, got { type (df )} " )
283283
284284
285- class RayPartitionSet (PartitionSet [ray .ObjectRef ]):
285+ class RayPartitionSet (PartitionSet [list [ ray .ObjectRef ] ]):
286286 _results : dict [PartID , RayMaterializedResult ]
287287
288288 def __init__ (self ) -> None :
289289 super ().__init__ ()
290290 self ._results = {}
291291
292- def items (self ) -> list [tuple [PartID , MaterializedResult [ray .ObjectRef ]]]:
292+ def items (self ) -> list [tuple [PartID , MaterializedResult [list [ ray .ObjectRef ] ]]]:
293293 return [(pid , result ) for pid , result in sorted (self ._results .items ())]
294294
295295 def _get_merged_micropartition (self , schema : Schema ) -> MicroPartition :
@@ -298,22 +298,30 @@ def _get_merged_micropartition(self, schema: Schema) -> MicroPartition:
298298 assert ids_and_partitions [0 ][0 ] == 0
299299 assert ids_and_partitions [- 1 ][0 ] + 1 == len (ids_and_partitions )
300300
301- all_partitions = ray .get ([part .partition () for id , part in ids_and_partitions ])
302- return MicroPartition .concat_or_empty (all_partitions , schema )
301+ all_refs = []
302+ for _ , part in ids_and_partitions :
303+ all_refs .extend (part .partition ())
304+
305+ all_micropartitions = ray .get (all_refs )
306+ return MicroPartition .concat_or_empty (all_micropartitions , schema )
303307
304308 def _get_preview_micropartitions (self , num_rows : int ) -> list [MicroPartition ]:
305309 ids_and_partitions = self .items ()
306310 preview_parts = []
307311 for _ , mat_result in ids_and_partitions :
308- ref : ray .ObjectRef = mat_result .partition ()
309- part : MicroPartition = ray .get (ref )
310- part_len = len (part )
311- if part_len >= num_rows : # if this part has enough rows, take what we need and break
312- preview_parts .append (part .slice (0 , num_rows ))
312+ refs : list [ray .ObjectRef ] = mat_result .partition ()
313+ parts : list [MicroPartition ] = ray .get (refs )
314+ for part in parts :
315+ part_len = len (part )
316+ if part_len >= num_rows :
317+ preview_parts .append (part .slice (0 , num_rows ))
318+ num_rows = 0
319+ break
320+ else :
321+ num_rows -= part_len
322+ preview_parts .append (part )
323+ if num_rows == 0 :
313324 break
314- else : # otherwise, take the whole part and keep going
315- num_rows -= part_len
316- preview_parts .append (part )
317325 return preview_parts
318326
319327 def to_ray_dataset (self ) -> RayDataset :
@@ -350,9 +358,9 @@ def _make_dask_dataframe_partition_from_micropartition(partition: MicroPartition
350358 return cast ("dd.DataFrame" , dd .from_delayed (ddf_parts , meta = meta ))
351359
352360 def get_partition (self , idx : PartID ) -> RayMaterializedResult :
353- return self ._results [idx ]. partition ()
361+ return self ._results [idx ]
354362
355- def set_partition (self , idx : PartID , result : MaterializedResult [ray .ObjectRef ]) -> None :
363+ def set_partition (self , idx : PartID , result : MaterializedResult [list [ ray .ObjectRef ] ]) -> None :
356364 assert isinstance (result , RayMaterializedResult )
357365 self ._results [idx ] = result
358366
@@ -377,8 +385,11 @@ def num_partitions(self) -> int:
377385 return len (self ._results )
378386
379387 def wait (self ) -> None :
380- deduped_object_refs = {r .partition () for r in self ._results .values ()}
381- ray .wait (list (deduped_object_refs ), fetch_local = False , num_returns = len (deduped_object_refs ))
388+ all_refs = []
389+ for r in self ._results .values ():
390+ all_refs .extend (r .partition ())
391+ deduped_object_refs = list (set (all_refs ))
392+ ray .wait (deduped_object_refs , fetch_local = False , num_returns = len (deduped_object_refs ))
382393
383394
384395def _from_arrow_type_with_ray_data_extensions (arrow_type : pa .DataType ) -> DataType :
@@ -444,7 +455,7 @@ def partition_set_from_ray_dataset(
444455 pset = RayPartitionSet ()
445456
446457 for i , obj in enumerate (daft_micropartitions ):
447- pset .set_partition (i , RayMaterializedResult (obj ))
458+ pset .set_partition (i , RayMaterializedResult ([ obj ] ))
448459 return (
449460 pset ,
450461 daft_schema ,
@@ -476,7 +487,7 @@ def partition_set_from_dask_dataframe(
476487 pset = RayPartitionSet ()
477488
478489 for i , obj in enumerate (daft_micropartitions ):
479- pset .set_partition (i , RayMaterializedResult (obj ))
490+ pset .set_partition (i , RayMaterializedResult ([ obj ] ))
480491 return (
481492 pset ,
482493 schemas [0 ],
@@ -487,7 +498,7 @@ def partition_set_from_dask_dataframe(
487498
488499
489500@ray .remote # type: ignore[untyped-decorator]
490- def get_metas (* partitions : MicroPartition ) -> list [PartitionMetadata ]:
501+ def get_metas (partitions : list [ MicroPartition ] ) -> list [PartitionMetadata ]:
491502 return [PartitionMetadata .from_table (partition ) for partition in partitions ]
492503
493504
@@ -665,7 +676,7 @@ def run_iter_tables(
665676 self , builder : LogicalPlanBuilder , results_buffer_size : int | None = None
666677 ) -> Iterator [MicroPartition ]:
667678 for result in self .run_iter (builder , results_buffer_size = results_buffer_size ):
668- yield ray . get ( result .partition () )
679+ yield result .micropartition ( )
669680
670681 def _collect_into_cache (
671682 self , results_iter : Generator [RayMaterializedResult , None , RecordBatch ]
@@ -689,48 +700,62 @@ def run(self, builder: LogicalPlanBuilder) -> tuple[PartitionCacheEntry, RecordB
689700 results_iter = self .run_iter (builder )
690701 return self ._collect_into_cache (results_iter )
691702
692- def put_partition_set_into_cache (self , pset : PartitionSet [ray .ObjectRef ]) -> PartitionCacheEntry :
703+ def put_partition_set_into_cache (self , pset : PartitionSet [list [ ray .ObjectRef ] ]) -> PartitionCacheEntry :
693704 if isinstance (pset , LocalPartitionSet ):
694705 new_pset = RayPartitionSet ()
695706 metadata_accessor = PartitionMetadataAccessor .from_metadata_list ([v .metadata () for v in pset .values ()])
696707 for i , (pid , py_mat_result ) in enumerate (pset .items ()):
697- new_pset .set_partition (
698- pid , RayMaterializedResult (ray .put (py_mat_result .partition ()), metadata_accessor , i )
699- )
708+ part = py_mat_result .partition ()
709+ new_pset .set_partition (pid , RayMaterializedResult ([ray .put (part )], metadata_accessor , i ))
700710 pset = new_pset
701711 return self ._part_set_cache .put_partition_set (pset = pset )
702712
703713 def runner_io (self ) -> RayRunnerIO :
704714 return RayRunnerIO ()
705715
706716
707- class RayMaterializedResult (MaterializedResult [ray .ObjectRef ]):
717+ class RayMaterializedResult (MaterializedResult [list [ ray .ObjectRef ] ]):
708718 def __init__ (
709719 self ,
710- partition : ray .ObjectRef [Any ],
720+ partition : list [ ray .ObjectRef [Any ] ],
711721 metadatas : PartitionMetadataAccessor | None = None ,
712722 metadata_idx : int | None = None ,
713723 ):
724+ assert isinstance (partition , list )
714725 self ._partition = partition
715726 if metadatas is None :
716727 assert metadata_idx is None
717728 metadatas = PartitionMetadataAccessor (get_metas .remote (self ._partition ))
718- metadata_idx = 0
729+
719730 self ._metadatas = metadatas
720731 self ._metadata_idx = metadata_idx
721732
722- def partition (self ) -> ray .ObjectRef :
733+ def partition (self ) -> list [ ray .ObjectRef ] :
723734 return self ._partition
724735
725736 def micropartition (self ) -> MicroPartition :
726- return ray .get (self ._partition )
737+ parts = ray .get (self ._partition )
738+ return MicroPartition .concat (parts )
727739
728740 def metadata (self ) -> PartitionMetadata :
729- assert self ._metadata_idx is not None
730- return self ._metadatas .get_index (self ._metadata_idx )
741+ all_metas = self ._metadatas ._get_metadatas ()
742+
743+ if self ._metadata_idx is not None :
744+ return all_metas [self ._metadata_idx ]
745+
746+ total_rows = sum (m .num_rows for m in all_metas )
747+ total_bytes = 0
748+ for m in all_metas :
749+ if m .size_bytes is not None :
750+ total_bytes += m .size_bytes
751+ return PartitionMetadata (
752+ num_rows = total_rows ,
753+ size_bytes = total_bytes if total_bytes > 0 else None ,
754+ )
731755
732756 def cancel (self ) -> None :
733- return ray .cancel (self ._partition )
757+ for p in self ._partition :
758+ ray .cancel (p )
734759
735760 def _noop (self , _ : ray .ObjectRef ) -> None :
736761 return None
0 commit comments