diff --git a/routstr/upstream/anthropic.py b/routstr/upstream/anthropic.py index 3f228e9c..c6fd631a 100644 --- a/routstr/upstream/anthropic.py +++ b/routstr/upstream/anthropic.py @@ -61,6 +61,34 @@ def transform_model_name(self, model_id: str) -> str: model_id = fixed_transforms[model_id] return model_id + def transform_parameters(self, data: dict) -> dict: + """Transform parameters for Anthropic API compatibility.""" + if "reasoning" in data: + reasoning = data.pop("reasoning") + if isinstance(reasoning, dict) and "effort" in reasoning: + effort = reasoning.pop("effort") + if effort == "low": + data["thinking"] = { + "type": "enabled", + "budget_tokens": 8192, + } + elif effort == "medium": + data["thinking"] = { + "type": "enabled", + "budget_tokens": 16384, + } + elif effort == "high": + data["thinking"] = { + "type": "enabled", + "budget_tokens": 32768, + } + elif effort == "none": + data["thinking"] = { + "type": "disabled", + } + + return super().transform_parameters(data) + async def fetch_models(self) -> list[Model]: """Fetch Anthropic models from OpenRouter API filtered by anthropic source.""" models_data = await async_fetch_openrouter_models(source_filter="anthropic") diff --git a/routstr/upstream/base.py b/routstr/upstream/base.py index 1c42da8c..dd26bfad 100644 --- a/routstr/upstream/base.py +++ b/routstr/upstream/base.py @@ -276,19 +276,17 @@ def prepare_request_body( try: data = json.loads(body) - if isinstance(data, dict) and "model" in data: + if isinstance(data, dict): original_model = model_obj.id - transformed_model = self.transform_model_name(original_model) - data["model"] = transformed_model - logger.debug( - "Transformed model name in request", - extra={ - "original": original_model, - "transformed": transformed_model, - "provider": self.provider_type or self.base_url, - }, - ) - return json.dumps(data).encode() + + data = self.update_parameters_from_model_name(data, original_model) + + if "model" in data: + transformed_model = self.transform_model_name(original_model) + data["model"] = transformed_model + + data = self.transform_parameters(data) + return json.dumps(data).encode() except Exception as e: logger.debug( "Could not transform request body", @@ -300,6 +298,43 @@ def prepare_request_body( return body + def update_parameters_from_model_name(self, data: dict, model_id: str) -> dict: + """Extract parameters from model name for provider-specific requirements. + + Args: + data: Original request body data + + Returns: + Transformed request body data + """ + if model_id.endswith(":thinking"): + model_id = model_id.removesuffix(":thinking") + data["reasoning"] = {"effort": "medium"} + if model_id.endswith("-thinking"): + model_id = model_id.removesuffix("-thinking") + data["reasoning"] = {"effort": "medium"} + + return data + + def transform_parameters(self, data: dict) -> dict: + """Transform parameters for provider-specific requirements. + + Args: + data: Original request body data + + Returns: + Transformed request body data + """ + # generic input to messages transformation + if ( + "input" in data + and isinstance(data["input"], list) + and isinstance(data["input"][0], dict) + and "role" in data["input"][0] + ): + data["messages"] = data.pop("input") + return data + def _extract_upstream_error_message( self, body_bytes: bytes ) -> tuple[str, str | None]: @@ -1094,7 +1129,9 @@ async def forward_request( url = f"{self.base_url}/{path}" + print(f"request_body: {request_body[:100]!r}") transformed_body = self.prepare_request_body(request_body, model_obj) + print(f"transformed_body: {transformed_body[:100]!r}") logger.info( "Forwarding request to upstream", diff --git a/routstr/upstream/openai.py b/routstr/upstream/openai.py index 11cc4336..d4fa9bfb 100644 --- a/routstr/upstream/openai.py +++ b/routstr/upstream/openai.py @@ -38,6 +38,12 @@ def get_provider_metadata(cls) -> dict[str, object]: "platform_url": cls.platform_url, } + def transform_parameters(self, data: dict) -> dict: + """Transform parameters for OpenAI API compatibility.""" + if "max_tokens" in data: + data["max_completion_tokens"] = data.pop("max_tokens") + return super().transform_parameters(data) + def transform_model_name(self, model_id: str) -> str: """Strip 'openai/' prefix for OpenAI API compatibility.""" return model_id.removeprefix("openai/")