Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions docs/index.html

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions src/mlpa/core/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class UserUpdatePayload(BaseModel):
blocked: bool | None = None


class BudgetUpdatePayload(BaseModel):
"""Payload for updating a user's budget tier."""

service_type: str


# iOS App Attest
class ChallengeResponse(BaseModel):
challenge: str
Expand Down
27 changes: 27 additions & 0 deletions src/mlpa/core/pg_services/litellm_pg_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@ async def get_user(self, user_id: str):
)
return dict(user) if user else None

async def update_user_budget(self, user_id: str, budget_id: str) -> dict:
"""Update a user's budget by linking them to a different budget tier."""
try:
async with self.pg.acquire() as conn:
async with conn.transaction():
updated_user_record = await conn.fetchrow(
'UPDATE "LiteLLM_EndUserTable" SET "budget_id" = $1 WHERE user_id = $2 RETURNING *',
budget_id,
user_id,
)

if updated_user_record is None:
logger.error(f"User {user_id} not found for budget update.")
raise HTTPException(status_code=404, detail="User not found.")

logger.info(
f"User {user_id} budget updated to {budget_id} successfully."
)
return dict(updated_user_record)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error updating budget for user {user_id}: {e}")
raise HTTPException(
status_code=500, detail={"error": "Error updating user budget"}
)

async def block_user(self, user_id: str, blocked: bool = True) -> dict:
try:
async with self.pg.acquire() as conn:
Expand Down
27 changes: 27 additions & 0 deletions src/mlpa/core/routers/user/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Query

from mlpa.core.classes import BudgetUpdatePayload
from mlpa.core.config import LITELLM_MASTER_AUTH_HEADERS, env
from mlpa.core.http_client import get_http_client
from mlpa.core.logger import logger
Expand Down Expand Up @@ -59,6 +60,32 @@ async def user_info(user_id: str):
return user


@router.post("/{user_id}/budget", tags=["User Management"])
async def update_user_budget(
user_id: str,
payload: BudgetUpdatePayload,
_: Annotated[None, Depends(require_master_key)] = None,
):
"""Update a user's budget tier by service type (e.g. ai-dev for higher limits)."""
if not user_id or user_id.strip() == "":
raise HTTPException(status_code=404, detail="User not found")
if payload.service_type not in env.valid_service_types:
raise HTTPException(
status_code=422,
detail={
"error": f"Unknown service type: {payload.service_type}. "
f"Valid values: {', '.join(env.valid_service_types)}"
},
)
budget_id = env.user_feature_budget[payload.service_type]["budget_id"]
user = await litellm_pg.update_user_budget(user_id, budget_id)
return {
"user_id": user["user_id"],
"budget_id": user["budget_id"],
"service_type": payload.service_type,
}


@router.post("/{user_id}/block", tags=["User Management"])
async def block_user(
user_id: str,
Expand Down
2 changes: 1 addition & 1 deletion src/mlpa/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
{"name": "Mock", "description": "Mock endpoints for testing purposes."},
{
"name": "User Management",
"description": "Endpoints for managing user blocking status.",
"description": "Endpoints for managing user blocking status and budgets.",
},
]

Expand Down
76 changes: 76 additions & 0 deletions src/tests/integration/test_user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,82 @@ def test_block_user_success(mocked_client_integration, mocker):
assert response.json()["user_id"] == TEST_USER_ID


def test_update_user_budget_success(mocked_client_integration, mocker):
"""Test updating a user's budget tier successfully."""
from tests.mocks import MockLiteLLMPGService

mock_litellm_pg = MockLiteLLMPGService()
mock_litellm_pg.store_user(
TEST_USER_ID,
{"user_id": TEST_USER_ID, "blocked": False, "alias": None, "budget_id": None},
)

mocker.patch("mlpa.core.routers.user.user.litellm_pg", mock_litellm_pg)

response = mocked_client_integration.post(
f"/user/{TEST_USER_ID}/budget",
headers={"master_key": f"Bearer {env.MASTER_KEY}"},
json={"service_type": "ai-dev"},
)

assert response.status_code == 200
data = response.json()
assert data["user_id"] == TEST_USER_ID
assert data["budget_id"] == "end-user-budget-ai-dev"
assert data["service_type"] == "ai-dev"


def test_update_user_budget_unauthorized(mocked_client_integration):
response = mocked_client_integration.post(
f"/user/{TEST_USER_ID}/budget",
headers={"master_key": "Bearer invalid-key"},
json={"service_type": "ai-dev"},
)

assert response.status_code == 401
assert "Unauthorized" in str(response.json())


def test_update_user_budget_user_not_found(mocked_client_integration, mocker):
"""Test updating budget for non-existent user returns 404."""
from tests.mocks import MockLiteLLMPGService

mock_litellm_pg = MockLiteLLMPGService()

mocker.patch("mlpa.core.routers.user.user.litellm_pg", mock_litellm_pg)

response = mocked_client_integration.post(
f"/user/{TEST_USER_ID}/budget",
headers={"master_key": f"Bearer {env.MASTER_KEY}"},
json={"service_type": "ai-dev"},
)

assert response.status_code == 404
assert "User not found" in str(response.json())


def test_update_user_budget_invalid_service_type(mocked_client_integration, mocker):
"""Test that invalid service_type returns 422."""
from tests.mocks import MockLiteLLMPGService

mock_litellm_pg = MockLiteLLMPGService()
mock_litellm_pg.store_user(
TEST_USER_ID,
{"user_id": TEST_USER_ID, "blocked": False, "alias": None},
)

mocker.patch("mlpa.core.routers.user.user.litellm_pg", mock_litellm_pg)

response = mocked_client_integration.post(
f"/user/{TEST_USER_ID}/budget",
headers={"master_key": f"Bearer {env.MASTER_KEY}"},
json={"service_type": "invalid-service"},
)

assert response.status_code == 422
assert "Unknown service type" in str(response.json())


def test_block_user_unauthorized(mocked_client_integration):
response = mocked_client_integration.post(
f"/user/{TEST_USER_ID}/block",
Expand Down
12 changes: 12 additions & 0 deletions src/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ async def block_user(self, user_id: str, blocked: bool = True):
self.users[user_id]["blocked"] = blocked
return self.users[user_id]

async def update_user_budget(self, user_id: str, budget_id: str):
"""Mock update_user_budget method for testing."""
logger.debug(
f"mock update_user_budget called with user_id: {user_id}, budget_id: {budget_id}",
)
if user_id not in self.users:
from fastapi import HTTPException

raise HTTPException(status_code=404, detail="User not found")
self.users[user_id]["budget_id"] = budget_id
return self.users[user_id]

async def list_users(self, limit: int = 50, offset: int = 0):
"""Mock list_users method for testing."""
logger.debug(
Expand Down