Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions routstr/upstream/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ def transform_model_name(self, model_id: str) -> str:
model_id = fixed_transforms[model_id]
return model_id

def transform_parameters(self, data: dict) -> dict:
"""Transform parameters for Anthropic API compatibility."""
if "reasoning" in data:
reasoning = data.pop("reasoning")
if isinstance(reasoning, dict) and "effort" in reasoning:
effort = reasoning.pop("effort")
if effort == "low":
data["thinking"] = {
"type": "enabled",
"budget_tokens": 8192,
}
elif effort == "medium":
data["thinking"] = {
"type": "enabled",
"budget_tokens": 16384,
}
elif effort == "high":
data["thinking"] = {
"type": "enabled",
"budget_tokens": 32768,
}
elif effort == "none":
data["thinking"] = {
"type": "disabled",
}

return super().transform_parameters(data)

async def fetch_models(self) -> list[Model]:
"""Fetch Anthropic models from OpenRouter API filtered by anthropic source."""
models_data = await async_fetch_openrouter_models(source_filter="anthropic")
Expand Down
61 changes: 49 additions & 12 deletions routstr/upstream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,19 +276,17 @@ def prepare_request_body(

try:
data = json.loads(body)
if isinstance(data, dict) and "model" in data:
if isinstance(data, dict):
original_model = model_obj.id
transformed_model = self.transform_model_name(original_model)
data["model"] = transformed_model
logger.debug(
"Transformed model name in request",
extra={
"original": original_model,
"transformed": transformed_model,
"provider": self.provider_type or self.base_url,
},
)
return json.dumps(data).encode()

data = self.update_parameters_from_model_name(data, original_model)

if "model" in data:
transformed_model = self.transform_model_name(original_model)
data["model"] = transformed_model

data = self.transform_parameters(data)
return json.dumps(data).encode()
except Exception as e:
logger.debug(
"Could not transform request body",
Expand All @@ -300,6 +298,43 @@ def prepare_request_body(

return body

def update_parameters_from_model_name(self, data: dict, model_id: str) -> dict:
"""Extract parameters from model name for provider-specific requirements.

Args:
data: Original request body data

Returns:
Transformed request body data
"""
if model_id.endswith(":thinking"):
model_id = model_id.removesuffix(":thinking")
data["reasoning"] = {"effort": "medium"}
if model_id.endswith("-thinking"):
model_id = model_id.removesuffix("-thinking")
data["reasoning"] = {"effort": "medium"}

return data

def transform_parameters(self, data: dict) -> dict:
"""Transform parameters for provider-specific requirements.

Args:
data: Original request body data

Returns:
Transformed request body data
"""
# generic input to messages transformation
if (
"input" in data
and isinstance(data["input"], list)
and isinstance(data["input"][0], dict)
and "role" in data["input"][0]
):
data["messages"] = data.pop("input")
return data

def _extract_upstream_error_message(
self, body_bytes: bytes
) -> tuple[str, str | None]:
Expand Down Expand Up @@ -1094,7 +1129,9 @@ async def forward_request(

url = f"{self.base_url}/{path}"

print(f"request_body: {request_body[:100]!r}")
transformed_body = self.prepare_request_body(request_body, model_obj)
print(f"transformed_body: {transformed_body[:100]!r}")

logger.info(
"Forwarding request to upstream",
Expand Down
6 changes: 6 additions & 0 deletions routstr/upstream/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def get_provider_metadata(cls) -> dict[str, object]:
"platform_url": cls.platform_url,
}

def transform_parameters(self, data: dict) -> dict:
"""Transform parameters for OpenAI API compatibility."""
if "max_tokens" in data:
data["max_completion_tokens"] = data.pop("max_tokens")
return super().transform_parameters(data)

def transform_model_name(self, model_id: str) -> str:
"""Strip 'openai/' prefix for OpenAI API compatibility."""
return model_id.removeprefix("openai/")
Expand Down
Loading