1414from typing import TYPE_CHECKING , Any
1515
1616import data_designer .lazy_heavy_imports as lazy
17+ from data_designer .config .analysis .dataset_profiler import DatasetProfilerResults
1718from data_designer .config .base import ProcessorConfig
1819from data_designer .config .config_builder import BuilderConfig , DataDesignerConfigBuilder
1920from data_designer .config .data_designer_config import DataDesignerConfig
21+ from data_designer .config .dataset_metadata import DatasetMetadata
2022from data_designer .config .errors import InvalidFileFormatError
2123from data_designer .config .seed import IndexRange , PartitionBlock , SamplingStrategy
2224from data_designer .config .seed_source import LocalFileSeedSource
2325from data_designer .config .utils .constants import DEFAULT_NUM_RECORDS
2426from data_designer .config .utils .type_helpers import StrEnum
2527from data_designer .config .version import get_library_version
2628from data_designer .engine .dataset_builders .errors import ArtifactStorageError
29+ from data_designer .engine .storage .artifact_storage import ArtifactStorage , ResumeMode
2730from data_designer .interface .errors import DataDesignerWorkflowError
2831from data_designer .interface .results import (
2932 SUPPORTED_EXPORT_FORMATS ,
3740if TYPE_CHECKING :
3841 import pandas as pd
3942
40- from data_designer .config .analysis .dataset_profiler import DatasetProfilerResults
4143 from data_designer .interface .data_designer import DataDesigner
4244
4345
4446logger = logging .getLogger (__name__ )
4547
4648OnSuccessCallback = Callable [[Path ], Path | str ]
49+ WORKFLOW_METADATA_FILENAME = "workflow-metadata.json"
50+ COMPLETED_STAGE_STATUSES = {"completed" , "completed_empty" }
51+ RESUMABLE_STAGE_STATUSES = {"running" , "failed" }
4752
4853
4954@dataclass (frozen = True )
@@ -221,8 +226,8 @@ def add_stage(
221226 )
222227 return self
223228
224- def run (self ) -> CompositeWorkflowResults :
225- """Run all stages from scratch .
229+ def run (self , * , resume : ResumeMode = ResumeMode . NEVER ) -> CompositeWorkflowResults :
230+ """Run all stages, optionally reusing compatible completed stage outputs .
226231
227232 Each stage writes a deterministic artifact directory under the parent
228233 Data Designer artifact path. Downstream stages are seeded from the
@@ -233,6 +238,7 @@ def run(self) -> CompositeWorkflowResults:
233238
234239 workflow_path = self ._data_designer .artifact_path / self .name
235240 workflow_path .mkdir (parents = True , exist_ok = True )
241+ prior_metadata = _read_prior_workflow_metadata (workflow_path , self .name , resume )
236242 metadata : dict [str , Any ] = {
237243 "name" : self .name ,
238244 "library_version" : get_library_version (),
@@ -245,6 +251,7 @@ def run(self) -> CompositeWorkflowResults:
245251 previous_stage_name : str | None = None
246252 previous_stage_fingerprint : str | None = None
247253 skipped_upstream_stage : str | None = None
254+ force_rerun_downstream = False
248255
249256 for index , stage in enumerate (self ._stages ):
250257 stage_dir_name = _stage_dir_name (index , stage .name )
@@ -288,7 +295,43 @@ def run(self) -> CompositeWorkflowResults:
288295 upstream_fingerprint = previous_stage_fingerprint ,
289296 )
290297 stage_path = workflow_path / stage_dir_name
291- if stage_path .exists ():
298+ prior_stage_metadata = _get_prior_stage_metadata (prior_metadata , index , stage , stage_dir_name )
299+ stage_resume = ResumeMode .NEVER
300+ prior_matches = (
301+ not force_rerun_downstream
302+ and prior_stage_metadata is not None
303+ and prior_stage_metadata .get ("fingerprint" ) == stage_fingerprint
304+ )
305+
306+ if prior_matches and _can_skip_prior_stage (stage , prior_stage_metadata ):
307+ stage_metadata .update (prior_stage_metadata )
308+ output_seed_path = Path (stage_metadata ["output_seed_path" ])
309+ output_records = _count_parquet_records (output_seed_path )
310+ output_result = _stage_result_from_metadata (
311+ workflow_path = workflow_path ,
312+ stage = stage ,
313+ stage_dir_name = stage_dir_name ,
314+ stage_builder = stage_builder ,
315+ )
316+ stage_results [stage .name ] = output_result
317+ stage_output_paths [stage .name ] = output_seed_path
318+ previous_seed_path = output_seed_path
319+ previous_output_records = output_records
320+ previous_stage_name = stage .name
321+ previous_stage_fingerprint = stage_fingerprint
322+ if stage_metadata ["status" ] == "completed_empty" :
323+ skipped_upstream_stage = stage .name
324+ _write_workflow_metadata (workflow_path , metadata )
325+ continue
326+
327+ if prior_matches and prior_stage_metadata .get ("status" ) in RESUMABLE_STAGE_STATUSES and stage_path .exists ():
328+ stage_resume = ResumeMode .ALWAYS
329+ elif resume == ResumeMode .ALWAYS and not force_rerun_downstream :
330+ raise DataDesignerWorkflowError (
331+ f"Cannot resume workflow { self .name !r} : stage { stage .name !r} is not reusable."
332+ )
333+
334+ if stage_resume == ResumeMode .NEVER and stage_path .exists ():
292335 shutil .rmtree (stage_path )
293336
294337 stage_metadata .update (
@@ -310,11 +353,15 @@ def run(self) -> CompositeWorkflowResults:
310353 num_records = num_records ,
311354 dataset_name = stage_dir_name ,
312355 artifact_path = workflow_path ,
356+ resume = stage_resume ,
313357 )
314358 actual_records = result .count_records ()
315359 output_result = result
316360 output_source_result = result
317361 if stage .output_processors :
362+ output_processor_path = stage_path / "output-processors"
363+ if output_processor_path .exists ():
364+ shutil .rmtree (output_processor_path )
318365 output_processor_builder = _output_processor_config_builder (
319366 stage_builder = stage_builder ,
320367 seed_path = result .artifact_storage .final_dataset_path ,
@@ -368,6 +415,7 @@ def run(self) -> CompositeWorkflowResults:
368415 previous_output_records = output_records
369416 previous_stage_name = stage .name
370417 previous_stage_fingerprint = stage_fingerprint
418+ force_rerun_downstream = True
371419 _write_workflow_metadata (workflow_path , metadata )
372420
373421 return CompositeWorkflowResults (
@@ -378,6 +426,123 @@ def run(self) -> CompositeWorkflowResults:
378426 )
379427
380428
429+ def _read_prior_workflow_metadata (
430+ workflow_path : Path ,
431+ workflow_name : str ,
432+ resume : ResumeMode ,
433+ ) -> dict [str , Any ] | None :
434+ if resume == ResumeMode .NEVER :
435+ return None
436+ metadata_path = workflow_path / WORKFLOW_METADATA_FILENAME
437+ if not metadata_path .exists ():
438+ if resume == ResumeMode .ALWAYS :
439+ raise DataDesignerWorkflowError (f"Cannot resume workflow { workflow_name !r} : no workflow metadata found." )
440+ return None
441+ try :
442+ metadata = json .loads (metadata_path .read_text (encoding = "utf-8" ))
443+ except json .JSONDecodeError as exc :
444+ raise DataDesignerWorkflowError (
445+ f"Cannot resume workflow { workflow_name !r} : workflow metadata is corrupt."
446+ ) from exc
447+ except OSError as exc :
448+ raise DataDesignerWorkflowError (
449+ f"Cannot resume workflow { workflow_name !r} : workflow metadata could not be read."
450+ ) from exc
451+ if metadata .get ("name" ) != workflow_name :
452+ raise DataDesignerWorkflowError (
453+ f"Cannot resume workflow { workflow_name !r} : workflow metadata name does not match."
454+ )
455+ return metadata
456+
457+
458+ def _get_prior_stage_metadata (
459+ prior_metadata : dict [str , Any ] | None ,
460+ index : int ,
461+ stage : _WorkflowStage ,
462+ stage_dir_name : str ,
463+ ) -> dict [str , Any ] | None :
464+ if prior_metadata is None :
465+ return None
466+ stages = prior_metadata .get ("stages" )
467+ if not isinstance (stages , list ) or index >= len (stages ):
468+ return None
469+ prior_stage = stages [index ]
470+ if not isinstance (prior_stage , dict ):
471+ return None
472+ if prior_stage .get ("name" ) != stage .name or prior_stage .get ("stage_dir" ) != stage_dir_name :
473+ return None
474+ return prior_stage
475+
476+
477+ def _can_skip_prior_stage (stage : _WorkflowStage , prior_stage_metadata : dict [str , Any ]) -> bool :
478+ if prior_stage_metadata .get ("status" ) not in COMPLETED_STAGE_STATUSES :
479+ return False
480+ if stage .on_success is not None and stage .on_success_version is None :
481+ return False
482+ output_seed_path = prior_stage_metadata .get ("output_seed_path" )
483+ if not isinstance (output_seed_path , str ) or not output_seed_path :
484+ return False
485+ try :
486+ _count_parquet_records (Path (output_seed_path ))
487+ except DataDesignerWorkflowError :
488+ return False
489+ return True
490+
491+
492+ def _stage_result_from_metadata (
493+ * ,
494+ workflow_path : Path ,
495+ stage : _WorkflowStage ,
496+ stage_dir_name : str ,
497+ stage_builder : DataDesignerConfigBuilder ,
498+ ) -> DatasetCreationResults :
499+ main_storage = ArtifactStorage (artifact_path = workflow_path , dataset_name = stage_dir_name , resume = ResumeMode .ALWAYS )
500+ result_storage = main_storage
501+ result_builder = stage_builder
502+ if stage .output_processors :
503+ result_storage = ArtifactStorage (
504+ artifact_path = workflow_path / stage_dir_name ,
505+ dataset_name = "output-processors" ,
506+ resume = ResumeMode .ALWAYS ,
507+ )
508+ result_builder = _output_processor_config_builder (
509+ stage_builder = stage_builder ,
510+ seed_path = main_storage .final_dataset_path ,
511+ output_processors = stage .output_processors ,
512+ )
513+ return DatasetCreationResults (
514+ artifact_storage = result_storage ,
515+ analysis = _load_stage_analysis (result_storage ),
516+ config_builder = result_builder ,
517+ dataset_metadata = DatasetMetadata (),
518+ )
519+
520+
521+ def _load_stage_analysis (artifact_storage : ArtifactStorage ) -> Any :
522+ try :
523+ metadata = artifact_storage .read_metadata ()
524+ except (FileNotFoundError , json .JSONDecodeError , OSError ):
525+ return None
526+ column_statistics = metadata .get ("column_statistics" )
527+ if not column_statistics :
528+ return None
529+ num_records = metadata .get ("actual_num_records" )
530+ if num_records is None :
531+ num_records = _count_parquet_records (artifact_storage .final_dataset_path )
532+ try :
533+ return DatasetProfilerResults .model_validate (
534+ {
535+ "num_records" : num_records ,
536+ "target_num_records" : metadata .get ("target_num_records" , num_records ),
537+ "column_statistics" : column_statistics ,
538+ "side_effect_column_names" : metadata .get ("side_effect_column_names" ),
539+ "column_profiles" : metadata .get ("column_profiles" ),
540+ }
541+ )
542+ except Exception :
543+ return None
544+
545+
381546def _clone_config_builder (config_builder : DataDesignerConfigBuilder ) -> DataDesignerConfigBuilder :
382547 return DataDesignerConfigBuilder .from_config (BuilderConfig (data_designer = config_builder .build ()))
383548
@@ -527,7 +692,7 @@ def _parquet_files(path: Path) -> list[Path]:
527692
528693
529694def _write_workflow_metadata (workflow_path : Path , metadata : dict [str , Any ]) -> None :
530- path = workflow_path / "workflow-metadata.json"
695+ path = workflow_path / WORKFLOW_METADATA_FILENAME
531696 path .write_text (json .dumps (metadata , indent = 2 , sort_keys = True ), encoding = "utf-8" )
532697
533698
0 commit comments