Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions routstr/payment/cost_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..core import get_logger
from ..core.db import AsyncSession
from ..core.settings import settings
from .price import sats_usd_price

logger = get_logger(__name__)

Expand Down Expand Up @@ -64,12 +65,63 @@ async def calculate_cost( # todo: can be sync
)
return cost_data

usage_data = response_data["usage"]

usd_cost = 0.0

# Prioritize cost_details.upstream_inference_cost
if "cost_details" in usage_data:
usd_cost = float(
usage_data["cost_details"].get("upstream_inference_cost", 0) or 0
)

# Fallback to cost field if upstream_inference_cost is 0
if usd_cost == 0 and "cost" in usage_data:
try:
usd_cost = float(usage_data.get("cost", 0) or 0)
except Exception:
pass

if usd_cost > 0:
try:
sats_per_usd = 1.0 / sats_usd_price()
cost_in_sats = usd_cost * sats_per_usd
cost_in_msats = math.ceil(cost_in_sats * 1000)

logger.info(
"Using cost from usage data/details",
extra={
"usd_cost": usd_cost,
"cost_in_sats": cost_in_sats,
"cost_in_msats": cost_in_msats,
"model": response_data.get("model", "unknown"),
},
)

return CostData(
base_msats=-1,
input_msats=-1, # Cost field doesn't break down by token type
output_msats=-1,
total_msats=cost_in_msats,
)
except Exception as e:
logger.warning(
"Error calculating cost from usage data",
extra={
"error": str(e),
"usd_cost": usd_cost,
"model": response_data.get("model", "unknown"),
},
)
# Fall through to token-based calculation

MSATS_PER_1K_INPUT_TOKENS: float = (
float(settings.fixed_per_1k_input_tokens) * 1000.0
)
MSATS_PER_1K_OUTPUT_TOKENS: float = (
float(settings.fixed_per_1k_output_tokens) * 1000.0
)
MSATS_PER_1K_IMAGE_COMPLETION_TOKENS: float = 0.0

if not settings.fixed_pricing:
response_model = response_data.get("model", "")
Expand Down Expand Up @@ -104,18 +156,21 @@ async def calculate_cost( # todo: can be sync
try:
mspp = float(model_obj.sats_pricing.prompt)
mspc = float(model_obj.sats_pricing.completion)
mspci = float(getattr(model_obj.sats_pricing, "completion_image", 0.0))
except Exception:
return CostDataError(message="Invalid pricing data", code="pricing_invalid")

MSATS_PER_1K_INPUT_TOKENS = mspp * 1_000_000.0
MSATS_PER_1K_OUTPUT_TOKENS = mspc * 1_000_000.0
MSATS_PER_1K_IMAGE_COMPLETION_TOKENS = mspci * 1_000_000.0

logger.info(
"Applied model-specific pricing",
extra={
"model": response_model,
"input_price_msats_per_1k": MSATS_PER_1K_INPUT_TOKENS,
"output_price_msats_per_1k": MSATS_PER_1K_OUTPUT_TOKENS,
"image_completion_price_msats_per_1k": MSATS_PER_1K_IMAGE_COMPLETION_TOKENS,
},
)

Expand All @@ -129,12 +184,43 @@ async def calculate_cost( # todo: can be sync
)
return cost_data

input_tokens = response_data.get("usage", {}).get("prompt_tokens", 0)
output_tokens = response_data.get("usage", {}).get("completion_tokens", 0)
input_tokens = usage_data.get("prompt_tokens", 0)
output_tokens = usage_data.get("completion_tokens", 0)

# added for response api
input_tokens = (
input_tokens if input_tokens != 0 else usage_data.get("input_tokens", 0)
)
output_tokens = (
output_tokens if output_tokens != 0 else usage_data.get("output_tokens", 0)
)

# Calculate image completion cost
image_completion_msats = 0.0
if MSATS_PER_1K_IMAGE_COMPLETION_TOKENS > 0:
completion_details = usage_data.get("completion_tokens_details", {})
image_tokens = completion_details.get("image_tokens", 0)

if image_tokens > 0:
if output_tokens >= image_tokens:
output_tokens -= image_tokens

image_completion_msats = round(
image_tokens / 1000 * MSATS_PER_1K_IMAGE_COMPLETION_TOKENS, 3
)

logger.info(
"Calculated image completion cost",
extra={
"image_tokens": image_tokens,
"image_completion_msats": image_completion_msats,
},
)

input_msats = round(input_tokens / 1000 * MSATS_PER_1K_INPUT_TOKENS, 3)

output_msats = round(output_tokens / 1000 * MSATS_PER_1K_OUTPUT_TOKENS, 3)
token_based_cost = math.ceil(input_msats + output_msats)
token_based_cost = math.ceil(input_msats + output_msats + image_completion_msats)

logger.info(
"Calculated token-based cost",
Expand All @@ -143,6 +229,7 @@ async def calculate_cost( # todo: can be sync
"output_tokens": output_tokens,
"input_cost_msats": input_msats,
"output_cost_msats": output_msats,
"image_completion_msats": image_completion_msats,
"total_cost_msats": token_based_cost,
"model": response_data.get("model", "unknown"),
},
Expand Down
25 changes: 25 additions & 0 deletions routstr/payment/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Pricing(BaseModel):
completion: float
request: float = 0.0
image: float = 0.0
completion_image: float = 0.0
Comment thread
shroominic marked this conversation as resolved.
Outdated
web_search: float = 0.0
internal_reasoning: float = 0.0
input_cache_read: float = 0.0
Expand All @@ -40,6 +41,13 @@ class Pricing(BaseModel):
max_cost: float = 0.0 # in sats not msats


PRICING_OVERRIDES = {
Comment thread
shroominic marked this conversation as resolved.
Outdated
"gemini-3-pro-image-preview": {"completion_image": 0.00012},
"gemini-2.5-flash-image": {"completion_image": 0.00003},
"gemini-2.0-flash": {"completion_image": 0.00003},
}


class TopProvider(BaseModel):
context_length: int | None = None
max_completion_tokens: int | None = None
Expand Down Expand Up @@ -116,6 +124,16 @@ async def async_fetch_openrouter_models(source_filter: str | None = None) -> lis
if not _has_valid_pricing(model):
continue

# Apply manual pricing overrides
if model_id in PRICING_OVERRIDES:
pricing = model.get("pricing", {})
if pricing:
for k, v in PRICING_OVERRIDES[model_id].items():
pricing[k] = str(
v
) # OpenRouter API returns strings for pricing
model["pricing"] = pricing

models_data.append(model)

return models_data
Expand Down Expand Up @@ -148,6 +166,12 @@ def _row_to_model(
if isinstance(pricing, dict) and float(pricing.get("request", 0.0)) <= 0.0:
pricing["request"] = max(pricing.get("request", 0.0), 0.0)

# Apply defaults for missing fields from manual overrides
if row.id in PRICING_OVERRIDES and isinstance(pricing, dict):
for k, v in PRICING_OVERRIDES[row.id].items():
if k not in pricing:
pricing[k] = v

parsed_pricing = Pricing.parse_obj(pricing)
model = Model(
id=row.id,
Expand Down Expand Up @@ -507,6 +531,7 @@ def _pricing_matches(
"completion",
"request",
"image",
"completion_image",
"web_search",
"internal_reasoning",
]
Expand Down
Loading