From 65b9f70e47e638f8216ffaa7d84b1e5b33f803bb Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Thu, 6 Feb 2025 18:36:52 +0200 Subject: [PATCH 1/7] basic implementatin of saving of croissant metadata for datasets in processing --- .../data/preparator/gpt_memmap/prepare.py | 52 +++++++++++++++++++ setup.cfg | 2 + tests/data/test_porcessing_metadata.py | 48 +++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 tests/data/test_porcessing_metadata.py diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 18c7cfbb..f8d7814e 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,6 +1,9 @@ import json +import logging import multiprocessing import pathlib +import requests +import shutil import typing import datasets @@ -9,12 +12,16 @@ import tqdm import transformers +from huggingface_hub import HfFolder + from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType +logger = logging.getLogger(__name__) + class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig @@ -64,6 +71,49 @@ def _load_dataset(self) -> datasets.Dataset: assert isinstance(dataset, datasets.Dataset) return dataset + def _get_croissant_metadata(self): + url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" + + token = HfFolder.get_token() + try: + if token is None: + response = requests.get(url) + else: + response = requests.get(url, headers={"Authorization": f"Bearer {token}"}) + + if response.status_code != 200: + logger.warning( + f"Failed to get croissant metadata, status_code: {response.status_code}, body: {response.text}" + ) + return None + + data = response.json() + except Exception as e: + logger.warning(f"Failed to get croissant metadata, {e}") + return None + if "error" in data: + logger.warning(f"Failed to get croissant metadata, error: {data['error']}") + return None + + return data + + def _save_croissant_metadata(self): + dataset_path = pathlib.Path(self._config.dataset.path) + dst_croissant_file = pathlib.Path(self._config.output_path) / "croissant.json" + if dataset_path.is_dir(): + croissant_file = dataset_path / "croissant.json" + if croissant_file.is_file(): + shutil.copy(croissant_file, dst_croissant_file) + else: + logger.warning(f"Source local dataset {self._config.dataset.path} does not have croissant file") + return + else: + data = self._get_croissant_metadata() + if data is None: + return + with dst_croissant_file.open("wt") as f: + json.dump(data, f) + def run(self) -> None: # Set transformers logging verbosity transformers.logging.set_verbosity_error() @@ -168,6 +218,8 @@ def run(self) -> None: output_file = self._config.output_path / "fast_llm_dataset.json" json.dump({"datasets": dataset_dicts}, output_file.open("w")) + self._save_croissant_metadata() + # Create an index file on rank 0 index_file = self._config.output_path / "index.txt" index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts]) diff --git a/setup.cfg b/setup.cfg index 351d1bcb..78c8bc99 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,6 +33,7 @@ OPTIONAL = # Huggingface tools transformers>=4.44.2 hf-transfer>=0.1.8 + huggingface-hub>=0.28.1 datasets>=3.1.0 # Weights and biases wandb>=0.17.7 @@ -40,6 +41,7 @@ OPTIONAL = hydra-core>=1.3.2 omegaconf>=2.3.0 # Miscellanous + requests>=2.32.3 tqdm>=4.66.3 DEV = diff --git a/tests/data/test_porcessing_metadata.py b/tests/data/test_porcessing_metadata.py new file mode 100644 index 00000000..e9d13a7a --- /dev/null +++ b/tests/data/test_porcessing_metadata.py @@ -0,0 +1,48 @@ +import tempfile +import pathlib + +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator + + +def get_prep(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: + config = GPTMemmapDatasetPreparatorConfig.from_dict( + { + "output_path": output_path, + "dataset": {"path": dataset_path_name}, + "tokenizer": {"path": "no_tokenizer"}, + }, + {}, + ) + return config.get_dataset_preparator_class()(config=config) + + +def test_existing_metadata_hf_hub_dataset(): + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + prep = get_prep(local_folder, "lhoestq/demo1") + prep._save_croissant_metadata() + assert (pathlib.Path(local_folder) / "croissant.json").is_file() + + +def test_absent_metadata_hf_hub_dataset(): + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + prep = get_prep(local_folder, "allenai/dolma") + prep._save_croissant_metadata() + assert not (pathlib.Path(local_folder) / "croissant.json").is_file() + + +def test_existing_metadata_local_dataset(): + with tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder: + (pathlib.Path(dataset_folder) / "croissant.json").touch() + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + prep = get_prep(local_folder, dataset_folder) + prep._save_croissant_metadata() + assert (pathlib.Path(local_folder) / "croissant.json").is_file() + + +def test_absent_metadata_local_dataset(): + with tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder: + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + prep = get_prep(local_folder, dataset_folder) + prep._save_croissant_metadata() + assert not (pathlib.Path(local_folder) / "croissant.json").is_file() From a513e3fdadd84d989be2ccba7c4f8dc0c74bd2bc Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Thu, 6 Feb 2025 19:23:56 +0200 Subject: [PATCH 2/7] fix run write under rank 0 only --- fast_llm/data/preparator/gpt_memmap/prepare.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index f8d7814e..41d7f0ec 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -220,9 +220,9 @@ def run(self) -> None: self._save_croissant_metadata() - # Create an index file on rank 0 - index_file = self._config.output_path / "index.txt" - index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts]) + # Create an index file on rank 0 + index_file = self._config.output_path / "index.txt" + index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts]) # Finalize distributed processing if self._config.distributed.world_size > 1: From db3fcb034452636b06c529f57f288cf751e58af1 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Thu, 6 Feb 2025 19:37:37 +0200 Subject: [PATCH 3/7] added comments --- fast_llm/data/preparator/gpt_memmap/prepare.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 41d7f0ec..6a2b3f51 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -72,10 +72,12 @@ def _load_dataset(self) -> datasets.Dataset: return dataset def _get_croissant_metadata(self): - url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" - + # Use HF hub functionality to get api token from logged in state + # or set HF_TOKEN or HF_TOKEN_PATH env vars token = HfFolder.get_token() try: + # Retrieve the dataset metadata in croissant format + url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" if token is None: response = requests.get(url) else: @@ -100,7 +102,9 @@ def _get_croissant_metadata(self): def _save_croissant_metadata(self): dataset_path = pathlib.Path(self._config.dataset.path) dst_croissant_file = pathlib.Path(self._config.output_path) / "croissant.json" + if dataset_path.is_dir(): + # If the dataset is local, check if it has the metadata file and copy it croissant_file = dataset_path / "croissant.json" if croissant_file.is_file(): shutil.copy(croissant_file, dst_croissant_file) @@ -108,6 +112,7 @@ def _save_croissant_metadata(self): logger.warning(f"Source local dataset {self._config.dataset.path} does not have croissant file") return else: + # If the dataset is on HF hub, retrieve the metadata if provided and save it data = self._get_croissant_metadata() if data is None: return From 5e1ae9d84748b9fecc8a9d4ce3f1f9f343e89136 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Fri, 7 Feb 2025 10:53:28 +0200 Subject: [PATCH 4/7] removed depencence on huggingface_hub by implementing get_token in the repo --- .../data/preparator/gpt_memmap/prepare.py | 7 +- fast_llm/ext_utils/hf_auth.py | 74 +++++++++++++++++++ setup.cfg | 1 - 3 files changed, 76 insertions(+), 6 deletions(-) create mode 100644 fast_llm/ext_utils/hf_auth.py diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 6a2b3f51..c424af8f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -12,13 +12,12 @@ import tqdm import transformers -from huggingface_hub import HfFolder - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.ext_utils.hf_auth import hf_auth_get_token logger = logging.getLogger(__name__) @@ -72,9 +71,7 @@ def _load_dataset(self) -> datasets.Dataset: return dataset def _get_croissant_metadata(self): - # Use HF hub functionality to get api token from logged in state - # or set HF_TOKEN or HF_TOKEN_PATH env vars - token = HfFolder.get_token() + token = hf_auth_get_token() try: # Retrieve the dataset metadata in croissant format url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" diff --git a/fast_llm/ext_utils/hf_auth.py b/fast_llm/ext_utils/hf_auth.py new file mode 100644 index 00000000..781df6e3 --- /dev/null +++ b/fast_llm/ext_utils/hf_auth.py @@ -0,0 +1,74 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This is a cut down version of getting an authentication token for HF hub +from https://github.com/huggingface/huggingface_hub/tree/d29290fbd2689fdc125e27ce0f91e8227a2f20de +It does not include getting the token in Google colab environment. +Also, it does not look for the token in the old path as it was marked +for removal in the future versions of the huggingface_hub. + +NOTE: It is needed to track changes in huggingface_hub or include it as dependency later. +""" + +import os +import pathlib + +from typing import Optional + +# default cache +default_home = os.path.join(os.path.expanduser("~"), ".cache") +HF_HOME = os.path.expanduser( + os.getenv( + "HF_HOME", + os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"), + ) +) + +HF_TOKEN_PATH = os.environ.get("HF_TOKEN_PATH", os.path.join(HF_HOME, "token")) + + +def hf_auth_get_token() -> Optional[str]: + """ + Get token if user is logged in. + + Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located + in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or + `huggingface-cli login`. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + """ + return _get_token_from_environment() or _get_token_from_file() + + +def _get_token_from_environment() -> Optional[str]: + # `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility) + return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")) + + +def _get_token_from_file() -> Optional[str]: + try: + return _clean_token(pathlib.Path(HF_TOKEN_PATH).read_text()) + except FileNotFoundError: + return None + + +def _clean_token(token: Optional[str]) -> Optional[str]: + """Clean token by removing trailing and leading spaces and newlines. + + If token is an empty string, return None. + """ + if token is None: + return None + return token.replace("\r", "").replace("\n", "").strip() or None diff --git a/setup.cfg b/setup.cfg index 78c8bc99..9c292ed1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,6 @@ OPTIONAL = # Huggingface tools transformers>=4.44.2 hf-transfer>=0.1.8 - huggingface-hub>=0.28.1 datasets>=3.1.0 # Weights and biases wandb>=0.17.7 From b2a5e047bcc5436a81a8150e5eae3032be806c03 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Feb 2025 15:58:28 -0500 Subject: [PATCH 5/7] Use hf hub --- .../data/preparator/gpt_memmap/prepare.py | 6 +- fast_llm/ext_utils/hf_auth.py | 74 ------------------- setup.cfg | 4 +- 3 files changed, 5 insertions(+), 79 deletions(-) delete mode 100644 fast_llm/ext_utils/hf_auth.py diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 7b148c4e..0fc8c7c4 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -2,12 +2,13 @@ import logging import multiprocessing import pathlib -import requests import shutil import typing import datasets +import huggingface_hub import numpy as np +import requests import torch.distributed import tqdm import transformers @@ -18,7 +19,6 @@ from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.ext_utils.hf_auth import hf_auth_get_token logger = logging.getLogger(__name__) @@ -104,7 +104,7 @@ def _load_dataset(self) -> datasets.Dataset: return dataset def _get_croissant_metadata(self): - token = hf_auth_get_token() + token = huggingface_hub.HfFolder.get_token() try: # Retrieve the dataset metadata in croissant format url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" diff --git a/fast_llm/ext_utils/hf_auth.py b/fast_llm/ext_utils/hf_auth.py deleted file mode 100644 index 781df6e3..00000000 --- a/fast_llm/ext_utils/hf_auth.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This is a cut down version of getting an authentication token for HF hub -from https://github.com/huggingface/huggingface_hub/tree/d29290fbd2689fdc125e27ce0f91e8227a2f20de -It does not include getting the token in Google colab environment. -Also, it does not look for the token in the old path as it was marked -for removal in the future versions of the huggingface_hub. - -NOTE: It is needed to track changes in huggingface_hub or include it as dependency later. -""" - -import os -import pathlib - -from typing import Optional - -# default cache -default_home = os.path.join(os.path.expanduser("~"), ".cache") -HF_HOME = os.path.expanduser( - os.getenv( - "HF_HOME", - os.path.join(os.getenv("XDG_CACHE_HOME", default_home), "huggingface"), - ) -) - -HF_TOKEN_PATH = os.environ.get("HF_TOKEN_PATH", os.path.join(HF_HOME, "token")) - - -def hf_auth_get_token() -> Optional[str]: - """ - Get token if user is logged in. - - Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located - in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or - `huggingface-cli login`. - - Returns: - `str` or `None`: The token, `None` if it doesn't exist. - """ - return _get_token_from_environment() or _get_token_from_file() - - -def _get_token_from_environment() -> Optional[str]: - # `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility) - return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")) - - -def _get_token_from_file() -> Optional[str]: - try: - return _clean_token(pathlib.Path(HF_TOKEN_PATH).read_text()) - except FileNotFoundError: - return None - - -def _clean_token(token: Optional[str]) -> Optional[str]: - """Clean token by removing trailing and leading spaces and newlines. - - If token is an empty string, return None. - """ - if token is None: - return None - return token.replace("\r", "").replace("\n", "").strip() or None diff --git a/setup.cfg b/setup.cfg index 9c292ed1..c21f02a7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,8 +19,7 @@ install_requires = # FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation CORE = # Available through the nvidia base image - # Keeping an older min version because later ones have no x86 wheel for Mac OS - torch>=2.2.2 + torch>=2.5.0 # Numpy major needs to match torch numpy>=1.24.4,<2.0.0 # Used for checkpoints @@ -34,6 +33,7 @@ OPTIONAL = transformers>=4.44.2 hf-transfer>=0.1.8 datasets>=3.1.0 + huggingface-hub>=0.28.1 # Weights and biases wandb>=0.17.7 # Hydra From 6e13771d4f7458440199d1f90f6ebc9494eacf4d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Feb 2025 16:16:10 -0500 Subject: [PATCH 6/7] misc --- fast_llm/data/preparator/gpt_memmap/prepare.py | 10 ++++------ ...ssing_metadata.py => test_processing_metadata.py} | 12 ++++++------ 2 files changed, 10 insertions(+), 12 deletions(-) rename tests/data/{test_porcessing_metadata.py => test_processing_metadata.py} (82%) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 0fc8c7c4..e029137c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -131,23 +131,21 @@ def _get_croissant_metadata(self): def _save_croissant_metadata(self): dataset_path = pathlib.Path(self._config.dataset.path) - dst_croissant_file = pathlib.Path(self._config.output_path) / "croissant.json" + croissant_path = pathlib.Path(self._config.output_path) / "croissant.json" if dataset_path.is_dir(): # If the dataset is local, check if it has the metadata file and copy it croissant_file = dataset_path / "croissant.json" if croissant_file.is_file(): - shutil.copy(croissant_file, dst_croissant_file) + shutil.copy(croissant_file, croissant_path) else: logger.warning(f"Source local dataset {self._config.dataset.path} does not have croissant file") return else: # If the dataset is on HF hub, retrieve the metadata if provided and save it data = self._get_croissant_metadata() - if data is None: - return - with dst_croissant_file.open("wt") as f: - json.dump(data, f) + if data is not None: + json.dump(data, croissant_path.open("w")) def run(self) -> None: # Set transformers logging verbosity diff --git a/tests/data/test_porcessing_metadata.py b/tests/data/test_processing_metadata.py similarity index 82% rename from tests/data/test_porcessing_metadata.py rename to tests/data/test_processing_metadata.py index e9d13a7a..9298968d 100644 --- a/tests/data/test_porcessing_metadata.py +++ b/tests/data/test_processing_metadata.py @@ -1,11 +1,11 @@ -import tempfile import pathlib +import tempfile from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -def get_prep(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: +def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: config = GPTMemmapDatasetPreparatorConfig.from_dict( { "output_path": output_path, @@ -19,14 +19,14 @@ def get_prep(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPrepar def test_existing_metadata_hf_hub_dataset(): with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_prep(local_folder, "lhoestq/demo1") + prep = get_preparator(local_folder, "lhoestq/demo1") prep._save_croissant_metadata() assert (pathlib.Path(local_folder) / "croissant.json").is_file() def test_absent_metadata_hf_hub_dataset(): with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_prep(local_folder, "allenai/dolma") + prep = get_preparator(local_folder, "allenai/dolma") prep._save_croissant_metadata() assert not (pathlib.Path(local_folder) / "croissant.json").is_file() @@ -35,7 +35,7 @@ def test_existing_metadata_local_dataset(): with tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder: (pathlib.Path(dataset_folder) / "croissant.json").touch() with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_prep(local_folder, dataset_folder) + prep = get_preparator(local_folder, dataset_folder) prep._save_croissant_metadata() assert (pathlib.Path(local_folder) / "croissant.json").is_file() @@ -43,6 +43,6 @@ def test_existing_metadata_local_dataset(): def test_absent_metadata_local_dataset(): with tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder: with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_prep(local_folder, dataset_folder) + prep = get_preparator(local_folder, dataset_folder) prep._save_croissant_metadata() assert not (pathlib.Path(local_folder) / "croissant.json").is_file() From c1013f0445870ac484652206cfe3737744e8f5e6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 14 Feb 2025 16:52:40 -0500 Subject: [PATCH 7/7] Improve tests --- tests/data/test_memmap.py | 19 ------- tests/data/test_prepare_gpt_memmap.py | 73 ++++++++++++++++++++++++++ tests/data/test_processing_metadata.py | 48 ----------------- 3 files changed, 73 insertions(+), 67 deletions(-) create mode 100644 tests/data/test_prepare_gpt_memmap.py delete mode 100644 tests/data/test_processing_metadata.py diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index c6af54bb..6aaf83e8 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -1,30 +1,11 @@ import pathlib -import tempfile -import numpy as np import pytest from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES from tests.common import DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset from tests.data.common import compare_indexed_dataset, get_dataset_config - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_dataset(dtype): - documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] - with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - assert np.array_equal( - dataset.get(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." - - MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 MEMMAP_DATASET_SAMPLES = { diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py new file mode 100644 index 00000000..d2810d12 --- /dev/null +++ b/tests/data/test_prepare_gpt_memmap.py @@ -0,0 +1,73 @@ +import json +import pathlib +import tempfile + +import numpy as np +import pytest + +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator + + +def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: + config = GPTMemmapDatasetPreparatorConfig.from_dict( + { + "output_path": output_path, + "dataset": {"path": dataset_path_name}, + "tokenizer": {"path": "no_tokenizer"}, + }, + {}, + ) + return config.get_dataset_preparator_class()(config=config) + + +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_dataset(dtype): + documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] + with tempfile.TemporaryDirectory() as temp_dir: + prefix = pathlib.Path(temp_dir) + GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = GPTMemmapDataset(name="foo", prefix=prefix) + for i, document in enumerate(documents): + assert np.array_equal( + dataset.get(i).token_ids, document.token_ids, equal_nan=True + ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + + +def test_load_metadata_from_hub(): + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata() + croissant_path = pathlib.Path(local_folder) / "croissant.json" + assert croissant_path.is_file() + metadata = json.load(croissant_path.open("r")) + assert metadata["url"] == "https://huggingface.co/datasets/lhoestq/demo1" + + +def test_absent_metadata_from_hub(): + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + get_preparator(local_folder, "allenai/dolma")._save_croissant_metadata() + assert not (pathlib.Path(local_folder) / "croissant.json").is_file() + + +def test_load_metadata_local(): + with ( + tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, + tempfile.TemporaryDirectory(suffix="test") as local_folder, + ): + metadata = {"name": "test"} + json.dump(metadata, (pathlib.Path(dataset_folder) / "croissant.json").open("w")) + get_preparator(local_folder, dataset_folder)._save_croissant_metadata() + croissant_path = pathlib.Path(local_folder) / "croissant.json" + assert croissant_path.is_file() + assert json.loads(croissant_path.open("r").read()) == metadata + + +def test_absent_metadata_local(): + with ( + tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, + tempfile.TemporaryDirectory(suffix="test") as local_folder, + ): + get_preparator(local_folder, dataset_folder)._save_croissant_metadata() + assert not (pathlib.Path(local_folder) / "croissant.json").is_file() diff --git a/tests/data/test_processing_metadata.py b/tests/data/test_processing_metadata.py deleted file mode 100644 index 9298968d..00000000 --- a/tests/data/test_processing_metadata.py +++ /dev/null @@ -1,48 +0,0 @@ -import pathlib -import tempfile - -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig -from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator - - -def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: - config = GPTMemmapDatasetPreparatorConfig.from_dict( - { - "output_path": output_path, - "dataset": {"path": dataset_path_name}, - "tokenizer": {"path": "no_tokenizer"}, - }, - {}, - ) - return config.get_dataset_preparator_class()(config=config) - - -def test_existing_metadata_hf_hub_dataset(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_preparator(local_folder, "lhoestq/demo1") - prep._save_croissant_metadata() - assert (pathlib.Path(local_folder) / "croissant.json").is_file() - - -def test_absent_metadata_hf_hub_dataset(): - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_preparator(local_folder, "allenai/dolma") - prep._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file() - - -def test_existing_metadata_local_dataset(): - with tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder: - (pathlib.Path(dataset_folder) / "croissant.json").touch() - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_preparator(local_folder, dataset_folder) - prep._save_croissant_metadata() - assert (pathlib.Path(local_folder) / "croissant.json").is_file() - - -def test_absent_metadata_local_dataset(): - with tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder: - with tempfile.TemporaryDirectory(suffix="test") as local_folder: - prep = get_preparator(local_folder, dataset_folder) - prep._save_croissant_metadata() - assert not (pathlib.Path(local_folder) / "croissant.json").is_file()