Skip to content

Commit cf402b1

Browse files
undo non-triton changes
1 parent 853bb77 commit cf402b1

File tree

11 files changed

+3
-243
lines changed

11 files changed

+3
-243
lines changed

tests/cpp/operator/test_cast_current_scaling.cu

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*************************************************************************
2-
* This file was modified for portability to AMDGPU
3-
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
42
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
53
*
64
* See LICENSE for license information.
@@ -197,43 +195,6 @@ TEST_P(CastCSTestSuite, TestCastCS) {
197195
);
198196
}
199197

200-
#ifdef __HIP_PLATFORM_AMD__
201-
202-
TEST(AmaxConsistencyTest, AtomicVsWorkspace) {
203-
using namespace transformer_engine;
204-
using namespace test;
205-
206-
std::vector<size_t> shape{256, 1024};
207-
const size_t N = product(shape);
208-
209-
// Input: FP32, Output: FP8 (E4M3) with per-tensor scaling
210-
Tensor input("input", shape, DType::kFloat32);
211-
Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false);
212-
Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false);
213-
214-
fillUniform(&input);
215-
216-
// Path 1: atomic-based amax (no workspace)
217-
nvte_compute_amax(input.data(), out_atomic.data(), 0);
218-
219-
// Path 2: two-stage amax using workspace
220-
std::vector<size_t> ws_shape{N};
221-
Tensor workspace("workspace", ws_shape, DType::kFloat32);
222-
nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0);
223-
224-
cudaDeviceSynchronize();
225-
auto err = cudaGetLastError();
226-
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
227-
228-
// Compare the resulting amax values
229-
float amax_atomic = out_atomic.amax();
230-
float amax_ws = out_ws.amax();
231-
232-
compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f);
233-
}
234-
235-
#endif
236-
237198

238199

239200
INSTANTIATE_TEST_SUITE_P(

transformer_engine/common/include/transformer_engine/recipe.h

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*************************************************************************
2-
* This file was modified for portability to AMDGPU
3-
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
42
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
53
*
64
* See LICENSE for license information.
@@ -75,12 +73,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
7573
std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
7674
float margin, cudaStream_t stream);
7775

78-
#ifdef __HIP_PLATFORM_AMD__
79-
80-
constexpr int amax_kernel_threads = 512;
81-
82-
#endif
83-
8476
/*! \brief Compute an FP8 tensor's amax.
8577
*
8678
* The amax (maximum absolute value) of the input tensor is computed
@@ -92,22 +84,6 @@ constexpr int amax_kernel_threads = 512;
9284
*/
9385
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
9486

