Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self:
selection_strategy=seed_config.selection_strategy,
)

for processor in data_designer_config.processors or []:
builder.add_processor(processor)

for profiler in data_designer_config.profilers or []:
builder.add_profiler(profiler)

return builder

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
InvalidColumnTypeError,
InvalidConfigError,
)
from data_designer.config.mcp import ToolConfig
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
from data_designer.config.processors import DropColumnsProcessorConfig, SchemaTransformProcessorConfig
from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint
from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams
from data_designer.config.seed import SamplingStrategy
from data_designer.config.seed import IndexRange, SamplingStrategy
from data_designer.config.seed_source import HuggingFaceSeedSource
from data_designer.config.seed_source_dataframe import DataFrameSeedSource
from data_designer.config.utils.code_lang import CodeLang
Expand Down Expand Up @@ -97,6 +98,87 @@ def test_from_config(stub_data_designer_builder_config_str):
assert isinstance(builder_from_object.get_column_config(name="code_id"), SamplerColumnConfig)


def test_from_config_restores_processors_and_profilers(stub_model_configs: list[ModelConfig]) -> None:
builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
builder.add_column(SamplerColumnConfig(name="test_id", sampler_type="uuid", params=UUIDSamplerParams()))
processor_config = SchemaTransformProcessorConfig(name="records", template={"id": "{{ test_id }}"})
profiler_config = JudgeScoreProfilerConfig(model_alias="stub-model", summary_score_sample_size=5)

builder.add_processor(processor_config)
builder.add_profiler(profiler_config)

loaded_builder = DataDesignerConfigBuilder.from_config(builder.get_builder_config())

assert loaded_builder.get_processor_configs() == [processor_config]
assert loaded_builder.get_profilers() == [profiler_config]


def test_from_config_restores_drop_column_processor_side_effects(stub_model_configs: list[ModelConfig]) -> None:
builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
for column_name in ("col_a", "col_b", "other"):
builder.add_column(SamplerColumnConfig(name=column_name, sampler_type="uuid", params=UUIDSamplerParams()))
processor_config = DropColumnsProcessorConfig(name="cleanup", column_names=["col_*"])

builder.add_processor(processor_config)

loaded_builder = DataDesignerConfigBuilder.from_config(builder.get_builder_config())

assert loaded_builder.get_processor_configs() == [processor_config]
assert loaded_builder.get_column_config("col_a").drop is True
assert loaded_builder.get_column_config("col_b").drop is True
assert loaded_builder.get_column_config("other").drop is False


def test_from_config_round_trips_all_data_designer_config_fields(stub_model_configs: list[ModelConfig]) -> None:
data_designer_config_fields = set(DataDesignerConfig.model_fields)

builder = DataDesignerConfigBuilder(
model_configs=stub_model_configs,
tool_configs=[
ToolConfig(
tool_alias="lookup-tools",
providers=["local-tools"],
allow_tools=["lookup"],
max_tool_call_turns=2,
timeout_sec=1.0,
)
],
)
builder.with_seed_dataset(
HuggingFaceSeedSource(path="datasets/test-repo/data/*.parquet", token="test-token"),
sampling_strategy=SamplingStrategy.SHUFFLE,
selection_strategy=IndexRange(start=1, end=3),
)
builder.add_column(
name="age",
column_type=DataDesignerColumnType.SAMPLER,
sampler_type=SamplerType.UNIFORM,
params={"low": 1, "high": 100},
)
builder.add_column(
name="height",
column_type=DataDesignerColumnType.SAMPLER,
sampler_type=SamplerType.UNIFORM,
params={"low": 15, "high": 200},
)
builder.add_column(
SamplerColumnConfig(name="internal_id", sampler_type=SamplerType.UUID, params=UUIDSamplerParams())
)
builder.add_constraint(ScalarInequalityConstraint(target_column="age", operator="gt", rhs=18))
builder.add_constraint(ColumnInequalityConstraint(target_column="height", operator="gt", rhs="age"))
builder.add_processor(SchemaTransformProcessorConfig(name="records", template={"age": "{{ age }}"}))
builder.add_processor(DropColumnsProcessorConfig(name="cleanup", column_names=["internal_*"]))
builder.add_profiler(JudgeScoreProfilerConfig(model_alias="stub-model", summary_score_sample_size=5))

original_config = builder.build()
loaded_config = DataDesignerConfigBuilder.from_config(BuilderConfig(data_designer=original_config)).build()

assert set(original_config.model_fields_set) == data_designer_config_fields
assert set(loaded_config.model_fields_set) == data_designer_config_fields
for field_name in sorted(data_designer_config_fields):
assert getattr(loaded_config, field_name) == getattr(original_config, field_name)


def test_from_config_auto_wraps_bare_data_designer_config(stub_data_designer_config_str: str) -> None:
"""Test that from_config auto-wraps a bare DataDesignerConfig (no 'data_designer' wrapper)."""
builder = DataDesignerConfigBuilder.from_config(config=stub_data_designer_config_str)
Expand Down
Loading