diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index ace9c5fe..e029137c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,10 +1,14 @@ import json +import logging import multiprocessing import pathlib +import shutil import typing import datasets +import huggingface_hub import numpy as np +import requests import torch.distributed import tqdm import transformers @@ -16,6 +20,8 @@ from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType +logger = logging.getLogger(__name__) + class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig @@ -97,6 +103,50 @@ def _load_dataset(self) -> datasets.Dataset: assert isinstance(dataset, datasets.Dataset) return dataset + def _get_croissant_metadata(self): + token = huggingface_hub.HfFolder.get_token() + try: + # Retrieve the dataset metadata in croissant format + url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" + if token is None: + response = requests.get(url) + else: + response = requests.get(url, headers={"Authorization": f"Bearer {token}"}) + + if response.status_code != 200: + logger.warning( + f"Failed to get croissant metadata, status_code: {response.status_code}, body: {response.text}" + ) + return None + + data = response.json() + except Exception as e: + logger.warning(f"Failed to get croissant metadata, {e}") + return None + if "error" in data: + logger.warning(f"Failed to get croissant metadata, error: {data['error']}") + return None + + return data + + def _save_croissant_metadata(self): + dataset_path = pathlib.Path(self._config.dataset.path) + croissant_path = pathlib.Path(self._config.output_path) / "croissant.json" + + if dataset_path.is_dir(): + # If the dataset is local, check if it has the metadata file and copy it + croissant_file = dataset_path / "croissant.json" + if croissant_file.is_file(): + shutil.copy(croissant_file, croissant_path) + else: + logger.warning(f"Source local dataset {self._config.dataset.path} does not have croissant file") + return + else: + # If the dataset is on HF hub, retrieve the metadata if provided and save it + data = self._get_croissant_metadata() + if data is not None: + json.dump(data, croissant_path.open("w")) + def run(self) -> None: # Set transformers logging verbosity transformers.logging.set_verbosity_error() @@ -207,9 +257,11 @@ def run(self) -> None: output_file = self._config.output_path / "fast_llm_dataset.json" json.dump({"datasets": dataset_dicts}, output_file.open("w")) - # Create an index file on rank 0 - index_file = self._config.output_path / "index.txt" - index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts]) + self._save_croissant_metadata() + + # Create an index file on rank 0 + index_file = self._config.output_path / "index.txt" + index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts]) # Finalize distributed processing if self._config.distributed.world_size > 1: diff --git a/setup.cfg b/setup.cfg index 351d1bcb..c21f02a7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,8 +19,7 @@ install_requires = # FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation CORE = # Available through the nvidia base image - # Keeping an older min version because later ones have no x86 wheel for Mac OS - torch>=2.2.2 + torch>=2.5.0 # Numpy major needs to match torch numpy>=1.24.4,<2.0.0 # Used for checkpoints @@ -34,12 +33,14 @@ OPTIONAL = transformers>=4.44.2 hf-transfer>=0.1.8 datasets>=3.1.0 + huggingface-hub>=0.28.1 # Weights and biases wandb>=0.17.7 # Hydra hydra-core>=1.3.2 omegaconf>=2.3.0 # Miscellanous + requests>=2.32.3 tqdm>=4.66.3 DEV = diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index c6af54bb..6aaf83e8 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -1,30 +1,11 @@ import pathlib -import tempfile -import numpy as np import pytest from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES from tests.common import DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset from tests.data.common import compare_indexed_dataset, get_dataset_config - -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_dataset(dtype): - documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] - with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - assert np.array_equal( - dataset.get(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." - - MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 MEMMAP_DATASET_SAMPLES = { diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py new file mode 100644 index 00000000..d2810d12 --- /dev/null +++ b/tests/data/test_prepare_gpt_memmap.py @@ -0,0 +1,73 @@ +import json +import pathlib +import tempfile + +import numpy as np +import pytest + +from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator + + +def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator: + config = GPTMemmapDatasetPreparatorConfig.from_dict( + { + "output_path": output_path, + "dataset": {"path": dataset_path_name}, + "tokenizer": {"path": "no_tokenizer"}, + }, + {}, + ) + return config.get_dataset_preparator_class()(config=config) + + +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_dataset(dtype): + documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] + with tempfile.TemporaryDirectory() as temp_dir: + prefix = pathlib.Path(temp_dir) + GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = GPTMemmapDataset(name="foo", prefix=prefix) + for i, document in enumerate(documents): + assert np.array_equal( + dataset.get(i).token_ids, document.token_ids, equal_nan=True + ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + + +def test_load_metadata_from_hub(): + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata() + croissant_path = pathlib.Path(local_folder) / "croissant.json" + assert croissant_path.is_file() + metadata = json.load(croissant_path.open("r")) + assert metadata["url"] == "https://huggingface.co/datasets/lhoestq/demo1" + + +def test_absent_metadata_from_hub(): + with tempfile.TemporaryDirectory(suffix="test") as local_folder: + get_preparator(local_folder, "allenai/dolma")._save_croissant_metadata() + assert not (pathlib.Path(local_folder) / "croissant.json").is_file() + + +def test_load_metadata_local(): + with ( + tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, + tempfile.TemporaryDirectory(suffix="test") as local_folder, + ): + metadata = {"name": "test"} + json.dump(metadata, (pathlib.Path(dataset_folder) / "croissant.json").open("w")) + get_preparator(local_folder, dataset_folder)._save_croissant_metadata() + croissant_path = pathlib.Path(local_folder) / "croissant.json" + assert croissant_path.is_file() + assert json.loads(croissant_path.open("r").read()) == metadata + + +def test_absent_metadata_local(): + with ( + tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder, + tempfile.TemporaryDirectory(suffix="test") as local_folder, + ): + get_preparator(local_folder, dataset_folder)._save_croissant_metadata() + assert not (pathlib.Path(local_folder) / "croissant.json").is_file()