Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions aten/src/ATen/native/cuda/StochasticRounding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <ATen/ATen.h>
#include <ATen/native/cuda/stochastic_rounding.cuh>


namespace at {
namespace native {

template <typename input_t, typename output_t>
__global__ void stochastic_rounding_kernel(
const input_t* input,
output_t* output,
const int64_t numel,
std::pair<uint64_t, uint64_t> seed_and_offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state);

for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
float inp = static_cast<float>(input[i]);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make these changes https://github.com/csarofeen/pytorch/pull/17/files#r420422970 the cast to float won't be necessary.

output[i] = round_stochastically<output_t>(inp, curand_uniform(&state));
}
}

Tensor stochastic_rounding_cuda(const Tensor& input, c10::optional<Generator> gen_) {

TORCH_CHECK(input.is_contiguous());

if (input.scalar_type() == kHalf) {
return input;
}

Tensor output = at::empty_like(input, input.options().dtype(kHalf), input.suggest_memory_format());
const int64_t numel = input.numel();
if (numel == 0) {
return output;
}

const int block = 256;
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block;
Copy link

@mcarilli mcarilli May 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only correct if the kernel's number of registers per thread is <= 32, otherwise register pressure limits your occupancy. You can recompile kernels with -ptxas-options=-v as an nvcc option and nvcc will print how many registers they use (this is easiest to do with the kernels in an extension, I'm not sure how you would pass that option to nvcc in a pytorch build).

unsigned int grid = (numel + block - 1) / block;
grid = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid);

auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs((numel + block * grid - 1) / (block * grid));
}

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "stochastic_rounding_cuda", [&] {
stochastic_rounding_kernel<scalar_t, at::Half><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest concern is, upstream will probably ask you to rewrite this with TensorIterator in some form, as @zasdfgbnm hinted.

input.data_ptr<scalar_t>(),
output.data_ptr<at::Half>(),
numel, rng_engine_inputs);
});

return output;
}

} // namespace native
} // namespace at
123 changes: 123 additions & 0 deletions aten/src/ATen/native/cuda/StochasticRoundingAdam.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include <ATen/ATen.h>
#include <ATen/native/cuda/stochastic_rounding.cuh>


namespace at {
namespace native {

template <typename scalar_t>
__global__ void stochastic_rounding_adam_step_kernel(
scalar_t *weights, scalar_t *gradients,
scalar_t *exp_avg, scalar_t *exp_avg_sq, scalar_t *max_exp_avg_sq,
float *inv_scale, float *found_inf,
float lr, float beta1, float beta2,
float weight_decay, float eps, int step,
bool is_decoupled, bool is_amsgrad,
int numel, std::pair<uint64_t, uint64_t> seeds) {

if (*found_inf) return;

int tid = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, tid, seeds.second, &state);

float m_correction = 1.0 - powf(beta1, step);
float v_correction = 1.0 - powf(beta2, step);

for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
float weight = static_cast<float>(weights[i]);
float gradient = static_cast<float>(gradients[i]) * (*inv_scale);
float m = static_cast<float>(exp_avg[i]);
// Stochastic Rounding Adam tracks square root of the exponential average of squared gradient.
float v = static_cast<float>(exp_avg_sq[i]);
v = v * v;
float4 random_values = curand_uniform4(&state);

if (weight_decay != 0.0f) {
if (is_decoupled)
weight *= (1 - lr * weight_decay);
else
gradient += weight_decay * weight;
}

// Update m and v.
m = beta1 * m + (1.0 - beta1) * gradient;
v = beta2 * v + (1.0 - beta2) * (gradient * gradient);

// Unbias v
float max_v = v;
if (is_amsgrad) {
float prev_max_v = static_cast<float>(max_exp_avg_sq[i]);
prev_max_v = prev_max_v * prev_max_v;
max_v = fmaxf(prev_max_v, v);
}

weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps);

weights[i] = round_stochastically<scalar_t>(weight, random_values.x);
exp_avg[i] = round_stochastically<scalar_t>(m, random_values.y);
exp_avg_sq[i] = round_stochastically<scalar_t>(sqrtf(v), random_values.z);
if (is_amsgrad) {
max_exp_avg_sq[i] = round_stochastically<scalar_t>(sqrtf(max_v), random_values.w);
}
}
}


