|
33 | 33 | InvalidColumnTypeError, |
34 | 34 | InvalidConfigError, |
35 | 35 | ) |
| 36 | +from data_designer.config.mcp import ToolConfig |
36 | 37 | from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig |
37 | 38 | from data_designer.config.processors import DropColumnsProcessorConfig, SchemaTransformProcessorConfig |
38 | 39 | from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint |
39 | 40 | from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams |
40 | | -from data_designer.config.seed import SamplingStrategy |
| 41 | +from data_designer.config.seed import IndexRange, SamplingStrategy |
41 | 42 | from data_designer.config.seed_source import HuggingFaceSeedSource |
42 | 43 | from data_designer.config.seed_source_dataframe import DataFrameSeedSource |
43 | 44 | from data_designer.config.utils.code_lang import CodeLang |
@@ -112,6 +113,72 @@ def test_from_config_restores_processors_and_profilers(stub_model_configs: list[ |
112 | 113 | assert loaded_builder.get_profilers() == [profiler_config] |
113 | 114 |
|
114 | 115 |
|
| 116 | +def test_from_config_restores_drop_column_processor_side_effects(stub_model_configs: list[ModelConfig]) -> None: |
| 117 | + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) |
| 118 | + for column_name in ("col_a", "col_b", "other"): |
| 119 | + builder.add_column(SamplerColumnConfig(name=column_name, sampler_type="uuid", params=UUIDSamplerParams())) |
| 120 | + processor_config = DropColumnsProcessorConfig(name="cleanup", column_names=["col_*"]) |
| 121 | + |
| 122 | + builder.add_processor(processor_config) |
| 123 | + |
| 124 | + loaded_builder = DataDesignerConfigBuilder.from_config(builder.get_builder_config()) |
| 125 | + |
| 126 | + assert loaded_builder.get_processor_configs() == [processor_config] |
| 127 | + assert loaded_builder.get_column_config("col_a").drop is True |
| 128 | + assert loaded_builder.get_column_config("col_b").drop is True |
| 129 | + assert loaded_builder.get_column_config("other").drop is False |
| 130 | + |
| 131 | + |
| 132 | +def test_from_config_round_trips_all_data_designer_config_fields(stub_model_configs: list[ModelConfig]) -> None: |
| 133 | + data_designer_config_fields = set(DataDesignerConfig.model_fields) |
| 134 | + |
| 135 | + builder = DataDesignerConfigBuilder( |
| 136 | + model_configs=stub_model_configs, |
| 137 | + tool_configs=[ |
| 138 | + ToolConfig( |
| 139 | + tool_alias="lookup-tools", |
| 140 | + providers=["local-tools"], |
| 141 | + allow_tools=["lookup"], |
| 142 | + max_tool_call_turns=2, |
| 143 | + timeout_sec=1.0, |
| 144 | + ) |
| 145 | + ], |
| 146 | + ) |
| 147 | + builder.with_seed_dataset( |
| 148 | + HuggingFaceSeedSource(path="datasets/test-repo/data/*.parquet", token="test-token"), |
| 149 | + sampling_strategy=SamplingStrategy.SHUFFLE, |
| 150 | + selection_strategy=IndexRange(start=1, end=3), |
| 151 | + ) |
| 152 | + builder.add_column( |
| 153 | + name="age", |
| 154 | + column_type=DataDesignerColumnType.SAMPLER, |
| 155 | + sampler_type=SamplerType.UNIFORM, |
| 156 | + params={"low": 1, "high": 100}, |
| 157 | + ) |
| 158 | + builder.add_column( |
| 159 | + name="height", |
| 160 | + column_type=DataDesignerColumnType.SAMPLER, |
| 161 | + sampler_type=SamplerType.UNIFORM, |
| 162 | + params={"low": 15, "high": 200}, |
| 163 | + ) |
| 164 | + builder.add_column( |
| 165 | + SamplerColumnConfig(name="internal_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams()) |
| 166 | + ) |
| 167 | + builder.add_constraint(ScalarInequalityConstraint(target_column="age", operator="gt", rhs=18)) |
| 168 | + builder.add_constraint(ColumnInequalityConstraint(target_column="height", operator="gt", rhs="age")) |
| 169 | + builder.add_processor(SchemaTransformProcessorConfig(name="records", template={"age": "{{ age }}"})) |
| 170 | + builder.add_processor(DropColumnsProcessorConfig(name="cleanup", column_names=["internal_*"])) |
| 171 | + builder.add_profiler(JudgeScoreProfilerConfig(model_alias="stub-model", summary_score_sample_size=5)) |
| 172 | + |
| 173 | + original_config = builder.build() |
| 174 | + loaded_config = DataDesignerConfigBuilder.from_config(BuilderConfig(data_designer=original_config)).build() |
| 175 | + |
| 176 | + assert set(original_config.model_fields_set) == data_designer_config_fields |
| 177 | + assert set(loaded_config.model_fields_set) == data_designer_config_fields |
| 178 | + for field_name in sorted(data_designer_config_fields): |
| 179 | + assert getattr(loaded_config, field_name) == getattr(original_config, field_name) |
| 180 | + |
| 181 | + |
115 | 182 | def test_from_config_auto_wraps_bare_data_designer_config(stub_data_designer_config_str: str) -> None: |
116 | 183 | """Test that from_config auto-wraps a bare DataDesignerConfig (no 'data_designer' wrapper).""" |
117 | 184 | builder = DataDesignerConfigBuilder.from_config(config=stub_data_designer_config_str) |
|
0 commit comments