Skip to content

[RFC] Iterable Dataset #2785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Jun 4, 2025

Core Issues

  1. No support for iterable dataset:

    • Dataset has to be fully loaded in memory
    • With map-style, no control over multi-sample operations (e.g. packing or skipping)
    • Map-style is slower
    • No support for streaming
  2. No support for weighted dataset:

    • We have it in a single newly added dev recipe/config, but API needs polishing
    • We also support ConcatDataset, but it's map style and there is no weighting
  3. No support for on-the-fly data packing:

    • It's done before training, taking a long time for large datasets

UX Issues

  1. Unclear boundaries between HF and torchtune args:
def alpaca_dataset(
    # --- message specific args ---
    train_on_input: bool = True,

    # --- torchtune specific args ---
    tokenizer: ModelTokenizer,
    packed: bool = False,

    # --- HF loading args ---
    source: str = "tatsu-lab/alpaca",
    column_map: Optional[Dict[str, str]] = None,
    split: str = "train",
    **load_dataset_kwargs: Dict[str, Any],

    # --- HF dataset method ---
    filter_fn: Optional[Callable] = None,
) -> Union[SFTDataset, PackedDataset]:
  1. Lack of dataloader args:
    • Args are scattered in the config
    • Important args are not exposed (e.g. num_workers, pin_memory)
dataset:
  _component_: torchtune.datasets.multimodal.the_cauldron_dataset
seed: null
batch_size: 8
shuffle: True
collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
  1. Different datasets have different arguments due to different message transforms

Principles

  • Common API signatures for all datasets
  • Offload what we can to HF datasets methods directly
  • Less protagonism from our functions (e.g. config manipulations, instantiation). Not the focus of this diff.

Proposal

In the diff

Felipe Mello added 4 commits June 4, 2025 12:01
Copy link

