From 2b3e9ae265bbcd4524d0c40ba3cff53e25ce8c53 Mon Sep 17 00:00:00 2001 From: Name Date: Fri, 27 Mar 2026 15:03:48 +0100 Subject: [PATCH] amp guard in efficientnet to avoid nan gradients in the case of fp16 --- avex/models/efficientnet.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/avex/models/efficientnet.py b/avex/models/efficientnet.py index 34650c4..b10309c 100644 --- a/avex/models/efficientnet.py +++ b/avex/models/efficientnet.py @@ -218,10 +218,22 @@ def forward( x = self.process_audio(x) # Extract features with optional gradient checkpointing - if self.gradient_checkpointing and self.training: - features = self._checkpointed_features(x) + # Empirically, EfficientNet's backward under CUDA autocast can produce NaN + # gradients even when activations and loss are finite. When autocast is + # enabled, run the feature extractor in FP32 for stability while keeping + # AMP for the rest of the training loop. + needs_guard = x.is_cuda and torch.is_autocast_enabled() and torch.get_autocast_dtype("cuda") == torch.float16 + if needs_guard: + with torch.autocast(device_type="cuda", enabled=False): + if self.gradient_checkpointing and self.training: + features = self._checkpointed_features(x.float()) + else: + features = self.model.features(x.float()) else: - features = self.model.features(x) + if self.gradient_checkpointing and self.training: + features = self._checkpointed_features(x) + else: + features = self.model.features(x) # Return unpooled spatial features if requested if self.return_features_only: