Skip to content

Commit c9c118f

Browse files
committed
align readiness with client mode
1 parent b3a9d88 commit c9c118f

5 files changed

Lines changed: 108 additions & 26 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
strip_skip_metadata_from_records,
6060
)
6161
from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar
62+
from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode
6263
from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler
6364
from data_designer.engine.processing.processors.base import Processor
6465
from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor
@@ -311,8 +312,13 @@ def build(
311312
Path to the generated dataset directory.
312313
"""
313314
self._reset_run_state()
315+
self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
314316

315-
run_readiness_check(self.single_column_configs, self._resource_provider)
317+
run_readiness_check(
318+
self.single_column_configs,
319+
self._resource_provider,
320+
client_concurrency_mode=ClientConcurrencyMode.ASYNC if self._use_async else ClientConcurrencyMode.SYNC,
321+
)
316322

317323
# For IF_POSSIBLE and ALWAYS: check config compatibility before touching the artifact
318324
# directory. _check_resume_config_compatibility() must NOT access base_dataset_path
@@ -382,7 +388,6 @@ def build(
382388
"start a new generation run."
383389
)
384390

385-
self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
386391
if self._use_async:
387392
self._build_async(generators, num_records, buffer_size, on_batch_complete, resume=resume)
388393
elif resume == ResumeMode.ALWAYS:
@@ -653,7 +658,12 @@ def _build_with_resume(
653658

654659
def build_preview(self, *, num_records: int) -> pd.DataFrame:
655660
self._reset_run_state()
656-
run_readiness_check(self.single_column_configs, self._resource_provider)
661+
self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
662+
run_readiness_check(
663+
self.single_column_configs,
664+
self._resource_provider,
665+
client_concurrency_mode=ClientConcurrencyMode.ASYNC if self._use_async else ClientConcurrencyMode.SYNC,
666+
)
657667

658668
# Set media storage to DATAFRAME mode for preview - base64 stored directly in DataFrame
659669
if self._has_image_columns():
@@ -662,7 +672,6 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame:
662672
generators, self._graph = self._initialize_generators_and_graph()
663673
start_time = time.perf_counter()
664674

665-
self._use_async = flags.DATA_DESIGNER_ASYNC_ENGINE and self._resolve_async_compatibility()
666675
if self._use_async:
667676
dataset = self._build_async_preview(generators, num_records)
668677
else:

packages/data-designer-engine/src/data_designer/engine/readiness.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from collections.abc import Sequence
2626
from typing import TYPE_CHECKING
2727

28-
from data_designer.engine import flags
2928
from data_designer.engine.column_generators.utils.generator_classification import column_type_is_model_generated
3029
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
30+
from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode
3131

3232
if TYPE_CHECKING:
3333
from data_designer.config.column_types import ColumnConfigT
@@ -42,12 +42,14 @@
4242
def run_readiness_check(
4343
column_configs: Sequence[ColumnConfigT],
4444
resource_provider: ResourceProvider,
45+
*,
46+
client_concurrency_mode: ClientConcurrencyMode,
4547
) -> None:
4648
"""Probe every model and MCP tool referenced by ``column_configs``.
4749
4850
For each unique model alias collected from the column configs,
49-
``ModelRegistry.run_health_check`` (or ``arun_health_check`` on the async
50-
engine) sends a tiny ``"Hello!"`` generation. Models whose ``ModelConfig``
51+
``ModelRegistry.run_health_check`` (or ``arun_health_check`` when async
52+
mode is selected) sends a tiny ``"Hello!"`` generation. Models whose ``ModelConfig``
5153
has ``skip_health_check=True`` are skipped by the registry. After the
5254
model pass, every unique MCP tool alias is probed via
5355
``MCPRegistry.run_health_check``.
@@ -58,6 +60,7 @@ def run_readiness_check(
5860
resource_provider: Provides access to the model registry and MCP
5961
registry. ``mcp_registry`` may be ``None`` only if no tool
6062
aliases are referenced.
63+
client_concurrency_mode: Resolved client mode for this run.
6164
6265
Raises:
6366
Typed model errors from ``data_designer.engine.models.errors`` for
@@ -67,13 +70,15 @@ def run_readiness_check(
6770
TimeoutError: If async health-check execution exceeds
6871
``_MODEL_HEALTH_CHECK_TIMEOUT_SECONDS``.
6972
"""
70-
_run_model_health_check(column_configs, resource_provider)
73+
_run_model_health_check(column_configs, resource_provider, client_concurrency_mode=client_concurrency_mode)
7174
_run_mcp_tool_health_check(column_configs, resource_provider)
7275

7376

7477
def _run_model_health_check(
7578
column_configs: Sequence[ColumnConfigT],
7679
resource_provider: ResourceProvider,
80+
*,
81+
client_concurrency_mode: ClientConcurrencyMode,
7782
) -> None:
7883
model_aliases: set[str] = set()
7984
for config in column_configs:
@@ -82,10 +87,9 @@ def _run_model_health_check(
8287
if not model_aliases:
8388
return
8489

85-
if flags.DATA_DESIGNER_ASYNC_ENGINE:
90+
if client_concurrency_mode == ClientConcurrencyMode.ASYNC:
8691
# Defer the async-engine imports to here so users on the legacy sync
87-
# engine never pay the import cost. Mirrors the gating in
88-
# ``dataset_builders.dataset_builder``.
92+
# engine never pay the import cost.
8993
import asyncio
9094

9195
from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop

packages/data-designer-engine/tests/engine/test_readiness.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from data_designer.engine import flags
1818
from data_designer.engine.dataset_builders.errors import DatasetGenerationError
1919
from data_designer.engine.mcp.registry import MCPRegistry
20+
from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode
2021
from data_designer.engine.readiness import run_readiness_check
22+
from data_designer.engine.resources.resource_provider import ResourceProvider
2123

2224

2325
@pytest.fixture(autouse=True)
@@ -52,6 +54,14 @@ def _build_columns(
5254
return builder.build().columns
5355

5456

57+
def _run_sync_readiness(column_configs: Sequence[ColumnConfigT], resource_provider: ResourceProvider) -> None:
58+
run_readiness_check(
59+
column_configs,
60+
resource_provider,
61+
client_concurrency_mode=ClientConcurrencyMode.SYNC,
62+
)
63+
64+
5565
# ---------------------------------------------------------------------------
5666
# Model health check
5767
# ---------------------------------------------------------------------------
@@ -82,7 +92,7 @@ def _gen_with_two_models(row, generator_params, models):
8292

8393
builder.add_column(CustomColumnConfig(name="custom_col", generator_function=_gen_with_two_models))
8494

85-
run_readiness_check(builder.build().columns, stub_resource_provider)
95+
_run_sync_readiness(builder.build().columns, stub_resource_provider)
8696

8797
stub_resource_provider.model_registry.run_health_check.assert_called_once()
8898
(called_aliases,), _ = stub_resource_provider.model_registry.run_health_check.call_args
@@ -99,7 +109,7 @@ def test_run_readiness_check_skips_model_probe_when_no_aliases(
99109

100110
columns = _build_columns(model_configs=stub_model_configs, llm_columns=[])
101111

102-
run_readiness_check(columns, stub_resource_provider)
112+
_run_sync_readiness(columns, stub_resource_provider)
103113

104114
stub_resource_provider.model_registry.run_health_check.assert_not_called()
105115

@@ -117,7 +127,7 @@ def test_run_readiness_check_propagates_model_probe_error(
117127
columns = _build_columns(model_configs=stub_model_configs, llm_columns=[("col", "stub-text")])
118128

119129
with pytest.raises(ModelAuthenticationError, match="bad creds"):
120-
run_readiness_check(columns, stub_resource_provider)
130+
_run_sync_readiness(columns, stub_resource_provider)
121131

122132

123133
# ---------------------------------------------------------------------------
@@ -142,7 +152,7 @@ def test_run_readiness_check_collects_unique_sorted_tool_aliases(
142152
LLMTextColumnConfig(name="c", prompt="x", model_alias="stub-text", tool_alias="alpha") # duplicate
143153
)
144154

145-
run_readiness_check(builder.build().columns, stub_resource_provider)
155+
_run_sync_readiness(builder.build().columns, stub_resource_provider)
146156

147157
mock_mcp_registry.run_health_check.assert_called_once_with(["alpha", "zebra"])
148158

@@ -158,7 +168,7 @@ def test_run_readiness_check_skips_tool_probe_when_no_tool_aliases(
158168

159169
columns = _build_columns(model_configs=stub_model_configs, llm_columns=[("col", "stub-text")])
160170

161-
run_readiness_check(columns, stub_resource_provider)
171+
_run_sync_readiness(columns, stub_resource_provider)
162172

163173
mock_mcp_registry.run_health_check.assert_not_called()
164174

@@ -176,7 +186,7 @@ def test_run_readiness_check_raises_when_tools_referenced_but_no_mcp_registry(
176186
builder.add_column(LLMTextColumnConfig(name="col", prompt="x", model_alias="stub-text", tool_alias="missing-tools"))
177187

178188
with pytest.raises(DatasetGenerationError, match="missing-tools"):
179-
run_readiness_check(builder.build().columns, stub_resource_provider)
189+
_run_sync_readiness(builder.build().columns, stub_resource_provider)
180190

181191

182192
def test_run_readiness_check_propagates_tool_probe_error(
@@ -194,7 +204,7 @@ def test_run_readiness_check_propagates_tool_probe_error(
194204
builder.add_column(LLMTextColumnConfig(name="col", prompt="x", model_alias="stub-text", tool_alias="tools"))
195205

196206
with pytest.raises(RuntimeError, match="mcp down"):
197-
run_readiness_check(builder.build().columns, stub_resource_provider)
207+
_run_sync_readiness(builder.build().columns, stub_resource_provider)
198208

199209

200210
# ---------------------------------------------------------------------------
@@ -218,7 +228,7 @@ def test_run_readiness_check_runs_models_before_tools(
218228
builder.add_column(LLMTextColumnConfig(name="col", prompt="x", model_alias="stub-text", tool_alias="tools"))
219229

220230
with pytest.raises(ModelAuthenticationError):
221-
run_readiness_check(builder.build().columns, stub_resource_provider)
231+
_run_sync_readiness(builder.build().columns, stub_resource_provider)
222232

223233
# The MCP probe must not have been reached.
224234
mock_mcp_registry.run_health_check.assert_not_called()
@@ -235,7 +245,7 @@ def test_run_readiness_check_no_models_no_tools_is_noop(
235245

236246
columns = _build_columns(model_configs=stub_model_configs, llm_columns=[])
237247

238-
run_readiness_check(columns, stub_resource_provider)
248+
_run_sync_readiness(columns, stub_resource_provider)
239249

240250
stub_resource_provider.model_registry.run_health_check.assert_not_called()
241251
mock_mcp_registry.run_health_check.assert_not_called()
@@ -266,7 +276,7 @@ def test_run_readiness_check_collects_image_model_aliases(
266276
builder.add_column(LLMTextColumnConfig(name="caption", prompt="x", model_alias="stub-text"))
267277
builder.add_column(ImageColumnConfig(name="picture", prompt="y", model_alias="stub-image"))
268278

269-
run_readiness_check(builder.build().columns, stub_resource_provider)
279+
_run_sync_readiness(builder.build().columns, stub_resource_provider)
270280

271281
stub_resource_provider.model_registry.run_health_check.assert_called_once()
272282
(called_aliases,), _ = stub_resource_provider.model_registry.run_health_check.call_args
@@ -292,7 +302,7 @@ def test_run_readiness_check_passes_skip_flagged_aliases_to_registry(
292302
llm_columns=[("col", "stub-text")],
293303
)
294304

295-
run_readiness_check(columns, stub_resource_provider)
305+
_run_sync_readiness(columns, stub_resource_provider)
296306

297307
stub_resource_provider.model_registry.run_health_check.assert_called_once_with(["stub-text"])
298308

@@ -331,7 +341,11 @@ def test_run_readiness_check_dispatches_to_async_registry_under_async_engine(
331341
patch("data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop"),
332342
patch("asyncio.run_coroutine_threadsafe", return_value=sentinel_future) as mock_submit,
333343
):
334-
run_readiness_check(columns, stub_resource_provider)
344+
run_readiness_check(
345+
columns,
346+
stub_resource_provider,
347+
client_concurrency_mode=ClientConcurrencyMode.ASYNC,
348+
)
335349

336350
# The async coroutine was created from arun_health_check and submitted to the loop.
337351
stub_resource_provider.model_registry.arun_health_check.assert_called_once_with(["stub-text"])
@@ -362,6 +376,32 @@ def test_run_readiness_check_cancels_future_and_reraises_on_timeout(
362376
patch("asyncio.run_coroutine_threadsafe", return_value=sentinel_future),
363377
pytest.raises(TimeoutError),
364378
):
365-
run_readiness_check(columns, stub_resource_provider)
379+
run_readiness_check(
380+
columns,
381+
stub_resource_provider,
382+
client_concurrency_mode=ClientConcurrencyMode.ASYNC,
383+
)
366384

367385
sentinel_future.cancel.assert_called_once()
386+
387+
388+
def test_run_readiness_check_uses_sync_registry_for_sync_mode_clients(
389+
stub_resource_provider,
390+
stub_model_configs,
391+
monkeypatch: pytest.MonkeyPatch,
392+
) -> None:
393+
"""Readiness follows the explicit client mode, not only the raw async env flag."""
394+
monkeypatch.setattr(flags, "DATA_DESIGNER_ASYNC_ENGINE", True)
395+
stub_resource_provider.model_registry.run_health_check = Mock()
396+
stub_resource_provider.model_registry.arun_health_check = Mock()
397+
stub_resource_provider.mcp_registry = None
398+
399+
columns = _build_columns(
400+
model_configs=stub_model_configs,
401+
llm_columns=[("col", "stub-text")],
402+
)
403+
404+
run_readiness_check(columns, stub_resource_provider, client_concurrency_mode=ClientConcurrencyMode.SYNC)
405+
406+
stub_resource_provider.model_registry.run_health_check.assert_called_once_with(["stub-text"])
407+
stub_resource_provider.model_registry.arun_health_check.assert_not_called()

packages/data-designer/src/data_designer/interface/data_designer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,11 @@ def check_models(self, config_builder: DataDesignerConfigBuilder) -> None:
578578
TimeoutError: If async health-check execution exceeds 180 seconds.
579579
"""
580580
resource_provider = self._create_resource_provider("check-models", config_builder)
581-
run_readiness_check(config_builder.build().columns, resource_provider)
581+
run_readiness_check(
582+
config_builder.build().columns,
583+
resource_provider,
584+
client_concurrency_mode=self._resolve_client_concurrency_mode(config_builder),
585+
)
582586

583587
def get_default_model_configs(self) -> list[ModelConfig]:
584588
"""Get the default model configurations.

packages/data-designer/tests/interface/test_data_designer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
HuggingFaceSeedSource,
4040
)
4141
from data_designer.engine import flags
42+
from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode
4243
from data_designer.engine.resources.seed_reader import (
4344
FileSystemSeedReader,
4445
SeedReaderError,
@@ -1544,9 +1545,33 @@ def test_check_models_invokes_readiness_check(
15441545
data_designer.check_models(config_builder)
15451546

15461547
assert mock_check.call_count == 1
1547-
(called_columns, called_resource_provider), _ = mock_check.call_args
1548+
(called_columns, called_resource_provider), kwargs = mock_check.call_args
15481549
assert [c.name for c in called_columns] == ["text"]
15491550
assert called_resource_provider is not None
1551+
assert kwargs["client_concurrency_mode"] == ClientConcurrencyMode.ASYNC
1552+
1553+
1554+
def test_check_models_passes_sync_mode_for_sync_fallback(
1555+
stub_artifact_path,
1556+
stub_model_providers,
1557+
stub_managed_assets_path,
1558+
monkeypatch: pytest.MonkeyPatch,
1559+
):
1560+
"""check_models readiness uses the resolved client mode, including allow_resize fallback."""
1561+
monkeypatch.setattr(flags, "DATA_DESIGNER_ASYNC_ENGINE", True)
1562+
config_builder = _builder_with_allow_resize()
1563+
data_designer = DataDesigner(
1564+
artifact_path=stub_artifact_path,
1565+
model_providers=stub_model_providers,
1566+
secret_resolver=PlaintextResolver(),
1567+
managed_assets_path=stub_managed_assets_path,
1568+
)
1569+
1570+
with patch("data_designer.interface.data_designer.run_readiness_check") as mock_check:
1571+
data_designer.check_models(config_builder)
1572+
1573+
_, kwargs = mock_check.call_args
1574+
assert kwargs["client_concurrency_mode"] == ClientConcurrencyMode.SYNC
15501575

15511576

15521577
def test_check_models_propagates_typed_model_error(

0 commit comments

Comments
 (0)