-
Notifications
You must be signed in to change notification settings - Fork 42
[Pipeline RL] Add support for PipelineRL #428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jlamypoirier
wants to merge
109
commits into
jlp_entropy_loss_tweaks
Choose a base branch
from
jlp_pipeline_rl
base: jlp_entropy_loss_tweaks
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 96 commits
Commits
Show all changes
109 commits
Select commit
Hold shift + click to select a range
1a18929
Dataset interface
jlamypoirier fd63846
misc
jlamypoirier 2486caf
fix
jlamypoirier 92e93e8
Language model sample
jlamypoirier d6f6944
fix
jlamypoirier 5c802fa
fixes
jlamypoirier 95d1840
test
jlamypoirier eafd9cb
fixes
jlamypoirier c56df69
cleanup
jlamypoirier 7f437e1
misc
jlamypoirier dfd27f5
misc
jlamypoirier 90cd009
Memmap dataset
jlamypoirier acfd30e
fixes
jlamypoirier 34939e9
fixes
jlamypoirier c5fa072
int64
jlamypoirier cd28676
Test and fix preparator
jlamypoirier 435d214
fix
jlamypoirier f6bef55
fix
jlamypoirier e05d9a1
fix
jlamypoirier 9ba8d1b
fix
jlamypoirier b35b297
fixes
jlamypoirier abe2357
misc
jlamypoirier 1801d87
fix
jlamypoirier 2223b85
fix right stage mode
bigximik a9a4ace
newer transformers fixes
bigximik 97f2b60
fix distributed tests skip on single gpu
bigximik 0fdc978
set mamba 2 style model conversions to broke
bigximik 665deb5
Merge branch 'jlp/dataset_interface' of github.com:ServiceNow/Fast-LL…
bigximik 4d03889
Merge branch 'jlp/lm_sample' of github.com:ServiceNow/Fast-LLM into d…
bigximik 224c2ec
mmaba2 enable conversion tests
bigximik f1afbf2
Merge branch 'jlp/memmap_dataset' of github.com:ServiceNow/Fast-LLM i…
bigximik 00bba27
added model_and_sequence_data_group
bigximik 5b20276
added Iterable dataset base classes
bigximik 978a68f
added naive sampled iterable dataset
bigximik 066a0bf
added iterable dataset configs, streaming dataset and PipelineRL samp…
bigximik 68b3d65
added distributed data loader wrapper
bigximik 2fbfe99
added iterable dataset to gpt data
bigximik 0892523
appended comment
bigximik 54fadb4
changed base classes for iterable dataset configs
bigximik 4e11bf3
fix batch type
bigximik 8428df8
fix added name property to the class
bigximik 04ee4d7
add eof for tests
bigximik 1217998
change base class to torch iterable
bigximik c542dac
added straming dataset, sampling and base data tests
bigximik 3999a8e
merge from main
bigximik c6ef780
merge from main
bigximik a1556f8
change import
bigximik 63737b1
fix iterable sampler for spawn, add fake redis server to multi proces…
bigximik e843c8e
preparation for multi gpu tests
bigximik d5ce3f2
added multi gpu gptdata streaming test
bigximik c13c6df
added streming dataset requirements
bigximik e6d8f49
added streaming dataset installation to tests
bigximik 1e92dd4
removed cheking for max samples
bigximik 3ac4882
remved test eof, reduces timeout
bigximik 46db991
changed tests to work without eof or max_samplmes_count
bigximik 187055b
fix quen2 converter to accept qkv biases properly
bigximik 21833a0
fix import errors
rafapi 2f5f848
changes to config
bigximik 1e07bad
Merge branch 'denis/new_datasets' of github.com:ServiceNow/Fast-LLM i…
bigximik c8cb9fd
added tensor iterator
bigximik e367998
added trainer events
bigximik 5230b74
update test for changed config
bigximik 1a94de5
added 2 gpus trainer events test
bigximik 6cfd445
fix for multiple gpus
bigximik 333665d
updated test to multiple gpus
bigximik 5d1f474
added not implemented for pp streaming
bigximik 5f7cb29
removed PipelineRL sample and batch
bigximik d07a900
base radis and streaming dataset config class refactoring
bigximik 3a7ba92
refactoring of redis config, trainer event config, corresponding tests
bigximik 59f6f7d
removed eof message which is not supported
bigximik 2c20ebd
added implementation for initial_weights_step_message_type event
bigximik f4107c3
removed explicit msg ack
bigximik c32ef89
fix of training finished event
bigximik f637649
alternative streaming immplementaions: one stream and n streams witho…
bigximik e43ce95
Merge remote-tracking branch 'origin/main' into denis/new_datasets
jlamypoirier 5545598
merge from main
bigximik 0d198ff
fix after merge added preprocessing empty configs
bigximik 70ef5c4
fix for tests with no import
bigximik 058c93c
fixes
jlamypoirier d34d39a
Merge remote-tracking branch 'origin/denis/new_datasets' into denis/n…
jlamypoirier ffb0a5f
Merge remote-tracking branch 'origin/main' into denis/new_datasets
jlamypoirier 359231f
removed cloudpickle
bigximik ca9e94e
Simplify distributed
jlamypoirier 9f0704c
Simplified pipeline RL
jlamypoirier 992f447
stuff
jlamypoirier 4144317
misc
jlamypoirier 6cf1e70
Merge remote-tracking branch 'origin/main' into jlp_pipeline_rl
jlamypoirier 4d07494
stuff
jlamypoirier c9d66dd
fixes
jlamypoirier b8b3d68
Merge remote-tracking branch 'origin/main' into jlp_pipeline_rl
jlamypoirier a4c1aa5
cleanup
jlamypoirier b2c99d3
stuff
jlamypoirier 4ddacde
Merge branch 'jlp_subtest' into jlp_pipeline_rl
jlamypoirier ea3fa20
stuff
jlamypoirier e8f3873
stuff
jlamypoirier 07dce8d
Merge branch 'jlp_subtest' into jlp_pipeline_rl
jlamypoirier 4b3b8d6
Merge remote-tracking branch 'origin/main' into jlp_subtest
jlamypoirier b62813a
Merge branch 'jlp_subtest' into jlp_pipeline_rl
jlamypoirier f1ca739
Merge remote-tracking branch 'origin/main' into jlp_pipeline_rl
jlamypoirier 961cdb9
comments
jlamypoirier 518cd66
Merge branch 'main' into jlp_pipeline_rl
tscholak 99bdad4
Merge remote-tracking branch 'origin/main' into jlp_pipeline_rl
jlamypoirier 39356fb
GRPO loss (#454)
jlamypoirier c044964
[Pipeline RL] GRPO sample, streaming dataset improvements (#458)
jlamypoirier caaa9f8
[Pipeline RL] Tensor-parallel GRPO loss (#464)
jlamypoirier 8f87302
Merge remote-tracking branch 'origin/main' into jlp_pipeline_rl
jlamypoirier 2e6c8a2
Merge branch 'jlp_entropy_loss_tweaks' into jlp_pipeline_rl
jlamypoirier afe8070
Fix merge
jlamypoirier 902d1df
Merge branch 'jlp_entropy_loss_tweaks' into jlp_pipeline_rl
jlamypoirier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import itertools | ||
| import typing | ||
|
|
||
| import torch.utils.data | ||
|
|
||
| from fast_llm.core.distributed import broadcast_object | ||
|
|
||
|
|
||
| class SampledDatasetIterator(torch.utils.data.Sampler): | ||
| """ | ||
| A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). | ||
| To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. | ||
| """ | ||
|
|
||
| def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): | ||
| super().__init__() | ||
| self._total_samples = total_samples | ||
| self._begin_index = begin_index | ||
| self._batch_size = micro_batch_size * data_parallel | ||
| self._start_idx = data_rank * micro_batch_size | ||
| self._end_idx = (data_rank + 1) * micro_batch_size | ||
|
|
||
| def __len__(self) -> int: | ||
| return self._total_samples | ||
|
|
||
| def __iter__(self) -> typing.Iterator[list[int]]: | ||
| for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): | ||
| yield list(range(idx + self._start_idx, idx + self._end_idx)) | ||
|
|
||
|
|
||
| class DistributedDataLoaderWrapper: | ||
| """ | ||
| Wraps a regular dataloader so that only the process group leader | ||
| loads data, and then broadcasts the batch to other ranks in the group. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| data_loader: torch.utils.data.dataloader.DataLoader, | ||
| process_group: torch.distributed.ProcessGroup | None, | ||
| ): | ||
| self._data_loader = data_loader | ||
| self._rank = 0 if process_group is None else process_group.rank() | ||
| self._process_group = process_group | ||
|
|
||
| def __iter__(self): | ||
| if self._rank == 0: | ||
| self._iterator = iter(self._data_loader) | ||
| else: | ||
| self._iterator = itertools.repeat(None) | ||
| if self._process_group is None: | ||
| return self._iterator | ||
| return self | ||
|
|
||
| def __next__(self): | ||
| # TODO: | ||
| # Instead of broadcasting a general object, make this iterator yield an actual Batch class. | ||
| # Implement `get_state_dict` and `from_state_dict` in the Batch class so that we can | ||
| # efficiently broadcast tensors directly. This avoids using `broadcast_object` on the | ||
| # entire Batch object, which is inefficient for tensors because it serializes | ||
| # (pickles) them before sending. | ||
|
|
||
| try: | ||
| data = next(self._iterator) # may raise StopIteration | ||
| except Exception as e: | ||
| data = e | ||
| data = broadcast_object(data, self._process_group, 0) | ||
|
|
||
| if isinstance(data, Exception): | ||
| raise data | ||
|
|
||
| return data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also not needed if we only use consumer gorups, see above comment for implementation.