Skip to content

Commit 95e213c

Browse files
committed
Be more strict with health check responses
Signed-off-by: Mike Knepper <mknepper@nvidia.com>
1 parent c9c118f commit 95e213c

2 files changed

Lines changed: 70 additions & 5 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/models/registry.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from data_designer.config.models import GenerationType, ModelConfig
1010
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
11+
from data_designer.engine.models.errors import ModelGenerationValidationFailureError
12+
from data_designer.engine.models.parsers.errors import ParserException
1113
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenCountSource, TokenUsageStats
1214
from data_designer.engine.secret_resolver import SecretResolver
1315
from data_designer.logging import LOG_INDENT
@@ -27,6 +29,21 @@
2729
logger = logging.getLogger(__name__)
2830

2931

32+
def _parse_health_check_chat_response(response: str) -> str:
33+
if not isinstance(response, str) or not response:
34+
raise ParserException("Health check response must be non-empty text.")
35+
return response
36+
37+
38+
def _validate_health_check_embedding_response(vectors: list[list[float]], *, model_alias: str) -> None:
39+
if not isinstance(vectors, list) or len(vectors) != 1 or not isinstance(vectors[0], list) or not vectors[0]:
40+
raise ModelGenerationValidationFailureError(
41+
f"Health check for model alias {model_alias!r} returned an invalid embedding response.",
42+
detail="Expected exactly one non-empty embedding vector.",
43+
failure_kind="validation_error",
44+
)
45+
46+
3047
def format_reasoning_token_count(reasoning_token_count: int, source: TokenCountSource | str | None) -> str:
3148
if source == TokenCountSource.ESTIMATED or source == TokenCountSource.ESTIMATED.value:
3249
return f"{reasoning_token_count} (estimated)"
@@ -241,15 +258,16 @@ def run_health_check(self, model_aliases: list[str]) -> None:
241258
)
242259
try:
243260
if model.model_generation_type == GenerationType.EMBEDDING:
244-
model.generate_text_embeddings(
261+
vectors = model.generate_text_embeddings(
245262
input_texts=["Hello!"],
246263
skip_usage_tracking=True,
247264
purpose="running health checks",
248265
)
266+
_validate_health_check_embedding_response(vectors, model_alias=model_alias)
249267
elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
250268
model.generate(
251269
prompt="Hello!",
252-
parser=lambda x: x,
270+
parser=_parse_health_check_chat_response,
253271
system_prompt="You are a helpful assistant.",
254272
max_correction_steps=0,
255273
max_conversation_restarts=0,
@@ -286,15 +304,16 @@ async def arun_health_check(self, model_aliases: list[str]) -> None:
286304
)
287305
try:
288306
if model.model_generation_type == GenerationType.EMBEDDING:
289-
await model.agenerate_text_embeddings(
307+
vectors = await model.agenerate_text_embeddings(
290308
input_texts=["Hello!"],
291309
skip_usage_tracking=True,
292310
purpose="running health checks",
293311
)
312+
_validate_health_check_embedding_response(vectors, model_alias=model_alias)
294313
elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
295314
await model.agenerate(
296315
prompt="Hello!",
297-
parser=lambda x: x,
316+
parser=_parse_health_check_chat_response,
298317
system_prompt="You are a helpful assistant.",
299318
max_correction_steps=0,
300319
max_conversation_restarts=0,

packages/data-designer-engine/tests/engine/models/test_model_registry.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
99
from data_designer.config.run_config import RequestAdmissionTuningConfig, RunConfig
10-
from data_designer.engine.models.errors import ModelAuthenticationError
10+
from data_designer.engine.models.errors import ModelAuthenticationError, ModelGenerationValidationFailureError
1111
from data_designer.engine.models.facade import ModelFacade
1212
from data_designer.engine.models.factory import create_model_registry
1313
from data_designer.engine.models.registry import ModelRegistry
1414
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenCountSource, TokenUsageStats
15+
from data_designer.engine.testing import make_stub_completion_response
1516
from data_designer.logging import LOG_INDENT
1617

1718

@@ -332,6 +333,8 @@ def test_run_health_check_success(
332333
mock_generate_image: object,
333334
stub_model_registry: ModelRegistry,
334335
) -> None:
336+
mock_completion.return_value = make_stub_completion_response(content="Hello!")
337+
mock_generate_text_embeddings.return_value = [[0.1]]
335338
model_aliases = ["stub-text", "stub-reasoning", "stub-embedding", "stub-image"]
336339
stub_model_registry.run_health_check(model_aliases)
337340
assert mock_completion.call_count == 2
@@ -365,6 +368,7 @@ def test_run_health_check_embedding_authentication_error(
365368
stub_model_registry: ModelRegistry,
366369
) -> None:
367370
auth_error = ModelAuthenticationError("Invalid API key for embedding model")
371+
mock_completion.return_value = make_stub_completion_response(content="Hello!")
368372
mock_generate_text_embeddings.side_effect = auth_error
369373
model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"]
370374

@@ -375,12 +379,39 @@ def test_run_health_check_embedding_authentication_error(
375379
mock_generate_text_embeddings.assert_called_once()
376380

377381

382+
@patch.object(ModelFacade, "completion", autospec=True)
383+
def test_run_health_check_rejects_empty_completion_response(
384+
mock_completion: object,
385+
stub_model_registry: ModelRegistry,
386+
) -> None:
387+
mock_completion.return_value = make_stub_completion_response(content="")
388+
389+
with pytest.raises(ModelGenerationValidationFailureError, match="Health check response must be non-empty text"):
390+
stub_model_registry.run_health_check(["stub-text"])
391+
392+
mock_completion.assert_called_once()
393+
394+
395+
@patch.object(ModelFacade, "generate_text_embeddings", autospec=True)
396+
def test_run_health_check_rejects_empty_embedding_vector(
397+
mock_generate_text_embeddings: object,
398+
stub_model_registry: ModelRegistry,
399+
) -> None:
400+
mock_generate_text_embeddings.return_value = [[]]
401+
402+
with pytest.raises(ModelGenerationValidationFailureError, match="invalid embedding response"):
403+
stub_model_registry.run_health_check(["stub-embedding"])
404+
405+
mock_generate_text_embeddings.assert_called_once()
406+
407+
378408
@patch.object(ModelFacade, "completion", autospec=True)
379409
def test_run_health_check_skip_health_check_flag(
380410
mock_completion: object,
381411
stub_secrets_resolver: object,
382412
stub_model_provider_registry: object,
383413
) -> None:
414+
mock_completion.return_value = make_stub_completion_response(content="Hello!")
384415
# Create model configs: one with skip_health_check=True, others with default (False)
385416
model_configs = [
386417
ModelConfig(
@@ -436,6 +467,7 @@ async def test_arun_health_check_success(
436467
mock_agenerate_image: AsyncMock,
437468
stub_model_registry: ModelRegistry,
438469
) -> None:
470+
mock_agenerate_text_embeddings.return_value = [[0.1]]
439471
model_aliases = ["stub-text", "stub-reasoning", "stub-embedding", "stub-image"]
440472
await stub_model_registry.arun_health_check(model_aliases)
441473
assert mock_agenerate.call_count == 2
@@ -461,6 +493,20 @@ async def test_arun_health_check_authentication_error(
461493
mock_agenerate_text_embeddings.assert_not_awaited()
462494

463495

496+
@patch.object(ModelFacade, "agenerate_text_embeddings", new_callable=AsyncMock)
497+
@pytest.mark.asyncio
498+
async def test_arun_health_check_rejects_empty_embedding_vector(
499+
mock_agenerate_text_embeddings: AsyncMock,
500+
stub_model_registry: ModelRegistry,
501+
) -> None:
502+
mock_agenerate_text_embeddings.return_value = [[]]
503+
504+
with pytest.raises(ModelGenerationValidationFailureError, match="invalid embedding response"):
505+
await stub_model_registry.arun_health_check(["stub-embedding"])
506+
507+
mock_agenerate_text_embeddings.assert_awaited_once()
508+
509+
464510
def test_get_aggregate_max_parallel_requests(stub_model_registry: ModelRegistry) -> None:
465511
"""get_aggregate_max_parallel_requests returns the sum across all model configs."""
466512
total = stub_model_registry.get_aggregate_max_parallel_requests()

0 commit comments

Comments
 (0)