@@ -26,39 +26,12 @@ using bf16__ = __nv_bfloat16;
2626using bf16__ = __hip_bfloat16;
2727#endif // __HIP_PLATFORM_AMD__
2828
29-
30- #ifdef __HIP_PLATFORM_AMD__
31-
32- template <int BLOCK_THREADS>
33- __global__ void amax_final_reduce (const float * __restrict__ block_amax,
34- float * __restrict__ global_amax,
35- int num_blocks) {
36- float val = 0 .f ;
37-
38- for (int i = threadIdx .x ; i < num_blocks; i += BLOCK_THREADS) {
39- val = fmaxf (val, block_amax[i]);
40- }
41-
42- const int warp_id = threadIdx .x / THREADS_PER_WARP;
43- const float block_max =
44- reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);
45-
46- if (threadIdx .x == 0 ) {
47- *global_amax = block_max;
48- }
49- }
50-
51- #endif
29+ constexpr int amax_kernel_threads = 512 ;
5230
5331template <int nvec, bool aligned, typename InputType>
5432__launch_bounds__ (amax_kernel_threads) __global__
55- #ifdef __HIP_PLATFORM_AMD__
56- void amax_kernel (const InputType *input, float *amax, float * __restrict__ block_amax, const size_t N,
57- const size_t num_aligned_elements) {
58- #else
5933 void amax_kernel (const InputType *input, float *amax, const size_t N,
6034 const size_t num_aligned_elements) {
61- #endif
6235 VectorizedLoader<InputType, nvec, aligned> loader (input, N);
6336 InputType max{0 .f };
6437 const int warp_id = threadIdx .x / THREADS_PER_WARP;
@@ -92,23 +65,12 @@ __launch_bounds__(amax_kernel_threads) __global__
9265 // Reduce amax over block
9366 max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
9467 if (threadIdx .x == 0 ) {
95- #ifdef __HIP_PLATFORM_AMD__
96- if (block_amax != nullptr ) {
97- // 2-stage: write per-block result
98- block_amax[blockIdx .x ] = max;
99- } else {
100- // Atomic path: directly update global amax
101- atomicMaxFloat (amax, max);
102- }
103- #else
10468 atomicMaxFloat (amax, max);
105- #endif
10669 }
10770}
10871
10972template <int nvec, typename InputType>
110- void launch_amax_kernel (const InputType *input, float *amax, const size_t N, float *block_amax,
111- size_t block_capacity, cudaStream_t stream) {
73+ void launch_amax_kernel (const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
11274 // Zero out amax so we can update with atomic max
11375 (void )cudaMemsetAsync (amax, 0 , sizeof (float ), stream);
11476
@@ -127,54 +89,24 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
12789 constexpr size_t max_blocks = 65535 ;
12890 num_blocks = std::min (num_blocks, max_blocks);
12991
130- #ifdef __HIP_PLATFORM_AMD__
131- if (block_capacity < num_blocks)
132- block_amax = nullptr ;
133- #endif
134-
13592 // Launch kernel
13693 switch (align) {
13794 case Alignment::SAME_ALIGNED:
138- #ifdef __HIP_PLATFORM_AMD__
139- amax_kernel<nvec, true , InputType>
140- <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
141- #else
14295 amax_kernel<nvec, true , InputType>
14396 <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements);
144- #endif
14597 break ;
14698 case Alignment::SAME_UNALIGNED:
147- #ifdef __HIP_PLATFORM_AMD__
148- amax_kernel<nvec, false , InputType>
149- <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
150- #else
15199 amax_kernel<nvec, false , InputType>
152100 <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements);
153- #endif
154101 break ;
155102 case Alignment::DIFFERENT: {
156103 // This case is a logic error, since there is only one pointer (input)
157104 // in the alignment check. Still safe to process without vectorization.
158- #ifdef __HIP_PLATFORM_AMD__
159- amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, N);
160- #else
161105 amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, amax, N, N);
162- #endif
163106 break ;
164107 }
165108 }
166109
167- #ifdef __HIP_PLATFORM_AMD__
168- if (block_amax != nullptr ) {
169- constexpr int FINAL_REDUCE_THREADS = 256 ;
170- dim3 fr_block (FINAL_REDUCE_THREADS);
171- dim3 fr_grid (1 );
172-
173- amax_final_reduce<FINAL_REDUCE_THREADS>
174- <<<fr_grid, fr_block, 0 , stream>>> (block_amax, amax, static_cast <int >(num_blocks));
175- }
176- #endif
177-
178110 // Check results
179111 NVTE_CHECK_CUDA (cudaGetLastError ());
180112}
@@ -183,12 +115,6 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
183115} // namespace transformer_engine
184116
185117void nvte_compute_amax (const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
186- #ifdef __HIP_PLATFORM_AMD__
187- nvte_compute_amax_with_workspace (input_, output_, /* workspace=*/ nullptr , stream);
188- }
189-
190- void nvte_compute_amax_with_workspace (const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
191- #endif
192118 NVTE_API_CALL (nvte_compute_amax);
193119 using namespace transformer_engine ;
194120
@@ -224,31 +150,11 @@ void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor
224150 to_string (output.amax .dtype ), " )" );
225151 CheckOutputTensor (output, " output_compute_amax" , true );
226152
227- #ifdef __HIP_PLATFORM_AMD__
228- // Optional workspace
229- float * block_amax = nullptr ;
230- size_t block_capacity = 0 ;
231-
232- if (workspace_ != nullptr ) {
233- auto &workspace = *reinterpret_cast <Tensor *>(workspace_);
234- if (workspace.data .dptr != nullptr ) {
235- NVTE_CHECK (workspace.data .dtype == DType::kFloat32 ,
236- " Workspace tensor for amax computation must be FP32, got dtype=" ,
237- to_string (workspace.data .dtype ));
238- block_amax = reinterpret_cast <float *>(workspace.data .dptr );
239- block_capacity = workspace.data .numel ();
240- }
241- }
242- #endif
243-
244153 // Compute amax
245154 TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
246155 input.data .dtype , IType, constexpr int nvec = 32 / sizeof (IType);
247156 launch_amax_kernel<nvec>(reinterpret_cast <const IType *>(input.data .dptr ),
248157 reinterpret_cast <float *>(output.amax .dptr ), input.data .numel (),
249- #ifdef __HIP_PLATFORM_AMD__
250- block_amax, block_capacity,
251- #endif
252158 stream);); // NOLINT(*)
253159}
254160
0 commit comments