Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create samples with padding to avoid truncations #186

Merged
merged 71 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
4f84301
introduce flash_attn_varlen_func and docwise position ids
sohamparikh Feb 12, 2025
f20cf6b
merge with main
sohamparikh Feb 18, 2025
dcb2bb1
make the basics work
sohamparikh Feb 19, 2025
5331d8d
use bos for separating docs
sohamparikh Feb 25, 2025
4e256aa
option to disable packing
sohamparikh Feb 25, 2025
74c2d94
fix
sohamparikh Feb 27, 2025
9f43df9
merge main
sohamparikh Feb 27, 2025
cf5fc8a
Merge branch 'main' into soham/cross-document-attn
sohamparikh Mar 5, 2025
90061b9
revert doc truncation
sohamparikh Mar 5, 2025
ba0e649
fix for sequence data parallel
sohamparikh Mar 6, 2025
7730a4c
make it work for abs positions and backup attn
sohamparikh Mar 6, 2025
f3c540b
pre-compute cumulative sequences, make position embeddings compatible
sohamparikh Mar 7, 2025
cd3244c
fix
sohamparikh Mar 7, 2025
f86f469
move to GPU once
sohamparikh Mar 7, 2025
b98ba1b
Merge branch 'main' into soham/cross-document-attn
sohamparikh Mar 7, 2025
1b17f59
single config flag, sequence_data_parallel works!
sohamparikh Mar 11, 2025
b3ee7c4
fix backupattn, absolute positions
sohamparikh Mar 11, 2025
fedbb07
move to preprocessor
sohamparikh Mar 11, 2025
8d768b7
imports, rename seqlens
sohamparikh Mar 11, 2025
79aca1a
remove unused config
sohamparikh Mar 11, 2025
801ac3e
add tests
sohamparikh Mar 11, 2025
489a5aa
comments
sohamparikh Mar 11, 2025
56c1fd2
pad legacy sampler
sohamparikh Mar 13, 2025
30da8bf
make it work
sohamparikh Mar 13, 2025
e6f405c
test
sohamparikh Mar 13, 2025
152d272
TODO for now
sohamparikh Mar 13, 2025
9624a9e
pad query lengths
sohamparikh Mar 13, 2025
481f870
rename config
sohamparikh Mar 14, 2025
a2a98ec
fix
sohamparikh Mar 14, 2025
8778878
maybe better naming
sohamparikh Mar 14, 2025
cf78f57
pad legacy sampler
sohamparikh Mar 13, 2025
466278b
make it work
sohamparikh Mar 13, 2025
217e6a0
test
sohamparikh Mar 13, 2025
da6c5d0
TODO for now
sohamparikh Mar 13, 2025
60dfd6f
merge with main
sohamparikh Mar 14, 2025
b2c0104
merge main
sohamparikh Mar 17, 2025
5b8239c
fix merging errors
sohamparikh Mar 17, 2025
5f3a812
simplify sampler
sohamparikh Mar 17, 2025
d5b39b0
rename
sohamparikh Mar 17, 2025
25813f4
new padded sampler
sohamparikh Mar 18, 2025
65e24db
fixes
sohamparikh Mar 18, 2025
6641997
remove comments
sohamparikh Mar 18, 2025
71b9e34
legacy does not support padding
sohamparikh Mar 18, 2025
12b5471
base sampled class
sohamparikh Mar 18, 2025
701d7de
fix
sohamparikh Mar 19, 2025
95ccc8a
use token cumsum for padded sampling
sohamparikh Mar 22, 2025
9f9c0be
rename to allow_truncations
sohamparikh Mar 22, 2025
f4691a1
warning
sohamparikh Mar 22, 2025
81adfd6
cleanup
sohamparikh Mar 22, 2025
50f38bb
fix
sohamparikh Mar 22, 2025
7408d7a
cleanup
sohamparikh Mar 22, 2025
4ccef04
fix
sohamparikh Mar 22, 2025
63465d3
simplify
jlamypoirier Mar 25, 2025
e8fccbb
cleanup
jlamypoirier Mar 25, 2025
966d5a1
cleanup
sohamparikh Mar 25, 2025
4a53e3f
fix
sohampnow Mar 25, 2025
de0bb13
fix
sohamparikh Mar 25, 2025
5748af0
fix
sohamparikh Mar 25, 2025
0147e3e
fix test
sohamparikh Mar 25, 2025
7d43786
fix
sohamparikh Mar 26, 2025
e3780fd
cleanup
sohamparikh Mar 26, 2025
1f21feb
sampling tests
sohamparikh Mar 26, 2025
77c46f3
length filter
sohamparikh Mar 26, 2025
88c9aff
clean
sohamparikh Mar 26, 2025
9e18838
fix cpp cumsum
sohamparikh Mar 27, 2025
c4586bc
many small tests
sohamparikh Mar 27, 2025
c202053
minor
sohamparikh Mar 28, 2025
17a14e3
review
sohamparikh Mar 29, 2025
c3aad50
fix low samples, test
sohamparikh Mar 29, 2025
2cd5cce
fix num_epochs
sohamparikh Apr 1, 2025
ab5d751
Merge branch 'main' into soham/padded-sampler
sohamparikh Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion fast_llm/csrc/data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

