Skip to content

Commit de7b2d8

Browse files
Saving of croissant metadata files for HF datasets (#142)
Co-authored-by: Joel Lamy-Poirier <[email protected]>
1 parent a637560 commit de7b2d8

File tree

4 files changed

+131
-24
lines changed

4 files changed

+131
-24
lines changed

fast_llm/data/preparator/gpt_memmap/prepare.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import json
2+
import logging
23
import multiprocessing
34
import pathlib
5+
import shutil
46
import typing
57

68
import datasets
9+
import huggingface_hub
710
import numpy as np
11+
import requests
812
import torch.distributed
913
import tqdm
1014
import transformers
@@ -16,6 +20,8 @@
1620
from fast_llm.data.tokenizer import Tokenizer
1721
from fast_llm.engine.config_utils.data_type import DataType
1822

23+
logger = logging.getLogger(__name__)
24+
1925

2026
class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]):
2127
config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig
@@ -97,6 +103,50 @@ def _load_dataset(self) -> datasets.Dataset:
97103
assert isinstance(dataset, datasets.Dataset)
98104
return dataset
99105

106+
def _get_croissant_metadata(self):
107+
token = huggingface_hub.HfFolder.get_token()
108+
try:
109+
# Retrieve the dataset metadata in croissant format
110+
url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant"
111+
if token is None:
112+
response = requests.get(url)
113+
else:
114+
response = requests.get(url, headers={"Authorization": f"Bearer {token}"})
115+
116+
if response.status_code != 200:
117+
logger.warning(
118+
f"Failed to get croissant metadata, status_code: {response.status_code}, body: {response.text}"
119+
)
120+
return None
121+
122+
data = response.json()
123+
except Exception as e:
124+
logger.warning(f"Failed to get croissant metadata, {e}")
125+
return None
126+
if "error" in data:
127+
logger.warning(f"Failed to get croissant metadata, error: {data['error']}")
128+
return None
129+
130+
return data
131+
132+
def _save_croissant_metadata(self):
133+
dataset_path = pathlib.Path(self._config.dataset.path)
134+
croissant_path = pathlib.Path(self._config.output_path) / "croissant.json"
135+
136+
if dataset_path.is_dir():
137+
# If the dataset is local, check if it has the metadata file and copy it
138+
croissant_file = dataset_path / "croissant.json"
139+
if croissant_file.is_file():
140+
shutil.copy(croissant_file, croissant_path)
141+
else:
142+
logger.warning(f"Source local dataset {self._config.dataset.path} does not have croissant file")
143+
return
144+
else:
145+
# If the dataset is on HF hub, retrieve the metadata if provided and save it
146+
data = self._get_croissant_metadata()
147+
if data is not None:
148+
json.dump(data, croissant_path.open("w"))
149+
100150
def run(self) -> None:
101151
# Set transformers logging verbosity
102152
transformers.logging.set_verbosity_error()
@@ -207,9 +257,11 @@ def run(self) -> None:
207257
output_file = self._config.output_path / "fast_llm_dataset.json"
208258
json.dump({"datasets": dataset_dicts}, output_file.open("w"))
209259

210-
# Create an index file on rank 0
211-
index_file = self._config.output_path / "index.txt"
212-
index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts])
260+
self._save_croissant_metadata()
261+
262+
# Create an index file on rank 0
263+
index_file = self._config.output_path / "index.txt"
264+
index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts])
213265

214266
# Finalize distributed processing
215267
if self._config.distributed.world_size > 1:

setup.cfg

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ install_requires =
1919
# FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation
2020
CORE =
2121
# Available through the nvidia base image
22-
# Keeping an older min version because later ones have no x86 wheel for Mac OS
23-
torch>=2.2.2
22+
torch>=2.5.0
2423
# Numpy major needs to match torch
2524
numpy>=1.24.4,<2.0.0
2625
# Used for checkpoints
@@ -34,12 +33,14 @@ OPTIONAL =
3433
transformers>=4.44.2
3534
hf-transfer>=0.1.8
3635
datasets>=3.1.0
36+
huggingface-hub>=0.28.1
3737
# Weights and biases
3838
wandb>=0.17.7
3939
# Hydra
4040
hydra-core>=1.3.2
4141
omegaconf>=2.3.0
4242
# Miscellanous
43+
requests>=2.32.3
4344
tqdm>=4.66.3
4445

