Skip to content

Commit 28e4c29

Browse files
committed
test(config): cover from_config round trips
Signed-off-by: Johnny Greco <jogreco@nvidia.com>
1 parent 3630301 commit 28e4c29

1 file changed

Lines changed: 68 additions & 1 deletion

File tree

packages/data-designer-config/tests/config/test_config_builder.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@
3333
InvalidColumnTypeError,
3434
InvalidConfigError,
3535
)
36+
from data_designer.config.mcp import ToolConfig
3637
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
3738
from data_designer.config.processors import DropColumnsProcessorConfig, SchemaTransformProcessorConfig
3839
from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint
3940
from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams
40-
from data_designer.config.seed import SamplingStrategy
41+
from data_designer.config.seed import IndexRange, SamplingStrategy
4142
from data_designer.config.seed_source import HuggingFaceSeedSource
4243
from data_designer.config.seed_source_dataframe import DataFrameSeedSource
4344
from data_designer.config.utils.code_lang import CodeLang
@@ -112,6 +113,72 @@ def test_from_config_restores_processors_and_profilers(stub_model_configs: list[
112113
assert loaded_builder.get_profilers() == [profiler_config]
113114

114115

116+
def test_from_config_restores_drop_column_processor_side_effects(stub_model_configs: list[ModelConfig]) -> None:
117+
builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
118+
for column_name in ("col_a", "col_b", "other"):
119+
builder.add_column(SamplerColumnConfig(name=column_name, sampler_type="uuid", params=UUIDSamplerParams()))
120+
processor_config = DropColumnsProcessorConfig(name="cleanup", column_names=["col_*"])
121+
122+
builder.add_processor(processor_config)
123+
124+
loaded_builder = DataDesignerConfigBuilder.from_config(builder.get_builder_config())
125+
126+
assert loaded_builder.get_processor_configs() == [processor_config]
127+
assert loaded_builder.get_column_config("col_a").drop is True
128+
assert loaded_builder.get_column_config("col_b").drop is True
129+
assert loaded_builder.get_column_config("other").drop is False
130+
131+
132+
def test_from_config_round_trips_all_data_designer_config_fields(stub_model_configs: list[ModelConfig]) -> None:
133+
data_designer_config_fields = set(DataDesignerConfig.model_fields)
134+
135+
builder = DataDesignerConfigBuilder(
136+
model_configs=stub_model_configs,
137+
tool_configs=[
138+
ToolConfig(
139+
tool_alias="lookup-tools",
140+
providers=["local-tools"],
141+
allow_tools=["lookup"],
142+
max_tool_call_turns=2,
143+
timeout_sec=1.0,
144+
)
145+
],
146+
)
147+
builder.with_seed_dataset(
148+
HuggingFaceSeedSource(path="datasets/test-repo/data/*.parquet", token="test-token"),
149+
sampling_strategy=SamplingStrategy.SHUFFLE,
150+
selection_strategy=IndexRange(start=1, end=3),
151+
)
152+
builder.add_column(
153+
name="age",
154+
column_type=DataDesignerColumnType.SAMPLER,
155+
sampler_type=SamplerType.UNIFORM,
156+
params={"low": 1, "high": 100},
157+
)
158+
builder.add_column(
159+
name="height",
160+
column_type=DataDesignerColumnType.SAMPLER,
161+
sampler_type=SamplerType.UNIFORM,
162+
params={"low": 15, "high": 200},
163+
)
164+
builder.add_column(
165+
SamplerColumnConfig(name="internal_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())
166+
)
167+
builder.add_constraint(ScalarInequalityConstraint(target_column="age", operator="gt", rhs=18))
168+
builder.add_constraint(ColumnInequalityConstraint(target_column="height", operator="gt", rhs="age"))
169+
builder.add_processor(SchemaTransformProcessorConfig(name="records", template={"age": "{{ age }}"}))
170+
builder.add_processor(DropColumnsProcessorConfig(name="cleanup", column_names=["internal_*"]))
171+
builder.add_profiler(JudgeScoreProfilerConfig(model_alias="stub-model", summary_score_sample_size=5))
172+
173+
original_config = builder.build()
174+
loaded_config = DataDesignerConfigBuilder.from_config(BuilderConfig(data_designer=original_config)).build()
175+
176+
assert set(original_config.model_fields_set) == data_designer_config_fields
177+
assert set(loaded_config.model_fields_set) == data_designer_config_fields
178+
for field_name in sorted(data_designer_config_fields):
179+
assert getattr(loaded_config, field_name) == getattr(original_config, field_name)
180+
181+
115182
def test_from_config_auto_wraps_bare_data_designer_config(stub_data_designer_config_str: str) -> None:
116183
"""Test that from_config auto-wraps a bare DataDesignerConfig (no 'data_designer' wrapper)."""
117184
builder = DataDesignerConfigBuilder.from_config(config=stub_data_designer_config_str)

0 commit comments

Comments
 (0)