@@ -120,11 +120,15 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]):
120120 # Mask updates of parameters that did not receive a gradient?
121121 mask_updates : bool
122122
123+ # Promote half-precision variables to use single precision internal storage?
124+ promote_fp16 : bool
125+
123126 # Maps the parameter name to a tuple containing
124127 # - the current parameter value
128+ # - whether the parameter was promoted to single precision
125129 # - an parameter-specific learning rate (or None)
126130 # - an arbitrary sequence of additional optimizer-dependent state values
127- state : Dict [str , Tuple [dr .ArrayBase , Optional [LearningRate ], Extra ]]
131+ state : Dict [str , Tuple [dr .ArrayBase , bool , Optional [LearningRate ], Extra ]]
128132
129133 DRJIT_STRUCT = {
130134 "lr" : LearningRate ,
@@ -137,6 +141,7 @@ def __init__(
137141 params : Optional [Mapping [str , dr .ArrayBase ]] = None ,
138142 * ,
139143 mask_updates : bool = False ,
144+ promote_fp16 : bool = True ,
140145 ):
141146 """
142147 Create an empty Optimizer object with the learning rate ``lr`` and initial
@@ -171,13 +176,21 @@ def __init__(
171176 of the above steps, similar to PyTorch's `SparseAdam optimizer
172177 <https://pytorch.org/docs/1.9.0/generated/torch.optim.SparseAdam.html>`_.
173178 Dr.Jit supports this feature for all optimizers.
179+
180+ promote_fp16 (bool):
181+ If set to ``True`` (the default), the optimizer internally
182+ promotes half precision parameters to single precision to
183+ prevent issues, where rounding inteferes with the optimization.
184+ Accessing the current state via ``optimizer["parameter_name"]``
185+ will cast back to half precision.
174186 """
175187
176188 if isinstance (lr , float ) and lr < 0 :
177189 raise RuntimeError ("'lr' must be >0" )
178190
179191 self .lr = lr
180192 self .mask_updates = mask_updates
193+ self .promote_fp16 = promote_fp16
181194 self .state = {}
182195
183196 if params :
@@ -189,7 +202,14 @@ def __contains__(self, key: object, /) -> bool:
189202
190203 def __getitem__ (self , key : str , / ) -> dr .ArrayBase :
191204 """Retrieve a parameter value from the optimizer."""
192- return self .state [key ][0 ]
205+ entry = self .state [key ]
206+ value = entry [0 ]
207+
208+ # If previously promoted from FP16 -> FP32, cast back
209+ if entry [1 ]:
210+ value = dr .float16_array_t (value )(value )
211+
212+ return value
193213
194214 def __setitem__ (self , key : str , value : dr .ArrayBase , / ):
195215 """
@@ -245,10 +265,14 @@ def __setitem__(self, key: str, value: dr.ArrayBase, /):
245265 # Reattach the copy to the AD graph
246266 dr .enable_grad (value )
247267
268+ promoted = self .promote_fp16 and dr .type_v (value ) == dr .VarType .Float16
269+ if promoted :
270+ value = dr .float32_array_t (value )(value )
271+
248272 if prev is not None and prev [0 ].shape == value .shape :
249- self .state [key ] = value , * prev [1 :]
273+ self .state [key ] = value , promoted , * prev [2 :]
250274 else :
251- self ._reset (key , value )
275+ self ._reset (key , value , promoted )
252276
253277 def __len__ (self ) -> int :
254278 """Return the number of registered parameters."""
@@ -271,7 +295,7 @@ def learning_rate(self, key: Optional[str] = None) -> Optional[LearningRate]:
271295 if key is None :
272296 return self .lr
273297 else :
274- return self .state [key ][1 ]
298+ return self .state [key ][2 ]
275299
276300 def set_learning_rate (
277301 self ,
@@ -318,7 +342,7 @@ def set_learning_rate(
318342 elif isinstance (value , Mapping ):
319343 for k , lr in value .items ():
320344 state = self .state [k ]
321- self .state [k ] = (state [0 ], lr , * state [2 :])
345+ self .state [k ] = (* state [: 2 ], lr , * state [3 :])
322346 if kwargs :
323347 self .set_learning_rate (kwargs )
324348
@@ -368,10 +392,11 @@ def reset(self, key: Optional[str] = None) -> None:
368392 """
369393
370394 if key is not None :
371- self ._reset (key , self [key ])
395+ value , promoted , lr , extra = self .state [key ]
396+ self ._reset (key , value , promoted )
372397 else :
373- for k in self .state .keys ():
374- self ._reset (k , self [ k ] )
398+ for key , ( value , promoted , lr , extra ) in self .state .items ():
399+ self ._reset (key , value , promoted )
375400
376401 # --------------------------------------------------------------------
377402 # Functionality that must be provided by subclasses
@@ -414,7 +439,7 @@ def step(
414439 with dr .profile_range ('Optimizer.step()' ):
415440 cache = _LRCache ()
416441
417- for key , (value , lr , extra ) in self .state .items ():
442+ for key , (value , promoted , lr , extra ) in self .state .items ():
418443 # Fetch the parameter gradient and convert special array types
419444 # (e.g. complex numbers) into ones with element-wise semantics
420445 grad = value .grad .array
@@ -451,7 +476,7 @@ def step(
451476 dr .enable_grad (new_value )
452477
453478 # Update the optimizer state and schedule it for evaluation
454- new_state = new_value , lr , new_extra
479+ new_state = new_value , promoted , lr , new_extra
455480
456481 dr .schedule (new_state )
457482 self .state [key ] = new_state
@@ -475,8 +500,8 @@ def _step(
475500 )
476501
477502 # To be provided by subclasses
478- def _reset (self , key : str , value : dr .ArrayBase , / ) -> None :
479- raise Exception (f"Optimizer._reset({ key } , { value } ): missing implementation!" )
503+ def _reset (self , key : str , value : dr .ArrayBase , promoted : bool , / ) -> None :
504+ raise Exception (f"Optimizer._reset({ key } , { value } , { promoted } ): missing implementation!" )
480505
481506 # Blend between the old and new versions of the optimizer extra state
482507 def _select (self , mask : dr .ArrayBase , extra : Extra , new_extra : Extra , / ) -> Extra :
@@ -592,6 +617,7 @@ def __init__(
592617 momentum : float = 0.0 ,
593618 nesterov : bool = False ,
594619 mask_updates : bool = False ,
620+ promote_fp16 : bool = True ,
595621 ):
596622 """
597623 Args:
@@ -607,6 +633,11 @@ def __init__(
607633 cause past gradients to persist for a longer amount of time.
608634
609635 mask_updates (bool):
636+ Mask updates to zero-valued gradient components?
637+ See :py:func:`Optimizer.__init__()` for details on this parameter.
638+
639+ promote_fp16 (bool):
640+ promoted half-precision variables to single precision internal storage?
610641 See :py:func:`Optimizer.__init__()` for details on this parameter.
611642
612643 params (Mapping[str, drjit.ArrayBase] | None):
@@ -623,7 +654,12 @@ def __init__(
623654 self .momentum = momentum
624655 self .nesterov = nesterov
625656
626- super ().__init__ (lr , params , mask_updates = mask_updates )
657+ super ().__init__ (
658+ lr ,
659+ params ,
660+ mask_updates = mask_updates ,
661+ promote_fp16 = promote_fp16
662+ )
627663
628664 # To be provided by subclasses
629665 def _step (
@@ -656,21 +692,21 @@ def _step(
656692
657693 return dr .fma (step , scale , value ), v_next
658694
659- def _reset (self , key : str , value : dr .ArrayBase , / ) -> None :
695+ def _reset (self , key : str , value : dr .ArrayBase , promoted : bool , / ) -> None :
660696 valarr = value .array
661697 tp = type (valarr )
662698 if self .momentum == 0 :
663699 m = None
664700 else :
665701 m = dr .opaque (tp , 0 , valarr .shape )
666- self .state [key ] = value , None , m
702+ self .state [key ] = value , promoted , None , m
667703
668704 def __repr__ (self ):
669705 """Return a human-readable string representation"""
670706 lr_dict : Dict [str , LearningRate ] = dict (default = self .lr )
671707 total_count = 0
672708 for k , state in self .state .items ():
673- lr = state [1 ]
709+ lr = state [2 ]
674710 total_count += dr .prod (state [0 ].shape )
675711 if lr is not None :
676712 lr_dict [k ] = lr
@@ -728,6 +764,7 @@ def __init__(
728764 alpha : float = 0.99 ,
729765 epsilon : float = 1e-8 ,
730766 mask_updates : bool = False ,
767+ promote_fp16 : bool = True ,
731768 ):
732769 """
733770 Construct a RMSProp optimizer instance.
@@ -746,14 +783,24 @@ def __init__(
746783 persist for a longer amount of time.
747784
748785 mask_updates (bool):
786+ Mask updates to zero-valued gradient components?
787+ See :py:func:`Optimizer.__init__()` for details on this parameter.
788+
789+ promote_fp16 (bool):
790+ promoted half-precision variables to single precision internal storage?
749791 See :py:func:`Optimizer.__init__()` for details on this parameter.
750792
751793 params (Mapping[str, drjit.ArrayBase] | None):
752794 Optional dictionary-like object containing an initial set of
753795 parameters.
754796 """
755797
756- super ().__init__ (lr , params , mask_updates = mask_updates )
798+ super ().__init__ (
799+ lr ,
800+ params ,
801+ mask_updates = mask_updates ,
802+ promote_fp16 = promote_fp16
803+ )
757804
758805 if alpha < 0 or alpha >= 1 :
759806 raise RuntimeError ("'alpha' must be on the interval [0, 1)" )
@@ -792,18 +839,18 @@ def _step(
792839 return dr .fma (step , scale , value ), m_t
793840
794841 # Implementation detail of Optimizer.reset()
795- def _reset (self , key : str , value : dr .ArrayBase , / ) -> None :
842+ def _reset (self , key : str , value : dr .ArrayBase , promoted : bool , / ) -> None :
796843 valarr = value .array
797844 tp = type (valarr )
798845 m_t = dr .opaque (tp , 0 , valarr .shape )
799- self .state [key ] = value , None , m_t
846+ self .state [key ] = value , promoted , None , m_t
800847
801848 def __repr__ (self ):
802849 """Return a human-readable string representation"""
803850 lr_dict : Dict [str , LearningRate ] = dict (default = self .lr )
804851 total_count = 0
805852 for k , state in self .state .items ():
806- lr = state [1 ]
853+ lr = state [2 ]
807854 total_count += dr .prod (state [0 ].shape )
808855 if lr is not None :
809856 lr_dict [k ] = lr
@@ -887,6 +934,7 @@ def __init__(
887934 beta_2 : float = 0.999 ,
888935 epsilon : float = 1e-8 ,
889936 mask_updates : bool = False ,
937+ promote_fp16 : bool = True ,
890938 uniform : bool = False ,
891939 ):
892940 """
@@ -918,14 +966,24 @@ def __init__(
918966 instead of the per-element second moments.
919967
920968 mask_updates (bool):
969+ Mask updates to zero-valued gradient components?
970+ See :py:func:`Optimizer.__init__()` for details on this parameter.
971+
972+ promote_fp16 (bool):
973+ promoted half-precision variables to single precision internal storage?
921974 See :py:func:`Optimizer.__init__()` for details on this parameter.
922975
923976 params (Mapping[str, drjit.ArrayBase] | None):
924977 Optional dictionary-like object containing an initial set of
925978 parameters.
926979 """
927980
928- super ().__init__ (lr , params , mask_updates = mask_updates )
981+ super ().__init__ (
982+ lr ,
983+ params ,
984+ mask_updates = mask_updates ,
985+ promote_fp16 = promote_fp16
986+ )
929987
930988 if beta_1 < 0 or beta_1 >= 1 :
931989 raise RuntimeError ("'beta_1' must be on the interval [0, 1)" )
@@ -965,10 +1023,11 @@ def _step(
9651023 # Compute the step size scale, which is a product of
9661024 # - EMA debiasing factor
9671025 # - Adaptive/parameter-specific scaling
968- Float32 = dr .float32_array_t ( dr . leaf_t (grad ) )
1026+ Base = dr .leaf_t (grad )
9691027 Float64 = dr .float64_array_t (dr .leaf_t (grad ))
970- ema_factor = Float32 (
971- - dr .sqrt (1 - Float64 (self .beta_2 ) ** t ) / (1 - Float64 (self .beta_1 ) ** t )
1028+ ema_factor = Base (
1029+ - dr .sqrt (1 - Float64 (self .beta_2 ) ** t ) /
1030+ (1 - Float64 (self .beta_1 ) ** t )
9721031 )
9731032 scale = cache .product (
9741033 dr .leaf_t (grad ), # Desired type
@@ -988,14 +1047,14 @@ def _step(
9881047 return dr .fma (step , scale , value ), (t , m_t , v_t )
9891048
9901049 # Implementation detail of Optimizer.reset()
991- def _reset (self , key : str , value : dr .ArrayBase , / ) -> None :
1050+ def _reset (self , key : str , value : dr .ArrayBase , promoted : bool , / ) -> None :
9921051 valarr = value .array
9931052 tp = type (valarr )
9941053 UInt = dr .uint32_array_t (dr .leaf_t (tp ))
9951054 t = UInt (0 )
9961055 m_t = dr .opaque (tp , 0 , valarr .shape )
9971056 v_t = dr .opaque (tp , 0 , valarr .shape )
998- self .state [key ] = value , None , (t , m_t , v_t )
1057+ self .state [key ] = value , promoted , None , (t , m_t , v_t )
9991058
10001059 # Blend between the old and new versions of the optimizer extra state
10011060 def _select (
@@ -1020,7 +1079,7 @@ def __repr__(self):
10201079 total_count = 0
10211080 for k , state in self .state .items ():
10221081 total_count += dr .prod (state [0 ].shape )
1023- lr = state [1 ]
1082+ lr = state [2 ]
10241083 if lr is not None :
10251084 lr_dict [k ] = lr
10261085
0 commit comments