4546
DEV =

tests/data/test_memmap.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,11 @@
11
import pathlib
2-
import tempfile
32

4-
import numpy as np
53
import pytest
64

75
from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig
8-
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
9-
from fast_llm.data.dataset.gpt.sampled import GPTSample
10-
from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES
116
from tests.common import DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset
127
from tests.data.common import compare_indexed_dataset, get_dataset_config
138

14-
15-
@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values())
16-
def test_write_memmap_dataset(dtype):
17-
documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)]
18-
with tempfile.TemporaryDirectory() as temp_dir:
19-
prefix = pathlib.Path(temp_dir)
20-
GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents)
21-
dataset = GPTMemmapDataset(name="foo", prefix=prefix)
22-
for i, document in enumerate(documents):
23-
assert np.array_equal(
24-
dataset.get(i).token_ids, document.token_ids, equal_nan=True
25-
), f"Mismatch for document {i}: {document} != {dataset.get(i)}."
26-
27-
289
MEMMAP_DATASET_LENGTH = 6153
2910
MEMMAP_DATASET_TOKENS = 508327
3011
MEMMAP_DATASET_SAMPLES = {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import json
2+
import pathlib
3+
import tempfile
4+
5+
import numpy as np
6+
import pytest
7+
8+
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
9+
from fast_llm.data.dataset.gpt.sampled import GPTSample
10+
from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig
11+
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator
12+
13+
14+
def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator:
15+
config = GPTMemmapDatasetPreparatorConfig.from_dict(
16+
{
17+
"output_path": output_path,
18+
"dataset": {"path": dataset_path_name},
19+
"tokenizer": {"path": "no_tokenizer"},
20+
},
21+
{},
22+
)
23+
return config.get_dataset_preparator_class()(config=config)
24+
25+
26+
@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values())
27+
def test_write_memmap_dataset(dtype):
28+
documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)]
29+
with tempfile.TemporaryDirectory() as temp_dir:
30+
prefix = pathlib.Path(temp_dir)
31+
GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents)
32+
dataset = GPTMemmapDataset(name="foo", prefix=prefix)
33+
for i, document in enumerate(documents):
34+
assert np.array_equal(
35+
dataset.get(i).token_ids, document.token_ids, equal_nan=True
36+
), f"Mismatch for document {i}: {document} != {dataset.get(i)}."
37+
38+
39+
def test_load_metadata_from_hub():
40+
with tempfile.TemporaryDirectory(suffix="test") as local_folder:
41+
get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata()
42+
croissant_path = pathlib.Path(local_folder) / "croissant.json"
43+
assert croissant_path.is_file()
44+
metadata = json.load(croissant_path.open("r"))
45+
assert metadata["url"] == "https://huggingface.co/datasets/lhoestq/demo1"
46+
47+
48+
def test_absent_metadata_from_hub():
49+
with tempfile.TemporaryDirectory(suffix="test") as local_folder:
50+
get_preparator(local_folder, "allenai/dolma")._save_croissant_metadata()
51+
assert not (pathlib.Path(local_folder) / "croissant.json").is_file()
52+
53+
54+
def test_load_metadata_local():
55+
with (
56+
tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder,
57+
tempfile.TemporaryDirectory(suffix="test") as local_folder,
58+
):
59+
metadata = {"name": "test"}
60+
json.dump(metadata, (pathlib.Path(dataset_folder) / "croissant.json").open("w"))
61+
get_preparator(local_folder, dataset_folder)._save_croissant_metadata()
62+
croissant_path = pathlib.Path(local_folder) / "croissant.json"
63+
assert croissant_path.is_file()
64+
assert json.loads(croissant_path.open("r").read()) == metadata
65+
66+
67+
def test_absent_metadata_local():
68+
with (
69+
tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder,
70+
tempfile.TemporaryDirectory(suffix="test") as local_folder,
71+
):
72+
get_preparator(local_folder, dataset_folder)._save_croissant_metadata()
73+
assert not (pathlib.Path(local_folder) / "croissant.json").is_file()

0 commit comments

Comments
 (0)