Skip to content

Commit 305b390

Browse files
committed
Update more non-core code to use config objects.
1 parent 51c6f47 commit 305b390

14 files changed

+45
-40
lines changed

megatron/arguments.py

+4
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ def core_transformer_config_from_args(args):
413413
kw_args['activation_func'] = F.silu
414414
kw_args['gated_linear_unit'] = True
415415
kw_args['bias_gelu_fusion'] = False
416+
if args.init_method_xavier_uniform:
417+
kw_args['init_method'] = torch.nn.init.xavier_uniform_
418+
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
419+
416420
return TransformerConfig(**kw_args)
417421

418422
def _add_transformer_engine_args(parser):

megatron/model/bert_model.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,9 @@ class BertLMHead(MegatronModule):
5454
"""
5555

5656
def __init__(self, mpu_vocab_size, hidden_size, config, parallel_output):
57-
super(BertLMHead, self).__init__()
57+
super().__init__(config=config)
5858

5959
args = get_args()
60-
self.config = config
6160
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
6261
tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
6362
self.parallel_output = parallel_output

megatron/model/classification.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,23 @@
1717
class Classification(MegatronModule):
1818

1919
def __init__(self,
20+
config,
2021
num_classes,
2122
num_tokentypes=2,
2223
pre_process=True,
2324
post_process=True):
24-
super(Classification, self).__init__(share_embeddings_and_output_weights=False)
25+
super().__init__(config=config, share_embeddings_and_output_weights=False)
2526
args = get_args()
2627

2728
self.num_classes = num_classes
2829
self.pre_process = pre_process
2930
self.post_process = post_process
30-
init_method = init_method_normal(args.init_method_std)
3131

3232
self.language_model, self._language_model_key = get_language_model(
33+
config=config,
3334
num_tokentypes=num_tokentypes,
3435
add_pooler=True,
3536
encoder_attn_mask_type=AttnMaskType.padding,
36-
init_method=init_method,
37-
scaled_init_method=scaled_init_method_normal(args.init_method_std,
38-
args.num_layers),
3937
pre_process=self.pre_process,
4038
post_process=self.post_process)
4139

megatron/model/language_model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,9 @@ def __init__(self,
412412
self.output_layer = tensor_parallel.ColumnParallelLinear(
413413
args.hidden_size,
414414
args.padded_vocab_size,
415-
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
415+
config=config,
416416
init_method=self.init_method,
417-
use_cpu_initialization=args.use_cpu_initialization,
418-
perform_initialization=args.perform_initialization)
417+
bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
419418
self._output_layer_key = 'output_layer'
420419

421420
def set_input_tensor(self, input_tensor):

megatron/model/multiple_choice.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,21 @@
1717
class MultipleChoice(MegatronModule):
1818

1919
def __init__(self,
20+
config,
2021
num_tokentypes=2,
2122
pre_process=True,
2223
post_process=True):
2324
super(MultipleChoice, self).__init__(share_embeddings_and_output_weights=False)
2425
args = get_args()
2526

26-
init_method = init_method_normal(args.init_method_std)
2727
self.pre_process = pre_process
2828
self.post_process = post_process
2929

3030
self.language_model, self._language_model_key = get_language_model(
31+
config=config,
3132
num_tokentypes=num_tokentypes,
3233
add_pooler=True,
3334
encoder_attn_mask_type=AttnMaskType.padding,
34-
init_method=init_method,
35-
scaled_init_method=scaled_init_method_normal(args.init_method_std,
36-
args.num_layers),
3735
pre_process=self.pre_process,
3836
post_process=self.post_process)
3937

megatron/model/vision/classification.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class VitClassificationModel(MegatronModule):
1414
"""Vision Transformer Model."""
1515

16-
def __init__(self, num_classes, finetune=False,
16+
def __init__(self, config, num_classes, finetune=False,
1717
pre_process=True, post_process=True):
1818
super(VitClassificationModel, self).__init__()
1919
args = get_args()
@@ -24,6 +24,7 @@ def __init__(self, num_classes, finetune=False,
2424
self.pre_process = pre_process
2525
self.post_process = post_process
2626
self.backbone = VitBackbone(
27+
config=config,
2728
pre_process=self.pre_process,
2829
post_process=self.post_process,
2930
single_token_output=True

megatron/model/vision/dino.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
173173
return schedule
174174

175175

176-
def get_student_backbone_and_num_features(pre_process=True, post_process=True):
176+
def get_student_backbone_and_num_features(config, pre_process=True, post_process=True):
177177
args = get_args()
178178

179179
if args.vision_backbone_type == 'vit':
180-
student = VitBackbone(pre_process=pre_process,
180+
student = VitBackbone(config,
181+
pre_process=pre_process,
181182
post_process=post_process,
182183
drop_path_rate=0.1,
183184
single_token_output=True)
@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):
194195

195196
return student, num_features
196197

197-
def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
198+
def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True):
198199
args = get_args()
199200

200201
if args.vision_backbone_type == 'vit':
201-
teacher = VitBackbone(pre_process=pre_process,
202+
teacher = VitBackbone(config,
203+
pre_process=pre_process,
202204
post_process=post_process,
203205
single_token_output=True)
204206
num_features = args.hidden_size
@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
215217

216218

217219
class DINOPretrainModel(MegatronModule):
218-
def __init__(self, pre_process=True, post_process=True):
220+
def __init__(self, config, pre_process=True, post_process=True):
219221
super(DINOPretrainModel, self).__init__()
220222
args = get_args()
221223
self.out_dim = 65536
@@ -234,7 +236,7 @@ def __init__(self, pre_process=True, post_process=True):
234236
self.momentum_teacher = 0.996
235237

236238
student_backbone, num_features = \
237-
get_student_backbone_and_num_features(pre_process, post_process)
239+
get_student_backbone_and_num_features(config, pre_process, post_process)
238240

239241
self.student = MultiCropWrapper(
240242
student_backbone,
@@ -249,7 +251,7 @@ def __init__(self, pre_process=True, post_process=True):
249251
)
250252

251253
teacher_backbone, num_features = \
252-
get_teacher_backbone_and_num_features(pre_process, post_process)
254+
get_teacher_backbone_and_num_features(config, pre_process, post_process)
253255
self.teacher = MultiCropWrapper(
254256
teacher_backbone,
255257
DINOHead(num_features, self.out_dim)

megatron/model/vision/inpainting.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818

1919
class VitInpaintingModel(MegatronModule):
2020

21-
def __init__(self, pre_process=True, post_process=True):
21+
def __init__(self, config, pre_process=True, post_process=True):
2222
super(VitInpaintingModel, self).__init__()
2323
args = get_args()
2424

2525
self.pre_process = pre_process
2626
self.post_process = post_process
2727
self.hidden_size = args.hidden_size
2828
self.backbone = VitBackbone(
29+
config=config,
2930
pre_process=self.pre_process,
3031
post_process=self.post_process,
3132
class_token=False,

megatron/model/vision/vit_backbone.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class VitBackbone(MegatronModule):
130130
"""Vision Transformer Model."""
131131

132132
def __init__(self,
133+
config,
133134
pre_process=True,
134135
post_process=True,
135136
class_token=True,
@@ -140,14 +141,6 @@ def __init__(self,
140141
args = get_args()
141142

142143
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
143-
if args.init_method_xavier_uniform:
144-
self.init_method = torch.nn.init.xavier_uniform_
145-
self.scaled_init_method = torch.nn.init.xavier_uniform_
146-
else:
147-
self.init_method = init_method_normal(args.init_method_std)
148-
self.scaled_init_method = scaled_init_method_normal(
149-
args.init_method_std, args.num_layers
150-
)
151144

152145
self.pre_process = pre_process
153146
self.post_process = post_process
@@ -202,8 +195,7 @@ def __init__(self,
202195

203196
# Transformer
204197
self.transformer = ParallelTransformer(
205-
self.init_method,
206-
self.scaled_init_method,
198+
config,
207199
pre_process=self.pre_process,
208200
post_process=self.post_process,
209201
post_layer_norm=self.post_layer_norm,

pretrain_vision_classify.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
from megatron.model.vision.classification import MitClassificationModel
1313
from megatron.training import pretrain
1414
from megatron.utils import average_losses_across_data_parallel_group
15+
from megatron.arguments import core_transformer_config_from_args
1516

1617

1718
def model_provider(pre_process=True, post_process=True):
1819
"""Build the model."""
1920

2021
args = get_args()
21-
22+
config = core_transformer_config_from_args(args)
2223
if args.vision_backbone_type == 'vit':
2324
print_rank_0("building VIT model ...")
24-
model = VitClassificationModel(num_classes=args.num_classes,
25+
model = VitClassificationModel(config=config,
26+
num_classes=args.num_classes,
2527
pre_process=pre_process,
2628
post_process=post_process)
2729
elif args.vision_backbone_type == 'mit':

pretrain_vision_dino.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
1717
from megatron.model import DistributedDataParallel as LocalDDP
1818
from megatron.model import Float16Module
19+
from megatron.arguments import core_transformer_config_from_args
1920

2021
def model_provider(pre_process=True, post_process=True):
2122
"""Build the model."""
22-
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
23+
config = core_transformer_config_from_args(get_args())
24+
return DINOPretrainModel(config, pre_process=pre_process, post_process=post_process)
2325

2426
def get_batch(data_iterator):
2527
"""Build the batch."""

pretrain_vision_inpaint.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
from megatron.training import pretrain
1414
from megatron.utils import average_losses_across_data_parallel_group
1515
from tasks.vision.metrics import SSIM, PSNR
16+
from megatron.arguments import core_transformer_config_from_args
1617

1718
def model_provider(pre_process=True, post_process=True):
1819
"""Build the model."""
1920
args = get_args()
21+
config = core_transformer_config_from_args(args)
2022
if args.vision_backbone_type == 'vit':
21-
model = VitInpaintingModel(pre_process=pre_process,
23+
model = VitInpaintingModel(config,
24+
pre_process=pre_process,
2225
post_process=post_process)
2326
elif args.vision_backbone_type == 'mit':
2427
model = MitInpaintingModel(pre_process=pre_process,

tasks/glue/finetune.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from megatron.model.classification import Classification
99
from tasks.eval_utils import accuracy_func_provider
1010
from tasks.finetune_utils import finetune
11+
from megatron.arguments import core_transformer_config_from_args
1112

1213

1314
def glue_classification(num_classes, Dataset,
@@ -28,10 +29,11 @@ def train_valid_datasets_provider():
2829
def model_provider(pre_process=True, post_process=True):
2930
"""Build the model."""
3031
args = get_args()
32+
config = core_transformer_config_from_args()
3133

3234
print_rank_0('building classification model for {} ...'.format(
3335
args.task))
34-
model = Classification(num_classes=num_classes, num_tokentypes=2,
36+
model = Classification(config=config, num_classes=num_classes, num_tokentypes=2,
3537
pre_process=pre_process, post_process=post_process)
3638

3739
return model

tasks/race/finetune.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tasks.eval_utils import accuracy_func_provider
1010
from tasks.finetune_utils import finetune
1111
from tasks.race.data import RaceDataset
12+
from megatron.arguments import core_transformer_config_from_args
1213

1314

1415
def train_valid_datasets_provider():
@@ -26,9 +27,10 @@ def train_valid_datasets_provider():
2627

2728
def model_provider(pre_process=True, post_process=True):
2829
"""Build the model."""
29-
30+
config = core_transformer_config_from_args(get_args())
3031
print_rank_0('building multichoice model for RACE ...')
31-
model = MultipleChoice(num_tokentypes=2,
32+
model = MultipleChoice(config=config,
33+
num_tokentypes=2,
3234
pre_process=pre_process,
3335
post_process=post_process)
3436

0 commit comments

Comments
 (0)