Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallelism for RWKV #1237

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4c7cb11
inital tp commits
jahatef Jun 4, 2024
46904d5
setup
jahatef Jun 19, 2024
e2933ef
configs
jahatef Sep 25, 2024
d1112ab
merge
jahatef Oct 3, 2024
43d641d
time mixing tp
jahatef Oct 3, 2024
de02f37
time-mixing
jahatef Oct 11, 2024
dd441b6
time mixing debugging
jahatef Oct 12, 2024
a418670
reset time_faaaa
jahatef Oct 13, 2024
540d856
Add additional asserts and update post training readme (#1300)
AI-WAIFU Oct 8, 2024
12aac35
Fix failling tests (#1301)
AI-WAIFU Oct 8, 2024
97c7915
inital tp commits
jahatef Jun 4, 2024
5f89ed8
merge
jahatef Nov 5, 2024
91cb759
Add ERROR logging prefix and sort the prefixes alphabetically (#1308)
TheBatmanofButler Oct 17, 2024
49b263a
inital tp commits
jahatef Jun 4, 2024
48de682
cleanup
jahatef Nov 6, 2024
c6fac96
cleanup
jahatef Nov 6, 2024
5a259c0
Update local_setup.yml
jahatef Nov 6, 2024
c2d6c85
add Triton FLA
jahatef Nov 10, 2024
bdb3658
change version of rwkv-fla
jahatef Nov 12, 2024
ff7f328
fix a GQA issue (#1314) (#1315)
tiandeyu-cs Nov 13, 2024
1350b2c
fix 'intermediate_size' in Llama configuration files after the 'mlp_t…
tiandeyu-cs Nov 13, 2024
c4d7a54
Python 3.10 support (#1313)
markNZed Nov 13, 2024
ee2f142
Fix documentation for converting SFT/DPO weights back to HF Llama (#1…
jacobthebanana Nov 13, 2024
6e81f0b
fix bug (#1311)
AI-WAIFU Nov 13, 2024
df95419
Add support for dropout in sparse attention (#1312)
michaelc-yu Nov 16, 2024
d682529
adds pyproject files and tests (#1302)
LouisCastricato Nov 16, 2024
0bc11d6
undo merge error (#1325)
Quentin-Anthony Nov 27, 2024
c6db95c
inital tp commits
jahatef Jun 4, 2024
daac503
setup
jahatef Jun 19, 2024
bf478ce
Merge branch 'main' into rwkv-tp
Quentin-Anthony Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Merge branch 'main' into rwkv-tp
  • Loading branch information
Quentin-Anthony committed Dec 19, 2024
commit bf478ce9c5337a5d41d8ad845e34b8134bf5ca8a
28 changes: 14 additions & 14 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
@@ -1303,20 +1303,20 @@ def forward(self, x, attention_mask, layer_past=None):
else:
raise KeyError(self.moe_type)

with torch.enable_grad() if not self.eval else nullcontext():
if mlp_bias == None or (
self.num_experts > 1 and self.moe_type == "deepspeed"
):
# No dropout either
assert mlp_bias is None
output = mlp_output + attention_output
else:
output = bias_dropout_fn(
mlp_output,
bias=mlp_bias.expand_as(attention_output),
residual=attention_output,
prob=self.hidden_dropout,
)
with torch.enable_grad() if not self.eval else nullcontext():
if mlp_bias == None or (
self.num_experts > 1 and self.moe_type == "deepspeed"
):
# No dropout either
assert mlp_bias is None
output = mlp_output + attention_output
else:
output = bias_dropout_fn(
mlp_output,
bias=mlp_bias.expand_as(attention_output),
residual=attention_output,
prob=self.hidden_dropout,
)

return output, moe_loss

77 changes: 76 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
@@ -502,11 +502,86 @@ class NeoXArgsModel(NeoXArgsTemplate):

# Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905)
output_layer_parallelism: Literal["column"] = "column"

"""
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""

serve_model_weights: bool = False
"""
If true, serve model weight pointers over a socket connection
"""

weight_server_port: Union[int, List[int]] = 6000
"""
Port(s) to serve model weights over
If an integer is provided, the port for each GPU will be 6000 + global rank
If a list is provided, the ports will be used in order, e.g. rank0 will be weight_server_port[0]
"""

online_dataserver_ips: Union[str, List[str]] = "localhost"
"""
ip addresses to connect to for online data serving, defaults to localhost
"""

online_dataserver_ports: Union[int, List[int]] = 10000
"""
Port(s) to connect to for online data serving, defaults to 10000
"""

te_columnparallel: bool = False
"""
Use TransformerEngine for RowParallelLinear layer.
"""

te_rowparallel: bool = False
"""
Use TransformerEngine for ColumnParallelLinear layer.
"""

te_layernorm_mlp: bool = False
"""
Use TransformerEngine for LayerNormMLP layer.
"""

te_mha: bool = False
"""
Use TransformerEngine for MultiheadAttention layer.
"""

te_fp8_format: Literal["e4m3", "hybrid"] = "hybrid"
"""
Controls the FP8 data format used during forward and backward pass by TransformerEngine.
Hybrid uses E4M3 during forward pass, E5M2 during backward pass.
"""

te_fp8_wgrad: bool = True
"""
When set to False, override FP8 config options and do the wgrad computation
in higher precision.
"""

te_fp8_amax_history_len: int = 1
"""
The length of the amax history window used for scaling factor computation.
"""

te_fp8_amax_compute_algo: str = "most_recent"
"""
Algorithm used for choosing the `amax` value for the scaling factor computation. There are 2
predefined choices: `max` chooses the largest `amax` in the history window, while `most_recent`
always chooses the most recently seen value.
"""

te_fp8_margin: int = 0
"""
Margin for the scaling factor computation.
"""

te_fp8_mha: bool = False
"""
When set to True, use the FP8 implementation of Multi Head Attention.
"""

dim_att: int = None
"""
Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size.
You are viewing a condensed version of this merge commit. You can view the full changes here.