Skip to content

Commit 0a94bb4

Browse files
xw285cornellalugoreyjithunnair-amd
authored andcommitted
[ROCm] CK Flash Attention Backend (pytorch#143695)
Replace pytorch#138947 for re-import. Replaces #1592 This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics. Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch. Pull Request resolved: pytorch#143695 Approved by: https://github.com/malfet Co-authored-by: Andy Lugo <[email protected]> Co-authored-by: Jithun Nair <[email protected]>
1 parent 3251171 commit 0a94bb4

File tree

1,840 files changed

+249657
-38
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,840 files changed

+249657
-38
lines changed

LICENSE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ All contributions by Cruise LLC:
3232
Copyright (c) 2022 Cruise LLC.
3333
All rights reserved.
3434

35+
All contributions by Tri Dao:
36+
Copyright (c) 2024 Tri Dao.
37+
All rights reserved.
38+
3539
All contributions by Arm:
3640
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
3741

aten/src/ATen/CMakeLists.txt

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,28 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
168168
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
169169
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
170170

171-
# flash_attention sources
171+
# flash_attention hip sources
172172
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
173-
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
173+
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
174+
if(USE_FLASH_ATTENTION)
175+
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
176+
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
177+
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
178+
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
179+
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
180+
if(NUM_ARCHS GREATER 1)
181+
message(WARNING "Building CK for multiple archs can increase build time considerably!
182+
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
183+
endif()
184+
endif()
185+
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
186+
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
187+
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
188+
endif()
189+
endif()
190+
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
191+
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
192+
endif()
174193

175194
#Mem_eff attention sources
176195
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
@@ -185,6 +204,7 @@ if(USE_FLASH_ATTENTION)
185204
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})
186205

187206
list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
207+
list(APPEND native_transformers_hip_hip ${flash_attention_hip_aot_hip})
188208
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
189209
endif()
190210

@@ -325,6 +345,9 @@ if(USE_ROCM)
325345
# Next two lines are needed because TunableOp uses third-party/fmt
326346
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
327347
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
348+
if(USE_FLASH_ATTENTION)
349+
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
350+
endif()
328351
list(APPEND ATen_HIP_SRCS
329352
${ATen_HIP_SRCS}
330353
${hip_hip}

aten/src/ATen/Context.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
343343
#endif
344344
}
345345

346+
at::ROCmFABackend Context::getROCmFAPreferredBackend() const {
347+
return rocm_fa_preferred_backend;
348+
}
349+
350+
void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
351+
352+
// TODO: add plumbing for hasCK for validity checking
353+
TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(),
354+
"Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm.");
355+
#ifdef USE_ROCM
356+
if(b == at::ROCmFABackend::Ck) {
357+
static const bool ck_unsupported = []() {
358+
static const std::vector<std::string> archs = {
359+
"gfx90a", "gfx942"
360+
};
361+
for (auto index: c10::irange(getNumGPUs())) {
362+
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
363+
TORCH_WARN_ONCE(
364+
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
365+
return true;
366+
}
367+
}
368+
return false;
369+
}();
370+
if(!ck_unsupported) rocm_fa_preferred_backend = b;
371+
}
372+
else {
373+
rocm_fa_preferred_backend = b;
374+
}
375+
#endif
376+
rocm_fa_preferred_backend = b;
377+
}
378+
379+
346380
bool Context::allowFP16ReductionCuBLAS() const {
347381
return allow_fp16_reduction_cublas;
348382
}

aten/src/ATen/Context.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/CPUGeneratorImpl.h>
55
#include <ATen/DeviceAccelerator.h>
66
#include <ATen/LinalgBackend.h>
7+
#include <ATen/ROCmFABackend.h>
78
#include <ATen/SDPBackend.h>
89
#include <ATen/core/ATenGeneral.h>
910
#include <ATen/core/DeprecatedTypeProperties.h>
@@ -239,6 +240,9 @@ class TORCH_API Context {
239240
at::BlasBackend blasPreferredBackend();
240241
void setBlasPreferredBackend(at::BlasBackend);
241242

243+
at::ROCmFABackend getROCmFAPreferredBackend() const;
244+
void setROCmFAPreferredBackend(at::ROCmFABackend);
245+
242246
// Note [Enabling Deterministic Operations]
243247
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
244248
// Operations in PyTorch that normally act nondeterministically, but have an
@@ -428,6 +432,10 @@ class TORCH_API Context {
428432
#endif
429433
? at::BlasBackend::Cublaslt
430434
: at::BlasBackend::Cublas;
435+
at::ROCmFABackend rocm_fa_preferred_backend =
436+
c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
437+
? at::ROCmFABackend::Ck
438+
: at::ROCmFABackend::Default;
431439
#ifdef C10_MOBILE
432440
bool release_original_weights = true;
433441
#else

aten/src/ATen/ROCmFABackend.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include <c10/util/Exception.h>
4+
5+
#include <ostream>
6+
#include <string>
7+
8+
namespace at {
9+
10+
enum class ROCmFABackend : int8_t { Default, AOTriton, Ck };
11+
12+
inline std::string ROCmFABackendToString(at::ROCmFABackend backend) {
13+
switch (backend) {
14+
case ROCmFABackend::Default:
15+
return "at::ROCmFABackend::Default";
16+
case ROCmFABackend::AOTriton:
17+
return "at::ROCmFABackend::AOTriton";
18+
case ROCmFABackend::Ck:
19+
return "at::ROCmFABackend::Ck";
20+
default:
21+
TORCH_CHECK(false, "Unknown ROCm flash attention backend")
22+
}
23+
}
24+
25+
inline std::ostream& operator<<(
26+
std::ostream& stream,
27+
at::ROCmFABackend backend) {
28+
return stream << ROCmFABackendToString(backend);
29+
}
30+
31+
} // namespace at

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#if USE_ROCM
2929
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
3030
#include <aotriton/flash.h>
31-
#define USE_AOTRITON 1
31+
#define USE_ROCM_ATTENTION 1
3232
#endif
3333
#endif
3434

@@ -219,15 +219,21 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
219219
using sm80 = SMVersion<8, 0>;
220220
using sm90 = SMVersion<9, 0>;
221221
#if USE_ROCM
222-
#if USE_AOTRITON
223-
auto stream = at::cuda::getCurrentCUDAStream().stream();
224-
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
225-
auto dprops = at::cuda::getCurrentDeviceProperties();
226-
if (debug) {
227-
TORCH_WARN(
228-
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
229-
}
230-
return false;
222+
#if USE_ROCM_ATTENTION
223+
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) {
224+
// User explicitly set CK as the flash attention backend. Return true for now
225+
// TODO: Flesh out sanity checks
226+
return true;
227+
} else {
228+
auto stream = at::cuda::getCurrentCUDAStream().stream();
229+
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
230+
auto dprops = at::cuda::getCurrentDeviceProperties();
231+
if (debug) {
232+
TORCH_WARN(
233+
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
234+
}
235+
return false;
236+
}
231237
}
232238
#else
233239
return false;
@@ -254,7 +260,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
254260
using sm50 = SMVersion<5, 0>;
255261
using sm90 = SMVersion<9, 0>;
256262
#if USE_ROCM
257-
#if USE_AOTRITON
263+
#if USE_ROCM_ATTENTION
258264
auto stream = at::cuda::getCurrentCUDAStream().stream();
259265
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
260266
auto dprops = at::cuda::getCurrentDeviceProperties();

aten/src/ATen/native/transformers/hip/aotriton_adapter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q)
124124
inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr)
125125
{
126126
return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(ptr),
127-
aotriton::DType::kUInt64); // AOTriton excepts unsigned int64
127+
aotriton::DType::kUInt64); // AOTriton accepts unsigned int64
128128
}
129129

130130
} // namespace aotriton_adapter

aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip renamed to aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -115,24 +115,18 @@ prepare_philox_arguments(float p_dropout, int64_t counter_offset) {
115115
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
116116

117117
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
118-
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
119-
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
120-
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
121-
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
122-
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
123-
const float p_dropout,
124-
const float softmax_scale,
125-
bool is_causal,
126-
int window_size_left,
127-
int window_size_right,
128-
const bool return_softmax,
129-
std::optional<at::Generator> gen_) {
130-
// Otherwise the kernel will be launched from cuda:0 device
131-
// Cast to char to avoid compiler warning about narrowing
132-
// [ROCM specific]: must be at the beginning of the function
133-
// Otherwise check_gpu_arch() checks cuda:0 device.
134-
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
135-
118+
mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
119+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
120+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
121+
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
122+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
123+
const float p_dropout,
124+
const float softmax_scale,
125+
bool is_causal,
126+
int window_size_left,
127+
int window_size_right,
128+
const bool return_softmax,
129+
std::optional<at::Generator> gen_) {
136130
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
137131
check_gpu_arch(stream);
138132

@@ -242,7 +236,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
242236
}
243237

244238
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
245-
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
239+
mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
246240
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
247241
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
248242
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -408,7 +402,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
408402
}
409403

410404
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
411-
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
405+
mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
412406
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
413407
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
414408
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
@@ -559,7 +553,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
559553
}
560554

561555
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
562-
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
556+
mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
563557
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
564558
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
565559
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -747,7 +741,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
747741

748742
return { dq, dk, dv, softmax_d };
749743
}
750-
751-
} // namespace pytorch_fmha
744+
} // namespace pytorch_flash
752745

753746
#endif
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3+
4+
#pragma once
5+
6+
#include <ostream>
7+
#include <string>
8+
#include <ck_tile/core.hpp>
9+
#include <ck_tile/ops/fmha.hpp>
10+
11+
// keep sync with BlockAttentionBiasEnum
12+
enum class bias_enum
13+
{
14+
no_bias = 0,
15+
elementwise_bias = 1,
16+
alibi = 2,
17+
};
18+
19+
struct bias_info
20+
{
21+
bias_enum type;
22+
/*
23+
* simple dispatch logic
24+
*
25+
* if type == elementwise_bias:
26+
* if rank_info == 0:
27+
* bias is 1*1*s*s
28+
* elif rank_info == 1:
29+
* bias is 1*h*s*s
30+
* elif rank_info == 2:
31+
* bias is b*h*s*s
32+
*
33+
* elif type == alibi:
34+
* if rank_info == 0:
35+
* alibi in 1*h
36+
* elif rank_info == 1:
37+
* alibi in b*h
38+
*/
39+
int rank_info;
40+
41+
void serialize(std::ostream& os) const
42+
{
43+
if(type == bias_enum::no_bias)
44+
os << "n";
45+
else if(type == bias_enum::elementwise_bias)
46+
{
47+
os << "e";
48+
if(rank_info != 0)
49+
{
50+
os << "[" << rank_info << "]";
51+
}
52+
}
53+
else if(type == bias_enum::alibi)
54+
{
55+
os << "alibi";
56+
if(rank_info != 0)
57+
{
58+
os << "[" << rank_info << "]";
59+
}
60+
}
61+
}
62+
63+
static bias_info decode(std::string str)
64+
{
65+
bias_info info{bias_enum::no_bias, 0};
66+
if(str == "0" || str == "n")
67+
{
68+
info.type = bias_enum::no_bias;
69+
}
70+
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
71+
str.compare(0, 11, "elementwise") == 0)
72+
{
73+
info.type = bias_enum::elementwise_bias;
74+
auto found_0 = str.find(':');
75+
if(found_0 != std::string::npos)
76+
{
77+
std::string e = str.substr(found_0 + 1);
78+
info.rank_info = atoi(e.c_str());
79+
}
80+
}
81+
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
82+
str.compare(0, 5, "alibi") == 0)
83+
{
84+
info.type = bias_enum::alibi;
85+
auto found_0 = str.find(':');
86+
if(found_0 != std::string::npos)
87+
{
88+
std::string e = str.substr(found_0 + 1);
89+
info.rank_info = atoi(e.c_str());
90+
}
91+
}
92+
return info;
93+
}
94+
95+
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
96+
{
97+
bi.serialize(os);
98+
return os;
99+
}
100+
};

0 commit comments

Comments
 (0)