pytorch-bot bot commented Jun 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2785

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fa61d02 with merge base 4ff30ca (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 4, 2025
@krammnic
Copy link
Contributor

krammnic commented Jun 4, 2025

I like this

@codecov-commenter
Copy link

codecov-commenter commented Jun 4, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 62.80%. Comparing base (9cb77af) to head (fa61d02).
Report is 4 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2785      +/-   ##
==========================================
+ Coverage   60.08%   62.80%   +2.72%     
==========================================
  Files         435      435              
  Lines       26742    26743       +1     
==========================================
+ Hits        16067    16797     +730     
+ Misses      10675     9946     -729     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@Darktex Darktex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: what's the plan for introducing packing-on-the-fly?

I was thinking that it may just be more expedient to get everything done in one go (datamixing, iterables, pack on the fly)

# Alternative: use dataclasses?

def setup_data(
dataset_cfg: ConfigDict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is at the core of why I wouldn't get into the business of constructing HF datasets ourselves.

I'd much rather standarize Tune's interface around generic iterable datasets, and this is an interface that very naturally plugs into HF datasets. If and when a user wants to use one of those, they can construct it by themselves and simply pass it over

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually had a somewhat similar thought but at the definition of the HfIterableDataset. At the top level I think some mixing of dataset config, dataloader config, and other stuff like packing will be necessary (though I would definitely like to see this signature simplified).

However, regarding the comment about not constructing HF datasets ourselves, I'm not sure I agree. Hugging Face has its own IterableDataset class, which is what plugs into interleave_datasets, load_dataset, split_dataset_by_node, ... While I agree it's not ideal to take an external dependency for something so fundamental, I am also cognizant that there is a lot of stuff in there. So I am wary to build from scratch because (a) we take on a lot of new maintenance burden and (b) if we really want to interop with all those nice APIs we wind up being pretty tied to HF's IterableDataset design either way. But do lmk if I'm missing your point here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I agree that eventually Tune's interface should be based on generic iterable principles, but in this case I think it's a little of "don't pre-optimize."

The small datasets library gives us a lot of power right now and I don't think we have enough signal to know really what a perfect generic iterable design would be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an extra note: it should be easy for users to load local data in many types of formats (reference)

the setup_data only assumes a HfIterableDataset when using interleave_datasets. We might move this interlave logic to build_multidataset. Then, if the user wants to completely swap HfIterableDataset with their own class, the setup_data will instantiate it and pass to the dataloader.


def state_dict(self):
state_dict = self.ds.state_dict()
state_dict["weight"] = self.weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to keep track of how far along we are in the iter, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that will be tracked in the dataloader state dict

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the dataloader, i think that the IterableDataset statedict, which we load/save directly and just append the extra weight arg.


- Common API signatures for all datasets
- Offload what we can to HF datasets methods directly
- Less protagonism from our functions (e.g. config manipulations, instantiation). Not the focus of this diff.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean?

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our recipes have all sort of extra logic to add/remove/manipulate configs. Its not intuitive or transparent. For example,

  • we often have in our recipes config.pop, config.get("arg", "some default"), cfg[arg] = new_arg
  • if/else that changes objects, e.g. collate function is replaced if packing=True, dataset type is changed to ConcatDataset if input is a list.

I would prefer a design that does instantiate(cfg), without manipulating it extensively.

An example on how we could achieve that through the config:

  • Use ${tokenizer} in the dataset definition
  • Pass ${tokenizer.max_seq_len} to packing directly
  • Exposing component + Dataclasses
  • Have a default_sft.yaml that all configs use listing all args

but i didnt word it too well. I tried to say that although i believe it, i didnt focus too much on it in this diff.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a high level I don't think the configs need to expose how the options will be used inside of the recipe. The configs are just common knobs users would want to control and keeping them simple and short as possible is better than exposing the structure of the recipe through the config and encouraging programming via config. The recipe itself (likely through a dataclass) should be responsible for explaining what the different options and defaults do.

That said, I think it's good that we're making datasets much more extendable here.

Comment on lines +140 to +146
dataset_defaults:
shuffle_buffer_size: 1000
num_shards_per_worker: 16
seed: ${seed}
tokenizer: ${tokenizer}
recipe_transform:
_component_: torchtune.datasets.SFTTransform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, these are properties that are shared across all datasets in a MultiDataset, right? Maybe controversial but personally I would try to nest the defaults under some kind of top-level builder. Otherwise it can be easy to miss that they're coupled to the dataset configs.

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you give an example of what you have in mind?

i thought about having an equivalent of the currentConcatDataset. The multidataset instantiation and interleave would be moved from setup_data to this builder, something like:

def build_multidataset(dataset_list, dataset_default_args, multidataset_stopping_strategy):
    # instantiate
    for base_cfg in dataset_list:
        weight = base_cfg.get("weight", 1.0)
        weights.append(weight)

        base_cfg = OmegaConf.merge(dataset_default_args, base_cfg)
        ds = instantiate(base_cfg)
        iterable_datasets.append(ds)

    # Interleave for multidataset
    if len(iterable_datasets) > 1:
        weights = normalize_weights(weights)  # sum to 1
        ds = interleave_datasets(
            iterable_datasets,
            probabilities=weights,
            seed=seed,
            # strategies: https://huggingface.co/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy
            stopping_strategy=multidataset_stopping_strategy,
        )
    else:
        ds = iterable_datasets[0]

   return ds

yaml

dataset:
	_component_: torchtune.datasets.build_multidataset
	dataset_list:
		- _component_: torchtune.datasets.build_alpaca_dataset
		  load_args:
		      split: "valid"
		  weight: 0.8
		- _component_: torchtune.datasets.build_gsm8k_dataset
		  message_transform:
		    system_prompt: "bar"
		  weight: 0.2
	default_args:
		    shuffle_buffer_size: 1000
		    num_shards_per_worker: 16
		    seed: ${seed}
		    tokenizer: ${tokenizer}
	multidataset_stopping_strategy: "first_exhausted" # or "all_exhausted"
	dataset_default_args:
		    arg_1: ...

I think its doable, but in this setting the builder would have to instantiate the datasets....Unless we add extra logic to setup_data, like we do today, but it would be ugly with all the config manipulation.

deft setup_data(...):
	....
	if hasattr(dataset_config, "dataset_list"):
		# Pop the arg, which is weird. We give an arg that is not used by the constructor
		dataset_default_args = dataset_config.pop("dataset_default_args")
	    for base_cfg in dataset_config:
	        weight = base_cfg.get("weight", 1.0)
	        weights.append(weight)
	
	        base_cfg = OmegaConf.merge(dataset_default_args, base_cfg)
	        ds = instantiate(base_cfg)
	        iterable_datasets.append(ds)
	    
	    multidataset = instantiate(dataset_config, dataset_list = ds, weights=weights)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the second option but we can simplify the recipe but saying that cfg.datasets always returns a list and making concat_dataset a function instead of a special dataset.

ds = concat_datasets([config.instantiate(d, self.tokenizer) for d in cfg.datasets])

Then we need to handle how we define packed too.

Comment on lines +151 to +155
dataset_setup:
packing:
_component_: torchtune.datasets.packing.SFTPacking
max_seq_len: ${tokenizer.max_seq_len}
multidataset_stopping_strategy: "first_exhausted" # or "all_exhausted"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I don't really understand either. Seems like packing should be its own thing and stopping strategy should be treated similarly to the other defaults above (i.e. it is basically a property of the multidataset)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation is that these 2 args are NOT dataset input args. I.e., i cannot do ds = dataset(**config), so i put them apart. I would like to avoid having the recipes selectively removing/adding args from/to a config, e.g. cfg.pop(arg_a). IMO it adds bloat and complexity.

If we add the multidataset builder, then multidataset_stopping_strategy becomes a dataset arg.
Packing could become its own thing, i guess.

Comment on lines +206 to +214
load_args: Dict,
message_transform: Callable,
tokenizer: Callable,
recipe_transform: Callable,
shuffle_buffer_size: Optional[int] = 1000,
seed: Optional[int] = 42,
num_shards_per_worker: int = 16,
weight: float = 1.0,
filter_args: Optional[Dict] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very unintuitive set of args imo. We are basically mixing arbitrary dicts to pass through to HF + composable stuff like transforms + very fundamental stuff like seed. We may need to layer this differently

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have something in mind?

We could move shuffle and split_by_node to setup_data. But this complicates things a bit if we end up adding a build_concat_dataset, because we need to shuffle -> split -> interleave, in this order.

the new signature would be:

class HfIterableDataset(IterableDataset, Stateful):
    def __init__(
        self,
        *,
        load_args: Dict,
        message_transform: Callable,
        tokenizer: Callable,
        recipe_transform: Callable,
        # shuffle_buffer_size: Optional[int] = 1000,
        # seed: Optional[int] = 42,
        num_shards_per_worker: int = 16,
        weight: float = 1.0,
        filter_args: Optional[Dict] = None,

We could also replace transform = [message_transform, tokenizer, recipe_transform]

# Alternative: use dataclasses?

def setup_data(
dataset_cfg: ConfigDict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually had a somewhat similar thought but at the definition of the HfIterableDataset. At the top level I think some mixing of dataset config, dataloader config, and other stuff like packing will be necessary (though I would definitely like to see this signature simplified).

However, regarding the comment about not constructing HF datasets ourselves, I'm not sure I agree. Hugging Face has its own IterableDataset class, which is what plugs into interleave_datasets, load_dataset, split_dataset_by_node, ... While I agree it's not ideal to take an external dependency for something so fundamental, I am also cognizant that there is a lot of stuff in there. So I am wary to build from scratch because (a) we take on a lot of new maintenance burden and (b) if we really want to interop with all those nice APIs we wind up being pretty tied to HF's IterableDataset design either way. But do lmk if I'm missing your point here.

# Consolidate all dataloader args here (currently scattered)
##########
dataloader:
_component_: torchdata.stateful_dataloader.StatefulDataLoader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any use cases or examples where in someone built their own dataloader? If we don't have strong evidence internally or externally that this is needed, maybe we can just expose the arguments and assume that we're using the StatefulDataLoader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate a bit on why removing the component here would be net positive? Personally, i like that we are pulling the curtains and telling the user "this is exactly how these args will be used"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think torchtune should always "lazily" expose options to the config. This keeps our testing surface smaller as we're not claiming to support other dataloaders out of the box (what's the contract we're expecting here? state_dict for one". I think you can keep this section but just with the component removed.

num_shards_per_worker: 16
seed: ${seed}
tokenizer: ${tokenizer}
recipe_transform:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the recipe transform should not be modifiable from the config. It's very much tied to how your specific recipe works and therefore should be a 1:1 mapping.

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with removing it. I think its a tradeoff of config bloat X recipe does too much magic. Exposing it helps with self documentation. But i think that most of team prefer the recipe magic over config bloat

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recipe isn't doing magic because the recipe code is visible to the user, configs are just like parameters to a function.

```python
def alpaca_dataset(
*,
load_args: Optional[Dict],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: load_dataset_kwargs to mirror the naming from hf

# Unify args
if not message_transform and isinstance(message_transform, dict):
# Remove component key since we're using alpaca_message_transform as default
message_transform.pop("_component_", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Components should have no idea about a config system as much as possible. They are standalone pieces that can be passed around or used elsewhere.

Any config manipulation should be done in the recipe.


def state_dict(self):
state_dict = self.ds.state_dict()
state_dict["weight"] = self.weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that will be tracked in the dataloader state dict

# Alternative: use dataclasses?

def setup_data(
dataset_cfg: ConfigDict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I agree that eventually Tune's interface should be based on generic iterable principles, but in this case I think it's a little of "don't pre-optimize."

The small datasets library gives us a lot of power right now and I don't think we have enough signal to know really what a perfect generic iterable design would be.


Options:

1. Make setup_data an utility, and have two utilities supporting old and new config formats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there idea here to have something like recipes/recipe_utils.py that would contain all these builder functions?

Or maybe it's put in torchtune core?

I'm not sure how to handle this cc @pbontrager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea would be to convert things at the config level. If you pass in the old config value, then we recognize that in init and translate the config to the new config. I think we need to support the old configs for a bit, and the actual same datasets, but not the exact same implementations, they won't be map style for example.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a number of comments. Hope we can discuss again and iterate quickly.


- Common API signatures for all datasets
- Offload what we can to HF datasets methods directly
- Less protagonism from our functions (e.g. config manipulations, instantiation). Not the focus of this diff.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a high level I don't think the configs need to expose how the options will be used inside of the recipe. The configs are just common knobs users would want to control and keeping them simple and short as possible is better than exposing the structure of the recipe through the config and encouraging programming via config. The recipe itself (likely through a dataclass) should be responsible for explaining what the different options and defaults do.

That said, I think it's good that we're making datasets much more extendable here.

# Consolidate all dataloader args here (currently scattered)
##########
dataloader:
_component_: torchdata.stateful_dataloader.StatefulDataLoader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think torchtune should always "lazily" expose options to the config. This keeps our testing surface smaller as we're not claiming to support other dataloaders out of the box (what's the contract we're expecting here? state_dict for one". I think you can keep this section but just with the component removed.

#########

# Option 1: Direct Class Usage (current SFTDataset approach)
dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

datasets


# Option 1: Direct Class Usage (current SFTDataset approach)
dataset:
- _component_: torchtune.datasets.HfIterableDataset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main benefit of exposing component here is if we support multiple datasets, otherwise it's just boilerplate. I know we've went back and forth on this, but I think the current setup is a great abstraction for covering every use case and at the same time not how people are trained to understand datasets at all. I think the existing datasets we have could just be cleaned up and extended it a bit but not thrown out completely. You'd end up with something like:

datasets:
  - _component_: torchtune.datasets.alpaca_dataset
    weight: .8
  - _component_: torchtune.datasets.sft_dataset
    source: "tatsu-lab/gsm8k"
    column_map:
      input: "prompt"
      output: "response"
      system_prompt: "bar"
    weight: .2

You can generalize SFTDataset with the new transform to be the new IterableDataset but maybe hide the message transforms and dataset transforms all together with builders. A user can write their own builder if they need too. Also the args that are needed here don't need to be standaridized since they're local to the component but you could create a standard interface if we wanted.

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that keeping what we have would be a step back.

  • the functions cannot be the same, because there are new args related to iterable dataset, and we have to remove 'packed' option.

  • a big part of this RFC was to standardize our datasets inputs, which are now a bit schizophrenic, so we can a) rely on HF instead more heavily; b) reduce cognitive load; c) make it easy to extend

compare these two, for example:

class PreferenceDataset(Dataset):
    def __init__(
        self,
        *,
        source: str,
        message_transform: Transform,
        tokenizer: ModelTokenizer,
        filter_fn: Optional[Callable] = None,
        packed: bool = False,
        **load_dataset_kwargs: dict[str, Any],
    ) -> None:

def alpaca_dataset(
    tokenizer: ModelTokenizer,
    *,
    source: str = "tatsu-lab/alpaca",
    column_map: Optional[dict[str, str]] = None,  -----> # different
    train_on_input: bool = True,  #  ----->  different
    packed: bool = False,
    filter_fn: Optional[Callable] = None,
    split: str = "train",  #  -----> different
    **load_dataset_kwargs: dict[str, Any],
) -> Union[SFTDataset, PackedDataset]:

TextCompletionDataset has add_eos

slimorca_dataset has new_system_prompt

cnn_dailymail_articles_dataset has max_seq_len

num_shards_per_worker: 16
seed: ${seed}
tokenizer: ${tokenizer}
recipe_transform:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recipe isn't doing magic because the recipe code is visible to the user, configs are just like parameters to a function.

# Common Dataset Arguments
# Used as cfg = dataset_defaults.update(dataset_cfg)
#########
dataset_defaults:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this defaults section, what objects are these parameters for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we dont have a default section, then when using multidataset, these have to be repeated for every single dataset. In setup_data, we take these defaults and apply them to the dataset configs before instantiation.

Comment on lines +140 to +146
dataset_defaults:
shuffle_buffer_size: 1000
num_shards_per_worker: 16
seed: ${seed}
tokenizer: ${tokenizer}
recipe_transform:
_component_: torchtune.datasets.SFTTransform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the second option but we can simplify the recipe but saying that cfg.datasets always returns a list and making concat_dataset a function instead of a special dataset.

ds = concat_datasets([config.instantiate(d, self.tokenizer) for d in cfg.datasets])

Then we need to handle how we define packed too.

Location: torchtune/datasets/hf_iterable_dataset.py

```python
class HfIterableDataset(IterableDataset, Stateful):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually like SFTDataset. I think you can rename it to IterableDataset and make it an abstract class where you have to define/override a recipe_transform and message_transform method and then keep flat init args that match HF more closely. I think this would move us away from having a bunch of separately defined transforms and datasets.

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would move us away from having a bunch of separately defined transforms and datasets.

I dont follow. Can you expand on it? Which transforms and which datasets? In both cases there is only one recipe transform 'SFTTransform'.

where you have to define/override a recipe_transform and message_transform

This is the same in the proposed case, right? We could create an SFTIterableDataset that has a recipe_transform hardcoded. But message transform has to be an input.

keep flat init args that match HF more closely

So on top of the args described here, you think that we should replace load_args with [path, name, data_dir, split, streaming, load_args_kwargs]?

# ii) we already do it currently
# Alternative: use dataclasses?

def setup_data(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think given all of my comments above, this can be simplified a lot. Also I don't think we need all of these method arguments. You can just pass in the config, plus anything that's computed in setup but not saved to self (though you could argue that most of those could just be computed inside of setup_data).

Copy link
Contributor Author

@felipemello1 felipemello1 Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which parts do you think that can be simplified?

  1. We could get rid of dataset_defaults, but if the user has 10 datasets, i am not sure how to make it easier for them.
  2. data_setup_cfg can be replaced with 'packed' and 'interleave_strategy', but then those would be scattered args in the config

Not sure what else could be deleted


Options:

1. Make setup_data an utility, and have two utilities supporting old and new config formats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea would be to convert things at the config level. If you pass in the old config value, then we recognize that in init and translate the config to the new config. I think we need to support the old configs for a bit, and the actual same datasets, but not the exact same implementations, they won't be map style for example.


base_cfg = OmegaConf.merge(dataset_defaults, base_cfg)
ds = instantiate(base_cfg)
iterable_datasets.append(ds)
Copy link
Contributor Author

@felipemello1 felipemello1 Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ds = concat_dataset(
	[
		instantiate(dataset_cfg, 
		tokenizer=tokenizer, 
		seed=seed, 
		recipe_transform=recipe_transform) for dataset_cfg in dataset_cfgs
	],
	stopping_strategy=stopping_strategy,
)

@felipemello1 felipemello1 mentioned this pull request Jun 26, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants