Skip to content

Saving of croissant metadata files for HF datasets #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import logging
import multiprocessing
import pathlib
import shutil
import typing

import datasets
import huggingface_hub
import numpy as np
import requests
import torch.distributed
import tqdm
import transformers
Expand All @@ -16,6 +20,8 @@
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.engine.config_utils.data_type import DataType

logger = logging.getLogger(__name__)


class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]):
config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig
Expand Down Expand Up @@ -97,6 +103,50 @@ def _load_dataset(self) -> datasets.Dataset:
assert isinstance(dataset, datasets.Dataset)
return dataset

def _get_croissant_metadata(self):
token = huggingface_hub.HfFolder.get_token()
try:
# Retrieve the dataset metadata in croissant format
url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that I think about this again, why do we need to go to hf.co and get the croissant metadata from there? why can't we use the hf datasets api?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems, at the moment at least, they do not have anything with it in the datasets library https://github.com/search?q=repo%3Ahuggingface%2Fdatasets%20croissant&type=code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fair enough. Looks like originalCroissant metadata can only be retrieved from that endpoint then.

Btw, have you checked this approach:

import datasets

builder = datasets.load_dataset_builder('dataset_name')

dataset_info = builder.info

dataset_info should be of type datasets.info.DatasetInfo and contain:

  • description: A textual description of the dataset.
  • features: The schema of the dataset (column names and types).
  • splits: Information about available splits (e.g., train, test, validation).
  • size_in_bytes: The dataset size.
  • citation: The citation reference.
  • license: The dataset's license.

This looks useful and could be all we need (cc @chrish42). We could convert it into Croissant and save it to disk. The benefit here is that this can be used with any hf dataset on disk as well, including previously downloaded and cached ones.

Furthermore, irrespective of whether we use Croissant or dataset_info, I'm wondering how we want to handle the features field. fast-llm prepare only keeps the main text field (and optionally a list of character spans that should be excluded from the loss, see #113). I think the features field should be modified based on that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked for dataset info via the URL, as described here: Hugging Face Dataset Viewer - Info, instead of using the builder. However, it did not provide any additional information compared to the Croissant format and failed on the same datasets.

I have now tested it using the builder:

  • It seems to be slower, as it downloads scripts and at least partially executes them.
  • On the plus side, it was able to read 6 out of the 7 repositories that neither the Croissant format nor the dataset info URL could provide.

However, if switching to the builder, I don’t see a reason to convert that information to Croissant. The main purpose of Croissant metadata is to be actionable β€” its recordSet and distribution fields allow actual data loading. So, if using dataset info, it would make more sense to simply save it to a YAML file, for example.

Copy link
Contributor Author

@bigximik bigximik Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also as datasets on HF are actually git repos, as @sebpaquet proposed we can save their url and commit_sha for lineage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What this PR currently saves, or even if we decide to save dataset info instead, is the information about the processing source.

If we were to add the current operation transformation, as proposed by @tscholak:

I'm wondering how we want to handle the features field. fast-llm prepare only keeps the main text field (and optionally a list of character spans that should be excluded from the loss, see PR #113). I think the features field should be modified based on that...

I propose that instead of trying to define a format to describe the transformation, we simply store the command, configuration, command Git URL, and commit SHA used for the data transformation.

Additionally, the command can provide a human-readable description of what it does, which can be included in the info.

This way, we will have an exact record of how this dataset was produced, ensuring proper lineage tracking.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now tested it using the builder:

  • It seems to be slower, as it downloads scripts and at least partially executes them.
  • On the plus side, it was able to read 6 out of the 7 repositories that neither the Croissant format nor the dataset info URL could provide.

How much slower? Shouldn't it be small compared to the whole of prepare?

I propose that instead of trying to define a format to describe the transformation, we simply store the command, configuration, command Git URL, and commit SHA used for the data transformation.

I don't think this is enough, we need some backup in case the source somehow disappears

if token is None:
response = requests.get(url)
else:
response = requests.get(url, headers={"Authorization": f"Bearer {token}"})

if response.status_code != 200:
logger.warning(
f"Failed to get croissant metadata, status_code: {response.status_code}, body: {response.text}"
)
return None

data = response.json()
except Exception as e:
logger.warning(f"Failed to get croissant metadata, {e}")
return None
if "error" in data:
logger.warning(f"Failed to get croissant metadata, error: {data['error']}")
return None

return data

def _save_croissant_metadata(self):
dataset_path = pathlib.Path(self._config.dataset.path)
croissant_path = pathlib.Path(self._config.output_path) / "croissant.json"

if dataset_path.is_dir():
# If the dataset is local, check if it has the metadata file and copy it
croissant_file = dataset_path / "croissant.json"
if croissant_file.is_file():
shutil.copy(croissant_file, croissant_path)
else:
logger.warning(f"Source local dataset {self._config.dataset.path} does not have croissant file")
return
else:
# If the dataset is on HF hub, retrieve the metadata if provided and save it
data = self._get_croissant_metadata()
if data is not None:
json.dump(data, croissant_path.open("w"))

