Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism for RWKV #1237

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4c7cb11
inital tp commits
jahatef Jun 4, 2024
46904d5
setup
jahatef Jun 19, 2024
e2933ef
configs
jahatef Sep 25, 2024
d1112ab
merge
jahatef Oct 3, 2024
43d641d
time mixing tp
jahatef Oct 3, 2024
de02f37
time-mixing
jahatef Oct 11, 2024
dd441b6
time mixing debugging
jahatef Oct 12, 2024
a418670
reset time_faaaa
jahatef Oct 13, 2024
540d856
Add additional asserts and update post training readme (#1300)
AI-WAIFU Oct 8, 2024
12aac35
Fix failling tests (#1301)
AI-WAIFU Oct 8, 2024
97c7915
inital tp commits
jahatef Jun 4, 2024
5f89ed8
merge
jahatef Nov 5, 2024
91cb759
Add ERROR logging prefix and sort the prefixes alphabetically (#1308)
TheBatmanofButler Oct 17, 2024
49b263a
inital tp commits
jahatef Jun 4, 2024
48de682
cleanup
jahatef Nov 6, 2024
c6fac96
cleanup
jahatef Nov 6, 2024
5a259c0
Update local_setup.yml
jahatef Nov 6, 2024
c2d6c85
add Triton FLA
jahatef Nov 10, 2024
bdb3658
change version of rwkv-fla
jahatef Nov 12, 2024
ff7f328
fix a GQA issue (#1314) (#1315)
tiandeyu-cs Nov 13, 2024
1350b2c
fix 'intermediate_size' in Llama configuration files after the 'mlp_t…
tiandeyu-cs Nov 13, 2024
c4d7a54
Python 3.10 support (#1313)
markNZed Nov 13, 2024
ee2f142
Fix documentation for converting SFT/DPO weights back to HF Llama (#1…
jacobthebanana Nov 13, 2024
6e81f0b
fix bug (#1311)
AI-WAIFU Nov 13, 2024
df95419
Add support for dropout in sparse attention (#1312)
michaelc-yu Nov 16, 2024
d682529
adds pyproject files and tests (#1302)
LouisCastricato Nov 16, 2024
0bc11d6
undo merge error (#1325)
Quentin-Anthony Nov 27, 2024
c6db95c
inital tp commits
jahatef Jun 4, 2024
daac503
setup
jahatef Jun 19, 2024
bf478ce
Merge branch 'main' into rwkv-tp
Quentin-Anthony Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,29 @@ Model Arguments



- **dim_att**: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should either have unified args (across mamba, rwkv, transformers) for these, or prepend these args with whatever block type they're targeting (e.g. rwkv_dim_att).


Default = None

Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size.



- **head_size**: int

Default = None

Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads.



- **ffn_dim**: int

Default = None

Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor.


## NeoXArgsOptimizer

Optimizer Arguments
Expand Down
103 changes: 103 additions & 0 deletions configs/rwkv/1.5B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 24,
"hidden_size": 2048,
"num_attention_heads": 32, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 4,

"attention_config": [[["rwkv"], 24]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,
"seed": 1234,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
103 changes: 103 additions & 0 deletions configs/rwkv/430M.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 24,
"hidden_size": 1024,
"num_attention_heads": 16, # head_size = dim_att / num_attention_heads.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here. Calling these attention heads is highly misleading.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of disagree, as rwkv code generally references time mixing as attention, and the RWKV kernel is often called a type of "linear attention." But, I can add a bunch of configs to decouple rkwv and transformer config options, but this will just create a lot of config args that have essentially the same purpose in my opinion.

# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 1,

"attention_config": [[["rwkv"], 24]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,
"seed": 1234,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
102 changes: 102 additions & 0 deletions configs/rwkv/7B.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 32,
"hidden_size": 4096,
"num_attention_heads": 64, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 4096,
"max_position_embeddings": 4096,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 8,

"attention_config": [[["rwkv"], 32]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 500,
"lr_decay_iters": 500,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
1 change: 1 addition & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def init_specs(self):
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
init_method=self.init_method,
layer_number=i,
)
)
Expand Down
Loading
Loading