Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions src/mlpa/core/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
)
from mlpa.core.http_client import get_http_client
from mlpa.core.logger import logger
from mlpa.core.prometheus_metrics import PrometheusResult, metrics
from mlpa.core.prometheus_metrics import (
PrometheusRejectionReason,
PrometheusResult,
metrics,
)
from mlpa.core.utils import is_context_window_error, is_rate_limit_error, raise_and_log

# Global default tokenizer - initialized once at module load time
Expand All @@ -39,6 +43,15 @@ def get_default_tokenizer() -> tiktoken.Encoding:
return _global_default_tokenizer


_RATE_LIMIT_REJECTION: dict[int, tuple[PrometheusRejectionReason, str]] = {
ERROR_CODE_BUDGET_LIMIT_EXCEEDED: (
PrometheusRejectionReason.BUDGET_EXCEEDED,
"86400",
),
ERROR_CODE_RATE_LIMIT_EXCEEDED: (PrometheusRejectionReason.RATE_LIMITED, "60"),
}


def _parse_rate_limit_error(error_text: str, user: str) -> int | None:
"""
Parse error response to detect budget or rate limit errors.
Expand All @@ -61,20 +74,12 @@ def _parse_rate_limit_error(error_text: str, user: str) -> int | None:
return None


def _handle_rate_limit_error(error_text: str, user: str) -> None:
error_code = _parse_rate_limit_error(error_text, user)
if error_code == ERROR_CODE_BUDGET_LIMIT_EXCEEDED:
raise HTTPException(
status_code=429,
detail={"error": ERROR_CODE_BUDGET_LIMIT_EXCEEDED},
headers={"Retry-After": "86400"},
)
if error_code == ERROR_CODE_RATE_LIMIT_EXCEEDED:
raise HTTPException(
status_code=429,
detail={"error": ERROR_CODE_RATE_LIMIT_EXCEEDED},
headers={"Retry-After": "60"},
)
def _record_rejection(
req: AuthorizedChatRequest, reason: PrometheusRejectionReason
) -> None:
metrics.chat_request_rejections.labels(
reason=reason, model=req.model, service_type=req.service_type
).inc()


def _tool_names_from_request(tools: list) -> list[str]:
Expand Down Expand Up @@ -162,7 +167,9 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest):
error_code = _parse_rate_limit_error(
error_text_str, authorized_chat_request.user
)
if error_code is not None:
if error_code in _RATE_LIMIT_REJECTION:
reason, _ = _RATE_LIMIT_REJECTION[error_code]
_record_rejection(authorized_chat_request, reason)
yield f'data: {{"error": {error_code}}}\n\n'.encode()
return

Expand All @@ -173,6 +180,10 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest):
logger.warning(
f"Context window exceeded for user {authorized_chat_request.user}"
)
_record_rejection(
authorized_chat_request,
PrometheusRejectionReason.PAYLOAD_TOO_LARGE,
)
yield f'data: {{"error": {ERROR_CODE_REQUEST_TOO_LARGE}}}\n\n'.encode()
return

Expand Down Expand Up @@ -303,12 +314,25 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest):
except httpx.HTTPStatusError as e:
error_text = e.response.text
if e.response.status_code in {400, 429}:
_handle_rate_limit_error(error_text, authorized_chat_request.user)
error_code = _parse_rate_limit_error(
error_text, authorized_chat_request.user
)
if error_code in _RATE_LIMIT_REJECTION:
reason, retry_after = _RATE_LIMIT_REJECTION[error_code]
_record_rejection(authorized_chat_request, reason)
raise HTTPException(
status_code=429,
detail={"error": error_code},
headers={"Retry-After": retry_after},
)
# Context window exceeded: detect by error text or upstream 413
if e.response.status_code == 413 or is_context_window_error(error_text):
logger.warning(
f"Context window exceeded for user {authorized_chat_request.user}"
)
_record_rejection(
authorized_chat_request, PrometheusRejectionReason.PAYLOAD_TOO_LARGE
)
raise HTTPException(
status_code=413,
detail={"error": ERROR_CODE_REQUEST_TOO_LARGE},
Expand Down
12 changes: 12 additions & 0 deletions src/mlpa/core/prometheus_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class PrometheusResult(StrEnum):
ERROR = "error"


class PrometheusRejectionReason(StrEnum):
BUDGET_EXCEEDED = "budget_exceeded"
RATE_LIMITED = "rate_limited"
PAYLOAD_TOO_LARGE = "payload_too_large"


BUCKETS_FAST_AUTH = (0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, float("inf"))
BUCKETS_AUTH = (0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, float("inf"))
BUCKETS_FXA = (0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, float("inf"))
Expand Down Expand Up @@ -76,6 +82,7 @@ class PrometheusMetrics:
chat_completions_with_tools: Counter
chat_tool_calls_per_completion: Histogram
chat_requests_with_tools: Counter
chat_request_rejections: Counter


metrics = PrometheusMetrics(
Expand Down Expand Up @@ -171,4 +178,9 @@ class PrometheusMetrics:
"Number of chat requests that included a tools payload.",
["tool_name", "model", "service_type"],
),
chat_request_rejections=Counter(
"mlpa_chat_request_rejections_total",
"Number of chat requests rejected due to budget, rate limit, or payload size.",
["reason", "model", "service_type"],
),
)
53 changes: 52 additions & 1 deletion src/tests/unit/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
LITELLM_COMPLETIONS_URL,
env,
)
from mlpa.core.prometheus_metrics import PrometheusResult
from mlpa.core.prometheus_metrics import PrometheusRejectionReason, PrometheusResult
from tests.consts import SAMPLE_REQUEST, SUCCESSFUL_CHAT_RESPONSE


Expand Down Expand Up @@ -254,6 +254,13 @@ async def test_get_completion_budget_limit_exceeded_429(mocker):
assert exc_info.value.detail == {"error": 1} # ERROR_CODE_BUDGET_LIMIT_EXCEEDED
assert exc_info.value.headers == {"Retry-After": "86400"}

mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.BUDGET_EXCEEDED,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()

# Verify latency metric was observed with ERROR
mock_metrics.chat_completion_latency.labels.assert_called_once_with(
result=PrometheusResult.ERROR,
Expand Down Expand Up @@ -299,6 +306,13 @@ async def test_get_completion_budget_limit_exceeded_400(mocker):
assert exc_info.value.detail == {"error": 1} # ERROR_CODE_BUDGET_LIMIT_EXCEEDED
assert exc_info.value.headers == {"Retry-After": "86400"}

mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.BUDGET_EXCEEDED,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()


async def test_get_completion_rate_limit_exceeded(mocker):
"""
Expand Down Expand Up @@ -337,6 +351,13 @@ async def test_get_completion_rate_limit_exceeded(mocker):
assert exc_info.value.detail == {"error": 2} # ERROR_CODE_RATE_LIMIT_EXCEEDED
assert exc_info.value.headers == {"Retry-After": "60"}

mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.RATE_LIMITED,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()


async def test_get_completion_400_non_rate_limit_error(mocker):
"""
Expand Down Expand Up @@ -438,6 +459,12 @@ async def test_get_completion_context_window_exceeded(mocker):
assert exc_info.value.detail == {"error": ERROR_CODE_REQUEST_TOO_LARGE}
mock_logger.warning.assert_called_once()
assert "Context window exceeded" in str(mock_logger.warning.call_args)
mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.PAYLOAD_TOO_LARGE,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()
mock_metrics.chat_completion_latency.labels.assert_called_once_with(
result=PrometheusResult.ERROR,
model=SAMPLE_REQUEST.model,
Expand Down Expand Up @@ -508,6 +535,12 @@ async def test_stream_completion_budget_limit_exceeded_429(
)
mock_logger.warning.assert_called_once()
assert "Budget limit exceeded" in str(mock_logger.warning.call_args)
mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.BUDGET_EXCEEDED,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()
mock_metrics.chat_completion_latency.labels.assert_called_once_with(
result=PrometheusResult.ERROR,
model=SAMPLE_REQUEST.model,
Expand Down Expand Up @@ -549,6 +582,12 @@ async def test_stream_completion_budget_limit_exceeded_400(
)
mock_logger.warning.assert_called_once()
assert "Budget limit exceeded" in str(mock_logger.warning.call_args)
mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.BUDGET_EXCEEDED,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()
mock_metrics.chat_completion_latency.labels.assert_called_once_with(
result=PrometheusResult.ERROR,
model=SAMPLE_REQUEST.model,
Expand Down Expand Up @@ -588,6 +627,12 @@ async def test_stream_completion_rate_limit_exceeded(httpx_mock: HTTPXMock, mock
)
mock_logger.warning.assert_called_once()
assert "Rate limit exceeded" in str(mock_logger.warning.call_args)
mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.RATE_LIMITED,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()
mock_metrics.chat_completion_latency.labels.assert_called_once_with(
result=PrometheusResult.ERROR,
model=SAMPLE_REQUEST.model,
Expand Down Expand Up @@ -622,6 +667,12 @@ async def test_stream_completion_context_window_exceeded(httpx_mock: HTTPXMock,
)
mock_logger.warning.assert_called_once()
assert "Context window exceeded" in str(mock_logger.warning.call_args)
mock_metrics.chat_request_rejections.labels.assert_called_once_with(
reason=PrometheusRejectionReason.PAYLOAD_TOO_LARGE,
model=SAMPLE_REQUEST.model,
service_type=SAMPLE_REQUEST.service_type,
)
mock_metrics.chat_request_rejections.labels().inc.assert_called_once()
mock_metrics.chat_completion_latency.labels.assert_called_once_with(
result=PrometheusResult.ERROR,
model=SAMPLE_REQUEST.model,
Expand Down
Loading