/*
Helper methods for fast index mapping builds.
Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx.
Changes for Fast-LLM: Use int16 for dataset index, add verbose argument to build_sample_idx, add build_sample_idx_padded
*/

#include <iostream>
Expand Down Expand Up @@ -129,6 +129,65 @@ py::array build_sample_idx(const py::array_t<int32_t>& sizes_,

}

py::array build_padded_token_cumsum(const py::array_t<int32_t>& sizes_,
const int32_t seq_length,
const int32_t token_cumsum_rate,
const int64_t offset
) {
/*
Build token cumsums at regular intervals from document sizes with padding in mind.
We inject 0 or more padding tokens at the end of every sequence to fill the sequence length.
*/
int32_t seq_size = 0;
int64_t sizes_idx = 0;
int32_t samples = 0;
auto sizes = sizes_.unchecked<1>();
std::vector<int64_t> token_cumsum;

int64_t cumsum = offset;

while (sizes_idx < sizes.size()) {
int32_t size = sizes[sizes_idx];
if (size > seq_length) {
// Skip sequences that are too long, to avoid truncations
if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum);
sizes_idx += 1;
samples += 1;
} else if (seq_size + size > seq_length) {
// add padded tokens if a document does not fit in current sequence and start a new sequence
cumsum += seq_length - seq_size;
seq_size = 0;
} else {
// Increment here to account for padding. This ensures that the stored values match the beginning of the next document.
if (samples % token_cumsum_rate==0) token_cumsum.push_back(cumsum);
seq_size += size;
cumsum += size;
sizes_idx += 1;
samples += 1;
}
}

// Add a final (padded) entry so we know how many tokens there are in total.
cumsum += seq_length - seq_size;
token_cumsum.push_back(cumsum);


int64_t* token_cumsum_result = new int64_t[token_cumsum.size()];
memcpy(token_cumsum_result, token_cumsum.data(), token_cumsum.size() * sizeof(int64_t));

py::capsule free_when_done(token_cumsum_result, [](void *mem_) {
int64_t *mem = reinterpret_cast<int64_t*>(mem_);
delete[] mem;
});

const auto byte_size = sizeof(int64_t);
return py::array(std::vector<int64_t>{token_cumsum.size()},
{byte_size},
token_cumsum_result,
free_when_done);
}

PYBIND11_MODULE(data, m) {
m.def("build_sample_idx", &build_sample_idx);
m.def("build_padded_token_cumsum", &build_padded_token_cumsum);
}
9 changes: 9 additions & 0 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ class GPTDataConfig(DataConfig, GPTLegacyConfig):
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)
truncate_documents: bool = Field(
default=True,
desc=(
"If enabled, documents may be truncated while being packed to fit the sequence length."
"Otherwise, sequences will be padded such that every document lies entirely within a sample"
" (and documents exceeding the sequence length will be skipped altogether)."
),
hint=FieldHint.feature,
)

def _validate(self) -> None:
if not self.datasets:
Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def setup(
sequence_length=self._max_sequence_length,
vocab_size=self._vocab_size,
tokenizer=self._tokenizer,
truncate_documents=self._config.truncate_documents,
cross_document_attention=self._cross_document_attention,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class GPTSamplingData(SamplingData):
sequence_length: int
vocab_size: int
tokenizer: "Tokenizer"
truncate_documents: bool = True
cross_document_attention: bool = True


Expand Down
165 changes: 122 additions & 43 deletions fast_llm/data/dataset/gpt/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType
from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset
from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.utils import Assert

try:
from fast_llm.csrc.data import build_sample_idx # noqa
from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa

_extension_available = True
except ImportError:
Expand Down Expand Up @@ -89,6 +89,7 @@ def __init__(
self._sequence_length = sampling.sequence_length
self._cross_document_attention = sampling.cross_document_attention
self._config = sampling.config
self._truncate_documents = sampling.truncate_documents
self._device = torch.device("cuda" if self._config.gpu else "cpu")

if sampling.cache_directory is None:
Expand Down Expand Up @@ -124,15 +125,35 @@ def _sample(self) -> None:
"""
# Get the document sizes, the main information needed for sampling.
document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device)

# Calculate basic stats.
documents_per_epoch = document_sizes.numel()
tokens_per_epoch = document_sizes.sum().item()

# Calculate basic stats.
if not self._truncate_documents:
assert _extension_available, (
"The C++ extension for dataset sampling is missing."
" Please make sure Fast-LLM is installed correctly."
)
long_docs_filter = document_sizes > self._sequence_length + 1
ignored_documents = sum(long_docs_filter)
if ignored_documents:
log_main_rank(
f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.",
log_fn=logger.warning,
)
tokens_per_epoch = document_sizes[~long_docs_filter].sum().item()
if tokens_per_epoch == 0:
raise RuntimeError(
f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}."
)
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
# but we also include that last label in the following sample,
# but in case of truncations we also include that last label in the following sample,
# so we need `sequence_length * num_samples + 1` tokens in total.
num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / tokens_per_epoch)
num_epochs = math.ceil(
((self._sequence_length + 1 - self._truncate_documents) * self._num_samples + 1 * self._truncate_documents)
/ tokens_per_epoch
)

# Prepare for shuffling.
generator = torch.Generator(device=self._device)
Expand All @@ -154,13 +175,17 @@ def _sample(self) -> None:
"num_samples": self._num_samples,
"unshuffled_epochs": unshuffled_epochs,
"sequence_length": self._sequence_length,
"truncate_documents": self._truncate_documents,
"config": self._config.to_serialized(),
}
self._load_yaml_data(yaml_data)

if self._yaml_path is not None:
if self._yaml_path.is_file():
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r"))
unshuffled_tokens = loaded_yaml_data.pop("unshuffled_tokens", None)
if unshuffled_tokens is not None:
self._unshuffled_tokens = unshuffled_tokens
if loaded_yaml_data != yaml_data:
raise RuntimeError(
f"Invalid dataset cache for dataset {self.name}."
Expand All @@ -172,9 +197,6 @@ def _sample(self) -> None:
# Dataset is already sampled, skip.
logger.info(f"Using existing sampling for dataset {self.name}")
return
else:
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))

if shuffled_documents > 1e8:
warnings.warn(
Expand Down Expand Up @@ -232,51 +254,78 @@ def _sample(self) -> None:
# So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`.
# Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation.
# Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))`
if unshuffled_epochs > 0:
token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum(
document_sizes,
offset=0,
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
)
if self._truncate_documents:
num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled)
else:
num_tokens_unshuffled = 0
self._unshuffled_tokens = num_tokens_unshuffled

if self._yaml_path is not None:
yaml_data["unshuffled_tokens"] = num_tokens_unshuffled
self._yaml_path.parent.mkdir(parents=True, exist_ok=True)
yaml.safe_dump(yaml_data, self._yaml_path.open("w"))

if shuffled_epochs > 0:
token_cumsum_shuffled = self._get_token_cumsum(
token_cumsum_shuffled, num_tokens_shuffled = self._get_token_cumsum(
document_sizes[
# Torch indexing only works with int32 or int64
document_shuffling.to(
dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32
)
],
offset=unshuffled_epochs * tokens_per_epoch,
dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch,
offset=num_tokens_unshuffled,
# TODO: Allowing for max 100% extra tokens for padding, is that enough?
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs),
)
self._token_cumsum_shuffled.save(token_cumsum_shuffled.numpy(force=self._config.gpu))
self._token_cumsum_shuffled.save(token_cumsum_shuffled)
self._document_shuffling.save(
document_shuffling[: (token_cumsum_shuffled.numel() + 1) * TOKEN_CUMSUM_RATE].numpy(
document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy(
force=self._config.gpu
)
)
# Free memory
del token_cumsum_shuffled
del document_shuffling

if unshuffled_epochs > 0:
token_cumsum_unshuffled = self._get_token_cumsum(
document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch
def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]:
if self._truncate_documents:
# Create the output tensor.
out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch)
# Get partial sums for regular intervals, excluding the last incomplete interval.
torch.sum(
sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE),
dim=1,
out=out[1:],
)
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu))

def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: torch.dtype) -> torch.Tensor:
# Create the output tensor.
out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype)
# Get partial sums for regular intervals, excluding the last incomplete interval.
torch.sum(
sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), dim=1, out=out[1:]
)
# Pad with the begin offset
out[0] = offset
# Calculate the cumsum.
out.cumsum_(0)
# Crop unnecessary entries.
return out[
: torch.clamp_min_(
torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"),
0,
# Pad with the begin offset
out[0] = offset
# Calculate the cumsum.
out.cumsum_(0)
# Crop unnecessary entries.
out = out[
: torch.clamp_min_(
torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"),
0,
)
]
return out.numpy(force=self._config.gpu), None
else:
# TODO: dynamically handle int64 or int32 in CPP
out = build_padded_token_cumsum(
sizes.cpu().numpy(), (self._sequence_length + 1), TOKEN_CUMSUM_RATE, offset
)
]
num_tokens = out[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are throwing away long documents, we could end up not generating enough tokens. At the very least we need to add a check for it. Or maybe exclude long documents from tokens_per_epoch?

out = out[:-1][
: np.clip(np.searchsorted(out, self._num_samples * (self._sequence_length + 1), side="right"), 0, None)
]
return out, num_tokens

def __len__(self) -> int:
return self._num_samples
Expand All @@ -288,7 +337,9 @@ def __getitem__(self, index: int) -> typing.Any:
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`).
"""
self._lazy_load()
token_start = index * self._sequence_length
# tokens at the boundary are included in only one sample when we pack without truncations
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
token_start = index * (self._sequence_length + 1 - self._truncate_documents)
token_end = token_start + self._sequence_length + 1

if token_start < self._unshuffled_tokens:
Expand All @@ -302,6 +353,7 @@ def __getitem__(self, index: int) -> typing.Any:
token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1

document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset

token_count = token_start_array[token_start_cumsum_index]

token_ids = []
Expand All @@ -314,6 +366,25 @@ def __getitem__(self, index: int) -> typing.Any:
document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item()

document_size = self._indexed_dataset.get_document_size(document_index)

if not self._truncate_documents:
if document_size > self._sequence_length + 1:
# Document too long, ignore
document_sampling_index += 1
continue
tokens_in_sample = token_count % (self._sequence_length + 1)
if document_size + tokens_in_sample > self._sequence_length + 1:
# Document belongs to the next sample, need to account for padding.
padding_size = self._sequence_length + 1 - tokens_in_sample
if token_count > token_start:
# Add padding tokens to current sample
token_ids.append(np.full((padding_size,), -100, dtype=np.int64))
Assert.eq(token_count + padding_size, token_end)
break
else:
# Move on to the next sample.
token_count += padding_size

# Determine if the document belongs to the requested sample.
if token_count + document_size >= token_start:
# Determine which part of the document belong to the sample, and add it to the list.
Expand Down Expand Up @@ -343,7 +414,9 @@ def __getitem__(self, index: int) -> typing.Any:
)
token_ids = np.concatenate(token_ids, dtype=np.int64)
loss_masking_spans = (
np.stack(loss_masking_spans, dtype=np.int32) if self._config.use_loss_masking_spans else None
(np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([]))
if self._config.use_loss_masking_spans
else None
)
Assert.eq(len(token_ids), self._sequence_length + 1)

Expand All @@ -357,9 +430,12 @@ def _lazy_load(self):
if not hasattr(self, "_documents_per_epoch"):
self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r")))

def _load_yaml_data(self, data: dict[str, typing.Any]):
def _load_yaml_data(self, data: dict[str, typing.Any]) -> None:
self._documents_per_epoch = data["dataset"]["documents_per_epoch"]
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"]
if unshuffled_tokens := data.get("unshuffled_tokens") is not None:
self._unshuffled_tokens = unshuffled_tokens
else:
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"]
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch


Expand All @@ -380,9 +456,12 @@ def __init__(
self._indexed_dataset = indexed_dataset
self._num_samples = sampling.num_samples
self._sequence_length = sampling.sequence_length
if not sampling.truncate_documents:
raise NotImplementedError(
"Legacy sampling only supports document truncation. Please use the latest dataset format."
)
self._cross_document_attention = sampling.cross_document_attention
self._config = sampling.config
self._tokenizer = sampling.tokenizer

if sampling.cache_directory is None:
log_main_rank(
Expand Down Expand Up @@ -498,7 +577,7 @@ def __getitem__(self, idx: int) -> typing.Any:
for span in sample.loss_masking_spans:
spans.append(span + offset)
offset += len(sample.token_ids)
spans = np.stack(spans, dtype=np.int32)
spans = np.stack(spans, dtype=np.int32) if spans else np.array([])
else:
spans = None
sequence_lengths = (
Expand Down
Loading