3
3
"""
4
4
5
5
from __future__ import annotations
6
- from typing import Optional , Any , Dict
6
+ from abc import abstractmethod , ABC
7
+ import logging
8
+ import numpy
7
9
import os
8
10
import socket
9
- import logging
11
+ from typing import Callable , Optional , Any , Dict , Type , Union
10
12
11
13
import torch
12
14
from torch .nn .parallel import DistributedDataParallel
13
15
14
- from returnn .config import Config
15
- from returnn .util .basic import CollectionReadCheckCovered
16
+ from returnn .util .basic import CollectionReadCheckCovered , OptionalNotImplementedError
16
17
17
18
_logger = logging .getLogger ("returnn.torch.distributed" )
18
19
19
20
21
+ class ParamSynchronizer (ABC ):
22
+ """
23
+ Custom parameter synchronization primitive.
24
+
25
+ Contains a callback that is called after every train step to synchronize model parameters
26
+ across processes/nodes.
27
+ """
28
+
29
+ @abstractmethod
30
+ def __init__ (self , * , rank : int , size : int , local_rank : int , local_size : int , ** kwargs ):
31
+ """
32
+ `__init__` called after the default global process group is created.
33
+ Can be used to initialize any additional custom process (sub)groups.
34
+
35
+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
36
+
37
+ :param rank: global rank of the current process across all nodes
38
+ :param size: global world size across all nodes
39
+ :param local_rank: local rank of the current process on the current node
40
+ :param local_rank: local world size on the current node
41
+ """
42
+ super ().__init__ ()
43
+
44
+ def make_distributed_model (self , * , module : torch .nn .Module , ** kwargs ) -> DistributedDataParallel :
45
+ """
46
+ Creates an associated `DistributedDataParallel` for the given module for gradient synchronization.
47
+
48
+ This function can be left unimplemented if no gradient synchronization is done.
49
+
50
+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
51
+ """
52
+ raise OptionalNotImplementedError
53
+
54
+ @abstractmethod
55
+ def step (self , * , module : torch .nn .Module , train_step_idx : int , ** kwargs ):
56
+ """
57
+ Parameter synchronization callback called after every train step.
58
+
59
+ Note this function is passed a randomly named kwarg on every invocation to ensure forwards compatbility.
60
+
61
+ :param module: the NN being trained
62
+ :param train_step_idx: the current train step
63
+ :param kwargs: any additional kwargs.
64
+ """
65
+ raise NotImplementedError
66
+
67
+ def __call__ (self , * args , ** kwargs ):
68
+ """forwards to :func:``step``"""
69
+ return self .step (* args , ** kwargs )
70
+
71
+
20
72
class DistributedContext :
21
73
"""
22
74
This class setups some helper functions for torch distributed training
@@ -26,6 +78,9 @@ def __init__(self, options: Dict[str, Any]):
26
78
import torch .distributed as dist
27
79
28
80
self ._opts = CollectionReadCheckCovered (options )
81
+ # Only used to generate forwards compatibility ensuring random kwargs, therefore
82
+ # the seed is not important
83
+ self ._rng = numpy .random .default_rng ()
29
84
30
85
# when no backend is specified, both gloo and nccl backends will be created
31
86
# the gloo backend will be used for collectives with CPU tensors and
@@ -42,8 +97,13 @@ def __init__(self, options: Dict[str, Any]):
42
97
% (socket .gethostname (), os .getpid (), self ._rank , self ._size , self ._local_rank , self ._local_size )
43
98
)
44
99
100
+ self ._custom_sync_class : Optional [Union [Callable , Type [ParamSynchronizer ]]] = self ._opts .get (
101
+ "synchronizer" , None
102
+ )
103
+ self ._custom_sync : Optional [Callable ] = None
45
104
self ._reduce_type = self ._opts .get ("reduce_type" , "grad" )
46
105
self ._param_sync_step : Optional [int ] = self ._opts .get ("param_sync_step" , None )
106
+
47
107
if self ._reduce_type == "param" :
48
108
assert isinstance (self ._param_sync_step , int ) and self ._param_sync_step > 0 , (
49
109
f"reduce_type param: param_sync_step must be a positive int,"
@@ -52,6 +112,23 @@ def __init__(self, options: Dict[str, Any]):
52
112
_logger .info (f"reduce_type param: param_sync_step { self ._param_sync_step } " )
53
113
elif self ._reduce_type == "grad" :
54
114
_logger .info ("reduce_type grad" )
115
+ elif self ._reduce_type == "custom" :
116
+ if issubclass (self ._custom_sync_class , ParamSynchronizer ):
117
+ self ._custom_sync = self ._custom_sync_class (
118
+ rank = self ._rank ,
119
+ size = self ._size ,
120
+ local_rank = self ._local_rank ,
121
+ local_size = self ._local_size ,
122
+ ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None },
123
+ )
124
+ elif isinstance (self ._custom_sync_class , Callable ):
125
+ self ._custom_sync = self ._custom_sync_class
126
+ else :
127
+ raise ValueError (
128
+ f"synchronizer must either be a callable or a class inheriting from { ParamSynchronizer .__name__ } "
129
+ )
130
+
131
+ _logger .info (f"reduce_type custom: { type (self ._custom_sync )} " )
55
132
else :
56
133
raise ValueError (f"invalid reduce_type { self ._reduce_type !r} " )
57
134
@@ -70,6 +147,8 @@ def _check_no_unknown_opts(self):
70
147
self ._opts .get ("options" )
71
148
if self ._reduce_type == "param" :
72
149
self ._opts .get ("sync_on_cpu" )
150
+ if self ._reduce_type == "custom" :
151
+ self ._opts .get ("synchronizer" )
73
152
74
153
self ._opts .assert_all_read ()
75
154
@@ -102,7 +181,24 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
102
181
"""
103
182
if self ._reduce_type == "param" :
104
183
return None
105
- assert self ._reduce_type == "grad"
184
+ assert self ._reduce_type in ["custom" , "grad" ]
185
+
186
+ if self ._reduce_type == "custom" :
187
+ assert isinstance (self ._custom_sync , (ParamSynchronizer , Callable ))
188
+
189
+ if isinstance (self ._custom_sync , ParamSynchronizer ):
190
+ try :
191
+ return self ._custom_sync .make_distributed_model (
192
+ module = module , ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None }
193
+ )
194
+ except OptionalNotImplementedError :
195
+ pass
196
+ else :
197
+ # callable short form does not have support for DistributedDataParallel
198
+ pass
199
+
200
+ return None
201
+
106
202
cls = self ._opts .get ("class" , DistributedDataParallel )
107
203
if cls is not DistributedDataParallel :
108
204
_logger .warning (f"Using custom class { cls } instead of DistributedDataParallel, might be unsupported." )
@@ -115,7 +211,14 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
115
211
116
212
def step_after_param_update (self , * , module : torch .nn .Module , epoch_step_idx : int ):
117
213
"""one train step"""
118
- if self ._reduce_type == "param" and ((epoch_step_idx % self ._param_sync_step ) == (self ._param_sync_step - 1 )):
214
+ if self ._reduce_type == "custom" :
215
+ with torch .no_grad (): # TODO: do we want this for all syncers?
216
+ self ._custom_sync (
217
+ module = module ,
218
+ train_step_idx = epoch_step_idx ,
219
+ ** {f"fwd_compatible_random_kwarg_{ self ._rng .integers (0 , 100 )} " : None },
220
+ )
221
+ elif self ._reduce_type == "param" and ((epoch_step_idx % self ._param_sync_step ) == (self ._param_sync_step - 1 )):
119
222
_sync_params_avg (module = module , sync_on_cpu = self ._opts .get ("sync_on_cpu" , False ))
120
223
121
224
@@ -155,7 +258,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
155
258
156
259
if sync_on_cpu :
157
260
for param in module .parameters ():
158
- # Separately move each param to CPU (instead of the whole module), to safe CPU memory.
261
+ # Separately move each param to CPU (instead of the whole module), to save CPU memory.
159
262
param_cpu = param .to (torch .device ("cpu" ))
160
263
# On CPU, we are likely using Gloo, and Gloo does not support AVG
161
264
dist .all_reduce (param_cpu .data , op = dist .ReduceOp .SUM )
@@ -166,12 +269,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
166
269
if dist .get_backend () == "gloo" :
167
270
# Gloo does not support AVG
168
271
reduce_op = dist .ReduceOp .SUM
272
+ elif hasattr (dist .ReduceOp , "AVG" ):
273
+ reduce_op = dist .ReduceOp .AVG
169
274
else :
170
- if hasattr (dist .ReduceOp , "AVG" ):
171
- reduce_op = dist .ReduceOp .AVG
172
- else :
173
- # Older PyTorch versions do not have ReduceOp.AVG.
174
- reduce_op = dist .ReduceOp .SUM
275
+ # Older PyTorch versions do not have ReduceOp.AVG.
276
+ reduce_op = dist .ReduceOp .SUM
175
277
176
278
for param in module .parameters ():
177
279
dist .all_reduce (param .data , op = reduce_op )
0 commit comments