Skip to content

Commit 532a4c4

Browse files
committed
torch distributed: add support for user-specified parameter synchronization
1 parent eb0f22e commit 532a4c4

File tree

1 file changed

+114
-12
lines changed

1 file changed

+114
-12
lines changed

returnn/torch/distributed.py

Lines changed: 114 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,72 @@
33
"""
44

55
from __future__ import annotations
6-
from typing import Optional, Any, Dict
6+
from abc import abstractmethod, ABC
7+
import logging
8+
import numpy
79
import os
810
import socket
9-
import logging
11+
from typing import Callable, Optional, Any, Dict, Type, Union
1012

1113
import torch
1214
from torch.nn.parallel import DistributedDataParallel
1315

14-
from returnn.config import Config
15-
from returnn.util.basic import CollectionReadCheckCovered
16+
from returnn.util.basic import CollectionReadCheckCovered, OptionalNotImplementedError
1617

1718
_logger = logging.getLogger("returnn.torch.distributed")
1819

1920

21+
class ParamSynchronizer(ABC):
22+
"""
23+
Custom parameter synchronization primitive.
24+
25+
Contains a callback that is called after every train step to synchronize model parameters
26+
across processes/nodes.
27+
"""
28+
29+
@abstractmethod
30+
def __init__(self, *, rank: int, size: int, local_rank: int, local_size: int, **kwargs):
31+
"""
32+
`__init__` called after the default global process group is created.
33+
Can be used to initialize any additional custom process (sub)groups.
34+
35+
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
36+
37+
:param rank: global rank of the current process across all nodes
38+
:param size: global world size across all nodes
39+
:param local_rank: local rank of the current process on the current node
40+
:param local_rank: local world size on the current node
41+
"""
42+
super().__init__()
43+
44+
def make_distributed_model(self, *, module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
45+
"""
46+
Creates an associated `DistributedDataParallel` for the given module for gradient synchronization.
47+
48+
This function can be left unimplemented if no gradient synchronization is done.
49+
50+
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
51+
"""
52+
raise OptionalNotImplementedError
53+
54+
@abstractmethod
55+
def step(self, *, module: torch.nn.Module, train_step_idx: int, **kwargs):
56+
"""
57+
Parameter synchronization callback called after every train step.
58+
59+
Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
60+
61+
:param module: the NN being trained
62+
:param train_step_idx: the current train step
63+
:param kwargs: any additional kwargs.
64+
"""
65+
raise NotImplementedError
66+
67+
def __call__(self, *args, **kwargs):
68+
"""forwards to :func:``step``"""
69+
return self.step(*args, **kwargs)
70+
71+
2072
class DistributedContext:
2173
"""
2274
This class setups some helper functions for torch distributed training
@@ -26,6 +78,9 @@ def __init__(self, options: Dict[str, Any]):
2678
import torch.distributed as dist
2779

2880
self._opts = CollectionReadCheckCovered(options)
81+
# Only used to generate forwards compatibility ensuring random kwargs, therefore
82+
# the seed is not important
83+
self._rng = numpy.random.default_rng()
2984

3085
# when no backend is specified, both gloo and nccl backends will be created
3186
# the gloo backend will be used for collectives with CPU tensors and
@@ -42,8 +97,13 @@ def __init__(self, options: Dict[str, Any]):
4297
% (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size)
4398
)
4499

