diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 495e6a3f8cc..d91397ff00c 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -95,7 +95,7 @@ class BuilderConfig: """ name: str = "default" - version: Optional[Union[str, utils.Version]] = "0.0.0" + version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0") data_dir: Optional[str] = None data_files: Optional[DataFilesDict] = None description: Optional[str] = None @@ -193,8 +193,8 @@ class DatasetBuilder: pre-defined set of configurations in :meth:`datasets.DatasetBuilder.builder_configs`. """ - # Default version. - VERSION = utils.Version("0.0.0") + # Default version + VERSION = None # Default version set in BuilderConfig # Class for the builder config. BUILDER_CONFIG_CLASS = BuilderConfig diff --git a/src/datasets/commands/dummy_data.py b/src/datasets/commands/dummy_data.py index 91ebcb18a95..76a801e355e 100644 --- a/src/datasets/commands/dummy_data.py +++ b/src/datasets/commands/dummy_data.py @@ -295,14 +295,9 @@ def run(self): auto_generate_results = [] with tempfile.TemporaryDirectory() as tmp_dir: for builder_config in builder_configs: - if builder_config is None: - name = None - version = builder_cls.VERSION - else: - version = builder_config.version - name = builder_config.name - + name = builder_config.name if builder_config else None dataset_builder = builder_cls(name=name, hash=dataset_module.hash, cache_dir=tmp_dir) + version = builder_config.version if builder_config else dataset_builder.config.version mock_dl_manager = MockDownloadManager( dataset_name=self._dataset_name, config=builder_config, diff --git a/tests/test_builder.py b/tests/test_builder.py index 29659590c8d..ecc123b0067 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -756,3 +756,66 @@ def test_custom_writer_batch_size(tmp_path, writer_batch_size, default_writer_ba builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.FORCE_REDOWNLOAD) dataset = builder.as_dataset("train") assert len(dataset.data[0].chunks) == expected_chunks + + +class DummyBuilderWithVersion(GeneratorBasedBuilder): + VERSION = "2.0.0" + + def _info(self): + return DatasetInfo(features=Features({"text": Value("string")})) + + def _split_generators(self, dl_manager): + pass + + def _generate_examples(self): + pass + + +class DummyBuilderWithBuilderConfigs(GeneratorBasedBuilder): + BUILDER_CONFIGS = [BuilderConfig(name="custom", version="2.0.0")] + + def _info(self): + return DatasetInfo(features=Features({"text": Value("string")})) + + def _split_generators(self, dl_manager): + pass + + def _generate_examples(self): + pass + + +class CustomBuilderConfig(BuilderConfig): + def __init__(self, date=None, language=None, version="2.0.0", **kwargs): + name = f"{date}.{language}" + super().__init__(name=name, version=version, **kwargs) + self.date = date + self.language = language + + +class DummyBuilderWithCustomBuilderConfigs(GeneratorBasedBuilder): + BUILDER_CONFIGS = [CustomBuilderConfig(date="20220501", language="en")] + BUILDER_CONFIG_CLASS = CustomBuilderConfig + + def _info(self): + return DatasetInfo(features=Features({"text": Value("string")})) + + def _split_generators(self, dl_manager): + pass + + def _generate_examples(self): + pass + + +@pytest.mark.parametrize( + "builder_class, kwargs", + [ + (DummyBuilderWithVersion, {}), + (DummyBuilderWithBuilderConfigs, {"name": "custom"}), + (DummyBuilderWithCustomBuilderConfigs, {"name": "20220501.en"}), + (DummyBuilderWithCustomBuilderConfigs, {"date": "20220501", "language": "ca"}), + ], +) +def test_builder_config_version(builder_class, kwargs, tmp_path): + cache_dir = str(tmp_path) + builder = builder_class(cache_dir=cache_dir, **kwargs) + assert builder.config.version == "2.0.0" diff --git a/tests/test_dataset_common.py b/tests/test_dataset_common.py index 55986c8f432..e7fb64e2ea0 100644 --- a/tests/test_dataset_common.py +++ b/tests/test_dataset_common.py @@ -136,10 +136,7 @@ def check_load_dataset(self, dataset_name, configs, is_local=False, use_local_du logger.info("Skip tests for this dataset for now") return - if config is not None: - version = config.version - else: - version = dataset_builder.VERSION + version = config.version if config else dataset_builder.config.version def check_if_url_is_valid(url): if is_remote_url(url) and "\\" in url: