-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor parallelism for RWKV #1237
Open
jahatef
wants to merge
30
commits into
main
Choose a base branch
from
rwkv-tp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
4c7cb11
inital tp commits
jahatef 46904d5
setup
jahatef e2933ef
configs
jahatef d1112ab
merge
jahatef 43d641d
time mixing tp
jahatef de02f37
time-mixing
jahatef dd441b6
time mixing debugging
jahatef a418670
reset time_faaaa
jahatef 540d856
Add additional asserts and update post training readme (#1300)
AI-WAIFU 12aac35
Fix failling tests (#1301)
AI-WAIFU 97c7915
inital tp commits
jahatef 5f89ed8
merge
jahatef 91cb759
Add ERROR logging prefix and sort the prefixes alphabetically (#1308)
TheBatmanofButler 49b263a
inital tp commits
jahatef 48de682
cleanup
jahatef c6fac96
cleanup
jahatef 5a259c0
Update local_setup.yml
jahatef c2d6c85
add Triton FLA
jahatef bdb3658
change version of rwkv-fla
jahatef ff7f328
fix a GQA issue (#1314) (#1315)
tiandeyu-cs 1350b2c
fix 'intermediate_size' in Llama configuration files after the 'mlp_t…
tiandeyu-cs c4d7a54
Python 3.10 support (#1313)
markNZed ee2f142
Fix documentation for converting SFT/DPO weights back to HF Llama (#1…
jacobthebanana 6e81f0b
fix bug (#1311)
AI-WAIFU df95419
Add support for dropout in sparse attention (#1312)
michaelc-yu d682529
adds pyproject files and tests (#1302)
LouisCastricato 0bc11d6
undo merge error (#1325)
Quentin-Anthony c6db95c
inital tp commits
jahatef daac503
setup
jahatef bf478ce
Merge branch 'main' into rwkv-tp
Quentin-Anthony File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
{ | ||
# Parallelism is not yet supported for rwkv | ||
"pipe_parallel_size": 1, | ||
"model_parallel_size": 1, | ||
|
||
"num_layers": 24, | ||
"hidden_size": 2048, | ||
"num_attention_heads": 32, # head_size = dim_att / num_attention_heads. | ||
# head_size is 64 for all rwkv models | ||
"seq_length": 4096, | ||
"max_position_embeddings": 4096, | ||
"output_layer_parallelism": "column", | ||
"norm": "rmsnorm", | ||
"rms_norm_epsilon": 1.0e-5, | ||
"train_micro_batch_size_per_gpu": 4, | ||
|
||
"attention_config": [[["rwkv"], 24]], | ||
|
||
"activation": "silu", | ||
|
||
# model settings | ||
|
||
#"pos_emb": "rotary", | ||
"rotary_pct": 0.25, | ||
"no_weight_tying": true, | ||
"gpt_j_residual": true, | ||
|
||
# these should provide some speedup but takes a while to build, set to true if desired | ||
"scaled_upper_triang_masked_softmax_fusion": false, | ||
"bias_gelu_fusion": false, | ||
"rope_fusion": false, | ||
"layernorm_fusion": false, | ||
|
||
|
||
# init methods | ||
"init_method": "small_init", | ||
"output_layer_init_method": "wang_init", | ||
|
||
# optimizer settings | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0008, | ||
"betas": [0.9, 0.95], | ||
"eps": 1.0e-8, | ||
} | ||
}, | ||
"min_lr": 0.00008, | ||
|
||
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training | ||
"zero_optimization": { | ||
"stage": 1, | ||
"allgather_partitions": True, | ||
"allgather_bucket_size": 500000000, | ||
"overlap_comm": True, | ||
"reduce_scatter": True, | ||
"reduce_bucket_size": 500000000, | ||
"contiguous_gradients": True, | ||
}, | ||
|
||
# batch / data settings | ||
"data_impl": "mmap", | ||
"num_workers": 1, | ||
|
||
# activation checkpointing | ||
"checkpoint_activations": true, | ||
"checkpoint_num_layers": 1, | ||
"partition_activations": true, | ||
"synchronize_each_layer": true, | ||
|
||
# regularization | ||
"gradient_clipping": 1.0, | ||
"weight_decay": 0.1, | ||
"hidden_dropout": 0, | ||
"attention_dropout": 0, | ||
|
||
# precision settings | ||
"bf16": { | ||
"bf16": true, | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 12, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1, | ||
}, | ||
|
||
# misc. training settings | ||
"train_iters": 320000, | ||
"lr_decay_iters": 320000, | ||
"distributed_backend": "nccl", | ||
"lr_decay_style": "constant", | ||
"warmup": 0.01, | ||
"checkpoint_factor": 100, | ||
"eval_interval": 100000, | ||
"eval_iters": 10, | ||
"seed": 1234, | ||
|
||
# logging | ||
"log_interval": 10, | ||
"steps_per_print": 10, | ||
"wall_clock_breakdown": true, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
{ | ||
# Parallelism is not yet supported for rwkv | ||
"pipe_parallel_size": 1, | ||
"model_parallel_size": 1, | ||
|
||
"num_layers": 24, | ||
"hidden_size": 1024, | ||
"num_attention_heads": 16, # head_size = dim_att / num_attention_heads. | ||
# head_size is 64 for all rwkv models | ||
"seq_length": 4096, | ||
"max_position_embeddings": 4096, | ||
"output_layer_parallelism": "column", | ||
"norm": "rmsnorm", | ||
"rms_norm_epsilon": 1.0e-5, | ||
"train_micro_batch_size_per_gpu": 1, | ||
|
||
"attention_config": [[["rwkv"], 24]], | ||
|
||
"activation": "silu", | ||
|
||
# model settings | ||
|
||
#"pos_emb": "rotary", | ||
"rotary_pct": 0.25, | ||
"no_weight_tying": true, | ||
"gpt_j_residual": true, | ||
|
||
# these should provide some speedup but takes a while to build, set to true if desired | ||
"scaled_upper_triang_masked_softmax_fusion": false, | ||
"bias_gelu_fusion": false, | ||
"rope_fusion": false, | ||
"layernorm_fusion": false, | ||
|
||
|
||
# init methods | ||
"init_method": "small_init", | ||
"output_layer_init_method": "wang_init", | ||
|
||
# optimizer settings | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0008, | ||
"betas": [0.9, 0.95], | ||
"eps": 1.0e-8, | ||
} | ||
}, | ||
"min_lr": 0.00008, | ||
|
||
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training | ||
"zero_optimization": { | ||
"stage": 1, | ||
"allgather_partitions": True, | ||
"allgather_bucket_size": 500000000, | ||
"overlap_comm": True, | ||
"reduce_scatter": True, | ||
"reduce_bucket_size": 500000000, | ||
"contiguous_gradients": True, | ||
}, | ||
|
||
# batch / data settings | ||
"data_impl": "mmap", | ||
"num_workers": 1, | ||
|
||
# activation checkpointing | ||
"checkpoint_activations": true, | ||
"checkpoint_num_layers": 1, | ||
"partition_activations": true, | ||
"synchronize_each_layer": true, | ||
|
||
# regularization | ||
"gradient_clipping": 1.0, | ||
"weight_decay": 0.1, | ||
"hidden_dropout": 0, | ||
"attention_dropout": 0, | ||
|
||
# precision settings | ||
"bf16": { | ||
"bf16": true, | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 12, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1, | ||
}, | ||
|
||
# misc. training settings | ||
"train_iters": 320000, | ||
"lr_decay_iters": 320000, | ||
"distributed_backend": "nccl", | ||
"lr_decay_style": "constant", | ||
"warmup": 0.01, | ||
"checkpoint_factor": 100, | ||
"eval_interval": 100000, | ||
"eval_iters": 10, | ||
"seed": 1234, | ||
|
||
# logging | ||
"log_interval": 10, | ||
"steps_per_print": 10, | ||
"wall_clock_breakdown": true, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
{ | ||
# Parallelism is not yet supported for rwkv | ||
"pipe_parallel_size": 1, | ||
"model_parallel_size": 1, | ||
|
||
"num_layers": 32, | ||
"hidden_size": 4096, | ||
"num_attention_heads": 64, # head_size = dim_att / num_attention_heads. | ||
# head_size is 64 for all rwkv models | ||
"seq_length": 4096, | ||
"max_position_embeddings": 4096, | ||
"output_layer_parallelism": "column", | ||
"norm": "rmsnorm", | ||
"rms_norm_epsilon": 1.0e-5, | ||
"train_micro_batch_size_per_gpu": 8, | ||
|
||
"attention_config": [[["rwkv"], 32]], | ||
|
||
"activation": "silu", | ||
|
||
# model settings | ||
|
||
#"pos_emb": "rotary", | ||
"rotary_pct": 0.25, | ||
"no_weight_tying": true, | ||
"gpt_j_residual": true, | ||
|
||
# these should provide some speedup but takes a while to build, set to true if desired | ||
"scaled_upper_triang_masked_softmax_fusion": false, | ||
"bias_gelu_fusion": false, | ||
"rope_fusion": false, | ||
"layernorm_fusion": false, | ||
|
||
|
||
# init methods | ||
"init_method": "small_init", | ||
"output_layer_init_method": "wang_init", | ||
|
||
# optimizer settings | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 0.0008, | ||
"betas": [0.9, 0.95], | ||
"eps": 1.0e-8, | ||
} | ||
}, | ||
"min_lr": 0.00008, | ||
|
||
# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training | ||
"zero_optimization": { | ||
"stage": 1, | ||
"allgather_partitions": True, | ||
"allgather_bucket_size": 500000000, | ||
"overlap_comm": True, | ||
"reduce_scatter": True, | ||
"reduce_bucket_size": 500000000, | ||
"contiguous_gradients": True, | ||
}, | ||
|
||
# batch / data settings | ||
"data_impl": "mmap", | ||
"num_workers": 1, | ||
|
||
# activation checkpointing | ||
"checkpoint_activations": true, | ||
"checkpoint_num_layers": 1, | ||
"partition_activations": true, | ||
"synchronize_each_layer": true, | ||
|
||
# regularization | ||
"gradient_clipping": 1.0, | ||
"weight_decay": 0.1, | ||
"hidden_dropout": 0, | ||
"attention_dropout": 0, | ||
|
||
# precision settings | ||
"bf16": { | ||
"bf16": true, | ||
"enabled": true, | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 12, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1, | ||
}, | ||
|
||
# misc. training settings | ||
"train_iters": 500, | ||
"lr_decay_iters": 500, | ||
"distributed_backend": "nccl", | ||
"lr_decay_style": "constant", | ||
"warmup": 0.01, | ||
"checkpoint_factor": 100, | ||
"eval_interval": 100000, | ||
"eval_iters": 10, | ||
|
||
# logging | ||
"log_interval": 10, | ||
"steps_per_print": 10, | ||
"wall_clock_breakdown": true, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here. Calling these attention heads is highly misleading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of disagree, as rwkv code generally references time mixing as
attention
, and the RWKV kernel is often called a type of "linear attention." But, I can add a bunch of configs to decouple rkwv and transformer config options, but this will just create a lot of config args that have essentially the same purpose in my opinion.