diff --git a/routstr/payment/cost_calculation.py b/routstr/payment/cost_calculation.py index f8eb4ffb..03b37362 100644 --- a/routstr/payment/cost_calculation.py +++ b/routstr/payment/cost_calculation.py @@ -5,6 +5,7 @@ from ..core import get_logger from ..core.db import AsyncSession from ..core.settings import settings +from .price import sats_usd_price logger = get_logger(__name__) @@ -64,6 +65,56 @@ async def calculate_cost( # todo: can be sync ) return cost_data + usage_data = response_data["usage"] + + usd_cost = 0.0 + + # Prioritize cost_details.upstream_inference_cost + if "cost_details" in usage_data: + usd_cost = float( + usage_data["cost_details"].get("upstream_inference_cost", 0) or 0 + ) + + # Fallback to cost field if upstream_inference_cost is 0 + if usd_cost == 0 and "cost" in usage_data: + try: + usd_cost = float(usage_data.get("cost", 0) or 0) + except Exception: + pass + + if usd_cost > 0: + try: + sats_per_usd = 1.0 / sats_usd_price() + cost_in_sats = usd_cost * sats_per_usd + cost_in_msats = math.ceil(cost_in_sats * 1000) + + logger.info( + "Using cost from usage data/details", + extra={ + "usd_cost": usd_cost, + "cost_in_sats": cost_in_sats, + "cost_in_msats": cost_in_msats, + "model": response_data.get("model", "unknown"), + }, + ) + + return CostData( + base_msats=-1, + input_msats=-1, # Cost field doesn't break down by token type + output_msats=-1, + total_msats=cost_in_msats, + ) + except Exception as e: + logger.warning( + "Error calculating cost from usage data", + extra={ + "error": str(e), + "usd_cost": usd_cost, + "model": response_data.get("model", "unknown"), + }, + ) + # Fall through to token-based calculation + MSATS_PER_1K_INPUT_TOKENS: float = ( float(settings.fixed_per_1k_input_tokens) * 1000.0 ) @@ -129,10 +180,19 @@ async def calculate_cost( # todo: can be sync ) return cost_data - input_tokens = response_data.get("usage", {}).get("prompt_tokens", 0) - output_tokens = response_data.get("usage", {}).get("completion_tokens", 0) + input_tokens = usage_data.get("prompt_tokens", 0) + output_tokens = usage_data.get("completion_tokens", 0) + + # added for response api + input_tokens = ( + input_tokens if input_tokens != 0 else usage_data.get("input_tokens", 0) + ) + output_tokens = ( + output_tokens if output_tokens != 0 else usage_data.get("output_tokens", 0) + ) input_msats = round(input_tokens / 1000 * MSATS_PER_1K_INPUT_TOKENS, 3) + output_msats = round(output_tokens / 1000 * MSATS_PER_1K_OUTPUT_TOKENS, 3) token_based_cost = math.ceil(input_msats + output_msats)