Skip to content

Commit e3f5762

Browse files
committed
opt.py: feature to promote half precision variables to single precision
This commit tweaks the built-in Dr.Jit optimizers (SGD, RMSProp, Adam) so that they can (optionally) promote half precision variables to single precision to improve the stability of the optimization. Specify `promote_fp16=False` to the optimizer constructor to disable this behavior.
1 parent 3bab90b commit e3f5762

File tree

2 files changed

+121
-29
lines changed

2 files changed

+121
-29
lines changed

drjit/opt.py

Lines changed: 87 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,15 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]):
120120
# Mask updates of parameters that did not receive a gradient?
121121
mask_updates: bool
122122

123+
# Promote half-precision variables to use single precision internal storage?
124+
promote_fp16: bool
125+
123126
# Maps the parameter name to a tuple containing
124127
# - the current parameter value
128+
# - whether the parameter was promoted to single precision
125129
# - an parameter-specific learning rate (or None)
126130
# - an arbitrary sequence of additional optimizer-dependent state values
127-
state: Dict[str, Tuple[dr.ArrayBase, Optional[LearningRate], Extra]]
131+
state: Dict[str, Tuple[dr.ArrayBase, bool, Optional[LearningRate], Extra]]
128132

129133
DRJIT_STRUCT = {
130134
"lr": LearningRate,
@@ -137,6 +141,7 @@ def __init__(
137141
params: Optional[Mapping[str, dr.ArrayBase]] = None,
138142
*,
139143
mask_updates: bool = False,
144+
promote_fp16: bool = True,
140145
):
141146
"""
142147
Create an empty Optimizer object with the learning rate ``lr`` and initial
@@ -171,13 +176,21 @@ def __init__(
171176
of the above steps, similar to PyTorch's `SparseAdam optimizer
172177
<https://pytorch.org/docs/1.9.0/generated/torch.optim.SparseAdam.html>`_.
173178
Dr.Jit supports this feature for all optimizers.
179+
180+
promote_fp16 (bool):
181+
If set to ``True`` (the default), the optimizer internally
182+
promotes half precision parameters to single precision to
183+
prevent issues, where rounding inteferes with the optimization.
184+
Accessing the current state via ``optimizer["parameter_name"]``
185+
will cast back to half precision.
174186
"""
175187

176188
if isinstance(lr, float) and lr < 0:
177189
raise RuntimeError("'lr' must be >0")
178190

179191
self.lr = lr
180192
self.mask_updates = mask_updates
193+
self.promote_fp16 = promote_fp16
181194
self.state = {}
182195

183196
if params:
@@ -189,7 +202,14 @@ def __contains__(self, key: object, /) -> bool:
189202

190203
def __getitem__(self, key: str, /) -> dr.ArrayBase:
191204
"""Retrieve a parameter value from the optimizer."""
192-
return self.state[key][0]
205+
entry = self.state[key]
206+
value = entry[0]
207+
208+
# If previously promoted from FP16 -> FP32, cast back
209+
if entry[1]:
210+
value = dr.float16_array_t(value)(value)
211+
212+
return value
193213

194214
def __setitem__(self, key: str, value: dr.ArrayBase, /):
195215
"""
@@ -245,10 +265,14 @@ def __setitem__(self, key: str, value: dr.ArrayBase, /):
245265
# Reattach the copy to the AD graph
246266
dr.enable_grad(value)
247267

268+
promoted = self.promote_fp16 and dr.type_v(value) == dr.VarType.Float16
269+
if promoted:
270+
value = dr.float32_array_t(value)(value)
271+
248272
if prev is not None and prev[0].shape == value.shape:
249-
self.state[key] = value, *prev[1:]
273+
self.state[key] = value, promoted, *prev[2:]
250274
else:
251-
self._reset(key, value)
275+
self._reset(key, value, promoted)
252276

253277
def __len__(self) -> int:
254278
"""Return the number of registered parameters."""
@@ -271,7 +295,7 @@ def learning_rate(self, key: Optional[str] = None) -> Optional[LearningRate]:
271295
if key is None:
272296
return self.lr
273297
else:
274-
return self.state[key][1]
298+
return self.state[key][2]
275299

276300
def set_learning_rate(
277301
self,
@@ -318,7 +342,7 @@ def set_learning_rate(
318342
elif isinstance(value, Mapping):
319343
for k, lr in value.items():
320344
state = self.state[k]
321-
self.state[k] = (state[0], lr, *state[2:])
345+
self.state[k] = (*state[:2], lr, *state[3:])
322346
if kwargs:
323347
self.set_learning_rate(kwargs)
324348

@@ -368,10 +392,11 @@ def reset(self, key: Optional[str] = None) -> None:
368392
"""
369393

370394
if key is not None:
371-
self._reset(key, self[key])
395+
value, promoted, lr, extra = self.state[key]
396+
self._reset(key, value, promoted)
372397
else:
373-
for k in self.state.keys():
374-
self._reset(k, self[k])
398+
for key, (value, promoted, lr, extra) in self.state.items():
399+
self._reset(key, value, promoted)
375400

376401
# --------------------------------------------------------------------
377402
# Functionality that must be provided by subclasses
@@ -414,7 +439,7 @@ def step(
414439
with dr.profile_range('Optimizer.step()'):
415440
cache = _LRCache()
416441

417-
for key, (value, lr, extra) in self.state.items():
442+
for key, (value, promoted, lr, extra) in self.state.items():
418443
# Fetch the parameter gradient and convert special array types
419444
# (e.g. complex numbers) into ones with element-wise semantics
420445
grad = value.grad.array
@@ -451,7 +476,7 @@ def step(
451476
dr.enable_grad(new_value)
452477

453478
# Update the optimizer state and schedule it for evaluation
454-
new_state = new_value, lr, new_extra
479+
new_state = new_value, promoted, lr, new_extra
455480

456481
dr.schedule(new_state)
457482
self.state[key] = new_state
@@ -475,8 +500,8 @@ def _step(
475500
)
476501

477502
# To be provided by subclasses
478-
def _reset(self, key: str, value: dr.ArrayBase, /) -> None:
479-
raise Exception(f"Optimizer._reset({key}, {value}): missing implementation!")
503+
def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None:
504+
raise Exception(f"Optimizer._reset({key}, {value}, {promoted}): missing implementation!")
480505

481506
# Blend between the old and new versions of the optimizer extra state
482507
def _select(self, mask: dr.ArrayBase, extra: Extra, new_extra: Extra, /) -> Extra:
@@ -592,6 +617,7 @@ def __init__(
592617
momentum: float = 0.0,
593618
nesterov: bool = False,
594619
mask_updates: bool = False,
620+
promote_fp16: bool = True,
595621
):
596622
"""
597623
Args:
@@ -607,6 +633,11 @@ def __init__(
607633
cause past gradients to persist for a longer amount of time.
608634
609635
mask_updates (bool):
636+
Mask updates to zero-valued gradient components?
637+
See :py:func:`Optimizer.__init__()` for details on this parameter.
638+
639+
promote_fp16 (bool):
640+
promoted half-precision variables to single precision internal storage?
610641
See :py:func:`Optimizer.__init__()` for details on this parameter.
611642
612643
params (Mapping[str, drjit.ArrayBase] | None):
@@ -623,7 +654,12 @@ def __init__(
623654
self.momentum = momentum
624655
self.nesterov = nesterov
625656

626-
super().__init__(lr, params, mask_updates=mask_updates)
657+
super().__init__(
658+
lr,
659+
params,
660+
mask_updates=mask_updates,
661+
promote_fp16=promote_fp16
662+
)
627663

628664
# To be provided by subclasses
629665
def _step(
@@ -656,21 +692,21 @@ def _step(
656692

657693
return dr.fma(step, scale, value), v_next
658694

659-
def _reset(self, key: str, value: dr.ArrayBase, /) -> None:
695+
def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None:
660696
valarr = value.array
661697
tp = type(valarr)
662698
if self.momentum == 0:
663699
m = None
664700
else:
665701
m = dr.opaque(tp, 0, valarr.shape)
666-
self.state[key] = value, None, m
702+
self.state[key] = value, promoted, None, m
667703

668704
def __repr__(self):
669705
"""Return a human-readable string representation"""
670706
lr_dict: Dict[str, LearningRate] = dict(default=self.lr)
671707
total_count = 0
672708
for k, state in self.state.items():
673-
lr = state[1]
709+
lr = state[2]
674710
total_count += dr.prod(state[0].shape)
675711
if lr is not None:
676712
lr_dict[k] = lr
@@ -728,6 +764,7 @@ def __init__(
728764
alpha: float = 0.99,
729765
epsilon: float = 1e-8,
730766
mask_updates: bool = False,
767+
promote_fp16: bool = True,
731768
):
732769
"""
733770
Construct a RMSProp optimizer instance.
@@ -746,14 +783,24 @@ def __init__(
746783
persist for a longer amount of time.
747784
748785
mask_updates (bool):
786+
Mask updates to zero-valued gradient components?
787+
See :py:func:`Optimizer.__init__()` for details on this parameter.
788+
789+
promote_fp16 (bool):
790+
promoted half-precision variables to single precision internal storage?
749791
See :py:func:`Optimizer.__init__()` for details on this parameter.
750792
751793
params (Mapping[str, drjit.ArrayBase] | None):
752794
Optional dictionary-like object containing an initial set of
753795
parameters.
754796
"""
755797

756-
super().__init__(lr, params, mask_updates=mask_updates)
798+
super().__init__(
799+
lr,
800+
params,
801+
mask_updates=mask_updates,
802+
promote_fp16=promote_fp16
803+
)
757804

758805
if alpha < 0 or alpha >= 1:
759806
raise RuntimeError("'alpha' must be on the interval [0, 1)")
@@ -792,18 +839,18 @@ def _step(
792839
return dr.fma(step, scale, value), m_t
793840

794841
# Implementation detail of Optimizer.reset()
795-
def _reset(self, key: str, value: dr.ArrayBase, /) -> None:
842+
def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None:
796843
valarr = value.array
797844
tp = type(valarr)
798845
m_t = dr.opaque(tp, 0, valarr.shape)
799-
self.state[key] = value, None, m_t
846+
self.state[key] = value, promoted, None, m_t
800847

801848
def __repr__(self):
802849
"""Return a human-readable string representation"""
803850
lr_dict: Dict[str, LearningRate] = dict(default=self.lr)
804851
total_count = 0
805852
for k, state in self.state.items():
806-
lr = state[1]
853+
lr = state[2]
807854
total_count += dr.prod(state[0].shape)
808855
if lr is not None:
809856
lr_dict[k] = lr
@@ -887,6 +934,7 @@ def __init__(
887934
beta_2: float = 0.999,
888935
epsilon: float = 1e-8,
889936
mask_updates: bool = False,
937+
promote_fp16: bool = True,
890938
uniform: bool = False,
891939
):
892940
"""
@@ -918,14 +966,24 @@ def __init__(
918966
instead of the per-element second moments.
919967
920968
mask_updates (bool):
969+
Mask updates to zero-valued gradient components?
970+
See :py:func:`Optimizer.__init__()` for details on this parameter.
971+
972+
promote_fp16 (bool):
973+
promoted half-precision variables to single precision internal storage?
921974
See :py:func:`Optimizer.__init__()` for details on this parameter.
922975
923976
params (Mapping[str, drjit.ArrayBase] | None):
924977
Optional dictionary-like object containing an initial set of
925978
parameters.
926979
"""
927980

928-
super().__init__(lr, params, mask_updates=mask_updates)
981+
super().__init__(
982+
lr,
983+
params,
984+
mask_updates=mask_updates,
985+
promote_fp16=promote_fp16
986+
)
929987

930988
if beta_1 < 0 or beta_1 >= 1:
931989
raise RuntimeError("'beta_1' must be on the interval [0, 1)")
@@ -965,10 +1023,11 @@ def _step(
9651023
# Compute the step size scale, which is a product of
9661024
# - EMA debiasing factor
9671025
# - Adaptive/parameter-specific scaling
968-
Float32 = dr.float32_array_t(dr.leaf_t(grad))
1026+
Base = dr.leaf_t(grad)
9691027
Float64 = dr.float64_array_t(dr.leaf_t(grad))
970-
ema_factor = Float32(
971-
-dr.sqrt(1 - Float64(self.beta_2) ** t) / (1 - Float64(self.beta_1) ** t)
1028+
ema_factor = Base(
1029+
-dr.sqrt(1 - Float64(self.beta_2) ** t) /
1030+
(1 - Float64(self.beta_1) ** t)
9721031
)
9731032
scale = cache.product(
9741033
dr.leaf_t(grad), # Desired type
@@ -988,14 +1047,14 @@ def _step(
9881047
return dr.fma(step, scale, value), (t, m_t, v_t)
9891048

9901049
# Implementation detail of Optimizer.reset()
991-
def _reset(self, key: str, value: dr.ArrayBase, /) -> None:
1050+
def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None:
9921051
valarr = value.array
9931052
tp = type(valarr)
9941053
UInt = dr.uint32_array_t(dr.leaf_t(tp))
9951054
t = UInt(0)
9961055
m_t = dr.opaque(tp, 0, valarr.shape)
9971056
v_t = dr.opaque(tp, 0, valarr.shape)
998-
self.state[key] = value, None, (t, m_t, v_t)
1057+
self.state[key] = value, promoted, None, (t, m_t, v_t)
9991058

10001059
# Blend between the old and new versions of the optimizer extra state
10011060
def _select(
@@ -1020,7 +1079,7 @@ def __repr__(self):
10201079
total_count = 0
10211080
for k, state in self.state.items():
10221081
total_count += dr.prod(state[0].shape)
1023-
lr = state[1]
1082+
lr = state[2]
10241083
if lr is not None:
10251084
lr_dict[k] = lr
10261085

tests/test_opt.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,5 +310,38 @@ def test09_optimize_nested(t):
310310
opt.step()
311311
opt["x"].grad = t(-1, -1, -1)
312312
opt.step()
313-
print(opt["x"])
314313
assert dr.width(opt["x"]) == 1
314+
315+
316+
@pytest.mark.parametrize("optimizer_class", [SGD, RMSProp, Adam])
317+
@pytest.mark.parametrize("promote_fp16", [False, True])
318+
@pytest.test_arrays("is_diff,float,shape=(*),float32")
319+
def test10_promote_fp16(optimizer_class, promote_fp16, t):
320+
"""Demosntrate that without FP16->FP32 promotion, rounding error can break optimizations"""
321+
t16 = dr.float16_array_t(t)
322+
323+
324+
# Starting point and target
325+
x = t16(100)
326+
target = 101
327+
328+
opt = optimizer_class(lr=1e-2, promote_fp16=promote_fp16)
329+
opt["x"] = x
330+
n_steps = 200
331+
332+
for _ in range(n_steps):
333+
loss = (opt["x"] - target) ** 2 # Simple quadratic loss
334+
dr.backward(loss)
335+
opt.step()
336+
337+
# Check internal representation
338+
assert dr.type_v(opt.state["x"][0]) == (dr.VarType.Float32 if promote_fp16 else dr.VarType.Float16)
339+
340+
# Check returned result
341+
final_value = opt["x"]
342+
assert dr.type_v(final_value) == dr.VarType.Float16
343+
344+
success = (abs(final_value[0]-target) < 0.1) == promote_fp16
345+
if not success:
346+
print(f" Target: {target}, Final: {final_value[0]:.8f}")
347+
assert success

0 commit comments

Comments
 (0)