-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathformats.py
137 lines (107 loc) · 4.83 KB
/
formats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""Classes for simulating (non-standard) number formats."""
from dataclasses import dataclass
from typing import Tuple, cast
import torch
from torch import Tensor
from ._internal_utils import generate__all__
Shape = Tuple[int, ...]
@dataclass
class FPFormat:
"""Generic representation of a floating-point number format."""
exponent_bits: int
mantissa_bits: int
rounding: str = "stochastic" # "stochastic|nearest"
srbits: int = 0 # Number of bits for stochastic rounding, zero => use all bits
def __post_init__(self) -> None:
assert self.exponent_bits >= 2, "FPFormat requires at least 2 exponent bits"
assert (
self.srbits == 0 or self.rounding == "stochastic"
), "Nonzero srbits for non-stochastic rounding"
if self.srbits == 0 and self.rounding == "stochastic":
self.srbits = 23 - self.mantissa_bits
@property
def bits(self) -> int:
"""The number of bits used by the format."""
return 1 + self.exponent_bits + self.mantissa_bits
def __str__(self) -> str: # pragma: no cover
return (
f"E{self.exponent_bits}M{self.mantissa_bits}-"
+ dict(stochastic="SR", nearest="RN")[self.rounding]
)
@property
def max_absolute_value(self) -> float:
"""The maximum absolute value representable by the format."""
max_exponent = 2 ** (self.exponent_bits - 1) - 1
return cast(float, 2**max_exponent * (2 - 2**-self.mantissa_bits))
@property
def min_absolute_normal(self) -> float:
"""The minimum absolute normal value representable by the format."""
min_exponent = 1 - 2 ** (self.exponent_bits - 1)
return cast(float, 2**min_exponent)
@property
def min_absolute_subnormal(self) -> float:
"""The minimum absolute subnormal value representable by the format."""
return self.min_absolute_normal * 2.0**-self.mantissa_bits
def quantise(self, x: Tensor) -> Tensor:
"""Non-differentiably quantise the given tensor in this format."""
absmax = self.max_absolute_value
downscale = 2.0 ** (127 - 2 ** (self.exponent_bits - 1))
mask = torch.tensor(2 ** (23 - self.mantissa_bits) - 1, device=x.device)
if self.rounding == "stochastic":
srbitsbar = 23 - self.mantissa_bits - self.srbits
offset = (
torch.randint(
0, 2**self.srbits, x.shape, dtype=torch.int32, device=x.device
)
<< srbitsbar
)
# Correct for bias. We can do this only for srbits < 23-mantissa_bits,
# but it is only likely to matter when srbits is small.
if srbitsbar > 0:
offset += 1 << (srbitsbar - 1)
elif self.rounding == "nearest":
offset = mask // 2
else: # pragma: no cover
raise ValueError(
f'Unexpected FPFormat(rounding="{self.rounding}"),'
' expected "stochastic" or "nearest"'
)
q = x.to(torch.float32)
q = torch.clip(x, -absmax, absmax)
q /= downscale
q = ((q.view(torch.int32) + offset) & ~mask).view(torch.float32)
q *= downscale
return q.to(x.dtype)
def quantise_fwd(self, x: Tensor) -> Tensor:
"""Quantise the given tensor in the forward pass only."""
class QuantiseForward(torch.autograd.Function):
@staticmethod
def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor:
return self.quantise(x)
@staticmethod
def backward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor
) -> Tensor:
return grad_y
return QuantiseForward.apply(x) # type: ignore
def quantise_bwd(self, x: Tensor) -> Tensor:
"""Quantise the given tensor in the backward pass only."""
class QuantiseBackward(torch.autograd.Function):
@staticmethod
def forward(ctx: torch.autograd.function.FunctionCtx, x: Tensor) -> Tensor:
return x
@staticmethod
def backward( # type:ignore[override]
ctx: torch.autograd.function.FunctionCtx, grad_y: Tensor
) -> Tensor:
return self.quantise(grad_y)
return QuantiseBackward.apply(x) # type: ignore
def format_to_tuple(format: FPFormat) -> Tuple[int, int]:
"""Convert the format into a tuple of `(exponent_bits, mantissa_bits)`"""
return (format.exponent_bits, format.mantissa_bits)
def tuple_to_format(t: Tuple[int, int]) -> FPFormat:
"""Given a tuple of `(exponent_bits, mantissa_bits)` returns the corresponding
:class:`FPFormat`"""
return FPFormat(*t)
__all__ = generate__all__(__name__)