Skip to content

Commit abe7667

Browse files
authored
feat: check_models external readiness check (#712)
1 parent 0ed833f commit abe7667

19 files changed

Lines changed: 1418 additions & 202 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 21 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
from data_designer.config.utils.type_helpers import StrEnum
2929
from data_designer.config.utils.warning_helpers import warn_at_caller
3030
from data_designer.config.version import get_library_version
31+
from data_designer.engine import flags
3132
from data_designer.engine.column_generators.generators.base import (
3233
ColumnGenerator,
3334
ColumnGeneratorWithModel,
3435
GenerationStrategy,
3536
)
36-
from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
3737
from data_designer.engine.compiler import compile_data_designer_config
3838
from data_designer.engine.context import current_row_group, current_row_group_start_offset
3939
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
@@ -59,9 +59,11 @@
5959
strip_skip_metadata_from_records,
6060
)
6161
from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar
62+
from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode
6263
from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
6364
from data_designer.engine.processing.processors.base import Processor
6465
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
66+
from data_designer.engine.readiness import run_readiness_check
6567
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
6668
from data_designer.engine.resources.resource_provider import ResourceProvider
6769
from data_designer.engine.storage.artifact_storage import (
@@ -82,12 +84,12 @@
8284

8385
logger = logging.getLogger(__name__)
8486

85-
# Async engine is the default execution path. Set ``DATA_DESIGNER_ASYNC_ENGINE=0``
86-
# to opt back into the legacy sync engine for one transitional release; the sync
87-
# path is scheduled for removal afterwards.
88-
DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "1") == "1"
87+
# The async-engine flag now lives in ``data_designer.engine.flags`` so the
88+
# engine, the public interface, and the readiness module can share one source
89+
# of truth. Always read ``flags.DATA_DESIGNER_ASYNC_ENGINE`` rather than caching
90+
# a local copy so monkeypatches in tests are visible.
8991

90-
if DATA_DESIGNER_ASYNC_ENGINE:
92+
if flags.DATA_DESIGNER_ASYNC_ENGINE:
9193
import asyncio
9294

