Skip to content

feat(orttraining): add CPU fallback for FusedAdam optimizer#28233

Open
Rishi-Dave wants to merge 2 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/fused-adam-cpu-fallback
Open

feat(orttraining): add CPU fallback for FusedAdam optimizer#28233
Rishi-Dave wants to merge 2 commits intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/fused-adam-cpu-fallback

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • FusedAdam.__init__ now detects torch.cuda.is_available() and falls back to a standard PyTorch optimizer on CPU instead of crashing.
  • A one-time UserWarning informs the user that the fused CUDA kernel is unavailable and a CPU implementation is in use.
  • step() and zero_grad() delegate to the fallback when present; the CUDA path is unchanged.

Motivation

On CPU-only PyTorch builds, FusedAdam raises immediately in __init__ because it unconditionally:

  1. Allocates torch.cuda.IntTensor([0]) as an overflow buffer.
  2. Imports the CUDA-only C++ extension onnxruntime.training.ortmodule.torch_cpp_extensions.fused_ops.

This makes it impossible to use FusedAdam in CPU-only test/dev environments or to write code that transparently works on either device. The maintainer (@baijumeswani) confirmed in the issue that a CPU fallback with a warning is the desired fix.

Fixes #17403

Changes

orttraining/orttraining/python/training/optim/fused_adam.py:

  • Wrap the two CUDA-specific allocations in if torch.cuda.is_available().
  • On CPU, build self._cpu_fallback_optimizer based on adam_w_mode:
    • ADAM_L2_REGULARIZATIONtorch.optim.Adam (weight_decay applied as L2 regularization)
    • ADAMW_TORCHtorch.optim.AdamW
    • ADAMW_TRANSFORMERStransformers.AdamW (with torch.optim.AdamW fallback when transformers is not installed, plus a second warning)
  • Emit a single UserWarning per instance.
  • step() and zero_grad() early-return through the fallback when set.
  • Update the docstring to drop the "GPU-only" claim.

orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py (new):

  • Patches torch.cuda.is_available() to return False so tests run deterministically on any host.
  • Asserts instantiation succeeds and emits a UserWarning.
  • Asserts a single step() produces parameter updates equivalent to torch.optim.AdamW.
  • Asserts AdamWMode.ADAM_L2_REGULARIZATION instantiates and steps without raising.

Test Plan

  • python -m pytest orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py -v — 3 passed.
  • lintrunner -a on both files — clean, no changes applied.
  • The CUDA code path is byte-for-byte unchanged in behavior; only wrapped in a conditional. No behavioral change for existing GPU users.

FusedAdam previously failed to instantiate on CPU-only PyTorch builds
because __init__ unconditionally allocated a torch.cuda.IntTensor and
imported the CUDA-only fused_ops C++ extension.

Detect torch.cuda.is_available() at construction time. When CUDA is
unavailable, emit a one-time UserWarning and build a standard PyTorch
optimizer that matches the requested AdamWMode:

- ADAM_L2_REGULARIZATION -> torch.optim.Adam
- ADAMW_TORCH            -> torch.optim.AdamW
- ADAMW_TRANSFORMERS     -> transformers.AdamW (falls back to
                            torch.optim.AdamW when transformers is
                            not installed)

step() and zero_grad() delegate to the fallback when present. The
CUDA path is unchanged.

Adds a focused unit test that patches torch.cuda.is_available() so it
runs deterministically on any host.

Fixes microsoft#17403
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a CPU-safe behavior for FusedAdam in ORTTraining so CPU-only PyTorch environments no longer crash at import/initialization time, and introduces unit tests to validate the new fallback path.

Changes:

  • Add torch.cuda.is_available() gating in FusedAdam.__init__ and create a CPU fallback optimizer with warnings when CUDA fused kernels are unavailable.
  • Route step() and zero_grad() through the fallback optimizer when running on CPU.
  • Add a dedicated unit test file that forces CUDA-off behavior and validates the fallback updates.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
orttraining/orttraining/python/training/optim/fused_adam.py Adds CPU fallback construction and delegates step()/zero_grad() when CUDA is unavailable.
orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py New tests that patch torch.cuda.is_available() to exercise and validate the CPU fallback path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread orttraining/orttraining/python/training/optim/fused_adam.py
Comment thread orttraining/orttraining/python/training/optim/fused_adam.py
Comment thread orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py Outdated
Comment thread orttraining/orttraining/python/training/optim/fused_adam.py
…in CPU fallback

- Pass self.param_groups directly to the fallback optimizer constructor instead
  of flattening to a single list of params, so per-group lr/betas/eps/weight_decay
  and any other group-level options the caller set are preserved.
- After fallback construction, alias self.param_groups and self.state to the
  fallback's so user-visible state_dict/load_state_dict and runtime mutations
  like opt.param_groups[0]['lr']=... reach the optimizer that step() actually
  delegates to.
- Override state_dict/load_state_dict/add_param_group to delegate to the
  fallback optimizer when active so saved/restored moments come from the
  optimizer that holds them.
- AdamWMode.ADAMW_TRANSFORMERS without 'transformers' installed: raise
  RuntimeError when bias_correction=False (torch.optim.AdamW always applies
  bias correction and would silently diverge). Fall through to torch.optim.AdamW
  when bias_correction=True, where the math matches.
- Update the test docstring to describe what the test actually does (importlib
  spec_from_file_location + sys.modules pre-registration), since sys.path is
  not modified.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] [Training] Support Fused Adam for CPU

2 participants