Skip to content

Commit 31bd573

Browse files
committed
Make round_stochastically flexible for future another dtype support
1 parent ccab446 commit 31bd573

File tree

4 files changed

+33
-21
lines changed

4 files changed

+33
-21
lines changed

aten/src/ATen/native/cuda/StochasticRounding.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ __global__ void stochastic_rounding_kernel(
1515
curandStatePhilox4_32_10_t state;
1616
curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state);
1717

18+
round_stochastically<output_t, input_t, at::Half> rounder;
19+
1820
for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
19-
float inp = static_cast<float>(input[i]);
20-
output[i] = round_stochastically<output_t>(inp, curand_uniform(&state));
21+
output[i] = rounder(input[i], curand_uniform(&state));
2122
}
2223
}
2324

aten/src/ATen/native/cuda/StochasticRoundingAdam.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ __global__ void stochastic_rounding_adam_step_kernel(
2121
curandStatePhilox4_32_10_t state;
2222
curand_init(seeds.first, tid, seeds.second, &state);
2323

24+
round_stochastically<scalar_t, float, at::Half> rounder;
25+
2426
float m_correction = 1.0 - powf(beta1, step);
2527
float v_correction = 1.0 - powf(beta2, step);
2628

@@ -54,11 +56,11 @@ __global__ void stochastic_rounding_adam_step_kernel(
5456

5557
weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps);
5658

57-
weights[i] = round_stochastically<scalar_t>(weight, random_values.x);
58-
exp_avg[i] = round_stochastically<scalar_t>(m, random_values.y);
59-
exp_avg_sq[i] = round_stochastically<scalar_t>(sqrtf(v), random_values.z);
59+
weights[i] = rounder(weight, random_values.x);
60+
exp_avg[i] = rounder(m, random_values.y);
61+
exp_avg_sq[i] = rounder(sqrtf(v), random_values.z);
6062
if (is_amsgrad) {
61-
max_exp_avg_sq[i] = round_stochastically<scalar_t>(sqrtf(max_v), random_values.w);
63+
max_exp_avg_sq[i] = rounder(sqrtf(max_v), random_values.w);
6264
}
6365
}
6466
}

aten/src/ATen/native/cuda/StochasticRoundingSGD.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ __global__ void stochastic_rounding_sgd_step_kernel(
1919
curandStatePhilox4_32_10_t state;
2020
curand_init(seeds.first, tid, seeds.second, &state);
2121

22+
round_stochastically<scalar_t, float, at::Half> rounder;
23+
2224
for (int i = tid; i < numel; i += blockDim.x * gridDim.x) {
2325
float weight = static_cast<float>(weights[i]);
2426
float gradient = static_cast<float>(gradients[i]) * (*inv_scale);
@@ -42,9 +44,9 @@ __global__ void stochastic_rounding_sgd_step_kernel(
4244

4345
weight -= lr * gradient;
4446

45-
weights[i] = round_stochastically<scalar_t>(weight, random_values.x);
47+
weights[i] = rounder(weight, random_values.x);
4648
if (momentum != 0.0f)
47-
momentum_buffer[i] = round_stochastically<scalar_t>(velocity, random_values.y);
49+
momentum_buffer[i] = rounder(velocity, random_values.y);
4850
}
4951
}
5052

aten/src/ATen/native/cuda/stochastic_rounding.cuh

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,24 @@ __device__ __forceinline__ float get_delta_fp16(float x) {
4444
}
4545

4646
// Natalia magic
47-
template <typename scalar_t>
48-
__device__ __forceinline__ scalar_t round_stochastically(float x, float random_value) {
49-
if (x == 0.0) {
50-
return scalar_t(0.0);
51-
}
52-
float delta = get_delta_fp16(x);
53-
float val;
54-
if (x < 0.0) {
55-
val = x - random_value * delta;
56-
} else {
57-
val = x + random_value * delta;
47+
template <typename out_type, typename in_type, typename round_to_prec=at::Half>
48+
struct round_stochastically {
49+
static_assert(std::is_same<round_to_prec, at::Half>::value, "round_stochastically only supports round_to_prec=at::Half");
50+
};
51+
52+
template <typename out_type, typename in_type>
53+
struct round_stochastically<out_type, in_type, at::Half> {
54+
__device__ __forceinline__ out_type operator()(in_type x, float random_value) {
55+
if (x == 0.0) {
56+
return out_type(0.0);
57+
}
58+
float delta = get_delta_fp16(static_cast<float>(x));
59+
float val;
60+
if (x < 0.0) {
61+
val = x - random_value * delta;
62+
} else {
63+
val = x + random_value * delta;
64+
}
65+
return maybe_upcast<out_type>(__float2half_rz(val));
5866
}
59-
return maybe_upcast<scalar_t>(__float2half_rz(val));
60-
}
67+
};

0 commit comments

Comments
 (0)