diff --git a/rapidfireai/evals/scheduling/controller.py b/rapidfireai/evals/scheduling/controller.py index c37c46fc..62241bc8 100644 --- a/rapidfireai/evals/scheduling/controller.py +++ b/rapidfireai/evals/scheduling/controller.py @@ -1400,7 +1400,19 @@ def run_multi_pipeline_inference( f"Busy actors: {status['busy_actors']}, " f"Gen: {status['current_generation']}" ) - time.sleep(0.5) + # Block until something new happens so a dead actor's failed + # futures surface here. ray.wait must be filtered to not-yet-ready + # futures: any already-resolved batch ref would satisfy + # num_returns=1 and turn this into a tight spin. + all_futures = [] + for task_info in active_tasks.values(): + all_futures.extend(task_info["futures"]) + if all_futures: + _ready, not_ready = ray.wait(all_futures, num_returns=len(all_futures), timeout=0) + if not_ready: + ray.wait(not_ready, num_returns=1, timeout=0.5) + else: + time.sleep(0.5) continue # Execute schedule @@ -1534,35 +1546,54 @@ def run_multi_pipeline_inference( self.logger.debug(f"Initialized actor {actor_id} for pipeline {pipeline_id} ({pipeline_name})") - futures = [] - preprocess_fn = pipeline_config.get("preprocess_fn") - postprocess_fn = pipeline_config.get("postprocess_fn") - compute_metrics_fn = pipeline_config.get("compute_metrics_fn") - accumulate_metrics_fn = pipeline_config.get("accumulate_metrics_fn") - for batch in batches: - future = actor.process_batch.remote( - batch, - preprocess_fn=preprocess_fn, - postprocess_fn=postprocess_fn, - compute_metrics_fn=compute_metrics_fn if accumulate_metrics_fn else None, - ) - futures.append(future) - - # Track task - task_start_time = time.time() - active_tasks[actor_id] = { - "futures": futures, - "pipeline_id": pipeline_id, - "shard_id": shard_id, - "task_id": task_id, - "batch_count": len(batches), - "start_time": task_start_time, - } + # Mirror the init-failure handler: if dispatch or bookkeeping + # raises, free the actor via remove_pipeline so it doesn't leak + # busy state. + try: + futures = [] + preprocess_fn = pipeline_config.get("preprocess_fn") + postprocess_fn = pipeline_config.get("postprocess_fn") + compute_metrics_fn = pipeline_config.get("compute_metrics_fn") + accumulate_metrics_fn = pipeline_config.get("accumulate_metrics_fn") + for batch in batches: + future = actor.process_batch.remote( + batch, + preprocess_fn=preprocess_fn, + postprocess_fn=postprocess_fn, + compute_metrics_fn=compute_metrics_fn if accumulate_metrics_fn else None, + ) + futures.append(future) + + # Track task + task_start_time = time.time() + active_tasks[actor_id] = { + "futures": futures, + "pipeline_id": pipeline_id, + "shard_id": shard_id, + "task_id": task_id, + "batch_count": len(batches), + "start_time": task_start_time, + } - # Update task status to in-progress - db.set_actor_task_start_time(task_id, task_start_time) - db.set_actor_task_status(task_id, TaskStatus.IN_PROGRESS) - db.set_pipeline_current_shard(pipeline_id, shard_id) + # Update task status to in-progress + db.set_actor_task_start_time(task_id, task_start_time) + db.set_actor_task_status(task_id, TaskStatus.IN_PROGRESS) + db.set_pipeline_current_shard(pipeline_id, shard_id) + except Exception as dispatch_err: + error_msg = str(dispatch_err) + self.logger.exception( + f"Pipeline {pipeline_id} ({pipeline_name}) failed during batch " + f"dispatch on actor {actor_id}: {error_msg}" + ) + db.set_actor_task_status(task_id, TaskStatus.FAILED) + db.set_actor_task_error(task_id, error_msg) + db.set_pipeline_status(pipeline_id, PipelineStatus.FAILED) + db.set_pipeline_error(pipeline_id, error_msg) + active_tasks.pop(actor_id, None) + scheduler.remove_pipeline(pipeline_id) + if progress_display: + progress_display.update_pipeline(pipeline_id, status="FAILED") + continue # PHASE 8: Compute final metrics for each pipeline (including dynamically cloned ones). # pipeline_id_to_config contains all pipelines (originals + clones added via _handle_clone). diff --git a/tests/test_pipeline_scheduler.py b/tests/test_pipeline_scheduler.py new file mode 100644 index 00000000..f6a80520 --- /dev/null +++ b/tests/test_pipeline_scheduler.py @@ -0,0 +1,91 @@ +"""Tests for PipelineScheduler bookkeeping. + +Invariant: an actor marked busy by schedule() must be freed by either +set_completed_task(actor_id) or remove_pipeline(pipeline_id). If neither +is called, the controller's busy-loop wedges. +""" + +from rapidfireai.evals.scheduling.pipeline_scheduler import PipelineScheduler + + +class TestPipelineSchedulerBookkeeping: + def test_schedule_marks_actor_busy(self): + scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=2, num_shards=3) + result = scheduler.schedule() + actor_id = result["actor_id"] + pipeline_id = result["pipeline_id"] + assert actor_id in (0, 1) + assert pipeline_id in (1, 2) + assert scheduler.actor_current_pipeline[actor_id] == pipeline_id + + def test_set_completed_task_frees_actor_and_advances_progress(self): + scheduler = PipelineScheduler(pipeline_ids=[1], num_actors=1, num_shards=2) + result = scheduler.schedule() + actor_id = result["actor_id"] + assert scheduler.actor_current_pipeline[actor_id] == 1 + assert scheduler.pipeline_shards_completed[1] == 0 + + scheduler.set_completed_task(actor_id) + + assert scheduler.actor_current_pipeline[actor_id] == -1 + assert scheduler.pipeline_shards_completed[1] == 1 + + def test_remove_pipeline_frees_actor_on_dispatch_failure(self): + scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=2, num_shards=2) + + first = scheduler.schedule() + actor_id = first["actor_id"] + pipeline_id = first["pipeline_id"] + assert scheduler.actor_current_pipeline[actor_id] == pipeline_id + + scheduler.remove_pipeline(pipeline_id) + + assert scheduler.actor_current_pipeline[actor_id] == -1 + assert pipeline_id not in scheduler.pipeline_ids + + second = scheduler.schedule() + assert second["pipeline_id"] != -1 + assert second["actor_id"] != -1 + + def test_dispatch_failure_does_not_wedge_remaining_pipelines(self): + scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=1, num_shards=1) + + first = scheduler.schedule() + failed_pipeline = first["pipeline_id"] + actor_id = first["actor_id"] + scheduler.remove_pipeline(failed_pipeline) + assert scheduler.actor_current_pipeline[actor_id] == -1 + + second = scheduler.schedule() + survivor = second["pipeline_id"] + assert survivor != -1 + assert survivor != failed_pipeline + scheduler.set_completed_task(second["actor_id"]) + + terminal = scheduler.schedule() + assert terminal["pipeline_id"] is None + + def test_all_actors_busy_returns_busy_sentinel(self): + scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=2, num_shards=4) + scheduler.schedule() + scheduler.schedule() + result = scheduler.schedule() + assert result == {"pipeline_id": -1, "actor_id": -1, "shard_id": -1} + + def test_set_completed_task_idempotent_on_free_actor(self): + scheduler = PipelineScheduler(pipeline_ids=[1], num_actors=1, num_shards=1) + assert scheduler.actor_current_pipeline[0] == -1 + scheduler.set_completed_task(0) + assert scheduler.actor_current_pipeline[0] == -1 + assert scheduler.pipeline_shards_completed[1] == 0 + + def test_actor_leaks_busy_when_neither_completion_nor_removal_called(self): + """Regression: if dispatch fails and the controller forgets remove_pipeline, + schedule() returns the busy sentinel indefinitely. This is the wedge the + controller fix in run_multi_pipeline_inference prevents.""" + scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=1, num_shards=1) + scheduler.schedule() + + for _ in range(20): + result = scheduler.schedule() + assert result == {"pipeline_id": -1, "actor_id": -1, "shard_id": -1}