9395
from data_designer.engine.dataset_builders.async_scheduler import (
@@ -193,7 +195,7 @@ def __init__(
193195
self._task_traces: list[TaskTrace] = []
194196
self._registry = registry or DataDesignerRegistry()
195197
self._graph: ExecutionGraph | None = None
196-
self._use_async: bool = DATA_DESIGNER_ASYNC_ENGINE
198+
self._use_async: bool = flags.DATA_DESIGNER_ASYNC_ENGINE
197199
# Structured signal: set by _build_async if the scheduler hit early shutdown.
198200
# Stays at defaults for sync-engine and successful async runs. Reset at
199201
# the start of each public run path so reused builder instances don't
@@ -275,10 +277,6 @@ def single_column_configs(self) -> list[ColumnConfigT]:
275277
def single_column_config_by_name(self) -> dict[str, ColumnConfigT]:
276278
return {config.name: config for config in self.single_column_configs}
277279

278-
@functools.cached_property
279-
def llm_generated_column_configs(self) -> list[ColumnConfigT]:
280-
return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
281-
282280
def build(
283281
self,
284282
*,
@@ -314,9 +312,13 @@ def build(
314312
Path to the generated dataset directory.
315313
"""
316314
self._reset_run_state()
315+
self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
317316

318-
self._run_model_health_check_if_needed()
319-
self._run_mcp_tool_check_if_needed()
317+
run_readiness_check(
318+
self.single_column_configs,
319+
self._resource_provider,
320+
client_concurrency_mode=ClientConcurrencyMode.ASYNC if self._use_async else ClientConcurrencyMode.SYNC,
321+
)
320322

321323
# For IF_POSSIBLE and ALWAYS: check config compatibility before touching the artifact
322324
# directory. _check_resume_config_compatibility() must NOT access base_dataset_path
@@ -386,7 +388,6 @@ def build(
386388
"start a new generation run."
387389
)
388390

389-
self._use_async = DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
390391
if self._use_async:
391392
self._build_async(generators, num_records, buffer_size, on_batch_complete, resume=resume)
392393
elif resume == ResumeMode.ALWAYS:
@@ -657,8 +658,12 @@ def _build_with_resume(
657658

658659
def build_preview(self, *, num_records: int) -> pd.DataFrame:
659660
self._reset_run_state()
660-
self._run_model_health_check_if_needed()
661-
self._run_mcp_tool_check_if_needed()
661+
self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
662+
run_readiness_check(
663+
self.single_column_configs,
664+
self._resource_provider,
665+
client_concurrency_mode=ClientConcurrencyMode.ASYNC if self._use_async else ClientConcurrencyMode.SYNC,
666+
)
662667

663668
# Set media storage to DATAFRAME mode for preview - base64 stored directly in DataFrame
664669
if self._has_image_columns():
@@ -667,7 +672,6 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame:
667672
generators, self._graph = self._initialize_generators_and_graph()
668673
start_time = time.perf_counter()
669674

670-
self._use_async = DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
671675
if self._use_async:
672676
dataset = self._build_async_preview(generators, num_records)
673677
else:
@@ -1407,38 +1411,6 @@ def _merge_skipped_and_generated(
14071411
batch.append(gen_result)
14081412
return batch
14091413

1410-
def _run_model_health_check_if_needed(self) -> None:
1411-
model_aliases: set[str] = set()
1412-
for config in self.single_column_configs:
1413-
model_aliases.update(config.get_model_aliases())
1414-
1415-
if not model_aliases:
1416-
return
1417-
1418-
if DATA_DESIGNER_ASYNC_ENGINE:
1419-
loop = ensure_async_engine_loop()
1420-
future = asyncio.run_coroutine_threadsafe(
1421-
self._resource_provider.model_registry.arun_health_check(list(model_aliases)),
1422-
loop,
1423-
)
1424-
try:
1425-
future.result(timeout=180)
1426-
except TimeoutError:
1427-
future.cancel()
1428-
raise
1429-
else:
1430-
self._resource_provider.model_registry.run_health_check(list(model_aliases))
1431-
1432-
def _run_mcp_tool_check_if_needed(self) -> None:
1433-
tool_aliases = sorted(
1434-
{config.tool_alias for config in self.llm_generated_column_configs if getattr(config, "tool_alias", None)}
1435-
)
1436-
if not tool_aliases:
1437-
return
1438-
if self._resource_provider.mcp_registry is None:
1439-
raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.")
1440-
self._resource_provider.mcp_registry.run_health_check(tool_aliases)
1441-
14421414
def _setup_fan_out(
14431415
self,
14441416
generator: ColumnGeneratorWithModelRegistry,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Engine-wide feature flags read from environment variables.
5+
6+
This module exists so the engine, the public interface, and the readiness
7+
module can share a single source of truth for runtime mode flags without
8+
forming an import cycle. Tests patch values here to flip behavior for a
9+
single test scope.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import os
15+
16+
# Async engine is the default execution path. Set ``DATA_DESIGNER_ASYNC_ENGINE=0``
17+
# to opt back into the legacy sync engine for one transitional release; the sync
18+
# path is scheduled for removal afterwards.
19+
DATA_DESIGNER_ASYNC_ENGINE: bool = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "1") == "1"

packages/data-designer-engine/src/data_designer/engine/models/registry.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from data_designer.config.models import GenerationType, ModelConfig
1010
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
11+
from data_designer.engine.models.errors import ModelGenerationValidationFailureError
12+
from data_designer.engine.models.parsers.errors import ParserException
1113
from data_designer.engine.models.usage import ModelUsageStats, RequestUsageStats, TokenCountSource, TokenUsageStats
1214
from data_designer.engine.secret_resolver import SecretResolver
1315
from data_designer.logging import LOG_INDENT
@@ -27,6 +29,21 @@
2729
logger = logging.getLogger(__name__)
2830

2931

32+
def _parse_health_check_chat_response(response: str) -> str:
33+
if not isinstance(response, str) or not response:
34+
raise ParserException("Health check response must be non-empty text.")
35+
return response
36+
37+
38+
def _validate_health_check_embedding_response(vectors: list[list[float]], *, model_alias: str) -> None:
39+
if not isinstance(vectors, list) or len(vectors) != 1 or not isinstance(vectors[0], list) or not vectors[0]:
40+
raise ModelGenerationValidationFailureError(
41+
f"Health check for model alias {model_alias!r} returned an invalid embedding response.",
42+
detail="Expected exactly one non-empty embedding vector.",
43+
failure_kind="validation_error",
44+
)
45+
46+
3047
def format_reasoning_token_count(reasoning_token_count: int, source: TokenCountSource | str | None) -> str:
3148
if source == TokenCountSource.ESTIMATED or source == TokenCountSource.ESTIMATED.value:
3249
return f"{reasoning_token_count} (estimated)"
@@ -241,15 +258,16 @@ def run_health_check(self, model_aliases: list[str]) -> None:
241258
)
242259
try:
243260
if model.model_generation_type == GenerationType.EMBEDDING:
244-
model.generate_text_embeddings(
261+
vectors = model.generate_text_embeddings(
245262
input_texts=["Hello!"],
246263
skip_usage_tracking=True,
247264
purpose="running health checks",
248265
)
266+
_validate_health_check_embedding_response(vectors, model_alias=model_alias)
249267
elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
250268
model.generate(
251269
prompt="Hello!",
252-
parser=lambda x: x,
270+
parser=_parse_health_check_chat_response,
253271
system_prompt="You are a helpful assistant.",
254272
max_correction_steps=0,
255273
max_conversation_restarts=0,
@@ -286,15 +304,16 @@ async def arun_health_check(self, model_aliases: list[str]) -> None:
286304
)
287305
try:
288306
if model.model_generation_type == GenerationType.EMBEDDING:
289-
await model.agenerate_text_embeddings(
307+
vectors = await model.agenerate_text_embeddings(
290308
input_texts=["Hello!"],
291309
skip_usage_tracking=True,
292310
purpose="running health checks",
293311
)
312+
_validate_health_check_embedding_response(vectors, model_alias=model_alias)
294313
elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
295314
await model.agenerate(
296315
prompt="Hello!",
297-
parser=lambda x: x,
316+
parser=_parse_health_check_chat_response,
298317
system_prompt="You are a helpful assistant.",
299318
max_correction_steps=0,
300319
max_conversation_restarts=0,
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""External-readiness checks for a DataDesigner workload.
5+
6+
A "readiness" check is a pre-flight probe of every external resource a
7+
configuration depends on: each referenced model alias is sent a tiny
8+
generation request, and every referenced MCP tool alias is contacted to
9+
confirm its server is reachable.
10+
11+
This module hosts the shared logic invoked from two places:
12+
13+
- ``DatasetBuilder.build`` / ``DatasetBuilder.build_preview`` — at the start
14+
of a workload, to fail fast before any expensive work begins.
15+
- ``DataDesigner.check_models`` — exposed publicly so users can verify
16+
external dependencies are responsive without triggering a workload.
17+
18+
The two callers must use the same code path here so the standalone method
19+
cannot drift from the workload-startup gate.
20+
"""
21+
22+
from __future__ import annotations
23+
24+
import logging
25+
from collections.abc import Sequence
26+
from typing import TYPE_CHECKING
27+
28+
from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
29+
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
30+
from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode
31+
32+
if TYPE_CHECKING:
33+
from data_designer.config.column_types import ColumnConfigT
34+
from data_designer.engine.resources.resource_provider import ResourceProvider
35+
36+
logger = logging.getLogger(__name__)
37+
38+
# Match the timeout the dataset builder's startup gate has always used.
39+
_MODEL_HEALTH_CHECK_TIMEOUT_SECONDS = 180
40+
41+
42+
def run_readiness_check(
43+
column_configs: Sequence[ColumnConfigT],
44+
resource_provider: ResourceProvider,
45+
*,
46+
client_concurrency_mode: ClientConcurrencyMode,
47+
) -> None:
48+
"""Probe every model and MCP tool referenced by ``column_configs``.
49+
50+
For each unique model alias collected from the column configs,
51+
``ModelRegistry.run_health_check`` (or ``arun_health_check`` when async
52+
mode is selected) sends a tiny ``"Hello!"`` generation. Models whose ``ModelConfig``
53+
has ``skip_health_check=True`` are skipped by the registry. After the
54+
model pass, every unique MCP tool alias is probed via
55+
``MCPRegistry.run_health_check``.
56+
57+
Args:
58+
column_configs: The column configs whose ``get_model_aliases()`` and
59+
``tool_alias`` fields determine which aliases are probed.
60+
resource_provider: Provides access to the model registry and MCP
61+
registry. ``mcp_registry`` may be ``None`` only if no tool
62+
aliases are referenced.
63+
client_concurrency_mode: Resolved client mode for this run.
64+
65+
Raises:
66+
Typed model errors from ``data_designer.engine.models.errors`` for
67+
any failing model probe.
68+
DatasetGenerationError: If a tool alias is referenced but no MCP
69+
registry is configured on the resource provider.
70+
TimeoutError: If async health-check execution exceeds
71+
``_MODEL_HEALTH_CHECK_TIMEOUT_SECONDS``.
72+
"""
73+
_run_model_health_check(column_configs, resource_provider, client_concurrency_mode=client_concurrency_mode)
74+
_run_mcp_tool_health_check(column_configs, resource_provider)
75+
76+
77+
def _run_model_health_check(
78+
column_configs: Sequence[ColumnConfigT],
79+
resource_provider: ResourceProvider,
80+
*,
81+
client_concurrency_mode: ClientConcurrencyMode,
82+
) -> None:
83+
model_aliases: set[str] = set()
84+
for config in column_configs:
85+
model_aliases.update(config.get_model_aliases())
86+
87+
if not model_aliases:
88+
return
89+
90+
if client_concurrency_mode == ClientConcurrencyMode.ASYNC:
91+
# Defer the async-engine imports to here so users on the legacy sync
92+
# engine never pay the import cost.
93+
import asyncio
94+
95+
from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop
96+
97+
loop = ensure_async_engine_loop()
98+
future = asyncio.run_coroutine_threadsafe(
99+
resource_provider.model_registry.arun_health_check(list(model_aliases)),
100+
loop,
101+
)
102+
try:
103+
future.result(timeout=_MODEL_HEALTH_CHECK_TIMEOUT_SECONDS)
104+
except TimeoutError:
105+
future.cancel()
106+
raise
107+
else:
108+
resource_provider.model_registry.run_health_check(list(model_aliases))
109+
110+
111+
def _run_mcp_tool_health_check(
112+
column_configs: Sequence[ColumnConfigT],
113+
resource_provider: ResourceProvider,
114+
) -> None:
115+
# Tool aliases are only meaningful on model-generated column configs.
116+
tool_aliases = sorted(
117+
{
118+
config.tool_alias
119+
for config in column_configs
120+
if column_type_is_model_generated(config.column_type) and getattr(config, "tool_alias", None)
121+
}
122+
)
123+
if not tool_aliases:
124+
return
125+
if resource_provider.mcp_registry is None:
126+
raise DatasetGenerationError(f"Tool alias(es) {tool_aliases!r} specified but no MCPRegistry configured.")
127+
resource_provider.mcp_registry.run_health_check(tool_aliases)

packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from __future__ import annotations
55

6-
import os
76
from typing import TYPE_CHECKING
87

98
from data_designer.config.base import ConfigBase
@@ -13,6 +12,7 @@
1312
from data_designer.config.run_config import RunConfig
1413
from data_designer.config.seed_source import SeedSource
1514
from data_designer.config.utils.type_helpers import StrEnum
15+
from data_designer.engine import flags
1616
from data_designer.engine.mcp.factory import create_mcp_registry
1717
from data_designer.engine.mcp.registry import MCPRegistry
1818
from data_designer.engine.model_provider import (
@@ -148,9 +148,7 @@ def create_resource_provider(
148148
# default for backward compatibility.
149149
if client_concurrency_mode is None:
150150
client_concurrency_mode = (
151-
ClientConcurrencyMode.ASYNC
152-
if os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "1") == "1"
153-
else ClientConcurrencyMode.SYNC
151+
ClientConcurrencyMode.ASYNC if flags.DATA_DESIGNER_ASYNC_ENGINE else ClientConcurrencyMode.SYNC
154152
)
155153

156154
effective_run_config = run_config or RunConfig()

0 commit comments

Comments
 (0)