From 0a38b92003526268c6495ee0f1727c8353651732 Mon Sep 17 00:00:00 2001 From: Htoo Pyae Date: Sun, 23 Nov 2025 19:19:08 +0700 Subject: [PATCH] feat(users): add token usage in user response --- src/any_llm/gateway/routes/users.py | 77 +++++++++++++++++++++++++++++ tests/gateway/conftest.py | 25 +++++++++- tests/gateway/test_users.py | 40 +++++++++++++++ 3 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 tests/gateway/test_users.py diff --git a/src/any_llm/gateway/routes/users.py b/src/any_llm/gateway/routes/users.py index 619f9e45..2fcafe53 100644 --- a/src/any_llm/gateway/routes/users.py +++ b/src/any_llm/gateway/routes/users.py @@ -1,8 +1,10 @@ +from dataclasses import dataclass from datetime import UTC, datetime from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, Field +from sqlalchemy import func from sqlalchemy.orm import Session from any_llm.gateway.auth import verify_master_key @@ -35,6 +37,9 @@ class UserResponse(BaseModel): created_at: str updated_at: str metadata: dict[str, Any] + total_input_tokens: int + total_output_tokens: int + total_tokens: int class UpdateUserRequest(BaseModel): @@ -64,6 +69,57 @@ class UsageLogResponse(BaseModel): error_message: str | None +@dataclass +class TokenUsage: + input: int + output: int + total: int + + +def _get_token_usage_by_user_id(db: Session, user_id: str) -> TokenUsage: + """Fetches token usage for a single user.""" + token_usage_result = ( + db.query( + func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("input"), + func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("output"), + func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total"), + ) + .filter(UsageLog.user_id == user_id) + .first() + ) + return TokenUsage( + input=int(token_usage_result.input), + output=int(token_usage_result.output), + total=int(token_usage_result.total), + ) + + +def _get_token_usage_by_user_ids(db: Session, user_ids: list[str]) -> dict[str, TokenUsage]: + """Fetches token usage for multiple users.""" + token_usage_results = ( + db.query( + UsageLog.user_id, + func.coalesce(func.sum(UsageLog.prompt_tokens), 0).label("input"), + func.coalesce(func.sum(UsageLog.completion_tokens), 0).label("output"), + func.coalesce(func.sum(UsageLog.total_tokens), 0).label("total"), + ) + .filter(UsageLog.user_id.in_(user_ids)) + .group_by(UsageLog.user_id) + .all() + ) + + token_usage_map = { + result.user_id: TokenUsage( + input=int(result.input), + output=int(result.output), + total=int(result.total), + ) + for result in token_usage_results + } + + return token_usage_map + + @router.post("", dependencies=[Depends(verify_master_key)]) async def create_user( request: CreateUserRequest, @@ -113,6 +169,9 @@ async def create_user( created_at=user.created_at.isoformat(), updated_at=user.updated_at.isoformat(), metadata=dict(user.metadata_) if user.metadata_ else {}, + total_input_tokens=0, + total_output_tokens=0, + total_tokens=0, ) @@ -125,6 +184,11 @@ async def list_users( """List all users with pagination.""" users = db.query(User).offset(skip).limit(limit).all() + user_ids = [user.user_id for user in users] + token_usage_map = _get_token_usage_by_user_ids(db, user_ids) + + default_token_usage = TokenUsage(0, 0, 0) + return [ UserResponse( user_id=user.user_id, @@ -137,6 +201,9 @@ async def list_users( created_at=user.created_at.isoformat(), updated_at=user.updated_at.isoformat(), metadata=dict(user.metadata_) if user.metadata_ else {}, + total_input_tokens=(token_usage := token_usage_map.get(user.user_id, default_token_usage)).input, + total_output_tokens=token_usage.output, + total_tokens=token_usage.total, ) for user in users ] @@ -156,6 +223,8 @@ async def get_user( detail=f"User with id '{user_id}' not found", ) + token_usage = _get_token_usage_by_user_id(db, user_id) + return UserResponse( user_id=user.user_id, alias=user.alias, @@ -167,6 +236,9 @@ async def get_user( created_at=user.created_at.isoformat(), updated_at=user.updated_at.isoformat(), metadata=dict(user.metadata_) if user.metadata_ else {}, + total_input_tokens=token_usage.input, + total_output_tokens=token_usage.output, + total_tokens=token_usage.total, ) @@ -210,6 +282,8 @@ async def update_user( db.commit() db.refresh(user) + token_usage = _get_token_usage_by_user_id(db, user.user_id) + return UserResponse( user_id=user.user_id, alias=user.alias, @@ -221,6 +295,9 @@ async def update_user( created_at=user.created_at.isoformat(), updated_at=user.updated_at.isoformat(), metadata=dict(user.metadata_) if user.metadata_ else {}, + total_input_tokens=token_usage.input, + total_output_tokens=token_usage.output, + total_tokens=token_usage.total, ) diff --git a/tests/gateway/conftest.py b/tests/gateway/conftest.py index 6d186c61..5e1b6bca 100644 --- a/tests/gateway/conftest.py +++ b/tests/gateway/conftest.py @@ -12,7 +12,7 @@ from testcontainers.postgres import PostgresContainer from any_llm.gateway.config import API_KEY_HEADER, GatewayConfig -from any_llm.gateway.db import Base, get_db +from any_llm.gateway.db import Base, get_db, User, UsageLog from any_llm.gateway.server import create_app MODEL_NAME = "gemini:gemini-2.5-flash" @@ -169,3 +169,26 @@ def model_pricing(client: TestClient, master_key_header: dict[str, str]) -> dict assert response.status_code == 200 result: dict[str, Any] = response.json() return result + + +@pytest.fixture +def user_factory(test_db: Session): + """Factory for creating users.""" + def _user_factory(**kwargs): + user = User(**kwargs) + test_db.add(user) + test_db.commit() + test_db.refresh(user) + return user + return _user_factory + +@pytest.fixture +def usage_log_factory(test_db: Session): + """Factory for creating usage logs.""" + def _usage_log_factory(**kwargs): + log = UsageLog(**kwargs) + test_db.add(log) + test_db.commit() + test_db.refresh(log) + return log + return _usage_log_factory diff --git a/tests/gateway/test_users.py b/tests/gateway/test_users.py new file mode 100644 index 00000000..68014a09 --- /dev/null +++ b/tests/gateway/test_users.py @@ -0,0 +1,40 @@ +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +def test_get_user_with_token_usage(client: TestClient, db: Session, user_factory, usage_log_factory): + """Test retrieving a user includes token usage.""" + user = user_factory(user_id="test_user_with_usage") + usage_log_factory( + user_id=user.user_id, + prompt_tokens=100, + completion_tokens=200, + total_tokens=300, + ) + usage_log_factory( + user_id=user.user_id, + prompt_tokens=50, + completion_tokens=150, + total_tokens=200, + ) + + response = client.get(f"/v1/users/{user.user_id}") + + assert response.status_code == 200 + + user_data = response.json() + assert user_data["total_input_tokens"] == 150 + assert user_data["total_output_tokens"] == 350 + assert user_data["total_tokens"] == 500 + +def test_get_user_without_token_usage(client: TestClient, db: Session, user_factory): + """Test retrieving a user with no token usage returns zero values.""" + user = user_factory(user_id="test_user_no_usage") + + response = client.get(f"/v1/users/{user.user_id}") + + assert response.status_code == 200 + + user_data = response.json() + assert user_data["total_input_tokens"] == 0 + assert user_data["total_output_tokens"] == 0 + assert user_data["total_tokens"] == 0