-
Notifications
You must be signed in to change notification settings - Fork 2
Fix EfficientNet AMP NaN grads (FP16-only guard) #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -218,10 +218,22 @@ def forward( | |
| x = self.process_audio(x) | ||
|
|
||
| # Extract features with optional gradient checkpointing | ||
| if self.gradient_checkpointing and self.training: | ||
| features = self._checkpointed_features(x) | ||
| # Empirically, EfficientNet's backward under CUDA autocast can produce NaN | ||
| # gradients even when activations and loss are finite. When autocast is | ||
| # enabled, run the feature extractor in FP32 for stability while keeping | ||
| # AMP for the rest of the training loop. | ||
| needs_guard = x.is_cuda and torch.is_autocast_enabled() and torch.get_autocast_dtype("cuda") == torch.float16 | ||
| if needs_guard: | ||
| with torch.autocast(device_type="cuda", enabled=False): | ||
| if self.gradient_checkpointing and self.training: | ||
| features = self._checkpointed_features(x.float()) | ||
| else: | ||
| features = self.model.features(x.float()) | ||
| else: | ||
| features = self.model.features(x) | ||
| if self.gradient_checkpointing and self.training: | ||
| features = self._checkpointed_features(x) | ||
| else: | ||
| features = self.model.features(x) | ||
|
Comment on lines
+225
to
+236
|
||
|
|
||
| # Return unpooled spatial features if requested | ||
| if self.return_features_only: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needs_guardcurrently triggers whenever CUDA autocast FP16 is enabled, includingmodel.eval()/torch.no_grad()paths. Since the reported issue is NaN gradients, consider gating this guard onself.trainingand/ortorch.is_grad_enabled()so FP16 autocast inference/eval keeps the expected throughput and memory benefits.