Skip to content

Commit 52aecfa

Browse files
authored
fix: exclude df from seed source serialization (#193)
* exclude from serialization * add unit test * should -> must
1 parent 7b5ea13 commit 52aecfa

2 files changed

Lines changed: 25 additions & 3 deletions

File tree

src/data_designer/config/seed_source.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ class HuggingFaceSeedSource(SeedSource):
5353

5454
path: str = Field(
5555
...,
56-
description="Path to the seed data in HuggingFace. Wildcards are allowed. Examples include 'datasets/my-username/my-dataset/data/000_00000.parquet', 'datasets/my-username/my-dataset/data/*.parquet', 'datasets/my-username/my-dataset/**/*.parquet'",
56+
description=(
57+
"Path to the seed data in HuggingFace. Wildcards are allowed. Examples include "
58+
"'datasets/my-username/my-dataset/data/000_00000.parquet', 'datasets/my-username/my-dataset/data/*.parquet', "
59+
"and 'datasets/my-username/my-dataset/**/*.parquet'"
60+
),
5761
)
5862
token: str | None = None
5963
endpoint: str = "https://huggingface.co"
@@ -64,7 +68,14 @@ class DataFrameSeedSource(SeedSource):
6468

6569
model_config = ConfigDict(arbitrary_types_allowed=True)
6670

67-
df: pd.DataFrame
71+
df: pd.DataFrame = Field(
72+
...,
73+
exclude=True,
74+
description=(
75+
"DataFrame to use directly as the seed dataset. NOTE: if you need to write a Data Designer config, "
76+
"you must use `LocalFileSeedSource` instead, since DataFrame objects are not serializable."
77+
),
78+
)
6879

6980

7081
SeedSourceT = Annotated[

tests/config/test_seed_source.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from data_designer.config.errors import InvalidFilePathError
10-
from data_designer.config.seed_source import LocalFileSeedSource
10+
from data_designer.config.seed_source import DataFrameSeedSource, LocalFileSeedSource
1111

1212

1313
def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path:
@@ -59,3 +59,14 @@ def test_local_source_from_dataframe(tmp_path: Path):
5959

6060
assert source.path == filepath
6161
pd.testing.assert_frame_equal(df, pd.read_parquet(filepath))
62+
63+
64+
def test_dataframe_seed_source_serialization():
65+
"""Test that DataFrameSeedSource excludes the DataFrame field during serialization."""
66+
df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})
67+
source = DataFrameSeedSource(df=df)
68+
69+
# Test model_dump excludes the df field
70+
serialized = source.model_dump(mode="json")
71+
assert "df" not in serialized
72+
assert serialized == {"seed_type": "df"}

0 commit comments

Comments
 (0)