Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def ruler(
judge_model: str = "openai/o3",
extra_litellm_params: dict | None = None,
rubric: str = DEFAULT_RUBRIC,
context_window: int | None = None,
*,
debug: bool = False,
) -> list[TrajectoryScore]:
Expand All @@ -80,6 +81,9 @@ async def ruler(
- "anthropic/claude-3-opus-20240229" - Alternative judge
extra_litellm_params: Additional parameters to pass to LiteLLM completion.
Can include temperature, max_tokens, etc.
context_window: Optional context window override (e.g., Ollama `num_ctx`).
If provided, it sets litellm `num_ctx` and `max_input_tokens` unless
already supplied in extra_litellm_params.
rubric: The grading rubric. The default rubric works well for most tasks.
debug: If True, pretty-print the judge's reasoning to help understand scores.

Expand Down Expand Up @@ -172,12 +176,17 @@ async def ruler(
{"role": "user", "content": user_text},
]

litellm_params = dict(extra_litellm_params) if extra_litellm_params else {}
if context_window is not None:
litellm_params.setdefault("num_ctx", context_window)
litellm_params.setdefault("max_input_tokens", context_window)

response = await acompletion(
model=judge_model,
messages=messages,
response_format=Response,
caching=False,
**extra_litellm_params if extra_litellm_params else {},
**litellm_params,
)
assert isinstance(response, ModelResponse)

Expand Down Expand Up @@ -222,6 +231,7 @@ async def ruler_score_group(
judge_model: str = "openai/o3",
extra_litellm_params: dict | None = None,
rubric: str = DEFAULT_RUBRIC,
context_window: int | None = None,
*,
swallow_exceptions: bool = False,
debug: bool = False,
Expand All @@ -242,6 +252,8 @@ async def ruler_score_group(
group: A TrajectoryGroup containing trajectories to score.
judge_model: The model to use for judging. See `ruler` for options.
extra_litellm_params: Additional parameters to pass to LiteLLM completion.
context_window: Optional context window override (e.g., Ollama `num_ctx`).
Sets litellm `num_ctx`/`max_input_tokens` if not already set.
rubric: Custom rubric or use the default which works well for most tasks.
swallow_exceptions: If True, returns None on errors instead of raising.
This is recommended for production to handle API failures gracefully.
Expand Down Expand Up @@ -298,6 +310,7 @@ async def ruler_score_group(
message_lists,
judge_model=judge_model,
extra_litellm_params=extra_litellm_params,
context_window=context_window,
rubric=rubric,
debug=debug,
)
Expand Down