Skip to content

Iterable Dataset #2852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: impl-step-based-ckpt
Choose a base branch
from

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Jun 26, 2025

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Enable Iterable datasets in torchtune.

CONTEXT: built on top of ongoing PR step-based-ckpt: #2384

TIps when reviewing this pr

Follow this order:

  1. recipes/configs/llama3_2/3B_full.yaml: see the configs
  2. torchtune/datasets/_iterable_base.py: base class for iterable dataset
  3. torchtune/datasets/_hf_iterable.py: ds based on HF -- Can be replaced easily. Downstream does not expect HF.
  4. torchtune/datasets/_interleaved.py: interleave the datasets
  5. torchtune/data/_metrics.py: metrics transform to create the metrics
  6. torchtune/data/_aggregator.py: aggregate the metrics at the recipe level
  7. recipes/full_finetune_distributed.py: everything put together
  8. unit tests

torchtune/datasets/_hf_iterable.py

Changelog

  1. Datasets are infinite
  2. User doesn't define epochs anymore, but training steps (how many times we update the optimizer)
  3. Support for dataset mixing -- follow up PRs is to enable curriculum learning
  4. Support for dataset metric logging -- User can understand epoch per dataset, distribution of token lens, etc. Easy to add new metrics.
  5. HF agnostic. Even though the current dataset is HF, the dataloader, packed, datamixing, metric logging is agnostic to it
  6. Well tested in distributed setting -- WARNING: need better testing for multiprocess dataloader. It doesnt guarantee determinism, so I postponed testing this setting

Config and builder design based on the discussions after this RFC: #2785

Next steps:
7. Gather feedback on metric logging. E.g. we can add more aggregation types.
8. Polish the code a little bit
9. Add packing from this RFC: #2819
10. Add curriculum learning
11. Docs?

Test plan

image image image

UNTESTED: resume from ckpt in the recipe. However, we have plenty of tests showing that resuming works for these iterable datasets.

Copy link

pytorch-bot bot commented Jun 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2852

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job, 1 Unrelated Failure

As of commit 93fa743 with merge base 3d73591 (image):

NEW FAILURE - The following job has failed:

  • GPU tests / gpu_test (3.11, stable) (gh)
    tests/recipes/test_qat_lora_finetune_distributed.py::TestQATLoRAFinetuneDistributedRecipe::test_training_state_on_resume_with_async_checkpointing[llama3/8B_qat_lora-llama3-tune-False]

CANCELLED JOB - The following job was cancelled. Please retry:

  • GPU tests / gpu_test (3.10, stable) (gh)
    tests/recipes/test_qat_lora_finetune_distributed.py::TestQATLoRAFinetuneDistributedRecipe::test_training_state_on_resume_with_async_checkpointing[llama3/8B_qat_lora-llama3-tune-False]

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2025
@felipemello1 felipemello1 changed the title first commit Iterable Dataset Jun 26, 2025
@@ -94,3 +95,72 @@ def slimorca_dataset(
)
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len)
return ds


def slimorca_iterable_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added here to demonstrate datamix iterable dataset with this example. Personally, i dislike exposing all of the args and defaults. I would prefer to expose only whats specific to this builder.

return tokenized_dict


def sft_iterable_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only purpose is to hardcode the 'output_transform'

Comment on lines +101 to +104
logger.warning(
f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. "
"This is unexpected for an infinite dataset. Re-initializing its iterator."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure i like this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this: simply have a subclass for InfiniteIterable so this is super explicit

@@ -101,3 +102,64 @@ def alpaca_dataset(
original Alpaca dataset, `yahma/alpaca-cleaned <https://huggingface.co/datasets/yahma/alpaca-cleaned>`_.
See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details.
"""


def alpaca_iterable_dataset(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added here to demonstrate datamix iterable dataset with this example. Personally, i dislike exposing all of the args and defaults. I would prefer to expose only whats specific to this builder.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you are doing this with ``load_dataset_kwargs, right? Or did you mean something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's a function, so... get_alpaca_iterable_dataset?

Copy link
Contributor

@Darktex Darktex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR! I mainly had a question on the interaction with packing and on the SFT transform


# Each stat becomes its own metric
# For percentiles, it is an approximattion by computing avg of averages
metrics[(ds_name, f"{metric_name}_mean")] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not grouped:
return reduced

for key, metric_dicts in grouped.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems weird to write the code below twice. Can you factor it out and simply call it after reducing on rank zero?

agg_type: AggregationType


class MetricTransform(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain a bit more about what's the API? It's not clear to me because we have __call__ which doesn't have a name that tells me what it does :D

...


class StandardMetricTransform(MetricTransform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, the name makes me think this is an ABC. Maybe DefaultTrainingMetricTransform?

@@ -101,3 +102,64 @@ def alpaca_dataset(
original Alpaca dataset, `yahma/alpaca-cleaned <https://huggingface.co/datasets/yahma/alpaca-cleaned>`_.
See the dataset page and :func:`~torchtune.datasets.alpaca_dataset` for more details.
"""


def alpaca_iterable_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you are doing this with ``load_dataset_kwargs, right? Or did you mean something else?

@@ -178,3 +180,99 @@ def __call__(self, sample: Mapping[str, Any]) -> dict[str, Any]:
tokenized_dict = transformed_sample

return tokenized_dict


class SFTOutputTransform(Transform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transform is critical for end to end performance, so this is a piece where we need to spend cycles to optimize.

I see several ways to find more performance in this code:

  1. No numpy. Even if we haven't moved to GPU yet, doing everything in torch ensures that the code will work regardless and makes it more robust
  2. I think np.where allocates?
  3. Asserts are debug-only statements that get disabled if you pass the -O flag, so this check feels more of a runtime check

I asked a LLM to rewrite given these contraints and it gave me this, which looks reasonable on the surface:

import torch

CROSS_ENTROPY_IGNORE_IDX = -100          # set to whatever you use

class SFTOutputTransform(Transform):
    """
    Build the `"labels"` tensor for causal-LM SFT.

    Expects each dataset element to contain **1-D** torch tensors
    * ``"tokens"`` – token IDs, dtype=torch.long
    * ``"mask"``   – bool/int where **True** marks positions to ignore

    Produces ``"labels"`` of the same shape such that

        labels[t] =  tokens[t+1]                # shift left
        labels[t] =  IGNORE_IDX  if mask[t+1]   # respect mask
        labels[-1] = IGNORE_IDX                 # last token has no target

    All ops are vectorised; only one fresh tensor (`labels`) is allocated.
    """

    __slots__ = ()      # tiny perf win (avoids per-instance __dict__)

    def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
        try:
            tokens: torch.Tensor = sample["tokens"]
            mask:   torch.Tensor = sample["mask"]
        except KeyError:
            raise ValueError(
                "SFTOutputTransform expects 'tokens' and 'mask' in the sample."
            )

        if tokens.ndim != 1 or mask.ndim != 1:
            raise ValueError("Both 'tokens' and 'mask' must be 1-D tensors.")

        # ── build labels ────────────────────────────────────────────────
        # 1. pre-fill with IGNORE so we don’t need extra assignments later
        labels = tokens.new_full(tokens.shape, CROSS_ENTROPY_IGNORE_IDX)

        # 2. left-shift via cheap views (no copy)
        labels[:-1].copy_(tokens[1:])

        # 3. apply mask in-place (single fused kernel on GPU/CPU)
        labels[:-1].masked_fill_(mask[1:].bool(), CROSS_ENTROPY_IGNORE_IDX)

        # (labels[-1] is already IGNORE_IDX from the new_full above)

        # ── return a shallow-copied mapping so the original sample stays intact
        out = dict(sample)
        out["labels"] = labels
        return out

self._sampling_generator = torch.Generator().manual_seed(seed)

# Normalize weights to sum to 1
# TODO: make it a property? rely on ds.weight?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd make it a property so it remains visible. Very cheap anyway


while True:
# Sample which dataset to use
ds_idx = torch.multinomial(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmmm, I think we should also log what we did. It's reasonable and cheap to accumulate a list that maps iteration_id to dataset_id in self.datasets. When this guy dumps its state, this log should be part of it

Comment on lines +101 to +104
logger.warning(
f"Child dataset {self._datasets[ds_name].dataset_name} was exhausted. "
"This is unexpected for an infinite dataset. Re-initializing its iterator."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this: simply have a subclass for InfiniteIterable so this is super explicit

from torch.utils.data import IterableDataset


class TuneIterableDataset(IterableDataset, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this guy to interact with packing and IIUC I don't believe this is currently happening?

The algo we should implement is this:

  1. One batch can be made of multiple calls to next. We keep taking until we exceed the max seq len. When we do, we put the last one aside (we'll use it to start the next batch), pad the current one to max len and return.
  2. The calls to next will go to the interleaved dataset, therefore we automatically construct mixed batches from multiple datasets without much effort
  3. Also, every time we call next we should make space for logging transforms (which we are, you already wrote them). I think it's ok to make your metrics transforms and aggregators an optional property here so the semantics are clearer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants