diff --git a/orttraining/orttraining/python/training/optim/fused_adam.py b/orttraining/orttraining/python/training/optim/fused_adam.py index 10b4a44fd5702..1491d2d546953 100644 --- a/orttraining/orttraining/python/training/optim/fused_adam.py +++ b/orttraining/orttraining/python/training/optim/fused_adam.py @@ -10,6 +10,7 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 """ +import warnings from enum import IntEnum import torch @@ -31,7 +32,10 @@ class FusedAdam(torch.optim.Optimizer): when adam_w_mode = 1 and `torch/Adam `_ when adam_w_mode = 2 - Currently GPU-only. + On CUDA-capable systems this optimizer uses fused CUDA kernels for efficiency. + On CPU-only systems (or when ``torch.cuda.is_available()`` returns ``False``) it + automatically falls back to an equivalent standard PyTorch optimizer with a + one-time :class:`UserWarning`. Performance will be reduced in fallback mode. This version of fused Adam implements 2 fusions. @@ -83,16 +87,62 @@ def __init__( self._adam_w_mode = adam_w_mode self._set_grad_none = set_grad_none - # Skip buffer - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + if torch.cuda.is_available(): + # Skip buffer + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops # noqa: PLC0415 + from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops # noqa: PLC0415 - self._multi_tensor_adam = fused_ops.multi_tensor_adam - self._multi_tensor_applier = MultiTensorApply(2048 * 32) - self._TorchTensorVector = fused_ops.TorchTensorVector + self._multi_tensor_adam = fused_ops.multi_tensor_adam + self._multi_tensor_applier = MultiTensorApply(2048 * 32) + self._TorchTensorVector = fused_ops.TorchTensorVector + self._cpu_fallback_optimizer = None + else: + warnings.warn( + "FusedAdam CUDA kernels are unavailable; falling back to a standard PyTorch optimizer on CPU. " + "Performance will be reduced.", + UserWarning, + stacklevel=2, + ) + # Build an equivalent standard PyTorch optimizer for the CPU path. + # Retrieve the flat list of parameters from the already-registered param_groups. + _params = [p for group in self.param_groups for p in group["params"]] + if adam_w_mode == AdamWMode.ADAM_L2_REGULARIZATION: + self._cpu_fallback_optimizer = torch.optim.Adam( + _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay + ) + elif adam_w_mode == AdamWMode.ADAMW_TORCH: + self._cpu_fallback_optimizer = torch.optim.AdamW( + _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay + ) + else: + # AdamWMode.ADAMW_TRANSFORMERS (default): prefer transformers.AdamW + try: + from transformers import AdamW as _TransformersAdamW # noqa: PLC0415 + + self._cpu_fallback_optimizer = _TransformersAdamW( + _params, + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=bias_correction, + ) + except ImportError: + warnings.warn( + "transformers package not available; using torch.optim.AdamW as CPU fallback " + "for AdamWMode.ADAMW_TRANSFORMERS.", + UserWarning, + stacklevel=2, + ) + self._cpu_fallback_optimizer = torch.optim.AdamW( + _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay + ) def zero_grad(self, set_to_none=True): + if self._cpu_fallback_optimizer is not None: + self._cpu_fallback_optimizer.zero_grad(set_to_none=self._set_grad_none or set_to_none) + return if self._set_grad_none or set_to_none: for group in self.param_groups: for p in group["params"]: @@ -109,6 +159,9 @@ def step(self, closure=None): The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. """ + if self._cpu_fallback_optimizer is not None: + return self._cpu_fallback_optimizer.step(closure) + loss = None if closure is not None: loss = closure() diff --git a/orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py b/orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py new file mode 100644 index 0000000000000..9f33a4ce65c45 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +"""Unit tests for FusedAdam CPU fallback (issue #17403). + +These tests patch torch.cuda.is_available to return False so they run +deterministically on both CPU-only and CUDA machines. + +Import strategy: + * Add the optim/ source directory to sys.path. + * Pre-register "_multi_tensor_apply" in sys.modules by importing it + directly (it is pure Python with no external deps). + * Load fused_adam.py via importlib with __package__ set to the name we + used for the pre-registered _multi_tensor_apply module so that the + relative import resolves correctly. + +This avoids touching the training/__init__.py which requires the compiled +onnxruntime C extension (not available in the source-tree environment). +The CUDA extension import inside fused_adam.__init__ is guarded by +``if torch.cuda.is_available():`` and never runs with the mock in place. +""" + +import importlib.util +import sys +import warnings +from pathlib import Path +from unittest.mock import patch + +import torch +import torch.nn as nn + +# --------------------------------------------------------------------------- +# Locate the optim source directory. +# File layout: +# orttraining/orttraining/test/python/ <- __file__ (parents[0]) +# orttraining/orttraining/test/ <- parents[1] +# orttraining/orttraining/ <- parents[2] +# orttraining/orttraining/python/training/optim/ <- _OPTIM_DIR +# --------------------------------------------------------------------------- +_OPTIM_DIR = Path(__file__).resolve().parents[2] / "python" / "training" / "optim" +assert _OPTIM_DIR.is_dir(), f"optim dir not found: {_OPTIM_DIR}" + +# Step 1: load _multi_tensor_apply as a top-level module (no package needed, +# it has zero external imports) and register it under the name that the +# relative import inside fused_adam.py expects. +_PKG = "fused_adam_pkg" +_mta_spec = importlib.util.spec_from_file_location( + f"{_PKG}._multi_tensor_apply", + _OPTIM_DIR / "_multi_tensor_apply.py", +) +_mta_mod = importlib.util.module_from_spec(_mta_spec) +sys.modules[f"{_PKG}._multi_tensor_apply"] = _mta_mod +_mta_spec.loader.exec_module(_mta_mod) + +# Step 2: load fused_adam.py with __package__ = _PKG so its relative import +# "from ._multi_tensor_apply import ..." resolves to the entry above. +_fa_spec = importlib.util.spec_from_file_location( + f"{_PKG}.fused_adam", + _OPTIM_DIR / "fused_adam.py", +) +_fa_mod = importlib.util.module_from_spec(_fa_spec) +_fa_mod.__package__ = _PKG +# Patch cuda before executing the module body so the CUDA block is skipped. +with patch("torch.cuda.is_available", return_value=False): + _fa_spec.loader.exec_module(_fa_mod) + +AdamWMode = _fa_mod.AdamWMode +FusedAdam = _fa_mod.FusedAdam + + +def _make_param(shape=(3, 3)): + """Return an nn.Parameter with a synthetic gradient.""" + p = nn.Parameter(torch.randn(*shape)) + p.grad = torch.randn(*shape) + return p + + +@patch("torch.cuda.is_available", return_value=False) +class TestFusedAdamCpuFallback: + """All tests run with CUDA disabled to exercise the CPU fallback path.""" + + def test_instantiation_warns_and_succeeds(self, _mock_cuda): + """FusedAdam must instantiate without error and emit a UserWarning.""" + param = nn.Parameter(torch.randn(3, 3)) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + opt = FusedAdam([param], lr=1e-3) + + assert opt is not None, "FusedAdam should instantiate on CPU" + user_warnings = [w for w in caught if issubclass(w.category, UserWarning)] + assert len(user_warnings) >= 1, "Expected at least one UserWarning about CPU fallback" + assert any("CPU" in str(w.message) or "fallback" in str(w.message).lower() for w in user_warnings) + + def test_step_updates_params_like_adamw(self, _mock_cuda): + """After one step, params must change in the same direction as torch.optim.AdamW.""" + torch.manual_seed(42) + weight_init = torch.randn(4, 4) + grad = torch.randn(4, 4) + + # FusedAdam (CPU fallback) path — ADAMW_TORCH maps to torch.optim.AdamW + p_fused = nn.Parameter(weight_init.clone()) + p_fused.grad = grad.clone() + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + opt_fused = FusedAdam([p_fused], lr=1e-3, adam_w_mode=AdamWMode.ADAMW_TORCH) + opt_fused.step() + + # Reference: plain torch.optim.AdamW with matching hyperparams + p_ref = nn.Parameter(weight_init.clone()) + p_ref.grad = grad.clone() + opt_ref = torch.optim.AdamW([p_ref], lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0) + opt_ref.step() + + assert not torch.allclose(p_fused.data, weight_init), "Parameters should have changed after step" + assert torch.allclose(p_fused.data, p_ref.data, atol=1e-5), ( + f"FusedAdam CPU fallback should match torch.optim.AdamW.\n" + f"Max diff: {(p_fused.data - p_ref.data).abs().max().item()}" + ) + + def test_adam_l2_mode_instantiates_and_steps(self, _mock_cuda): + """AdamWMode.ADAM_L2_REGULARIZATION must instantiate and step without error.""" + param = _make_param() + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + opt = FusedAdam([param], lr=1e-3, adam_w_mode=AdamWMode.ADAM_L2_REGULARIZATION) + + before = param.data.clone() + opt.step() + assert not torch.allclose(param.data, before), "Parameters should change after step"