95-
#ifdef __HIP_PLATFORM_AMD__
96-
97-
/*! \brief Compute an FP8 tensor's amax.
98-
*
99-
* The amax (maximum absolute value) of the input tensor is computed
100-
* and written to the amax buffer of the output tensor.
101-
*
102-
* \param[in] input Input tensor. Must be unquantized.
103-
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
104-
* \param[out] workspace Output tensor. Must be FP32.
105-
* \param[in] stream CUDA stream used for the operation.
106-
*/
107-
void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, NVTETensor workspace, cudaStream_t stream);
108-
109-
#endif
110-
11187
/*! \brief Update an FP8 tensor's scale based on its amax.
11288
*
11389
* This is only supported for FP8 tensors with per-tensor scaling.

transformer_engine/common/recipe/current_scaling.cu

Lines changed: 2 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,12 @@ using bf16__ = __nv_bfloat16;
2626
using bf16__ = __hip_bfloat16;
2727
#endif //__HIP_PLATFORM_AMD__
2828

29-
30-
#ifdef __HIP_PLATFORM_AMD__
31-
32-
template <int BLOCK_THREADS>
33-
__global__ void amax_final_reduce(const float* __restrict__ block_amax,
34-
float* __restrict__ global_amax,
35-
int num_blocks) {
36-
float val = 0.f;
37-
38-
for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) {
39-
val = fmaxf(val, block_amax[i]);
40-
}
41-
42-
const int warp_id = threadIdx.x / THREADS_PER_WARP;
43-
const float block_max =
44-
reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);
45-
46-
if (threadIdx.x == 0) {
47-
*global_amax = block_max;
48-
}
49-
}
50-
51-
#endif
29+
constexpr int amax_kernel_threads = 512;
5230

5331
template <int nvec, bool aligned, typename InputType>
5432
__launch_bounds__(amax_kernel_threads) __global__
55-
#ifdef __HIP_PLATFORM_AMD__
56-
void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N,
57-
const size_t num_aligned_elements) {
58-
#else
5933
void amax_kernel(const InputType *input, float *amax, const size_t N,
6034
const size_t num_aligned_elements) {
61-
#endif
6235
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
6336
InputType max{0.f};
6437
const int warp_id = threadIdx.x / THREADS_PER_WARP;
@@ -92,23 +65,12 @@ __launch_bounds__(amax_kernel_threads) __global__
9265
// Reduce amax over block
9366
max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
9467
if (threadIdx.x == 0) {
95-
#ifdef __HIP_PLATFORM_AMD__
96-
if (block_amax != nullptr) {
97-
// 2-stage: write per-block result
98-
block_amax[blockIdx.x] = max;
99-
} else {
100-
// Atomic path: directly update global amax
101-
atomicMaxFloat(amax, max);
102-
}
103-
#else
10468
atomicMaxFloat(amax, max);
105-
#endif
10669
}
10770
}
10871

10972
template <int nvec, typename InputType>
110-
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax,
111-
size_t block_capacity, cudaStream_t stream) {
73+
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
11274
// Zero out amax so we can update with atomic max
11375
(void)cudaMemsetAsync(amax, 0, sizeof(float), stream);
11476

@@ -127,54 +89,24 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
12789
constexpr size_t max_blocks = 65535;
12890
num_blocks = std::min(num_blocks, max_blocks);
12991

130-
#ifdef __HIP_PLATFORM_AMD__
131-
if (block_capacity < num_blocks)
132-
block_amax = nullptr;
133-
#endif
134-
13592
// Launch kernel
13693
switch (align) {
13794
case Alignment::SAME_ALIGNED:
138-
#ifdef __HIP_PLATFORM_AMD__
139-
amax_kernel<nvec, true, InputType>
140-
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
141-
#else
14295
amax_kernel<nvec, true, InputType>
14396
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
144-
#endif
14597
break;
14698
case Alignment::SAME_UNALIGNED:
147-
#ifdef __HIP_PLATFORM_AMD__
148-
amax_kernel<nvec, false, InputType>
149-
<<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements);
150-
#else
15199
amax_kernel<nvec, false, InputType>
152100
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
153-
#endif
154101
break;
155102
case Alignment::DIFFERENT: {
156103
// This case is a logic error, since there is only one pointer (input)
157104
// in the alignment check. Still safe to process without vectorization.
158-
#ifdef __HIP_PLATFORM_AMD__
159-
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N);
160-
#else
161105
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
162-
#endif
163106
break;
164107
}
165108
}
166109

167-
#ifdef __HIP_PLATFORM_AMD__
168-
if (block_amax != nullptr) {
169-
constexpr int FINAL_REDUCE_THREADS = 256;
170-
dim3 fr_block(FINAL_REDUCE_THREADS);
171-
dim3 fr_grid(1);
172-
173-
amax_final_reduce<FINAL_REDUCE_THREADS>
174-
<<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks));
175-
}
176-
#endif
177-
178110
// Check results
179111
NVTE_CHECK_CUDA(cudaGetLastError());
180112
}
@@ -183,12 +115,6 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
183115
} // namespace transformer_engine
184116

185117
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
186-
#ifdef __HIP_PLATFORM_AMD__
187-
nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream);
188-
}
189-
190-
void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
191-
#endif
192118
NVTE_API_CALL(nvte_compute_amax);
193119
using namespace transformer_engine;
194120

@@ -224,31 +150,11 @@ void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor
224150
to_string(output.amax.dtype), ")");
225151
CheckOutputTensor(output, "output_compute_amax", true);
226152

227-
#ifdef __HIP_PLATFORM_AMD__
228-
// Optional workspace
229-
float* block_amax = nullptr;
230-
size_t block_capacity = 0;
231-
232-
if (workspace_ != nullptr) {
233-
auto &workspace = *reinterpret_cast<Tensor *>(workspace_);
234-
if (workspace.data.dptr != nullptr) {
235-
NVTE_CHECK(workspace.data.dtype == DType::kFloat32,
236-
"Workspace tensor for amax computation must be FP32, got dtype=",
237-
to_string(workspace.data.dtype));
238-
block_amax = reinterpret_cast<float*>(workspace.data.dptr);
239-
block_capacity = workspace.data.numel();
240-
}
241-
}
242-
#endif
243-
244153
// Compute amax
245154
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
246155
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
247156
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
248157
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
249-
#ifdef __HIP_PLATFORM_AMD__
250-
block_amax, block_capacity,
251-
#endif
252158
stream);); // NOLINT(*)
253159
}
254160

transformer_engine/pytorch/csrc/common.cpp

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*************************************************************************
2-
* This file was modified for portability to AMDGPU
3-
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
42
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
53
*
64
* See LICENSE for license information.
@@ -12,10 +10,6 @@
1210
#include "pybind.h"
1311
#include "transformer_engine/transformer_engine.h"
1412

15-
#ifdef __HIP_PLATFORM_AMD__
16-
#include "common/common.h"
17-
#endif
18-
1913
namespace transformer_engine::pytorch {
2014

2115
std::vector<size_t> getTensorShape(at::Tensor t) {
@@ -283,32 +277,4 @@ int roundup(const int value, const int multiple) {
283277
return ((value + multiple - 1) / multiple) * multiple;
284278
}
285279

286-
#ifdef __HIP_PLATFORM_AMD__
287-
288-
inline bool nvte_use_atomic_amax() {
289-
const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX");
290-
if (env_p && std::string(env_p) == "1")
291-
return true;
292-
return false;
293-
}
294-
295-
TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor) {
296-
if (nvte_use_atomic_amax() || input_tensor.numel() == 0) {
297-
// User chose atomic path, or empty tensor -> no need for workspace
298-
return TensorWrapper{};
299-
}
300-
301-
const auto N = input_tensor.numel();
302-
constexpr size_t max_blocks_hw = 65535;
303-
304-
size_t max_blocks = DIVUP(N, static_cast<size_t>(amax_kernel_threads));
305-
size_t workspace_blocks = std::min(max_blocks, max_blocks_hw);
306-
307-
at::Tensor ws = at::empty(workspace_blocks, at::CUDA(at::kFloat));
308-
309-
return makeTransformerEngineTensor(ws);
310-
}
311-
312-
#endif
313-
314280
} // namespace transformer_engine::pytorch

transformer_engine/pytorch/csrc/common.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,6 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
374374

375375
int roundup(const int value, const int multiple);
376376

377-
#ifdef __HIP_PLATFORM_AMD__
378-
TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor);
379-
#endif
380377
} // namespace transformer_engine::pytorch
381378

382379
namespace std {

transformer_engine/pytorch/csrc/extensions/activation.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*************************************************************************
2-
* This file was modified for portability to AMDGPU
3-
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
42
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
53
*
64
* See LICENSE for license information.
@@ -38,18 +36,10 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int
3836
auto [te_output_act, out_act] =
3937
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));
4038

41-
#ifdef __HIP_PLATFORM_AMD__
42-
auto workspace = allocate_amax_workspace(te_input);
43-
#endif
4439
NVTE_SCOPED_GIL_RELEASE({
4540
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
4641
// use te_output_act as input to the compute amax and find the amax of activated tensor
47-
#ifdef __HIP_PLATFORM_AMD__
48-
nvte_compute_amax_with_workspace(te_output_act.data(), te_output.data(),
49-
workspace.data(), at::cuda::getCurrentCUDAStream());
50-
#else
5142
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
52-
#endif
5343
});
5444

5545
// my_quantizer here has to be a Float8CurrentScalingQuantizer

transformer_engine/pytorch/csrc/extensions/bias.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*************************************************************************
2-
* This file was modified for portability to AMDGPU
3-
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
42
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
53
*
64
* See LICENSE for license information.
@@ -51,13 +49,7 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
5149
// my_quantizer here has to be a Float8CurrentScalingQuantizer
5250
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
5351
NVTE_SCOPED_GIL_RELEASE({
54-
#ifdef __HIP_PLATFORM_AMD__
55-
nvte_compute_amax_with_workspace(input_tensor.data(), out_tensor.data(),
56-
allocate_amax_workspace(input_tensor).data(),
57-
at::cuda::getCurrentCUDAStream());
58-
#else
5952
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
60-
#endif
6153
});
6254
// check if we need to do amax reudction (depending on model parallel configs)
6355
if (my_quantizer_cs->with_amax_reduction) {

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
/*************************************************************************
2-
* This file was modified for portability to AMDGPU
3-
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
42
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
53
*
64
* See LICENSE for license information.
@@ -55,13 +53,7 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
5553
// my_quantizer here has to be a Float8CurrentScalingQuantizer
5654
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
5755
NVTE_SCOPED_GIL_RELEASE({
58-
#ifdef __HIP_PLATFORM_AMD__
59-
nvte_compute_amax_with_workspace(te_input.data(), te_output.data(),
60-
allocate_amax_workspace(te_input).data(),
61-
at::cuda::getCurrentCUDAStream());
62-
#else
6356
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
64-
#endif
6557
});
6658
// check if we need to do amax reudction (depending on model parallel configs)
6759
if (my_quantizer_cs->with_amax_reduction) {

0 commit comments

Comments
 (0)