Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataset builder default version #4356

Merged
6 changes: 3 additions & 3 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class BuilderConfig:
"""

name: str = "default"
version: Optional[Union[str, utils.Version]] = "0.0.0"
version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0")
data_dir: Optional[str] = None
data_files: Optional[DataFilesDict] = None
description: Optional[str] = None
Expand Down Expand Up @@ -193,8 +193,8 @@ class DatasetBuilder:
pre-defined set of configurations in :meth:`datasets.DatasetBuilder.builder_configs`.
"""

# Default version.
VERSION = utils.Version("0.0.0")
# Default version
VERSION = None # Default version set in BuilderConfig

# Class for the builder config.
BUILDER_CONFIG_CLASS = BuilderConfig
Expand Down
9 changes: 2 additions & 7 deletions src/datasets/commands/dummy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,9 @@ def run(self):
auto_generate_results = []
with tempfile.TemporaryDirectory() as tmp_dir:
for builder_config in builder_configs:
if builder_config is None:
name = None
version = builder_cls.VERSION
else:
version = builder_config.version
name = builder_config.name

name = builder_config.name if builder_config else None
dataset_builder = builder_cls(name=name, hash=dataset_module.hash, cache_dir=tmp_dir)
version = builder_config.version if builder_config else dataset_builder.config.version
mock_dl_manager = MockDownloadManager(
dataset_name=self._dataset_name,
config=builder_config,
Expand Down
63 changes: 63 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,3 +756,66 @@ def test_custom_writer_batch_size(tmp_path, writer_batch_size, default_writer_ba
builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD)
dataset = builder.as_dataset("train")
assert len(dataset.data[0].chunks) == expected_chunks


class DummyBuilderWithVersion(GeneratorBasedBuilder):
VERSION = "2.0.0"

def _info(self):
return DatasetInfo(features=Features({"text": Value("string")}))

def _split_generators(self, dl_manager):
pass

def _generate_examples(self):
pass


class DummyBuilderWithBuilderConfigs(GeneratorBasedBuilder):
BUILDER_CONFIGS = [BuilderConfig(name="custom", version="2.0.0")]

def _info(self):
return DatasetInfo(features=Features({"text": Value("string")}))

def _split_generators(self, dl_manager):
pass

def _generate_examples(self):
pass


class CustomBuilderConfig(BuilderConfig):
def __init__(self, date=None, language=None, version="2.0.0", **kwargs):
name = f"{date}.{language}"
super().__init__(name=name, version=version, **kwargs)
self.date = date
self.language = language


class DummyBuilderWithCustomBuilderConfigs(GeneratorBasedBuilder):
BUILDER_CONFIGS = [CustomBuilderConfig(date="20220501", language="en")]
BUILDER_CONFIG_CLASS = CustomBuilderConfig

def _info(self):
return DatasetInfo(features=Features({"text": Value("string")}))

def _split_generators(self, dl_manager):
pass

def _generate_examples(self):
pass


@pytest.mark.parametrize(
"builder_class, kwargs",
[
(DummyBuilderWithVersion, {}),
(DummyBuilderWithBuilderConfigs, {"name": "custom"}),
(DummyBuilderWithCustomBuilderConfigs, {"name": "20220501.en"}),
(DummyBuilderWithCustomBuilderConfigs, {"date": "20220501", "language": "ca"}),
],
)
def test_builder_config_version(builder_class, kwargs, tmp_path):
cache_dir = str(tmp_path)
builder = builder_class(cache_dir=cache_dir, **kwargs)
assert builder.config.version == "2.0.0"
5 changes: 1 addition & 4 deletions tests/test_dataset_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ def check_load_dataset(self, dataset_name, configs, is_local=False, use_local_du
logger.info("Skip tests for this dataset for now")
return

if config is not None:
version = config.version
else:
version = dataset_builder.VERSION
version = config.version if config else dataset_builder.config.version

def check_if_url_is_valid(url):
if is_remote_url(url) and "\\" in url:
Expand Down