From 6b9917c0bd167877204672ffd9d4c63327313382 Mon Sep 17 00:00:00 2001 From: Shroominic Date: Wed, 24 Dec 2025 12:08:39 +0100 Subject: [PATCH] max cost filtered models --- routstr/payment/models.py | 85 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 3 deletions(-) diff --git a/routstr/payment/models.py b/routstr/payment/models.py index d6eccf5d..24cb0c1f 100644 --- a/routstr/payment/models.py +++ b/routstr/payment/models.py @@ -3,14 +3,15 @@ import random import httpx -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from pydantic.v1 import BaseModel from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -from ..core.db import ModelRow, create_session, get_session +from ..core.db import ApiKey, ModelRow, create_session, get_session from ..core.logging import get_logger from ..core.settings import settings +from ..wallet import deserialize_token_from_string from .price import sats_usd_price logger = get_logger(__name__) @@ -521,11 +522,89 @@ def _pricing_matches( return True + +async def _get_request_balance(request: Request, session: AsyncSession) -> int | None: + """Get the balance from the request headers if authentication is provided.""" + headers = request.headers + token: str | None = None + + if x_cashu := headers.get("x-cashu"): + token = x_cashu + elif auth := headers.get("authorization"): + parts = auth.split(" ") + if len(parts) > 1: + token = parts[1] + + if not token: + return None + + # Handle API keys (sk-*) + if token.startswith("sk-"): + try: + # sk- keys use the part after "sk-" as the ID + key_id = token[3:] + key = await session.get(ApiKey, key_id) + if key: + return key.balance - key.reserved_balance + except Exception as e: + logger.warning(f"Error checking API key balance: {e}") + return None + + # Handle Cashu tokens + try: + token_obj = deserialize_token_from_string(token) + amount_msat = ( + token_obj.amount if token_obj.unit == "msat" else token_obj.amount * 1000 + ) + return amount_msat + except Exception as e: + logger.debug(f"Failed to deserialize cashu token for balance check: {e}") + return None + + @models_router.get("/v1/models") @models_router.get("/models", include_in_schema=False) -async def models(session: AsyncSession = Depends(get_session)) -> dict: +async def models( + request: Request, session: AsyncSession = Depends(get_session) +) -> dict: """Get all available models from all providers with database overrides applied.""" from ..proxy import get_unique_models items = get_unique_models() + + # Optional: Filter by user balance if authenticated + user_balance = await _get_request_balance(request, session) + if user_balance is not None: + filtered_items = [] + + # Calculate tolerance factor once + tol_factor = 1.0 + if not settings.fixed_pricing and settings.tolerance_percentage > 0: + tol_factor = 1.0 - (settings.tolerance_percentage / 100.0) + + fixed_cost = settings.fixed_cost_per_request * 1000 + min_cost = settings.min_request_msat + + for model in items: + model_cost_msats = 0 + + if settings.fixed_pricing: + model_cost_msats = max(min_cost, fixed_cost) + elif model.sats_pricing: + # Use model specific pricing + max_cost = ( + model.sats_pricing.max_cost + * 1000 + * tol_factor + ) + model_cost_msats = max(min_cost, int(max_cost)) + else: + # Fallback if no pricing found + model_cost_msats = max(min_cost, fixed_cost) + + if model_cost_msats <= user_balance: + filtered_items.append(model) + + items = filtered_items + return {"data": items}