Skip to content

Commit 5615847

Browse files
committed
fix: block model config writes when legacy aliases omit provider
Raise LegacyModelConfigMigrationError from ModelRepository.load() when on-disk model_configs.yaml entries lack a required provider field, so CLI and agent flows fail fast instead of treating the file as empty and overwriting legacy aliases on add.
1 parent 62e111d commit 5615847

5 files changed

Lines changed: 110 additions & 6 deletions

File tree

packages/data-designer/src/data_designer/cli/controllers/model_controller.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import TYPE_CHECKING
88

99
from data_designer.cli.forms.model_builder import ModelFormBuilder
10-
from data_designer.cli.repositories.model_repository import ModelRepository
10+
from data_designer.cli.repositories.model_repository import LegacyModelConfigMigrationError, ModelRepository
1111
from data_designer.cli.repositories.provider_repository import ProviderRepository
1212
from data_designer.cli.services.model_service import ModelService
1313
from data_designer.cli.services.provider_service import ProviderService
@@ -54,7 +54,11 @@ def run(self) -> None:
5454
console.print()
5555

5656
# Check for existing configuration
57-
models = self.model_service.list_all()
57+
try:
58+
models = self.model_service.list_all()
59+
except LegacyModelConfigMigrationError as e:
60+
print_error(str(e))
61+
return
5862

5963
if models:
6064
self._show_existing_config()

packages/data-designer/src/data_designer/cli/repositories/model_repository.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
from pathlib import Path
7+
from typing import Any
78

89
from pydantic import BaseModel
910

@@ -13,12 +14,41 @@
1314
from data_designer.config.utils.io_helpers import load_config_file, save_config_file
1415

1516

17+
class LegacyModelConfigMigrationError(ValueError):
18+
"""Raised when on-disk model configs omit the required ``provider`` field."""
19+
20+
1621
class ModelConfigRegistry(BaseModel):
1722
"""Registry for model configurations."""
1823

1924
model_configs: list[ModelConfig]
2025

2126

27+
def _aliases_missing_provider(config_dict: Any) -> list[str]:
28+
if not isinstance(config_dict, dict):
29+
return []
30+
entries = config_dict.get("model_configs")
31+
if not isinstance(entries, list):
32+
return []
33+
missing: list[str] = []
34+
for entry in entries:
35+
if not isinstance(entry, dict):
36+
continue
37+
if entry.get("provider") is None:
38+
alias = entry.get("alias")
39+
missing.append(str(alias) if alias is not None else "<unknown>")
40+
return missing
41+
42+
43+
def _format_missing_provider_message(aliases: list[str]) -> str:
44+
alias_list = ", ".join(f"'{alias}'" for alias in aliases)
45+
return (
46+
f"model_configs.yaml contains model alias(es) missing a required 'provider' field: {alias_list}. "
47+
"Add an explicit provider name to each alias before saving changes "
48+
"(edit the file directly or run `data-designer config models` after updating it)."
49+
)
50+
51+
2252
class ModelRepository(ConfigRepository[ModelConfigRegistry]):
2353
"""Repository for model configurations."""
2454

@@ -34,6 +64,14 @@ def load(self) -> ModelConfigRegistry | None:
3464

3565
try:
3666
config_dict = load_config_file(self.config_file)
67+
except Exception:
68+
return None
69+
70+
missing_providers = _aliases_missing_provider(config_dict)
71+
if missing_providers:
72+
raise LegacyModelConfigMigrationError(_format_missing_provider_message(missing_providers))
73+
74+
try:
3775
return ModelConfigRegistry.model_validate(config_dict)
3876
except Exception:
3977
return None

