Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 44 additions & 36 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,18 +433,8 @@ async def a_generate(

results_dict["prompt"] = [cleanup_messages(p) for p in results_dict["prompt"]]

# prepare GenerateOutputs and run rollouts
results = GenerateOutputs(
prompt=results_dict["prompt"],
answer=results_dict["answer"],
task=results_dict["task"],
info=results_dict["info"],
completion=[],
state=[],
reward=[],
metrics={},
)
n = len(results.prompt)
# run rollouts
n = len(results_dict["prompt"])

# Resolve concurrency knobs
gen_limit = max_concurrent_generation
Expand Down Expand Up @@ -475,10 +465,10 @@ async def a_generate(
)

async def run_one(i: int) -> None:
prompt_i = results.prompt[i]
answer_i = results.answer[i]
task_i = results.task[i]
info_i = results.info[i]
prompt_i = results_dict["prompt"][i]
answer_i = results_dict["answer"][i]
task_i = results_dict["task"][i]
info_i = results_dict["info"][i]
# Generation stage
if gen_semaphore is not None:
async with gen_semaphore:
Expand Down Expand Up @@ -541,42 +531,60 @@ async def run_one(i: int) -> None:
*tasks, total=n, desc=f"Running {n} rollouts (interleaved)"
)

results.completion = results_completion # type: ignore[assignment]
results.state = results_state # type: ignore[assignment]
results.reward = rewards
results.metrics = metrics
return results
results_dict["completion"] = results_completion # type: ignore[assignment]
results_dict["state"] = results_state # type: ignore[assignment]
results_dict["reward"] = rewards
results_dict["metrics"] = metrics
return GenerateOutputs(
prompt=results_dict["prompt"],
answer=results_dict["answer"],
task=results_dict["task"],
info=results_dict["info"],
completion=results_dict.get("completion", []),
state=results_dict.get("state", []),
reward=results_dict.get("reward", []),
metrics=results_dict.get("metrics", {}),
)
else:
# Non-interleaved: generate all then score all
rollouts = await self.run_rollouts(
prompts=results.prompt,
answers=results.answer,
tasks=results.task,
infos=results.info,
prompts=results_dict["prompt"],
answers=results_dict["answer"],
tasks=results_dict["task"],
infos=results_dict["info"],
client=client,
model=model,
sampling_args=gen_sampling_args,
max_concurrent=gen_limit if gen_limit is not None else max_concurrent,
**kwargs,
)
results.completion = [rollout[0] for rollout in rollouts]
results.state = [rollout[1] for rollout in rollouts]
results_dict["completion"] = [rollout[0] for rollout in rollouts]
results_dict["state"] = [rollout[1] for rollout in rollouts]
if score_rollouts:
rollout_scores = await self.rubric.score_rollouts(
prompts=results.prompt,
completions=results.completion,
answers=results.answer,
states=results.state,
tasks=results.task,
infos=results.info,
prompts=results_dict["prompt"],
completions=results_dict["completion"],
answers=results_dict["answer"],
states=results_dict["state"],
tasks=results_dict["task"],
infos=results_dict["info"],
max_concurrent=score_limit
if score_limit is not None
else max_concurrent,
apply_weights=True,
)
results.reward = rollout_scores.reward
results.metrics = rollout_scores.metrics
return results
results_dict["reward"] = rollout_scores.reward
results_dict["metrics"] = rollout_scores.metrics
return GenerateOutputs(
prompt=results_dict["prompt"],
answer=results_dict["answer"],
task=results_dict["task"],
info=results_dict["info"],
completion=results_dict.get("completion", []),
state=results_dict.get("state", []),
reward=results_dict.get("reward", []),
metrics=results_dict.get("metrics", {}),
)

def generate(
self,
Expand Down
5 changes: 3 additions & 2 deletions verifiers/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import (
Annotated,
Any,
Awaitable,
Callable,
Expand All @@ -22,7 +23,7 @@
FunctionDefinition,
FunctionParameters,
)
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SkipValidation

# typing aliases
ChatMessage = ChatCompletionMessageParam
Expand Down Expand Up @@ -57,7 +58,7 @@ class GenerateOutputs(BaseModel):
prompt: list[Messages]
completion: list[Messages]
answer: list[str]
state: list[State]
state: Annotated[list[State], SkipValidation]
info: list[Info]
task: list[str]
reward: list[float]
Expand Down
Loading