Skip to content

Commit 9738caa

Browse files
committed
weight decay (L2)
#59
1 parent 69cd23c commit 9738caa

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

nn/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,23 @@ def initial(self, value: Optional[Union[nn.Tensor, RawTensorTypes, nn.init.Varia
476476
self.layer_dict.pop("init_by_layer", None)
477477
self.layer_dict["init"] = value
478478

479+
@property
480+
def weight_decay(self) -> float:
481+
"""
482+
Weight decay, which is equivalent to L2 loss on the parameters for SGD.
483+
On RETURNN side, whether this is handled separately or is part of the main loss,
484+
can be controlled via the ``decouple_constraints`` config option.
485+
https://github.com/rwth-i6/returnn_common/issues/59#issuecomment-1073913421
486+
"""
487+
return self.layer_dict.get("L2", 0.0)
488+
489+
@weight_decay.setter
490+
def weight_decay(self, value: Optional[float]):
491+
if value:
492+
self.layer_dict["L2"] = value
493+
else:
494+
self.layer_dict.pop("L2", None)
495+
479496

480497
class LayerState(dict):
481498
"""

0 commit comments

Comments
 (0)