packages/data-designer/src/data_designer/cli/utils/agent_introspection.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import data_designer.config as dd
1414
from data_designer.cli.agent_command_defs import AGENT_COMMANDS
15-
from data_designer.cli.repositories.model_repository import ModelRepository
15+
from data_designer.cli.repositories.model_repository import LegacyModelConfigMigrationError, ModelRepository
1616
from data_designer.cli.repositories.persona_repository import PersonaRepository
1717
from data_designer.cli.repositories.provider_repository import ProviderRepository
1818
from data_designer.cli.services.download_service import DownloadService
@@ -295,7 +295,14 @@ def _get_source_file(cls: type) -> str:
295295
def _load_registry(repo: Any) -> Any:
296296
if not repo.exists():
297297
return None
298-
registry = repo.load()
298+
try:
299+
registry = repo.load()
300+
except LegacyModelConfigMigrationError as e:
301+
raise AgentIntrospectionError(
302+
code="legacy_model_config",
303+
message=str(e),
304+
details={"config_file": str(repo.config_file)},
305+
) from e
299306
if registry is None:
300307
raise AgentIntrospectionError(
301308
code="invalid_registry",

packages/data-designer/tests/cli/repositories/test_model_repository.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33

44
from pathlib import Path
55

6-
from data_designer.cli.repositories.model_repository import ModelConfigRegistry, ModelRepository
6+
import pytest
7+
8+
from data_designer.cli.repositories.model_repository import (
9+
LegacyModelConfigMigrationError,
10+
ModelConfigRegistry,
11+
ModelRepository,
12+
)
713
from data_designer.config.models import ModelConfig
814
from data_designer.config.utils.constants import MODEL_CONFIGS_FILE_NAME
915
from data_designer.config.utils.io_helpers import save_config_file
@@ -34,3 +40,22 @@ def test_save(tmp_path: Path, stub_model_configs: list[ModelConfig]):
3440
repository.save(ModelConfigRegistry(model_configs=stub_model_configs))
3541
assert repository.load() is not None
3642
assert repository.load().model_configs == stub_model_configs
43+
44+
45+
def test_load_legacy_missing_provider_raises(tmp_path: Path) -> None:
46+
model_configs_file_path = tmp_path / MODEL_CONFIGS_FILE_NAME
47+
save_config_file(
48+
model_configs_file_path,
49+
{
50+
"model_configs": [
51+
{
52+
"alias": "legacy-alias",
53+
"model": "test-model",
54+
"inference_parameters": {"generation_type": "chat-completion"},
55+
}
56+
]
57+
},
58+
)
59+
repository = ModelRepository(tmp_path)
60+
with pytest.raises(LegacyModelConfigMigrationError, match="legacy-alias"):
61+
repository.load()

packages/data-designer/tests/cli/services/test_model_service.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55

66
import pytest
77

8-
from data_designer.cli.repositories.model_repository import ModelRepository
8+
from data_designer.cli.repositories.model_repository import LegacyModelConfigMigrationError, ModelRepository
99
from data_designer.cli.services.model_service import ModelService
1010
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
11+
from data_designer.config.utils.constants import MODEL_CONFIGS_FILE_NAME
12+
from data_designer.config.utils.io_helpers import load_config_file, save_config_file
1113

1214

1315
def test_list_all(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]):
@@ -136,3 +138,31 @@ def test_delete_by_aliases_no_registry(tmp_path: Path):
136138
service = ModelService(ModelRepository(tmp_path))
137139
with pytest.raises(ValueError, match="No models configured"):
138140
service.delete_by_aliases(["test-alias-1"])
141+
142+
143+
def test_add_blocks_when_legacy_config_missing_provider(
144+
tmp_path: Path,
145+
stub_new_model_config: ModelConfig,
146+
) -> None:
147+
"""Legacy aliases without provider must block writes instead of overwriting the file."""
148+
model_configs_file_path = tmp_path / MODEL_CONFIGS_FILE_NAME
149+
save_config_file(
150+
model_configs_file_path,
151+
{
152+
"model_configs": [
153+
{
154+
"alias": "legacy-alias",
155+
"model": "legacy-model",
156+
"inference_parameters": {"generation_type": "chat-completion"},
157+
}
158+
]
159+
},
160+
)
161+
162+
service = ModelService(ModelRepository(tmp_path))
163+
with pytest.raises(LegacyModelConfigMigrationError, match="legacy-alias"):
164+
service.add(stub_new_model_config)
165+
166+
saved = load_config_file(model_configs_file_path)
167+
assert saved["model_configs"][0]["alias"] == "legacy-alias"
168+
assert "provider" not in saved["model_configs"][0]

0 commit comments

Comments
 (0)