From 97c7915fa997864cec44b49d24bc96a9059981f0 Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 4 Jun 2024 11:22:29 -0400 Subject: [PATCH] inital tp commits --- megatron/model/rwkv/v6/rwkv.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py index 970613f27..fa0eaa53f 100644 --- a/megatron/model/rwkv/v6/rwkv.py +++ b/megatron/model/rwkv/v6/rwkv.py @@ -297,11 +297,16 @@ class ParallelRWKV_ChannelMix(nn.Module): Channel Mix layer. The ffn in RWKV """ + def __init__(self, neox_args, layer_number, init_method): def __init__(self, neox_args, layer_number, init_method): super().__init__() self.neox_args = neox_args self.layer_number = layer_number + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) + + world_size = mpu.get_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(neox_args.hidden_size, world_size) @@ -363,7 +368,7 @@ class RWKVResidualLayer(nn.Module): """ RWKV layer definition """ - + def __init__(self, neox_args, init_method, layer_number): super().__init__() self.neox_args = neox_args