diff --git a/mmf/models/transformers/base.py b/mmf/models/transformers/base.py index eda50e1a5..83feb6468 100644 --- a/mmf/models/transformers/base.py +++ b/mmf/models/transformers/base.py @@ -103,25 +103,13 @@ def build(self): def get_optimizer_parameters(self, config): lr = config.optimizer.params.lr - + trunk_param_set = set() param_list = [] parameters = [] - head_configs = self.config.get("heads", []) + for name, module in self.named_children(): - # Heads can have different learning rates. This is handled here - if name == "heads": - # Parameters in the head which have a separate learning - # rate, are added as a separate param group - for head_config, head in zip(head_configs, self.heads): - parameters, param_list = self.set_lr_for_parameters( - config=head_config, - module_name="{} head".format(head_config.get("type", "MLP")), - base_lr=lr, - module=head, - parameters=parameters, - param_list=param_list, - ) - elif name == "encoders": + + if name == "encoders": for key in module: for modality in self.config.modalities: if key == modality.key: @@ -134,29 +122,56 @@ def get_optimizer_parameters(self, config): parameters=parameters, param_list=param_list, ) - else: + if name != "heads": # For other modules in trunk, add to same param group param_list += list(module.named_parameters()) - + trunk_param_set.update(list(module.parameters())) + head_configs = self.config.get("heads", []) + # Heads can have different learning rates. This is handled here + if len(head_configs) > 0: + # Parameters in the head which have a separate learning + # rate, are added as a separate param group + for head_config, head in zip(head_configs, self.heads): + parameters, param_list = self.set_lr_for_parameters( + config=head_config, + module_name="{} head".format(head_config.get("type", "MLP")), + base_lr=lr, + module=head, + parameters=parameters, + param_list=param_list, + excluded_params=trunk_param_set, + ) parameters += get_bert_configured_parameters(param_list) return parameters def set_lr_for_parameters( - self, config, module_name, base_lr, module, parameters, param_list + self, + config, + module_name, + base_lr, + module, + parameters, + param_list, + excluded_params=None, ): lr_multiplier = config.get("lr_multiplier", 1.0) + module_param = list(module.named_parameters()) + if excluded_params is not None: + module_param = [ + tup for tup in module_param if tup[1] not in excluded_params + ] if lr_multiplier != 1.0: logger.info( f"Setting learning rate of {module_name} to be {base_lr} * {lr_multiplier}." ) # noqa parameters += get_bert_configured_parameters( - module, base_lr * lr_multiplier + module_param, base_lr * lr_multiplier ) else: # Parameters for the modules with same learning rate as # trunk, add to same param group - param_list += list(module.named_parameters()) + param_list += module_param return parameters, param_list def build_encoders(self):