diff --git a/src/mlpa/core/completions.py b/src/mlpa/core/completions.py index 68a3925..ab57432 100644 --- a/src/mlpa/core/completions.py +++ b/src/mlpa/core/completions.py @@ -18,7 +18,11 @@ ) from mlpa.core.http_client import get_http_client from mlpa.core.logger import logger -from mlpa.core.prometheus_metrics import PrometheusResult, metrics +from mlpa.core.prometheus_metrics import ( + PrometheusRejectionReason, + PrometheusResult, + metrics, +) from mlpa.core.utils import is_context_window_error, is_rate_limit_error, raise_and_log # Global default tokenizer - initialized once at module load time @@ -39,6 +43,15 @@ def get_default_tokenizer() -> tiktoken.Encoding: return _global_default_tokenizer +_RATE_LIMIT_REJECTION: dict[int, tuple[PrometheusRejectionReason, str]] = { + ERROR_CODE_BUDGET_LIMIT_EXCEEDED: ( + PrometheusRejectionReason.BUDGET_EXCEEDED, + "86400", + ), + ERROR_CODE_RATE_LIMIT_EXCEEDED: (PrometheusRejectionReason.RATE_LIMITED, "60"), +} + + def _parse_rate_limit_error(error_text: str, user: str) -> int | None: """ Parse error response to detect budget or rate limit errors. @@ -61,20 +74,12 @@ def _parse_rate_limit_error(error_text: str, user: str) -> int | None: return None -def _handle_rate_limit_error(error_text: str, user: str) -> None: - error_code = _parse_rate_limit_error(error_text, user) - if error_code == ERROR_CODE_BUDGET_LIMIT_EXCEEDED: - raise HTTPException( - status_code=429, - detail={"error": ERROR_CODE_BUDGET_LIMIT_EXCEEDED}, - headers={"Retry-After": "86400"}, - ) - if error_code == ERROR_CODE_RATE_LIMIT_EXCEEDED: - raise HTTPException( - status_code=429, - detail={"error": ERROR_CODE_RATE_LIMIT_EXCEEDED}, - headers={"Retry-After": "60"}, - ) +def _record_rejection( + req: AuthorizedChatRequest, reason: PrometheusRejectionReason +) -> None: + metrics.chat_request_rejections.labels( + reason=reason, model=req.model, service_type=req.service_type + ).inc() def _tool_names_from_request(tools: list) -> list[str]: @@ -162,7 +167,9 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest): error_code = _parse_rate_limit_error( error_text_str, authorized_chat_request.user ) - if error_code is not None: + if error_code in _RATE_LIMIT_REJECTION: + reason, _ = _RATE_LIMIT_REJECTION[error_code] + _record_rejection(authorized_chat_request, reason) yield f'data: {{"error": {error_code}}}\n\n'.encode() return @@ -173,6 +180,10 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest): logger.warning( f"Context window exceeded for user {authorized_chat_request.user}" ) + _record_rejection( + authorized_chat_request, + PrometheusRejectionReason.PAYLOAD_TOO_LARGE, + ) yield f'data: {{"error": {ERROR_CODE_REQUEST_TOO_LARGE}}}\n\n'.encode() return @@ -303,12 +314,25 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): except httpx.HTTPStatusError as e: error_text = e.response.text if e.response.status_code in {400, 429}: - _handle_rate_limit_error(error_text, authorized_chat_request.user) + error_code = _parse_rate_limit_error( + error_text, authorized_chat_request.user + ) + if error_code in _RATE_LIMIT_REJECTION: + reason, retry_after = _RATE_LIMIT_REJECTION[error_code] + _record_rejection(authorized_chat_request, reason) + raise HTTPException( + status_code=429, + detail={"error": error_code}, + headers={"Retry-After": retry_after}, + ) # Context window exceeded: detect by error text or upstream 413 if e.response.status_code == 413 or is_context_window_error(error_text): logger.warning( f"Context window exceeded for user {authorized_chat_request.user}" ) + _record_rejection( + authorized_chat_request, PrometheusRejectionReason.PAYLOAD_TOO_LARGE + ) raise HTTPException( status_code=413, detail={"error": ERROR_CODE_REQUEST_TOO_LARGE}, diff --git a/src/mlpa/core/prometheus_metrics.py b/src/mlpa/core/prometheus_metrics.py index 49978bc..4a214f5 100644 --- a/src/mlpa/core/prometheus_metrics.py +++ b/src/mlpa/core/prometheus_metrics.py @@ -9,6 +9,12 @@ class PrometheusResult(StrEnum): ERROR = "error" +class PrometheusRejectionReason(StrEnum): + BUDGET_EXCEEDED = "budget_exceeded" + RATE_LIMITED = "rate_limited" + PAYLOAD_TOO_LARGE = "payload_too_large" + + BUCKETS_FAST_AUTH = (0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, float("inf")) BUCKETS_AUTH = (0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, float("inf")) BUCKETS_FXA = (0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, float("inf")) @@ -76,6 +82,7 @@ class PrometheusMetrics: chat_completions_with_tools: Counter chat_tool_calls_per_completion: Histogram chat_requests_with_tools: Counter + chat_request_rejections: Counter metrics = PrometheusMetrics( @@ -171,4 +178,9 @@ class PrometheusMetrics: "Number of chat requests that included a tools payload.", ["tool_name", "model", "service_type"], ), + chat_request_rejections=Counter( + "mlpa_chat_request_rejections_total", + "Number of chat requests rejected due to budget, rate limit, or payload size.", + ["reason", "model", "service_type"], + ), ) diff --git a/src/tests/unit/test_completions.py b/src/tests/unit/test_completions.py index 89afc43..1908299 100644 --- a/src/tests/unit/test_completions.py +++ b/src/tests/unit/test_completions.py @@ -15,7 +15,7 @@ LITELLM_COMPLETIONS_URL, env, ) -from mlpa.core.prometheus_metrics import PrometheusResult +from mlpa.core.prometheus_metrics import PrometheusRejectionReason, PrometheusResult from tests.consts import SAMPLE_REQUEST, SUCCESSFUL_CHAT_RESPONSE @@ -254,6 +254,13 @@ async def test_get_completion_budget_limit_exceeded_429(mocker): assert exc_info.value.detail == {"error": 1} # ERROR_CODE_BUDGET_LIMIT_EXCEEDED assert exc_info.value.headers == {"Retry-After": "86400"} + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.BUDGET_EXCEEDED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() + # Verify latency metric was observed with ERROR mock_metrics.chat_completion_latency.labels.assert_called_once_with( result=PrometheusResult.ERROR, @@ -299,6 +306,13 @@ async def test_get_completion_budget_limit_exceeded_400(mocker): assert exc_info.value.detail == {"error": 1} # ERROR_CODE_BUDGET_LIMIT_EXCEEDED assert exc_info.value.headers == {"Retry-After": "86400"} + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.BUDGET_EXCEEDED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() + async def test_get_completion_rate_limit_exceeded(mocker): """ @@ -337,6 +351,13 @@ async def test_get_completion_rate_limit_exceeded(mocker): assert exc_info.value.detail == {"error": 2} # ERROR_CODE_RATE_LIMIT_EXCEEDED assert exc_info.value.headers == {"Retry-After": "60"} + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.RATE_LIMITED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() + async def test_get_completion_400_non_rate_limit_error(mocker): """ @@ -438,6 +459,12 @@ async def test_get_completion_context_window_exceeded(mocker): assert exc_info.value.detail == {"error": ERROR_CODE_REQUEST_TOO_LARGE} mock_logger.warning.assert_called_once() assert "Context window exceeded" in str(mock_logger.warning.call_args) + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.PAYLOAD_TOO_LARGE, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with( result=PrometheusResult.ERROR, model=SAMPLE_REQUEST.model, @@ -508,6 +535,12 @@ async def test_stream_completion_budget_limit_exceeded_429( ) mock_logger.warning.assert_called_once() assert "Budget limit exceeded" in str(mock_logger.warning.call_args) + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.BUDGET_EXCEEDED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with( result=PrometheusResult.ERROR, model=SAMPLE_REQUEST.model, @@ -549,6 +582,12 @@ async def test_stream_completion_budget_limit_exceeded_400( ) mock_logger.warning.assert_called_once() assert "Budget limit exceeded" in str(mock_logger.warning.call_args) + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.BUDGET_EXCEEDED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with( result=PrometheusResult.ERROR, model=SAMPLE_REQUEST.model, @@ -588,6 +627,12 @@ async def test_stream_completion_rate_limit_exceeded(httpx_mock: HTTPXMock, mock ) mock_logger.warning.assert_called_once() assert "Rate limit exceeded" in str(mock_logger.warning.call_args) + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.RATE_LIMITED, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with( result=PrometheusResult.ERROR, model=SAMPLE_REQUEST.model, @@ -622,6 +667,12 @@ async def test_stream_completion_context_window_exceeded(httpx_mock: HTTPXMock, ) mock_logger.warning.assert_called_once() assert "Context window exceeded" in str(mock_logger.warning.call_args) + mock_metrics.chat_request_rejections.labels.assert_called_once_with( + reason=PrometheusRejectionReason.PAYLOAD_TOO_LARGE, + model=SAMPLE_REQUEST.model, + service_type=SAMPLE_REQUEST.service_type, + ) + mock_metrics.chat_request_rejections.labels().inc.assert_called_once() mock_metrics.chat_completion_latency.labels.assert_called_once_with( result=PrometheusResult.ERROR, model=SAMPLE_REQUEST.model,