-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
324 lines (270 loc) · 12.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""Utility functions for developing unit-scaled models."""
import ast
import math
import re
import typing
from collections import OrderedDict
from dataclasses import dataclass
from types import FunctionType, ModuleType
from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union, cast
import einops
import torch
from pygments import highlight
from pygments.formatters import TerminalFormatter
from pygments.lexers import PythonLexer
from torch import Tensor, fx, nn
from . import functional
from ._internal_utils import generate__all__
@dataclass
class ScalePair:
"""Dataclass containing a pair of scalars, intended to represent the standard
deviation of an arbitrary tensor in the forward and backward passes."""
forward: Optional[float] = None
backward: Optional[float] = None
def __str__(self) -> str:
fwd = f"{self.forward:.3}" if self.forward is not None else "n/a"
bwd = f"{self.backward:.3}" if self.backward is not None else "n/a"
return f"(-> {fwd}, <- {bwd})"
ScaleDict = typing.OrderedDict[str, ScalePair]
class ScaleTracker(torch.autograd.Function):
"""Given a `nn.Tensor`, records its standard deviation in the forward and
backward pass in the supplied `ScalePair`."""
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
t: Tensor,
scale_tracker: ScalePair,
) -> Tensor:
scale_tracker.forward = float(t.std())
ctx.scale_tracker = scale_tracker # type: ignore
return t
@staticmethod
def backward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, t: Tensor
) -> Tuple[Tensor, None, None]:
ctx.scale_tracker.backward = float(t.std()) # type: ignore
return t, None, None
@staticmethod
def track(t: Tensor, scale_tracker: ScalePair) -> Tensor:
# Add typing information to `apply()` method from `torch.autograd.Function`
apply = cast(Callable[[Tensor, ScalePair], Tensor], ScaleTracker.apply)
return apply(t, scale_tracker)
class ScaleTrackingInterpreter(fx.Interpreter):
"""Wraps an `fx.GraphModule` such than when executed it records the standard
deviation of every intermediate `nn.Tensor` in the forward and backward pass.
Args:
module (fx.GraphModule): the module to be instrumented.
"""
def __init__(self, module: fx.GraphModule):
super().__init__(module)
self.scales: typing.OrderedDict[str, ScalePair] = OrderedDict()
def run_node(self, n: fx.Node) -> Any:
out = super().run_node(n)
if isinstance(out, Tensor) and out.is_floating_point():
scale_pair = ScalePair()
out = ScaleTracker.track(out, scale_pair)
self.scales[n.name] = scale_pair
return out
def call_function(
self, target: fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
return super().call_function(target, args, kwargs)
def placeholder(
self,
target: fx.node.Target,
args: Tuple[fx.node.Argument, ...],
kwargs: Dict[str, Any],
) -> Any:
"""To handle functions being passed as arguments (for example constraints) the
tracer represents them as placeholder nodes. This method extracts the original
function from the node, as stored in the `target_to_function` dict."""
if isinstance(target, str) and target.startswith("function_placeholder__"):
return self.module.graph._tracer_extras["target_to_function"][
target
] # pragma: no cover
return super().placeholder(target, args, kwargs)
def _record_scales(
fx_graph_module: fx.GraphModule,
inputs: Tuple[Tensor, ...],
backward: Optional[Tensor] = None,
) -> ScaleDict:
"""Given a `torch.fx.GraphModule`, and dummy tensors to feed into the forward and
backward passes, returns a dictionary of the scales (standard deviations) of every
intermediate tensor in the model (forward and backward pass).
Args:
fx_graph_module (fx.GraphModule): the module to record.
input (Tuple[Tensor, ...]): fed into the forward pass for analysis.
backward (Tensor, optional): fed into the output's `.backward()` method for
analysis. Defaults to `None`, equivalent to calling plain `.backward()`.
Returns:
ScaleDict: An ordered dictionary with `ScalePair`s for each intermediate tensor.
"""
tracking_module = ScaleTrackingInterpreter(fx_graph_module)
out = tracking_module.run(*inputs)
out.backward(backward)
return tracking_module.scales
def _annotate(code: str, scales: ScaleDict, syntax_highlight: bool) -> str:
"""Given a string representation of some code and an `ScaleDict` with accompanying
scales, annotates the code to include the scales on the right-hand side."""
function_placeholder_regex = r"function_placeholder__(\w+)"
def is_function_placeholder_line(code_line: str) -> bool:
return bool(re.search(f" = {function_placeholder_regex}$", code_line))
def cleanup_function_signature(code_line: str) -> str:
code_line = re.sub(f", {function_placeholder_regex}", "", code_line)
inner_code_line = code_line.split("(", 1)[1]
replacement = re.sub(r"_([a-zA-Z0-9_]+)_", r"\1", inner_code_line)
return code_line.replace(inner_code_line, replacement)
def annotate_line(code_line: str) -> str:
if code_line.startswith("torch.fx._symbolic_trace.wrap"):
return ""
code_line = code_line.split(";")[0]
if is_function_placeholder_line(code_line): # pragma: no cover
return ""
words = code_line.strip().split(" ")
if words:
if words[0] in scales:
return f"{code_line}; {scales[words[0]]}"
elif words[0] == "def":
parsed = ast.parse(code_line + "\n\t...").body[0]
assert isinstance(parsed, ast.FunctionDef)
arg_names = [arg.arg for arg in parsed.args.args]
scale_strs = [str(scales[a]) for a in arg_names if a in scales]
code_line = cleanup_function_signature(code_line)
if scale_strs:
return f"{code_line} {', '.join(scale_strs)}" # pragma: no cover
else:
return code_line
return code_line
def remove_empty_lines(code_lines: Iterator[str]) -> Iterator[str]:
return (line for line in code_lines if line.strip())
code_lines = map(annotate_line, code.splitlines())
code = "\n".join(remove_empty_lines(code_lines)).strip()
code = code.replace("unit_scaling_functional_", "U.")
if syntax_highlight:
return highlight(code, PythonLexer(), TerminalFormatter()) # pragma: no cover
return code
class _DeepTracer(fx.Tracer):
"""Version of `torch.fx.Tracer` which recurses into all sub-modules (if specified).
Args:
recurse_modules (bool): toggles recursive behavour. Defaults to True.
autowrap_modules (Tuple[ModuleType]): defaults to
`(math, einops, U.functional)`,
Python modules whose functions should be wrapped automatically
without needing to use fx.wrap().
autowrap_function (Tuple[Callable, ...]): defaults to `()`,
Python functions that should be wrapped automatically without
needing to use fx.wrap().
"""
def __init__(
self,
recurse_modules: bool = True,
autowrap_modules: Tuple[ModuleType, ...] = (math, einops, functional),
autowrap_functions: Tuple[Callable[..., Any], ...] = (),
) -> None:
super().__init__(
autowrap_modules=autowrap_modules, # type: ignore[arg-type]
autowrap_functions=autowrap_functions,
)
self.recurse_modules = recurse_modules
self.target_to_function: Dict[str, FunctionType] = {}
self.function_to_node: Dict[FunctionType, fx.Node] = {}
# Fixes: `TypeError: __annotations__ must be set to a dict object`
if id(FunctionType) in self._autowrap_function_ids:
self._autowrap_function_ids.remove(id(FunctionType))
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
return not self.recurse_modules
def create_arg(self, a: Any) -> fx.node.Argument:
"""Replaces callable arguments with strings for tracing."""
if isinstance(a, FunctionType): # pragma: no cover
node = self.function_to_node.get(a)
if node is None:
assert hasattr(
a, "__name__"
), f"can't create arg for unnamed function: {a}"
name = getattr(a, "__name__")
target = f"function_placeholder__{name}"
node = self.create_node("placeholder", target, (), {}, name)
self.target_to_function[target] = a
self.function_to_node[a] = node
return node
return super().create_arg(a)
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
) -> fx.Graph:
"""Adds the `target_to_function` dict to the graph so the interpreter can use it
downstream."""
graph = super().trace(root, concrete_args)
if not hasattr(graph, "_tracer_extras") or graph._tracer_extras is None:
graph._tracer_extras = {}
graph._tracer_extras["target_to_function"] = self.target_to_function
return graph
def analyse_module(
module: nn.Module,
inputs: Union[Tensor, Tuple[Tensor, ...]],
backward: Optional[Tensor] = None,
recurse_modules: bool = True,
syntax_highlight: bool = True,
autowrap_modules: Tuple[ModuleType, ...] = (math, einops, functional),
autowrap_functions: Tuple[Callable[..., Any], ...] = (),
) -> str:
"""Given a `nn.Module` and dummy forward and backward tensors, generates code
representing the module annotated with the scales (standard deviation) of each
tensor in both forward and backward passes. Implemented using `torch.fx`.
Args:
module (nn.Module): the module to analyse.
inputs (Union[Tensor, Tuple[Tensor, ...]]): fed into the forward pass for
analysis.
backward (Tensor, optional): fed into the output's `.backward()` method for
analysis. Defaults to `None`, equivalent to calling plain `.backward()`.
recurse_modules (bool, optional): toggles recursive behavour. Defaults to True.
syntax_highlight (bool, optional): Defaults to True.
autowrap_modules (Tuple[ModuleType]): defaults to
`(math, einops, U.functional)`,
Python modules whose functions should be wrapped automatically
without needing to use fx.wrap().
autowrap_function (Tuple[Callable, ...]): defaults to `()`,
Python functions that should be wrapped automatically without
needing to use fx.wrap().
Returns:
str:
a code string representing the operations in the module with scale
annotations for each tensor, reflecting their standard deviations in the
forward and backward passes.
Examples::
>>> class MLP(nn.Module):
>>> def __init__(self, d):
>>> super().__init__()
>>> self.fc1 = nn.Linear(d, d * 4)
>>> self.relu = nn.ReLU()
>>> self.fc2 = nn.Linear(d * 4, d)
>>> def forward(self, x):
>>> x = self.fc1(x)
>>> x = self.relu(x)
>>> x = self.fc2(x)
>>> return x
>>> hidden_size = 2**10
>>> x = torch.randn(hidden_size, hidden_size).requires_grad_()
>>> bwd = torch.randn(hidden_size, hidden_size)
>>> code = analyse_module(MLP(hidden_size), x, bwd)
>>> print(code)
def forward(self, x): (-> 1.0, <- 0.236)
fc1_weight = self.fc1.weight; (-> 0.018, <- 6.54)
fc1_bias = self.fc1.bias; (-> 0.0182, <- 6.51)
linear = torch._C._nn.linear(x, fc1_weight, fc1_bias); (-> 0.578, <- 0.204)
relu = torch.nn.functional.relu(linear, inplace = False); (-> 0.337, <- 0.288)
fc2_weight = self.fc2.weight; (-> 0.00902, <- 13.0)
fc2_bias = self.fc2.bias; (-> 0.00904, <- 31.6)
linear_1 = torch._C._nn.linear(relu, fc2_weight, fc2_bias); (-> 0.235, <- 0.999)
return linear_1
""" # noqa: E501
tracer = _DeepTracer(recurse_modules, autowrap_modules, autowrap_functions)
fx_graph = tracer.trace(module)
fx_graph_module = fx.GraphModule(tracer.root, fx_graph)
if not isinstance(inputs, tuple):
inputs = (inputs,)
scales = _record_scales(fx_graph_module, inputs, backward)
return _annotate(fx_graph_module.code, scales, syntax_highlight=syntax_highlight)
__all__ = generate__all__(__name__)