-
Notifications
You must be signed in to change notification settings - Fork 32
create samples with padding to avoid truncations #186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
71 commits
Select commit
Hold shift + click to select a range
4f84301
introduce flash_attn_varlen_func and docwise position ids
sohamparikh f20cf6b
merge with main
sohamparikh dcb2bb1
make the basics work
sohamparikh 5331d8d
use bos for separating docs
sohamparikh 4e256aa
option to disable packing
sohamparikh 74c2d94
fix
sohamparikh 9f43df9
merge main
sohamparikh cf5fc8a
Merge branch 'main' into soham/cross-document-attn
sohamparikh 90061b9
revert doc truncation
sohamparikh ba0e649
fix for sequence data parallel
sohamparikh 7730a4c
make it work for abs positions and backup attn
sohamparikh f3c540b
pre-compute cumulative sequences, make position embeddings compatible
sohamparikh cd3244c
fix
sohamparikh f86f469
move to GPU once
sohamparikh b98ba1b
Merge branch 'main' into soham/cross-document-attn
sohamparikh 1b17f59
single config flag, sequence_data_parallel works!
sohamparikh b3ee7c4
fix backupattn, absolute positions
sohamparikh fedbb07
move to preprocessor
sohamparikh 8d768b7
imports, rename seqlens
sohamparikh 79aca1a
remove unused config
sohamparikh 801ac3e
add tests
sohamparikh 489a5aa
comments
sohamparikh 56c1fd2
pad legacy sampler
sohamparikh 30da8bf
make it work
sohamparikh e6f405c
test
sohamparikh 152d272
TODO for now
sohamparikh 9624a9e
pad query lengths
sohamparikh 481f870
rename config
sohamparikh a2a98ec
fix
sohamparikh 8778878
maybe better naming
sohamparikh cf78f57
pad legacy sampler
sohamparikh 466278b
make it work
sohamparikh 217e6a0
test
sohamparikh da6c5d0
TODO for now
sohamparikh 60dfd6f
merge with main
sohamparikh b2c0104
merge main
sohamparikh 5b8239c
fix merging errors
sohamparikh 5f3a812
simplify sampler
sohamparikh d5b39b0
rename
sohamparikh 25813f4
new padded sampler
sohamparikh 65e24db
fixes
sohamparikh 6641997
remove comments
sohamparikh 71b9e34
legacy does not support padding
sohamparikh 12b5471
base sampled class
sohamparikh 701d7de
fix
sohamparikh 95ccc8a
use token cumsum for padded sampling
sohamparikh 9f9c0be
rename to allow_truncations
sohamparikh f4691a1
warning
sohamparikh 81adfd6
cleanup
sohamparikh 50f38bb
fix
sohamparikh 7408d7a
cleanup
sohamparikh 4ccef04
fix
sohamparikh 63465d3
simplify
jlamypoirier e8fccbb
cleanup
jlamypoirier 966d5a1
cleanup
sohamparikh 4a53e3f
fix
sohampnow de0bb13
fix
sohamparikh 5748af0
fix
sohamparikh 0147e3e
fix test
sohamparikh 7d43786
fix
sohamparikh e3780fd
cleanup
sohamparikh 1f21feb
sampling tests
sohamparikh 77c46f3
length filter
sohamparikh 88c9aff
clean
sohamparikh 9e18838
fix cpp cumsum
sohamparikh c4586bc
many small tests
sohamparikh c202053
minor
sohamparikh 17a14e3
review
sohamparikh c3aad50
fix low samples, test
sohamparikh 2cd5cce
fix num_epochs
sohamparikh ab5d751
Merge branch 'main' into soham/padded-sampler
sohamparikh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,12 @@ | |
from fast_llm.data.dataset.abstract import SampledDataset | ||
from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType | ||
from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset | ||
from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type | ||
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type | ||
from fast_llm.engine.config_utils.run import log_main_rank | ||
from fast_llm.utils import Assert | ||
|
||
try: | ||
from fast_llm.csrc.data import build_sample_idx # noqa | ||
from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa | ||
|
||
_extension_available = True | ||
except ImportError: | ||
|
@@ -89,6 +89,7 @@ def __init__( | |
self._sequence_length = sampling.sequence_length | ||
self._cross_document_attention = sampling.cross_document_attention | ||
self._config = sampling.config | ||
self._truncate_documents = sampling.truncate_documents | ||
self._device = torch.device("cuda" if self._config.gpu else "cpu") | ||
|
||
if sampling.cache_directory is None: | ||
|
@@ -124,15 +125,35 @@ def _sample(self) -> None: | |
""" | ||
# Get the document sizes, the main information needed for sampling. | ||
document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) | ||
|
||
# Calculate basic stats. | ||
documents_per_epoch = document_sizes.numel() | ||
tokens_per_epoch = document_sizes.sum().item() | ||
|
||
# Calculate basic stats. | ||
if not self._truncate_documents: | ||
assert _extension_available, ( | ||
"The C++ extension for dataset sampling is missing." | ||
" Please make sure Fast-LLM is installed correctly." | ||
) | ||
long_docs_filter = document_sizes > self._sequence_length + 1 | ||
ignored_documents = sum(long_docs_filter) | ||
if ignored_documents: | ||
log_main_rank( | ||
f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", | ||
log_fn=logger.warning, | ||
) | ||
tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() | ||
if tokens_per_epoch == 0: | ||
raise RuntimeError( | ||
f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." | ||
) | ||
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? | ||
# We produce sequences of length `self._sequence_length + 1` so the last token has a label, | ||
# but we also include that last label in the following sample, | ||
# but in case of truncations we also include that last label in the following sample, | ||
# so we need `sequence_length * num_samples + 1` tokens in total. | ||
num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / tokens_per_epoch) | ||
num_epochs = math.ceil( | ||
((self._sequence_length + 1 - self._truncate_documents) * self._num_samples + 1 * self._truncate_documents) | ||
/ tokens_per_epoch | ||
) | ||
|
||
# Prepare for shuffling. | ||
generator = torch.Generator(device=self._device) | ||
|
@@ -154,13 +175,17 @@ def _sample(self) -> None: | |
"num_samples": self._num_samples, | ||
"unshuffled_epochs": unshuffled_epochs, | ||
"sequence_length": self._sequence_length, | ||
"truncate_documents": self._truncate_documents, | ||
"config": self._config.to_serialized(), | ||
} | ||
self._load_yaml_data(yaml_data) | ||
|
||
if self._yaml_path is not None: | ||
if self._yaml_path.is_file(): | ||
loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) | ||
unshuffled_tokens = loaded_yaml_data.pop("unshuffled_tokens", None) | ||
if unshuffled_tokens is not None: | ||
self._unshuffled_tokens = unshuffled_tokens | ||
if loaded_yaml_data != yaml_data: | ||
raise RuntimeError( | ||
f"Invalid dataset cache for dataset {self.name}." | ||
|
@@ -172,9 +197,6 @@ def _sample(self) -> None: | |
# Dataset is already sampled, skip. | ||
logger.info(f"Using existing sampling for dataset {self.name}") | ||
return | ||
else: | ||
self._yaml_path.parent.mkdir(parents=True, exist_ok=True) | ||
yaml.safe_dump(yaml_data, self._yaml_path.open("w")) | ||
|
||
if shuffled_documents > 1e8: | ||
warnings.warn( | ||
|
@@ -232,51 +254,78 @@ def _sample(self) -> None: | |
# So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. | ||
# Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. | ||
# Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` | ||
if unshuffled_epochs > 0: | ||
token_cumsum_unshuffled, num_tokens_unshuffled = self._get_token_cumsum( | ||
document_sizes, | ||
offset=0, | ||
# TODO: Allowing for max 100% extra tokens for padding, is that enough? | ||
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), | ||
) | ||
if self._truncate_documents: | ||
num_tokens_unshuffled = tokens_per_epoch * unshuffled_epochs | ||
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) | ||
else: | ||
num_tokens_unshuffled = 0 | ||
self._unshuffled_tokens = num_tokens_unshuffled | ||
|
||
if self._yaml_path is not None: | ||
yaml_data["unshuffled_tokens"] = num_tokens_unshuffled | ||
self._yaml_path.parent.mkdir(parents=True, exist_ok=True) | ||
yaml.safe_dump(yaml_data, self._yaml_path.open("w")) | ||
|
||
if shuffled_epochs > 0: | ||
token_cumsum_shuffled = self._get_token_cumsum( | ||
token_cumsum_shuffled, num_tokens_shuffled = self._get_token_cumsum( | ||
document_sizes[ | ||
# Torch indexing only works with int32 or int64 | ||
document_shuffling.to( | ||
dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 | ||
) | ||
], | ||
offset=unshuffled_epochs * tokens_per_epoch, | ||
dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch, | ||
offset=num_tokens_unshuffled, | ||
# TODO: Allowing for max 100% extra tokens for padding, is that enough? | ||
dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), | ||
) | ||
self._token_cumsum_shuffled.save(token_cumsum_shuffled.numpy(force=self._config.gpu)) | ||
self._token_cumsum_shuffled.save(token_cumsum_shuffled) | ||
self._document_shuffling.save( | ||
document_shuffling[: (token_cumsum_shuffled.numel() + 1) * TOKEN_CUMSUM_RATE].numpy( | ||
document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( | ||
force=self._config.gpu | ||
) | ||
) | ||
# Free memory | ||
del token_cumsum_shuffled | ||
del document_shuffling | ||
|
||
if unshuffled_epochs > 0: | ||
token_cumsum_unshuffled = self._get_token_cumsum( | ||
document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch | ||
def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: | ||
if self._truncate_documents: | ||
# Create the output tensor. | ||
out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch) | ||
# Get partial sums for regular intervals, excluding the last incomplete interval. | ||
torch.sum( | ||
sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), | ||
dim=1, | ||
out=out[1:], | ||
) | ||
self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu)) | ||
|
||
def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: torch.dtype) -> torch.Tensor: | ||
# Create the output tensor. | ||
out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype) | ||
# Get partial sums for regular intervals, excluding the last incomplete interval. | ||
torch.sum( | ||
sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), dim=1, out=out[1:] | ||
) | ||
# Pad with the begin offset | ||
out[0] = offset | ||
# Calculate the cumsum. | ||
out.cumsum_(0) | ||
# Crop unnecessary entries. | ||
return out[ | ||
: torch.clamp_min_( | ||
torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"), | ||
0, | ||
# Pad with the begin offset | ||
out[0] = offset | ||
# Calculate the cumsum. | ||
out.cumsum_(0) | ||
# Crop unnecessary entries. | ||
out = out[ | ||
: torch.clamp_min_( | ||
torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"), | ||
0, | ||
) | ||
] | ||
return out.numpy(force=self._config.gpu), None | ||
sohamparikh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
# TODO: dynamically handle int64 or int32 in CPP | ||
out = build_padded_token_cumsum( | ||
sizes.cpu().numpy(), (self._sequence_length + 1), TOKEN_CUMSUM_RATE, offset | ||
) | ||
] | ||
num_tokens = out[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are throwing away long documents, we could end up not generating enough tokens. At the very least we need to add a check for it. Or maybe exclude long documents from |
||
out = out[:-1][ | ||
: np.clip(np.searchsorted(out, self._num_samples * (self._sequence_length + 1), side="right"), 0, None) | ||
] | ||
return out, num_tokens | ||
|
||
def __len__(self) -> int: | ||
return self._num_samples | ||
|
@@ -288,7 +337,9 @@ def __getitem__(self, index: int) -> typing.Any: | |
The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). | ||
""" | ||
self._lazy_load() | ||
token_start = index * self._sequence_length | ||
# tokens at the boundary are included in only one sample when we pack without truncations | ||
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample | ||
token_start = index * (self._sequence_length + 1 - self._truncate_documents) | ||
token_end = token_start + self._sequence_length + 1 | ||
|
||
if token_start < self._unshuffled_tokens: | ||
|
@@ -302,6 +353,7 @@ def __getitem__(self, index: int) -> typing.Any: | |
token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 | ||
|
||
document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset | ||
|
||
token_count = token_start_array[token_start_cumsum_index] | ||
|
||
token_ids = [] | ||
|
@@ -314,6 +366,25 @@ def __getitem__(self, index: int) -> typing.Any: | |
document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() | ||
|
||
document_size = self._indexed_dataset.get_document_size(document_index) | ||
|
||
if not self._truncate_documents: | ||
if document_size > self._sequence_length + 1: | ||
# Document too long, ignore | ||
document_sampling_index += 1 | ||
continue | ||
tokens_in_sample = token_count % (self._sequence_length + 1) | ||
if document_size + tokens_in_sample > self._sequence_length + 1: | ||
# Document belongs to the next sample, need to account for padding. | ||
padding_size = self._sequence_length + 1 - tokens_in_sample | ||
if token_count > token_start: | ||
# Add padding tokens to current sample | ||
token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) | ||
Assert.eq(token_count + padding_size, token_end) | ||
break | ||
else: | ||
# Move on to the next sample. | ||
token_count += padding_size | ||
|
||
# Determine if the document belongs to the requested sample. | ||
if token_count + document_size >= token_start: | ||
# Determine which part of the document belong to the sample, and add it to the list. | ||
|
@@ -343,7 +414,9 @@ def __getitem__(self, index: int) -> typing.Any: | |
) | ||
token_ids = np.concatenate(token_ids, dtype=np.int64) | ||
loss_masking_spans = ( | ||
np.stack(loss_masking_spans, dtype=np.int32) if self._config.use_loss_masking_spans else None | ||
(np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) | ||
if self._config.use_loss_masking_spans | ||
else None | ||
) | ||
Assert.eq(len(token_ids), self._sequence_length + 1) | ||
|
||
|
@@ -357,9 +430,12 @@ def _lazy_load(self): | |
if not hasattr(self, "_documents_per_epoch"): | ||
self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) | ||
|
||
def _load_yaml_data(self, data: dict[str, typing.Any]): | ||
def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: | ||
self._documents_per_epoch = data["dataset"]["documents_per_epoch"] | ||
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] | ||
if unshuffled_tokens := data.get("unshuffled_tokens") is not None: | ||
self._unshuffled_tokens = unshuffled_tokens | ||
else: | ||
self._unshuffled_tokens = data["unshuffled_epochs"] * data["dataset"]["tokens_per_epoch"] | ||
self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch | ||
|
||
|
||
|
@@ -380,9 +456,12 @@ def __init__( | |
self._indexed_dataset = indexed_dataset | ||
self._num_samples = sampling.num_samples | ||
self._sequence_length = sampling.sequence_length | ||
if not sampling.truncate_documents: | ||
raise NotImplementedError( | ||
"Legacy sampling only supports document truncation. Please use the latest dataset format." | ||
) | ||
self._cross_document_attention = sampling.cross_document_attention | ||
self._config = sampling.config | ||
self._tokenizer = sampling.tokenizer | ||
|
||
if sampling.cache_directory is None: | ||
log_main_rank( | ||
|
@@ -498,7 +577,7 @@ def __getitem__(self, idx: int) -> typing.Any: | |
for span in sample.loss_masking_spans: | ||
spans.append(span + offset) | ||
offset += len(sample.token_ids) | ||
spans = np.stack(spans, dtype=np.int32) | ||
spans = np.stack(spans, dtype=np.int32) if spans else np.array([]) | ||
else: | ||
spans = None | ||
sequence_lengths = ( | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.