Skip to content

Commit

Permalink
Fix find_tied_params for models with shared layers (#2986)
Browse files Browse the repository at this point in the history
* Add test case

* Fix find_tied_params

* Sort params in test

* Refactor variable naming, add comments

* Apply suggestions from code review

Co-authored-by: Zach Mueller <[email protected]>

* Fix docstrings quality

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
qubvel and muellerzr authored Aug 13, 2024
1 parent cd5698b commit 851cf34
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 28 deletions.
113 changes: 85 additions & 28 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tempfile
import warnings
from collections import OrderedDict, defaultdict
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -616,7 +616,65 @@ def check_tied_parameters_on_same_device(tied_params, device_map):
)


def find_tied_parameters(model: nn.Module, **kwargs):
def _get_named_modules(
module: torch.nn.Module,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
"""
Return an iterator over all modules in the network, yielding both the name of the module as well as the module
itself. Copied from PyTorch `torch.nn.Module.named_modules` for compatability with torch < 2.0 versions with
`remove_duplicate` option added.
Args:
memo (set of `torch.nn.Module`, *optional*):
A memo to store the set of modules already added to the result
prefix (`str`, *optional*):
A prefix that will be added to the name of the module
remove_duplicate (`bool`, *optional*):
Whether to remove the duplicated module instances in the result or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following example, ``l`` will be returned only once.
"""
if memo is None:
memo = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
yield prefix, module
for name, sub_module in module._modules.items():
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _get_named_modules(sub_module, memo, submodule_prefix, remove_duplicate)


def _get_named_parameters(module: torch.nn.Module, prefix="", recurse=True, remove_duplicate: bool = True):
"""
Help yield various names + members of modules. Copied from PyTorch `torch.nn.Module.named_modules` for
compatability with torch < 2.0 versions with `remove_duplicate` option added.
"""
memo = set()
modules = (
_get_named_modules(module, prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, module)]
)
for module_prefix, module in modules:
members = module._parameters.items()
for k, v in members:
if v is None or v in memo:
continue
if remove_duplicate:
memo.add(v)
name = module_prefix + ("." if module_prefix else "") + k
yield name, v


def find_tied_parameters(model: torch.nn.Module, **kwargs):
"""
Find the tied parameters in a given model.
Expand Down Expand Up @@ -645,33 +703,32 @@ def find_tied_parameters(model: nn.Module, **kwargs):
[['linear1.weight', 'linear2.weight']]
```
"""
# Initialize result and named_parameters before recursing.
named_parameters = kwargs.get("named_parameters", None)
prefix = kwargs.get("prefix", "")
result = kwargs.get("result", {})

if named_parameters is None:
named_parameters = {n: p for n, p in model.named_parameters()}
else:
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
# `named_parameters`.
for name, parameter in model.named_parameters():
full_name = name if prefix == "" else f"{prefix}.{name}"
if full_name not in named_parameters:
# When we find one, it has to be one of the existing parameters.
for new_name, new_param in named_parameters.items():
if new_param is parameter:
if new_name not in result:
result[new_name] = []
result[new_name].append(full_name)

# Once we have treated direct parameters, we move to the child modules.
for name, child in model.named_children():
child_name = name if prefix == "" else f"{prefix}.{name}"
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)

return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()])
# get ALL model parameters and thier names
all_named_parameters = {name: param for name, param in _get_named_parameters(model, remove_duplicate=False)}

# get ONLY unique named parameters,
# if parameter is tied and have multiple names, it will be included only once
no_duplicate_named_parameters = {
name: param for name, param in _get_named_parameters(model, remove_duplicate=True)
}

# the difference of the two sets will give us the tied parameters
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())

# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
# which names refer to the same parameter. To identify this, we need to group them together.
tied_param_groups = {}
for tied_param_name in tied_param_names:
tied_param = all_named_parameters[tied_param_name]
for param_name, param in no_duplicate_named_parameters.items():
# compare if parameters are the same, if so, group thier names together
if param is tied_param:
if param_name not in tied_param_groups:
tied_param_groups[param_name] = []
tied_param_groups[param_name].append(tied_param_name)

return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()])


def retie_parameters(model, tied_params):
Expand Down
5 changes: 5 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ def test_find_tied_parameters(self):
model.block1.linear1.weight = model.block2.linear1.weight
assert find_tied_parameters(model) == [["block1.linear1.weight", "block2.linear1.weight"]]

layer = nn.Linear(10, 10)
model = nn.Sequential(layer, layer)
tied_params = find_tied_parameters(model)
assert sorted(tied_params) == [["0.bias", "1.bias"], ["0.weight", "1.weight"]]

def test_retie_parameters(self):
model = sequential_model(2)
retie_parameters(model, [["linear1.weight", "linear2.weight"]])
Expand Down

0 comments on commit 851cf34

Please sign in to comment.