feat(orttraining): add CPU fallback for FusedAdam optimizer#28233
Open
Rishi-Dave wants to merge 2 commits intomicrosoft:mainfrom
Open
feat(orttraining): add CPU fallback for FusedAdam optimizer#28233Rishi-Dave wants to merge 2 commits intomicrosoft:mainfrom
Rishi-Dave wants to merge 2 commits intomicrosoft:mainfrom
Conversation
FusedAdam previously failed to instantiate on CPU-only PyTorch builds
because __init__ unconditionally allocated a torch.cuda.IntTensor and
imported the CUDA-only fused_ops C++ extension.
Detect torch.cuda.is_available() at construction time. When CUDA is
unavailable, emit a one-time UserWarning and build a standard PyTorch
optimizer that matches the requested AdamWMode:
- ADAM_L2_REGULARIZATION -> torch.optim.Adam
- ADAMW_TORCH -> torch.optim.AdamW
- ADAMW_TRANSFORMERS -> transformers.AdamW (falls back to
torch.optim.AdamW when transformers is
not installed)
step() and zero_grad() delegate to the fallback when present. The
CUDA path is unchanged.
Adds a focused unit test that patches torch.cuda.is_available() so it
runs deterministically on any host.
Fixes microsoft#17403
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a CPU-safe behavior for FusedAdam in ORTTraining so CPU-only PyTorch environments no longer crash at import/initialization time, and introduces unit tests to validate the new fallback path.
Changes:
- Add
torch.cuda.is_available()gating inFusedAdam.__init__and create a CPU fallback optimizer with warnings when CUDA fused kernels are unavailable. - Route
step()andzero_grad()through the fallback optimizer when running on CPU. - Add a dedicated unit test file that forces CUDA-off behavior and validates the fallback updates.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
orttraining/orttraining/python/training/optim/fused_adam.py |
Adds CPU fallback construction and delegates step()/zero_grad() when CUDA is unavailable. |
orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py |
New tests that patch torch.cuda.is_available() to exercise and validate the CPU fallback path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…in CPU fallback - Pass self.param_groups directly to the fallback optimizer constructor instead of flattening to a single list of params, so per-group lr/betas/eps/weight_decay and any other group-level options the caller set are preserved. - After fallback construction, alias self.param_groups and self.state to the fallback's so user-visible state_dict/load_state_dict and runtime mutations like opt.param_groups[0]['lr']=... reach the optimizer that step() actually delegates to. - Override state_dict/load_state_dict/add_param_group to delegate to the fallback optimizer when active so saved/restored moments come from the optimizer that holds them. - AdamWMode.ADAMW_TRANSFORMERS without 'transformers' installed: raise RuntimeError when bias_correction=False (torch.optim.AdamW always applies bias correction and would silently diverge). Fall through to torch.optim.AdamW when bias_correction=True, where the math matches. - Update the test docstring to describe what the test actually does (importlib spec_from_file_location + sys.modules pre-registration), since sys.path is not modified.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
FusedAdam.__init__now detectstorch.cuda.is_available()and falls back to a standard PyTorch optimizer on CPU instead of crashing.UserWarninginforms the user that the fused CUDA kernel is unavailable and a CPU implementation is in use.step()andzero_grad()delegate to the fallback when present; the CUDA path is unchanged.Motivation
On CPU-only PyTorch builds,
FusedAdamraises immediately in__init__because it unconditionally:torch.cuda.IntTensor([0])as an overflow buffer.onnxruntime.training.ortmodule.torch_cpp_extensions.fused_ops.This makes it impossible to use
FusedAdamin CPU-only test/dev environments or to write code that transparently works on either device. The maintainer (@baijumeswani) confirmed in the issue that a CPU fallback with a warning is the desired fix.Fixes #17403
Changes
orttraining/orttraining/python/training/optim/fused_adam.py:if torch.cuda.is_available().self._cpu_fallback_optimizerbased onadam_w_mode:ADAM_L2_REGULARIZATION→torch.optim.Adam(weight_decay applied as L2 regularization)ADAMW_TORCH→torch.optim.AdamWADAMW_TRANSFORMERS→transformers.AdamW(withtorch.optim.AdamWfallback whentransformersis not installed, plus a second warning)UserWarningper instance.step()andzero_grad()early-return through the fallback when set.orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py(new):torch.cuda.is_available()to returnFalseso tests run deterministically on any host.UserWarning.step()produces parameter updates equivalent totorch.optim.AdamW.AdamWMode.ADAM_L2_REGULARIZATIONinstantiates and steps without raising.Test Plan
python -m pytest orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py -v— 3 passed.lintrunner -aon both files — clean, no changes applied.