From f0f62d4a6c86d32b66a5feca81f8835dedfdd6e1 Mon Sep 17 00:00:00 2001 From: Ansh-info Date: Thu, 11 Dec 2025 08:55:07 +0100 Subject: [PATCH] feat: Implemented an explicit, configurable context-window override for RULER to avoid the implicit 8k cap when using Ollama or other long-context backends Co-authored-by: Apoorva Gupta --- src/art/rewards/ruler.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/art/rewards/ruler.py b/src/art/rewards/ruler.py index 2ea33312..e5856857 100644 --- a/src/art/rewards/ruler.py +++ b/src/art/rewards/ruler.py @@ -55,6 +55,7 @@ async def ruler( judge_model: str = "openai/o3", extra_litellm_params: dict | None = None, rubric: str = DEFAULT_RUBRIC, + context_window: int | None = None, *, debug: bool = False, ) -> list[TrajectoryScore]: @@ -80,6 +81,9 @@ async def ruler( - "anthropic/claude-3-opus-20240229" - Alternative judge extra_litellm_params: Additional parameters to pass to LiteLLM completion. Can include temperature, max_tokens, etc. + context_window: Optional context window override (e.g., Ollama `num_ctx`). + If provided, it sets litellm `num_ctx` and `max_input_tokens` unless + already supplied in extra_litellm_params. rubric: The grading rubric. The default rubric works well for most tasks. debug: If True, pretty-print the judge's reasoning to help understand scores. @@ -172,12 +176,17 @@ async def ruler( {"role": "user", "content": user_text}, ] + litellm_params = dict(extra_litellm_params) if extra_litellm_params else {} + if context_window is not None: + litellm_params.setdefault("num_ctx", context_window) + litellm_params.setdefault("max_input_tokens", context_window) + response = await acompletion( model=judge_model, messages=messages, response_format=Response, caching=False, - **extra_litellm_params if extra_litellm_params else {}, + **litellm_params, ) assert isinstance(response, ModelResponse) @@ -222,6 +231,7 @@ async def ruler_score_group( judge_model: str = "openai/o3", extra_litellm_params: dict | None = None, rubric: str = DEFAULT_RUBRIC, + context_window: int | None = None, *, swallow_exceptions: bool = False, debug: bool = False, @@ -242,6 +252,8 @@ async def ruler_score_group( group: A TrajectoryGroup containing trajectories to score. judge_model: The model to use for judging. See `ruler` for options. extra_litellm_params: Additional parameters to pass to LiteLLM completion. + context_window: Optional context window override (e.g., Ollama `num_ctx`). + Sets litellm `num_ctx`/`max_input_tokens` if not already set. rubric: Custom rubric or use the default which works well for most tasks. swallow_exceptions: If True, returns None on errors instead of raising. This is recommended for production to handle API failures gracefully. @@ -298,6 +310,7 @@ async def ruler_score_group( message_lists, judge_model=judge_model, extra_litellm_params=extra_litellm_params, + context_window=context_window, rubric=rubric, debug=debug, )