Tensor stochastic_rounding_adam_step_cuda(
Tensor& param,
const Tensor& grad,
Tensor& exp_avg,
Tensor& exp_avg_sq,
Tensor& max_exp_avg_sq,
const Tensor& inv_scale,
const Tensor& found_inf,
double lr, double beta1, double beta2,
double weight_decay, double eps, int64_t step,
bool is_decoupled, bool is_amsgrad, c10::optional<Generator> gen_) {

if (param.numel() == 0) return param;

TORCH_CHECK(param.is_contiguous());
TORCH_CHECK(grad.is_contiguous());
TORCH_CHECK(exp_avg.is_contiguous());
TORCH_CHECK(exp_avg_sq.is_contiguous());
TORCH_CHECK(max_exp_avg_sq.is_contiguous());

const int64_t numel = param.numel();
const int block_size = 256;
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
dim3 dim_block(block_size);
dim3 grid((numel + block_size - 1) / block_size);
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);

auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());

uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (block_size * grid.x)) * 4;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
param.scalar_type(), "stochastic_rounding_adam_step_cuda", [&] {
stochastic_rounding_adam_step_kernel<scalar_t><<<grid, dim_block, 0, c10::cuda::getCurrentCUDAStream()>>>(
param.data_ptr<scalar_t>(),
grad.data_ptr<scalar_t>(),
exp_avg.data_ptr<scalar_t>(),
exp_avg_sq.data_ptr<scalar_t>(),
max_exp_avg_sq.data_ptr<scalar_t>(),
inv_scale.data_ptr<float>(),
found_inf.data_ptr<float>(),
lr, beta1, beta2, weight_decay, eps, step,
is_decoupled, is_amsgrad,
numel, rng_engine_inputs);
}
);
AT_CUDA_CHECK(cudaGetLastError());
return param;
}

} // namespace native
} // namespace at
93 changes: 93 additions & 0 deletions aten/src/ATen/native/cuda/StochasticRoundingSGD.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include <ATen/ATen.h>
#include <ATen/native/cuda/stochastic_rounding.cuh>


namespace at {
namespace native {

// SGD update math with Stochastic Rounding
template <typename scalar_t>
__global__ void stochastic_rounding_sgd_step_kernel(
scalar_t *weights, scalar_t *gradients, scalar_t *momentum_buffer,
float* inv_scale, float* found_inf,
float weight_decay, float momentum, float dampening, float lr,
bool nesterov, bool first_run, int numel, std::pair<uint64_t, uint64_t> seeds) {

if (*found_inf) return;

int tid = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, tid, seeds.second, &state);

for (int i = tid; i < numel; i += blockDim.x * gridDim.x) {
float weight = static_cast<float>(weights[i]);
float gradient = static_cast<float>(gradients[i]) * (*inv_scale);
float velocity = static_cast<float>(momentum_buffer[i]);
float4 random_values = curand_uniform4(&state);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you generate 4 rng and only use 2. I don't think that's a big problem though.


if (weight_decay != 0.0f)
gradient += weight_decay * weight;

if (momentum != 0.0f) {
if (!first_run)
velocity = velocity * momentum + (1.0f - dampening) * gradient;
else
velocity = gradient;

if (nesterov)
gradient += momentum * velocity;
else
gradient = velocity;
}

weight -= lr * gradient;

weights[i] = round_stochastically<scalar_t>(weight, random_values.x);
if (momentum != 0.0f)
momentum_buffer[i] = round_stochastically<scalar_t>(velocity, random_values.y);
}
}

Tensor stochastic_rounding_sgd_step_cuda(
Tensor& param, const Tensor& grad, Tensor& momentum_buffer,
const Tensor& inv_scale, const Tensor& found_inf,
double lr, double momentum, double weight_decay, double dampening,
bool nesterov, bool first_run, c10::optional<Generator> gen_) {

if (param.numel() == 0) return param;

TORCH_CHECK(param.is_contiguous());
TORCH_CHECK(grad.is_contiguous());
TORCH_CHECK(momentum_buffer.is_contiguous());

const int64_t numel = param.numel();
const int block_size = 256;
const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
dim3 dim_block(block_size);
dim3 grid((numel + block_size - 1) / block_size);
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);

auto gen = get_generator_or_default<CUDAGeneratorImpl>(gen_, cuda::detail::getDefaultCUDAGenerator());
uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (dim_block.x * grid.x)) * 4;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
param.scalar_type(), "stochastic_rounding_sgd_step_cuda", [&] {
stochastic_rounding_sgd_step_kernel<scalar_t><<<grid, dim_block, 0, c10::cuda::getCurrentCUDAStream()>>>(
param.data_ptr<scalar_t>(),
grad.data_ptr<scalar_t>(),
momentum_buffer.data_ptr<scalar_t>(),
inv_scale.data_ptr<float>(), found_inf.data_ptr<float>(),
static_cast<float>(weight_decay), static_cast<float>(momentum), static_cast<float>(dampening), static_cast<float>(lr),
nesterov, first_run, numel, rng_engine_inputs);
});
AT_CUDA_CHECK(cudaGetLastError());
return param;
}

} // namespace native
} // namespace at
60 changes: 60 additions & 0 deletions aten/src/ATen/native/cuda/stochastic_rounding.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <math.h>
#include <utility>

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>