100+
self._custom_sync_class: Optional[Union[Callable, Type[ParamSynchronizer]]] = self._opts.get(
101+
"synchronizer", None
102+
)
103+
self._custom_sync: Optional[Callable] = None
45104
self._reduce_type = self._opts.get("reduce_type", "grad")
46105
self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None)
106+
47107
if self._reduce_type == "param":
48108
assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, (
49109
f"reduce_type param: param_sync_step must be a positive int,"
@@ -52,6 +112,23 @@ def __init__(self, options: Dict[str, Any]):
52112
_logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}")
53113
elif self._reduce_type == "grad":
54114
_logger.info("reduce_type grad")
115+
elif self._reduce_type == "custom":
116+
if issubclass(self._custom_sync_class, ParamSynchronizer):
117+
self._custom_sync = self._custom_sync_class(
118+
rank=self._rank,
119+
size=self._size,
120+
local_rank=self._local_rank,
121+
local_size=self._local_size,
122+
**{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None},
123+
)
124+
elif isinstance(self._custom_sync_class, Callable):
125+
self._custom_sync = self._custom_sync_class
126+
else:
127+
raise ValueError(
128+
f"synchronizer must either be a callable or a class inheriting from {ParamSynchronizer.__name__}"
129+
)
130+
131+
_logger.info(f"reduce_type custom: {type(self._custom_sync)}")
55132
else:
56133
raise ValueError(f"invalid reduce_type {self._reduce_type!r}")
57134

@@ -70,6 +147,8 @@ def _check_no_unknown_opts(self):
70147
self._opts.get("options")
71148
if self._reduce_type == "param":
72149
self._opts.get("sync_on_cpu")
150+
if self._reduce_type == "custom":
151+
self._opts.get("synchronizer")
73152

74153
self._opts.assert_all_read()
75154

@@ -102,7 +181,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
102181
"""
103182
if self._reduce_type == "param":
104183
return None
105-
assert self._reduce_type == "grad"
184+
assert self._reduce_type in ["custom", "grad"]
185+
186+
if self._reduce_type == "custom":
187+
assert isinstance(self._custom_sync, (ParamSynchronizer, Callable))
188+
189+
if isinstance(self._custom_sync, ParamSynchronizer):
190+
try:
191+
return self._custom_sync.make_distributed_model(
192+
module=module, **{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None}
193+
)
194+
except OptionalNotImplementedError:
195+
pass
196+
else:
197+
# callable short form does not have support for DistributedDataParallel
198+
pass
199+
200+
return None
201+
106202
cls = self._opts.get("class", DistributedDataParallel)
107203
if cls is not DistributedDataParallel:
108204
_logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.")
@@ -115,7 +211,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
115211

116212
def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int):
117213
"""one train step"""
118-
if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
214+
if self._reduce_type == "custom":
215+
with torch.no_grad(): # TODO: do we want this for all syncers?
216+
self._custom_sync(
217+
module=module,
218+
train_step_idx=epoch_step_idx,
219+
**{f"fwd_compatible_random_kwarg_{self._rng.integers(0, 100)}": None},
220+
)
221+
elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
119222
_sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False))
120223

121224

@@ -155,7 +258,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
155258

156259
if sync_on_cpu:
157260
for param in module.parameters():
158-
# Separately move each param to CPU (instead of the whole module), to safe CPU memory.
261+
# Separately move each param to CPU (instead of the whole module), to save CPU memory.
159262
param_cpu = param.to(torch.device("cpu"))
160263
# On CPU, we are likely using Gloo, and Gloo does not support AVG
161264
dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM)
@@ -166,12 +269,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
166269
if dist.get_backend() == "gloo":
167270
# Gloo does not support AVG
168271
reduce_op = dist.ReduceOp.SUM
272+
elif hasattr(dist.ReduceOp, "AVG"):
273+
reduce_op = dist.ReduceOp.AVG
169274
else:
170-
if hasattr(dist.ReduceOp, "AVG"):
171-
reduce_op = dist.ReduceOp.AVG
172-
else:
173-
# Older PyTorch versions do not have ReduceOp.AVG.
174-
reduce_op = dist.ReduceOp.SUM
275+
# Older PyTorch versions do not have ReduceOp.AVG.
276+
reduce_op = dist.ReduceOp.SUM
175277

176278
for param in module.parameters():
177279
dist.all_reduce(param.data, op=reduce_op)

0 commit comments

Comments
 (0)