def run(self) -> None:
# Set transformers logging verbosity
transformers.logging.set_verbosity_error()
Expand Down Expand Up @@ -207,9 +257,11 @@ def run(self) -> None:
output_file = self._config.output_path / "fast_llm_dataset.json"
json.dump({"datasets": dataset_dicts}, output_file.open("w"))

# Create an index file on rank 0
index_file = self._config.output_path / "index.txt"
index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts])
self._save_croissant_metadata()

# Create an index file on rank 0
index_file = self._config.output_path / "index.txt"
index_file.open("w").writelines([dataset_dict["prefix"] + "\n" for dataset_dict in dataset_dicts])

# Finalize distributed processing
if self._config.distributed.world_size > 1:
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ install_requires =
# FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install -e ".[CORE]" --no-build-isolation
CORE =
# Available through the nvidia base image
# Keeping an older min version because later ones have no x86 wheel for Mac OS
torch>=2.2.2
torch>=2.5.0
# Numpy major needs to match torch
numpy>=1.24.4,<2.0.0
# Used for checkpoints
Expand All @@ -34,12 +33,14 @@ OPTIONAL =
transformers>=4.44.2
hf-transfer>=0.1.8
datasets>=3.1.0
huggingface-hub>=0.28.1
# Weights and biases
wandb>=0.17.7
# Hydra
hydra-core>=1.3.2
omegaconf>=2.3.0
# Miscellanous
requests>=2.32.3
tqdm>=4.66.3

DEV =
Expand Down
19 changes: 0 additions & 19 deletions tests/data/test_memmap.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,11 @@
import pathlib
import tempfile

import numpy as np
import pytest

from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES
from tests.common import DATASET_PREFIX, DATASET_SAMPLING_CACHE, get_test_dataset
from tests.data.common import compare_indexed_dataset, get_dataset_config


@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values())
def test_write_memmap_dataset(dtype):
documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)]
with tempfile.TemporaryDirectory() as temp_dir:
prefix = pathlib.Path(temp_dir)
GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents)
dataset = GPTMemmapDataset(name="foo", prefix=prefix)
for i, document in enumerate(documents):
assert np.array_equal(
dataset.get(i).token_ids, document.token_ids, equal_nan=True
), f"Mismatch for document {i}: {document} != {dataset.get(i)}."


MEMMAP_DATASET_LENGTH = 6153
MEMMAP_DATASET_TOKENS = 508327
MEMMAP_DATASET_SAMPLES = {
Expand Down
73 changes: 73 additions & 0 deletions tests/data/test_prepare_gpt_memmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
import pathlib
import tempfile

import numpy as np
import pytest

from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator


def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDatasetPreparator:
config = GPTMemmapDatasetPreparatorConfig.from_dict(
{
"output_path": output_path,
"dataset": {"path": dataset_path_name},
"tokenizer": {"path": "no_tokenizer"},
},
{},
)
return config.get_dataset_preparator_class()(config=config)


@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values())
def test_write_memmap_dataset(dtype):
documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)]
with tempfile.TemporaryDirectory() as temp_dir:
prefix = pathlib.Path(temp_dir)
GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents)
dataset = GPTMemmapDataset(name="foo", prefix=prefix)
for i, document in enumerate(documents):
assert np.array_equal(
dataset.get(i).token_ids, document.token_ids, equal_nan=True
), f"Mismatch for document {i}: {document} != {dataset.get(i)}."


def test_load_metadata_from_hub():
with tempfile.TemporaryDirectory(suffix="test") as local_folder:
get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata()
croissant_path = pathlib.Path(local_folder) / "croissant.json"
assert croissant_path.is_file()
metadata = json.load(croissant_path.open("r"))
assert metadata["url"] == "https://huggingface.co/datasets/lhoestq/demo1"


def test_absent_metadata_from_hub():
with tempfile.TemporaryDirectory(suffix="test") as local_folder:
get_preparator(local_folder, "allenai/dolma")._save_croissant_metadata()
assert not (pathlib.Path(local_folder) / "croissant.json").is_file()


def test_load_metadata_local():
with (
tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder,
tempfile.TemporaryDirectory(suffix="test") as local_folder,
):
metadata = {"name": "test"}
json.dump(metadata, (pathlib.Path(dataset_folder) / "croissant.json").open("w"))
get_preparator(local_folder, dataset_folder)._save_croissant_metadata()
croissant_path = pathlib.Path(local_folder) / "croissant.json"
assert croissant_path.is_file()
assert json.loads(croissant_path.open("r").read()) == metadata


def test_absent_metadata_local():
with (
tempfile.TemporaryDirectory(suffix="dataset") as dataset_folder,
tempfile.TemporaryDirectory(suffix="test") as local_folder,
):
get_preparator(local_folder, dataset_folder)._save_croissant_metadata()
assert not (pathlib.Path(local_folder) / "croissant.json").is_file()