diff --git a/aten/src/ATen/native/cuda/StochasticRounding.cu b/aten/src/ATen/native/cuda/StochasticRounding.cu new file mode 100644 index 0000000000000..8df7b20f0560d --- /dev/null +++ b/aten/src/ATen/native/cuda/StochasticRounding.cu @@ -0,0 +1,63 @@ +#include +#include + + +namespace at { +namespace native { + +template +__global__ void stochastic_rounding_kernel( + const input_t* input, + output_t* output, + const int64_t numel, + std::pair seed_and_offset) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state); + + round_stochastically rounder; + + for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) { + output[i] = rounder(input[i], curand_uniform(&state)); + } +} + +Tensor stochastic_rounding_cuda(const Tensor& input, c10::optional gen_) { + + TORCH_CHECK(input.is_contiguous()); + + if (input.scalar_type() == kHalf) { + return input; + } + + Tensor output = at::empty_like(input, input.options().dtype(kHalf), input.suggest_memory_format()); + const int64_t numel = input.numel(); + if (numel == 0) { + return output; + } + + const int block = 256; + const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block; + unsigned int grid = (numel + block - 1) / block; + grid = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid); + + auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); + std::pair rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs((numel + block * grid - 1) / (block * grid)); + } + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "stochastic_rounding_cuda", [&] { + stochastic_rounding_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + numel, rng_engine_inputs); + }); + + return output; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu new file mode 100644 index 0000000000000..c06b2bafa4ab1 --- /dev/null +++ b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu @@ -0,0 +1,125 @@ +#include +#include + + +namespace at { +namespace native { + +template +__global__ void stochastic_rounding_adam_step_kernel( + scalar_t *weights, scalar_t *gradients, + scalar_t *exp_avg, scalar_t *exp_avg_sq, scalar_t *max_exp_avg_sq, + float *inv_scale, float *found_inf, + float lr, float beta1, float beta2, + float weight_decay, float eps, int step, + bool is_decoupled, bool is_amsgrad, + int numel, std::pair seeds) { + + if (*found_inf) return; + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seeds.first, tid, seeds.second, &state); + + round_stochastically rounder; + + float m_correction = 1.0 - powf(beta1, step); + float v_correction = 1.0 - powf(beta2, step); + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + float weight = static_cast(weights[i]); + float gradient = static_cast(gradients[i]) * (*inv_scale); + float m = static_cast(exp_avg[i]); + // Stochastic Rounding Adam tracks square root of the exponential average of squared gradient. + float v = static_cast(exp_avg_sq[i]); + v = v * v; + float4 random_values = curand_uniform4(&state); + + if (weight_decay != 0.0f) { + if (is_decoupled) + weight *= (1 - lr * weight_decay); + else + gradient += weight_decay * weight; + } + + // Update m and v. + m = beta1 * m + (1.0 - beta1) * gradient; + v = beta2 * v + (1.0 - beta2) * (gradient * gradient); + + // Unbias v + float max_v = v; + if (is_amsgrad) { + float prev_max_v = static_cast(max_exp_avg_sq[i]); + prev_max_v = prev_max_v * prev_max_v; + max_v = fmaxf(prev_max_v, v); + } + + weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps); + + weights[i] = rounder(weight, random_values.x); + exp_avg[i] = rounder(m, random_values.y); + exp_avg_sq[i] = rounder(sqrtf(v), random_values.z); + if (is_amsgrad) { + max_exp_avg_sq[i] = rounder(sqrtf(max_v), random_values.w); + } + } +} + + +Tensor stochastic_rounding_adam_step_cuda( + Tensor& param, + const Tensor& grad, + Tensor& exp_avg, + Tensor& exp_avg_sq, + Tensor& max_exp_avg_sq, + const Tensor& inv_scale, + const Tensor& found_inf, + double lr, double beta1, double beta2, + double weight_decay, double eps, int64_t step, + bool is_decoupled, bool is_amsgrad, c10::optional gen_) { + + if (param.numel() == 0) return param; + + TORCH_CHECK(param.is_contiguous()); + TORCH_CHECK(grad.is_contiguous()); + TORCH_CHECK(exp_avg.is_contiguous()); + TORCH_CHECK(exp_avg_sq.is_contiguous()); + TORCH_CHECK(max_exp_avg_sq.is_contiguous()); + + const int64_t numel = param.numel(); + const int block_size = 256; + const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); + + auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); + + uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (block_size * grid.x)) * 4; + std::pair rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + param.scalar_type(), "stochastic_rounding_adam_step_cuda", [&] { + stochastic_rounding_adam_step_kernel<<>>( + param.data_ptr(), + grad.data_ptr(), + exp_avg.data_ptr(), + exp_avg_sq.data_ptr(), + max_exp_avg_sq.data_ptr(), + inv_scale.data_ptr(), + found_inf.data_ptr(), + lr, beta1, beta2, weight_decay, eps, step, + is_decoupled, is_amsgrad, + numel, rng_engine_inputs); + } + ); + AT_CUDA_CHECK(cudaGetLastError()); + return param; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu new file mode 100644 index 0000000000000..86b92136f2e8f --- /dev/null +++ b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu @@ -0,0 +1,95 @@ +#include +#include + + +namespace at { +namespace native { + +// SGD update math with Stochastic Rounding +template +__global__ void stochastic_rounding_sgd_step_kernel( + scalar_t *weights, scalar_t *gradients, scalar_t *momentum_buffer, + float* inv_scale, float* found_inf, + float weight_decay, float momentum, float dampening, float lr, + bool nesterov, bool first_run, int numel, std::pair seeds) { + + if (*found_inf) return; + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seeds.first, tid, seeds.second, &state); + + round_stochastically rounder; + + for (int i = tid; i < numel; i += blockDim.x * gridDim.x) { + float weight = static_cast(weights[i]); + float gradient = static_cast(gradients[i]) * (*inv_scale); + float velocity = static_cast(momentum_buffer[i]); + float4 random_values = curand_uniform4(&state); + + if (weight_decay != 0.0f) + gradient += weight_decay * weight; + + if (momentum != 0.0f) { + if (!first_run) + velocity = velocity * momentum + (1.0f - dampening) * gradient; + else + velocity = gradient; + + if (nesterov) + gradient += momentum * velocity; + else + gradient = velocity; + } + + weight -= lr * gradient; + + weights[i] = rounder(weight, random_values.x); + if (momentum != 0.0f) + momentum_buffer[i] = rounder(velocity, random_values.y); + } +} + +Tensor stochastic_rounding_sgd_step_cuda( + Tensor& param, const Tensor& grad, Tensor& momentum_buffer, + const Tensor& inv_scale, const Tensor& found_inf, + double lr, double momentum, double weight_decay, double dampening, + bool nesterov, bool first_run, c10::optional gen_) { + + if (param.numel() == 0) return param; + + TORCH_CHECK(param.is_contiguous()); + TORCH_CHECK(grad.is_contiguous()); + TORCH_CHECK(momentum_buffer.is_contiguous()); + + const int64_t numel = param.numel(); + const int block_size = 256; + const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); + + auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); + uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (dim_block.x * grid.x)) * 4; + std::pair rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + param.scalar_type(), "stochastic_rounding_sgd_step_cuda", [&] { + stochastic_rounding_sgd_step_kernel<<>>( + param.data_ptr(), + grad.data_ptr(), + momentum_buffer.data_ptr(), + inv_scale.data_ptr(), found_inf.data_ptr(), + static_cast(weight_decay), static_cast(momentum), static_cast(dampening), static_cast(lr), + nesterov, first_run, numel, rng_engine_inputs); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return param; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/stochastic_rounding.cuh b/aten/src/ATen/native/cuda/stochastic_rounding.cuh new file mode 100644 index 0000000000000..4c24a3b40b2d6 --- /dev/null +++ b/aten/src/ATen/native/cuda/stochastic_rounding.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// 2^-10 is the step for normal FP16 numbers. +// 2^-24 is the unit in the last place (ULP)/precision limitation. +// 24 is **NOT** related to the number of mantissa bits of single precision format. +__device__ const float TWO_10 = 0.0009765625; +__device__ const float TWO_24 = 0.000000059604644775390625; + + +template +__device__ __forceinline__ T maybe_upcast(__half x){ + return T(__half2float(x)); +} + +template<> +__device__ __forceinline__ __half maybe_upcast<__half>(__half x){ + return x; +} + +__device__ __forceinline__ float get_delta_fp16(float x) { + int exponent; + frexpf(x, &exponent); + exponent -= 1; + if (exponent >= -14) + return TWO_10 * std::pow(2, exponent); + else + return TWO_24; +} + +// Natalia magic +template +struct round_stochastically { + static_assert(std::is_same::value, "round_stochastically only supports round_to_prec=at::Half"); +}; + +template +struct round_stochastically { + __device__ __forceinline__ out_type operator()(in_type x, float random_value) { + if (x == 0.0) { + return out_type(0.0); + } + float delta = get_delta_fp16(static_cast(x)); + float val; + if (x < 0.0) { + val = x - random_value * delta; + } else { + val = x + random_value * delta; + } + return maybe_upcast(__float2half_rz(val)); + } +}; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a37b0af4d7723..59c64ca90b28f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6719,3 +6719,15 @@ # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full + +- func: stochastic_rounding(Tensor input, Generator? gen_=None) -> Tensor + dispatch: + CUDA: stochastic_rounding_cuda + +- func: stochastic_rounding_adam_step(Tensor(a!) param, Tensor grad, Tensor(b!) exp_avg, Tensor(c!) exp_avg_sq, Tensor(d!) max_exp_avg_sq, Tensor inv_scale, Tensor found_inf, float lr, float beta1, float beta2, float weight_decay, float eps, int step, bool is_decoupled, bool is_amsgrad, Generator? gen_=None) -> Tensor(a!) + dispatch: + CUDA: stochastic_rounding_adam_step_cuda + +- func: stochastic_rounding_sgd_step(Tensor(a!) param, Tensor grad, Tensor(b!) momentum_buffer, Tensor inv_scale, Tensor found_inf, float lr, float momentum, float weight_decay, float dampening, bool nesterov, bool first_run, Generator? gen_=None) -> Tensor(a!) + dispatch: + CUDA: stochastic_rounding_sgd_step_cuda diff --git a/docs/source/optim.rst b/docs/source/optim.rst index f09685bd53d5c..dfa556b24bcad 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -129,6 +129,12 @@ Algorithms :members: .. autoclass:: SGD :members: +.. autoclass:: SRAdam + :members: +.. autoclass:: SRAdamW + :members: +.. autoclass:: SRSGD + :members: How to adjust learning rate --------------------------- diff --git a/docs/source/torch.rst b/docs/source/torch.rst index d5be6aee2207b..6a515b0c3b71f 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -249,6 +249,7 @@ Pointwise Ops .. autofunction:: sinh .. autofunction:: sqrt .. autofunction:: square +.. autofunction:: stochastic_rounding .. autofunction:: tan .. autofunction:: tanh .. autofunction:: true_divide diff --git a/test/run_test.py b/test/run_test.py index a90381134f91d..e35956e49e73f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -79,6 +79,7 @@ 'test_overrides', 'test_jit_fuser_te', 'test_tensorexpr', + 'test_stochastic_rounding', ] # skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn diff --git a/test/test_optim.py b/test/test_optim.py index b0d502b5ef63e..656b8e48c000c 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1547,5 +1547,112 @@ def test_cosine_then_cyclic(self): self.assertLessEqual(last_lr, max_lr) + +def _calc_loss(weight, bias, input): + y = weight.mv(input) + if y.get_device() != bias.get_device(): + y = y.cuda(bias.get_device()) + return (y + bias).pow(2).sum() + + +@unittest.skipIf(not torch.cuda.is_available(), 'No CUDA') +class TestStochasticRoundingOptim(TestCase): + + exact_dtype = True + + def _test_basic_cases_template( + self, weight, bias, input, constructor, grad_scaler=None): + + optimizer = constructor(weight, bias) + initial_value = None + for _i in range(10): + optimizer.zero_grad() + loss = _calc_loss(weight, bias, input) + if grad_scaler is None: + loss.backward() + optimizer.step() + else: + grad_scaler.scale(loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + if initial_value is None: + initial_value = loss.item() + self.assertLess(_calc_loss(weight, bias, input).item(), initial_value) + + # Check whether weight and bias can be represented in 16 bits. + with torch.no_grad(): + for param_group in optimizer.param_groups: + for p in param_group['params']: + half_p = p.clone().detach().to(torch.half).to(weight.dtype) + diff = (p - half_p).abs() + self.assertTrue(torch.equal(diff, torch.zeros_like(diff))) + + def _test_basic_cases(self, dtype, constructor, grad_scaler=None): + self._test_basic_cases_template( + torch.nn.Parameter(torch.randn(10, 5).cuda().to(dtype)), + torch.nn.Parameter(torch.randn(10).cuda().to(dtype)), + torch.randn(5, requires_grad=True).cuda().to(dtype), + constructor, grad_scaler) + + if torch.cuda.device_count() > 1: + self._test_basic_cases_template( + torch.nn.Parameter(torch.randn(10, 5).cuda(0).to(dtype)), + torch.nn.Parameter(torch.randn(10).cuda(1).to(dtype)), + torch.randn(5).cuda(0).to(dtype), + constructor, grad_scaler) + + def _test_without_GradScaler(self, opt): + for dtype in (torch.float16, torch.float32, torch.float64): + self._test_basic_cases( + dtype, lambda weight, bias: opt([weight, bias], lr=1e-2), None) + + def _test_with_GradScaler(self, opt): + for dtype in (torch.float16, torch.float32): + self._test_basic_cases( + dtype, lambda weight, bias: opt([weight, bias], lr=1e-2), + torch.cuda.amp.GradScaler()) + + def test_SRAdam(self): + self._test_without_GradScaler(optim.SRAdam) + self._test_with_GradScaler(optim.SRAdam) + + def test_SRAdamW(self): + self._test_without_GradScaler(optim.SRAdamW) + self._test_with_GradScaler(optim.SRAdamW) + + def test_SRSGD(self): + self._test_without_GradScaler(optim.SRSGD) + self._test_with_GradScaler(optim.SRSGD) + + def _prepare_optimizer(self, opt, update=False): + weight = torch.nn.Parameter(torch.randn(10, 5).cuda()) + bias = torch.nn.Parameter(torch.randn(10).cuda()) + optimizer = opt([weight, bias], lr=1e-2) + if not update: + return optimizer + input = torch.randn(5).cuda() + + optimizer.zero_grad() + _calc_loss(weight, bias, input).backward() + optimizer.step() + optimizer.zero_grad() + + return optimizer + + def _test_state_dict(self, opt_1, opt_2): + optimizer = self._prepare_optimizer(opt_1) + optimizer.load_state_dict(optimizer.state_dict()) + optimizer2 = self._prepare_optimizer(opt_2, False) + optimizer2.load_state_dict(optimizer.state_dict()) + optimizer2 = self._prepare_optimizer(opt_2, False) + optimizer.load_state_dict(optimizer2.state_dict()) + self._prepare_optimizer(opt_1, False).load_state_dict(optimizer2.state_dict()) + + def test_state_dict_compatibility(self): + self._test_state_dict(optim.SRSGD, optim.SGD) + self._test_state_dict(optim.SRAdam, optim.Adam) + self._test_state_dict(optim.SRAdamW, optim.AdamW) + + if __name__ == '__main__': run_tests() diff --git a/test/test_stochastic_rounding.py b/test/test_stochastic_rounding.py new file mode 100644 index 0000000000000..f2c43f98f78d3 --- /dev/null +++ b/test/test_stochastic_rounding.py @@ -0,0 +1,34 @@ +import math + +import torch +import pytest + + +N = 2 ** 14 + + +@pytest.mark.parametrize('scale', tuple(range(-18, 11))) +def test_stochastic_rounding(scale): + + base = math.pow(2, scale) + original_value = (base + math.pow(2, scale + 1)) / 2.0 + .5 * base + x = torch.tensor([original_value] * N).cuda() + _, exponent = math.frexp(original_value) + exponent -= 1 + rounded = torch.stochastic_rounding(x) + + mean = torch.mean(rounded).item() + delta_fp16 = math.pow(2, -10 + exponent if exponent >= -14 else -24) + threshold = 1e-6 + diff = math.fabs(original_value - mean) + + # The right condition of `diff < delta_fp16 / 2.0` is for larger `original_value`. + # The larger `original_value` is, the larger `delta_fp16` is. So, no matter how many elements + # we prepare, it's difficult to guarantee that `mean` is close enough the original value. + assert diff < threshold or diff < delta_fp16 / 2.0 + + +def test_stochastic_rounding_half(): + x = torch.randn((32, 32)).cuda().half() + y = torch.stochastic_rounding(x) + assert torch.eq(x, y).all() diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 5da13158a1a38..83baf1e9d94be 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7544,6 +7544,29 @@ def merge_dicts(*dicts): [100, 200]], dtype=torch.uint8) """) +add_docstr(torch.stochastic_rounding, + r""" +stochastic_rounding(input, generator=None) -> Tensor + +Rounds a tensor to half stochastically. If the dtype of :attr:`input` is Half, +this is equivalent to noop. This function supports only CUDA tensor. +For a floating-point number :attr:`x` and there are two close half values :attr:`y` and :attr:`z`. +Then :attr:`x` is rounded to :attr:`y` (:attr:`z`) with the probability of +:math:`\dfrac{| x - z |}{| y - z |}` (:math:`\dfrac{| x - y |}{| y - z |}`). + +See `Deep learning with limited numerical precision`_ for further details. + +.. _Deep learning with limited numerical precision: https://dl.acm.org/doi/10.5555/3045118.3045303 + +Args: + input (Tensor): float tensor to round stochastically + generator (Generator, optional): A torch.Generator object + +Returns: + Tensor: A stochastically rounded half tensor + +""") + add_docstr(torch._C.Generator, r""" Generator(device='cpu') -> Generator diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 20fb9406412e9..714a34b54cda0 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -13,6 +13,9 @@ from .adamax import Adamax from .asgd import ASGD from .sgd import SGD +from .sradam import SRAdam +from .sradamw import SRAdamW +from .srsgd import SRSGD from .rprop import Rprop from .rmsprop import RMSprop from .optimizer import Optimizer @@ -27,6 +30,9 @@ del adamax del asgd del sgd +del sradam +del sradamw +del srsgd del rprop del rmsprop del optimizer diff --git a/torch/optim/__init__.pyi b/torch/optim/__init__.pyi index e82b5821e5ce3..4ead24b2e9089 100644 --- a/torch/optim/__init__.pyi +++ b/torch/optim/__init__.pyi @@ -11,3 +11,6 @@ from .rmsprop import RMSprop from .rprop import Rprop from .sgd import SGD as SGD from .sparse_adam import SparseAdam +from .sradam import SRAdam +from .sradamw import SRAdamW +from .srsgd import SRSGD diff --git a/torch/optim/_amp_helper.py b/torch/optim/_amp_helper.py new file mode 100644 index 0000000000000..606a879ee8153 --- /dev/null +++ b/torch/optim/_amp_helper.py @@ -0,0 +1,18 @@ +import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator + + +def _combined_found_inf_helper(optimizer, grad_scaler, device): + + found_inf_dict = grad_scaler._check_inf_per_device(optimizer) + # Combines found_inf tensors from all devices. As in GradScaler.update(), + # tensors are combined on the scale's device, which is an arbitrary but + # reasonable choice that avoids new context creation. + found_infs = [f.to(device, non_blocking=True) for f in found_inf_dict.values()] + assert len(found_infs) > 0, "No inf checks were recorded in _check_inf_per_device." + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + with torch.no_grad(): + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + return _MultiDeviceReplicator(found_inf_combined) diff --git a/torch/optim/_amp_helper.pyi b/torch/optim/_amp_helper.pyi new file mode 100644 index 0000000000000..65c0ca5ab0a8c --- /dev/null +++ b/torch/optim/_amp_helper.pyi @@ -0,0 +1,8 @@ +import torch +from torch.optim.optimizer import Optimizer +from torch.cuda.amp.grad_scaler import GradScaler +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator + + +def _combined_found_inf_helper( + optimizer: Optimizer, grad_scaler: GradScaler, device: torch.Device) -> _MultiDeviceReplicator diff --git a/torch/optim/sradam.py b/torch/optim/sradam.py new file mode 100644 index 0000000000000..7bed8af84e029 --- /dev/null +++ b/torch/optim/sradam.py @@ -0,0 +1,116 @@ +import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator +from .adam import Adam +from ._amp_helper import _combined_found_inf_helper + + +def _apply_square_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].square_() + state_per_param['max_exp_avg_sq'].square() + return state_dict + + +def _apply_sqrt_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].sqrt_() + if 'max_exp_avg_sq' not in state_per_param: + state_per_param['max_exp_avg_sq'] = torch.zeros_like(state_per_param['exp_avg_sq']) + else: + state_per_param['max_exp_avg_sq'].sqrt_() + return state_dict + + +class SRAdam(Adam): + r"""Implements Adam algorithm with Stochastic Rounding. + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + With Stochastic Rounding, param, `exp_avg`, `exp_avg_sq`, and optionally `max_exp_avg_sq` + can be represented with 16 bits. See :func:`torch.stochastic_rounding` for details. + This optimizer requires CUDA. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + _step_supports_amp_scaling = True + + def state_dict(self): + return _apply_square_to_state_dict(super().state_dict()) + + def load_state_dict(self, state_dict): + super().load_state_dict(_apply_sqrt_to_state_dict(state_dict)) + + @torch.no_grad() + def step(self, closure=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.no_grad(): + loss = closure() + + if grad_scaler is not None: + inv_scale = grad_scaler._get_scale_async().double().reciprocal().float() + found_inf = _combined_found_inf_helper(self, grad_scaler, inv_scale.device) + else: + inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator( + torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device())) + + inv_scale = _MultiDeviceReplicator(inv_scale) + + for group in self.param_groups: + for param in group['params']: + if param.grad is None: + continue + grad = param.grad + if grad.is_sparse: + raise RuntimeError('SRAdam does not support sparse gradients') + + state = self.state[param] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + beta1, beta2 = group['betas'] + + state['step'] += 1 + + torch.stochastic_rounding_adam_step( + param, grad, + state['exp_avg'], state['exp_avg_sq'], state['max_exp_avg_sq'], + inv_scale.get(param.device), found_inf.get(param.device), + group['lr'], beta1, beta2, + group['weight_decay'], group['eps'], state['step'], + False, group['amsgrad']) + + return loss diff --git a/torch/optim/sradam.pyi b/torch/optim/sradam.pyi new file mode 100644 index 0000000000000..03d6c7d98d4c5 --- /dev/null +++ b/torch/optim/sradam.pyi @@ -0,0 +1,7 @@ +from typing import Callable, Optional, List +from torch.cuda.amp import GradScaler +from .adam import Adam + + +class RSAdam(Adam): + def step(self, closure: Optional[Callable[[], float]]=..., grad_scaler: GradScaler=...) -> Optional[float]: ... diff --git a/torch/optim/sradamw.py b/torch/optim/sradamw.py new file mode 100644 index 0000000000000..0a3e4279ea595 --- /dev/null +++ b/torch/optim/sradamw.py @@ -0,0 +1,120 @@ +import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator +from .adamw import AdamW +from ._amp_helper import _combined_found_inf_helper + + +def _apply_square_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].square_() + state_per_param['max_exp_avg_sq'].square() + return state_dict + + +def _apply_sqrt_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].sqrt_() + if 'max_exp_avg_sq' not in state_per_param: + state_per_param['max_exp_avg_sq'] = torch.zeros_like(state_per_param['exp_avg_sq']) + else: + state_per_param['max_exp_avg_sq'].sqrt_() + return state_dict + + +class SRAdamW(AdamW): + r"""Implements AdamW algorithm with Stochastic Rounding. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + With Stochastic Rounding, param, `exp_avg`, `exp_avg_sq`, and optionally `max_exp_avg_sq` + can be represented with 16 bits. See :func:`torch.stochastic_rounding` for details. + This optimizer requires CUDA. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + _step_supports_amp_scaling = True + + def state_dict(self): + return _apply_square_to_state_dict(super().state_dict()) + + def load_state_dict(self, state_dict): + super().load_state_dict(_apply_sqrt_to_state_dict(state_dict)) + + @torch.no_grad() + def step(self, closure=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if grad_scaler is not None: + inv_scale = grad_scaler._get_scale_async().double().reciprocal().float() + found_inf = _combined_found_inf_helper(self, grad_scaler, inv_scale.device) + else: + inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator( + torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device())) + + inv_scale = _MultiDeviceReplicator(inv_scale) + + for group in self.param_groups: + for param in group['params']: + if param.grad is None: + continue + + grad = param.grad + if grad.is_sparse: + raise RuntimeError('SRAdamW does not support sparse gradients') + + state = self.state[param] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + beta1, beta2 = group['betas'] + + state['step'] += 1 + + torch.stochastic_rounding_adam_step( + param, grad, + state['exp_avg'], state['exp_avg_sq'], state['max_exp_avg_sq'], + inv_scale.get(param.device), found_inf.get(param.device), + group['lr'], beta1, beta2, + group['weight_decay'], group['eps'], state['step'], + True, group['amsgrad']) + + return loss diff --git a/torch/optim/sradamw.pyi b/torch/optim/sradamw.pyi new file mode 100644 index 0000000000000..4b9cd350dac1e --- /dev/null +++ b/torch/optim/sradamw.pyi @@ -0,0 +1,7 @@ +from typing import Callable, Optional, List +from torch.cuda.amp import GradScaler +from .adamw import AdamW + + +class RSAdamW(AdamW): + def step(self, closure: Optional[Callable[[], float]]=..., grad_scaler: GradScaler=...) -> Optional[float]: ... diff --git a/torch/optim/srsgd.py b/torch/optim/srsgd.py new file mode 100644 index 0000000000000..16d6588c153f4 --- /dev/null +++ b/torch/optim/srsgd.py @@ -0,0 +1,75 @@ +import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator +from .sgd import SGD +from ._amp_helper import _combined_found_inf_helper + + +class SRSGD(SGD): + r"""Implements stochastic gradient descent with Stochastic Rounding. + + With Stochastic Rounding, param and `momentum_buffer` can be represented with 16 bits. + See :func:`torch.stochastic_rounding` for details. This optimizer requires CUDA. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + """ + + _step_supports_amp_scaling = True + + @torch.no_grad() + def step(self, closure=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + grad_scaler (:class:`torch.cuda.amp.GradScaler`, optional): + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if grad_scaler is not None: + inv_scale = grad_scaler._get_scale_async().double().reciprocal().float() + found_inf = _combined_found_inf_helper(self, grad_scaler, inv_scale.device) + else: + inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator( + torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device())) + + inv_scale = _MultiDeviceReplicator(inv_scale) + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for param in group['params']: + if param.grad is None: + continue + grad = param.grad + if grad.is_sparse: + raise RuntimeError('SRSGD does not support sparse gradients') + + first_run = False + param_state = self.state[param] + if 'momentum_buffer' not in param_state: + first_run = True + param_state['momentum_buffer'] = torch.zeros_like(param) + momentum_buffer = param_state['momentum_buffer'] + + torch.stochastic_rounding_sgd_step( + param, grad, momentum_buffer, + inv_scale.get(param.device), found_inf.get(param.device), + group['lr'], momentum, weight_decay, dampening, + nesterov, first_run) + + return loss diff --git a/torch/optim/srsgd.pyi b/torch/optim/srsgd.pyi new file mode 100644 index 0000000000000..03ea939b0ccb0 --- /dev/null +++ b/torch/optim/srsgd.pyi @@ -0,0 +1,7 @@ +from typing import Callable, Optional, List +from torch.cuda.amp import GradScaler +from .sgd import SGD + + +class RSSGD(SGD): + def step(self, closure: Optional[Callable[[], float]]=..., grad_scaler: GradScaler=...) -> Optional[float]: ...