#include <ATen/Utils.h>
#include <ATen/Generator.h>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAFunctions.h>

// 2^-10 is the step for normal FP16 numbers.
// 2^-24 is the unit in the last place (ULP)/precision limitation.
// 24 is **NOT** related to the number of mantissa bits of single precision format.
__device__ const float TWO_10 = 0.0009765625;
__device__ const float TWO_24 = 0.000000059604644775390625;


template<typename T>
__device__ __forceinline__ T maybe_upcast(__half x){
return T(__half2float(x));
}

template<>
__device__ __forceinline__ __half maybe_upcast<__half>(__half x){
return x;
}

__device__ __forceinline__ float get_delta_fp16(float x) {
int exponent;
frexpf(x, &exponent);
exponent -= 1;
if (exponent >= -14)
return TWO_10 * std::pow(2, exponent);
else
return TWO_24;
}

// Natalia magic

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep this comment.

template <typename scalar_t>
__device__ __forceinline__ scalar_t round_stochastically(float x, float random_value) {
Copy link

@mcarilli mcarilli May 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, this function only supports rounding to fp16. I think its syntax is misleading.

To make the usage clearer, and to establish an API that supports stochastic rounding to other types in the future, I think you should define it as follows:

template<typename round_to_prec, typename out_type, typename in_type=float>
struct round_stochastically {
  static_assert(false, "round_stochastically only supports round_to_prec=at::Half");
  __device__ __forceinline__ out_type operator()(in_type x, float random_value) {}
};

template <typename out_type, typename in_type=float>
struct round_stochastically<at::Half, out_type, in_type> {
  __device__ __forceinline__ at::Half operator()(in_type x, float random_value) {
    // what we have now
  }
}

Then the caller should say

weights[i] = round_stochastically<at::Half, scalar_t>(weight, random_values.x);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L.59' s maybe_upcast does cast from at::Half to float/double if necessary and stochastic rounding SGD & Adam kernel use this functionality.

Copy link

@mcarilli mcarilli May 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to support stochastic rounding to other formats (like bfloat16) later. The API should allow the caller to set a type that determines the rounding precision, even if the actual rounding code for that precision isn't implemented yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understand. Updated.

if (x == 0.0) {
return scalar_t(0.0);
}
float delta = get_delta_fp16(x);
float val;
if (x < 0.0) {
val = x - random_value * delta;
} else {
val = x + random_value * delta;
}
return maybe_upcast<scalar_t>(__float2half_rz(val));
}
12 changes: 12 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6719,3 +6719,15 @@
# It is undocumented and should not be used outside of tests.
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full

- func: stochastic_rounding(Tensor input, Generator? gen_=None) -> Tensor
dispatch:
CUDA: stochastic_rounding_cuda

- func: stochastic_rounding_adam_step(Tensor(a!) param, Tensor grad, Tensor(b!) exp_avg, Tensor(c!) exp_avg_sq, Tensor(d!) max_exp_avg_sq, Tensor inv_scale, Tensor found_inf, float lr, float beta1, float beta2, float weight_decay, float eps, int step, bool is_decoupled, bool is_amsgrad, Generator? gen_=None) -> Tensor(a!)
dispatch:
CUDA: stochastic_rounding_adam_step_cuda

- func: stochastic_rounding_sgd_step(Tensor(a!) param, Tensor grad, Tensor(b!) momentum_buffer, Tensor inv_scale, Tensor found_inf, float lr, float momentum, float weight_decay, float dampening, bool nesterov, bool first_run, Generator? gen_=None) -> Tensor(a!)
dispatch:
CUDA: stochastic_rounding_sgd_step_cuda
6 changes: 6 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ Algorithms
:members:
.. autoclass:: SGD
:members:
.. autoclass:: SRAdam
:members:
.. autoclass:: SRAdamW
:members:
.. autoclass:: SRSGD
:members:

How to adjust learning rate
---------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ Pointwise Ops
.. autofunction:: sinh
.. autofunction:: sqrt
.. autofunction:: square
.. autofunction:: stochastic_rounding
.. autofunction:: tan
.. autofunction:: tanh
.. autofunction:: true_divide
Expand Down
1 change: 1 addition & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
'test_overrides',
'test_jit_fuser_te',
'test_tensorexpr',
'test_stochastic_rounding',
]

# skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn
Expand Down
Loading