Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions rapidfireai/evals/scheduling/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,15 @@ def run_multi_pipeline_inference(
f"Busy actors: {status['busy_actors']}, "
f"Gen: {status['current_generation']}"
)
time.sleep(0.5)
# Block on any in-flight future so a dead actor's failed futures
# surface here and get reaped on the next iteration.
all_pending = []
for task_info in active_tasks.values():
all_pending.extend(task_info["futures"])
if all_pending:
ray.wait(all_pending, num_returns=1, timeout=0.5)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
else:
time.sleep(0.5)
continue

# Execute schedule
Expand Down Expand Up @@ -1534,35 +1542,54 @@ def run_multi_pipeline_inference(

self.logger.debug(f"Initialized actor {actor_id} for pipeline {pipeline_id} ({pipeline_name})")

futures = []
preprocess_fn = pipeline_config.get("preprocess_fn")
postprocess_fn = pipeline_config.get("postprocess_fn")
compute_metrics_fn = pipeline_config.get("compute_metrics_fn")
accumulate_metrics_fn = pipeline_config.get("accumulate_metrics_fn")
for batch in batches:
future = actor.process_batch.remote(
batch,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
compute_metrics_fn=compute_metrics_fn if accumulate_metrics_fn else None,
)
futures.append(future)

# Track task
task_start_time = time.time()
active_tasks[actor_id] = {
"futures": futures,
"pipeline_id": pipeline_id,
"shard_id": shard_id,
"task_id": task_id,
"batch_count": len(batches),
"start_time": task_start_time,
}
# Mirror the init-failure handler: if dispatch or bookkeeping
# raises, free the actor via remove_pipeline so it doesn't leak
# busy state.
try:
futures = []
preprocess_fn = pipeline_config.get("preprocess_fn")
postprocess_fn = pipeline_config.get("postprocess_fn")
compute_metrics_fn = pipeline_config.get("compute_metrics_fn")
accumulate_metrics_fn = pipeline_config.get("accumulate_metrics_fn")
for batch in batches:
future = actor.process_batch.remote(
batch,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
compute_metrics_fn=compute_metrics_fn if accumulate_metrics_fn else None,
)
futures.append(future)

# Track task
task_start_time = time.time()
active_tasks[actor_id] = {
"futures": futures,
"pipeline_id": pipeline_id,
"shard_id": shard_id,
"task_id": task_id,
"batch_count": len(batches),
"start_time": task_start_time,
}

# Update task status to in-progress
db.set_actor_task_start_time(task_id, task_start_time)
db.set_actor_task_status(task_id, TaskStatus.IN_PROGRESS)
db.set_pipeline_current_shard(pipeline_id, shard_id)
# Update task status to in-progress
db.set_actor_task_start_time(task_id, task_start_time)
db.set_actor_task_status(task_id, TaskStatus.IN_PROGRESS)
db.set_pipeline_current_shard(pipeline_id, shard_id)
except Exception as dispatch_err:
error_msg = str(dispatch_err)
self.logger.exception(
f"Pipeline {pipeline_id} ({pipeline_name}) failed during batch "
f"dispatch on actor {actor_id}: {error_msg}"
)
db.set_actor_task_status(task_id, TaskStatus.FAILED)
db.set_actor_task_error(task_id, error_msg)
db.set_pipeline_status(pipeline_id, PipelineStatus.FAILED)
db.set_pipeline_error(pipeline_id, error_msg)
active_tasks.pop(actor_id, None)
scheduler.remove_pipeline(pipeline_id)
if progress_display:
progress_display.update_pipeline(pipeline_id, status="FAILED")
continue

# PHASE 8: Compute final metrics for each pipeline (including dynamically cloned ones).
# pipeline_id_to_config contains all pipelines (originals + clones added via _handle_clone).
Expand Down
91 changes: 91 additions & 0 deletions tests/test_pipeline_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for PipelineScheduler bookkeeping.

Invariant: an actor marked busy by schedule() must be freed by either
set_completed_task(actor_id) or remove_pipeline(pipeline_id). If neither
is called, the controller's busy-loop wedges.
"""

from rapidfireai.evals.scheduling.pipeline_scheduler import PipelineScheduler


class TestPipelineSchedulerBookkeeping:
def test_schedule_marks_actor_busy(self):
scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=2, num_shards=3)
result = scheduler.schedule()
actor_id = result["actor_id"]
pipeline_id = result["pipeline_id"]
assert actor_id in (0, 1)
assert pipeline_id in (1, 2)
assert scheduler.actor_current_pipeline[actor_id] == pipeline_id

def test_set_completed_task_frees_actor_and_advances_progress(self):
scheduler = PipelineScheduler(pipeline_ids=[1], num_actors=1, num_shards=2)
result = scheduler.schedule()
actor_id = result["actor_id"]
assert scheduler.actor_current_pipeline[actor_id] == 1
assert scheduler.pipeline_shards_completed[1] == 0

scheduler.set_completed_task(actor_id)

assert scheduler.actor_current_pipeline[actor_id] == -1
assert scheduler.pipeline_shards_completed[1] == 1

def test_remove_pipeline_frees_actor_on_dispatch_failure(self):
scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=2, num_shards=2)

first = scheduler.schedule()
actor_id = first["actor_id"]
pipeline_id = first["pipeline_id"]
assert scheduler.actor_current_pipeline[actor_id] == pipeline_id

scheduler.remove_pipeline(pipeline_id)

assert scheduler.actor_current_pipeline[actor_id] == -1
assert pipeline_id not in scheduler.pipeline_ids

second = scheduler.schedule()
assert second["pipeline_id"] != -1
assert second["actor_id"] != -1

def test_dispatch_failure_does_not_wedge_remaining_pipelines(self):
scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=1, num_shards=1)

first = scheduler.schedule()
failed_pipeline = first["pipeline_id"]
actor_id = first["actor_id"]
scheduler.remove_pipeline(failed_pipeline)
assert scheduler.actor_current_pipeline[actor_id] == -1

second = scheduler.schedule()
survivor = second["pipeline_id"]
assert survivor != -1
assert survivor != failed_pipeline
scheduler.set_completed_task(second["actor_id"])

terminal = scheduler.schedule()
assert terminal["pipeline_id"] is None

def test_all_actors_busy_returns_busy_sentinel(self):
scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=2, num_shards=4)
scheduler.schedule()
scheduler.schedule()
result = scheduler.schedule()
assert result == {"pipeline_id": -1, "actor_id": -1, "shard_id": -1}

def test_set_completed_task_idempotent_on_free_actor(self):
scheduler = PipelineScheduler(pipeline_ids=[1], num_actors=1, num_shards=1)
assert scheduler.actor_current_pipeline[0] == -1
scheduler.set_completed_task(0)
assert scheduler.actor_current_pipeline[0] == -1
assert scheduler.pipeline_shards_completed[1] == 0

def test_actor_leaks_busy_when_neither_completion_nor_removal_called(self):
"""Regression: if dispatch fails and the controller forgets remove_pipeline,
schedule() returns the busy sentinel indefinitely. This is the wedge the
controller fix in run_multi_pipeline_inference prevents."""
scheduler = PipelineScheduler(pipeline_ids=[1, 2], num_actors=1, num_shards=1)
scheduler.schedule()

for _ in range(20):
result = scheduler.schedule()
assert result == {"pipeline_id": -1, "actor_id": -1, "shard_id": -1}