Skip to content

Commit e378930

Browse files
Add torch.backends.cuda.math_sdp.fp32_precision (#2848)
**Overview** This PR adds a new float32 precision API torch.backends.cuda.math_sdp.fp32_precision to configure fp32 precision behavior of SDPBackend.MATH **Rationale** The test/test_transformers.py testing suite calculates the numerical tolerance by comparing output tensors from the same precision ("reference") and higher precision ("golden"), both calculated by SDPBackend.MATH. However, the golden output is calculated with TF32 rather than FP32, which in fact is less accurate than the FA/ME backend if they used IEEE rather than TF32 for their accumulation. The loss of precison causes false negatives in SDPA tests like TestSDPACudaOnlyCUDA.test_flash_attention_vs_math_ref_grads_batch_size_8_seq_len_q_143_seq_len_k_4_head_dim_203_is_causal_False_dropout_p_0_22_float16_scale_l1_enable_gqa_True_n_heads1_cuda_float16 , at least on ROCM platform. The false negative disappears after forcing higher_precision_dtype = torch.float64 **Major Changes** To restore the precision of golden output, a new API torch.backends.cuda.math_sdp.fp32_precision is introduced, which allows configuration of "matmul" precision during SDPBackend.MATH, and a new decorator @math_sdp_precision("ieee") is added to all tests that use check_out_and_grad. At last, an assert is added to the inner most function _check_equal as a sanity check to ensure math_sdp has the right precison configured for torch.float32 golden tensors. **Known Issues** The backward phase honors the configuration when calling backward(), regardless the configuration when creating the graph. --------- Co-authored-by: Xinya Zhang <[email protected]>
1 parent 060be4b commit e378930

File tree

7 files changed

+75
-1
lines changed

7 files changed

+75
-1
lines changed

aten/src/ATen/Context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ Float32Op str2op(const std::string& name) {
5858
return Float32Op::RNN;
5959
else if (name == "matmul")
6060
return Float32Op::MATMUL;
61+
else if (name == "math_sdp")
62+
return Float32Op::MATH_SDP;
6163
TORCH_CHECK(false, "Unknown op: ", name);
6264
}
6365

aten/src/ATen/Context.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ enum class CuBLASReductionOption : uint8_t {
4646
DisallowReducedPrecisionDisallowSplitK = 2,
4747
};
4848
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
49-
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
49+
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL, MATH_SDP };
5050
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
5151

5252
TORCH_API Float32Backend str2backend(const std::string& name);
@@ -522,6 +522,7 @@ class TORCH_API Context {
522522
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
523523
? Float32Precision::NONE
524524
: Float32Precision::TF32},
525+
{{Float32Backend::CUDA, Float32Op::MATH_SDP}, Float32Precision::NONE},
525526
};
526527

527528
Allocator* prev_allocator_ptr_{nullptr};
@@ -694,6 +695,36 @@ struct TORCH_API NoTF32Guard {
694695
bool changed = false;
695696
};
696697

698+
template <Float32Backend target_backend, Float32Op target_op>
699+
struct Fp32PrecisonGuard {
700+
Fp32PrecisonGuard(const Float32Precision new_precision) {
701+
if (new_precision == Float32Precision::NONE) {
702+
return;
703+
}
704+
saved_precision =
705+
globalContext().float32Precision(target_backend, target_op);
706+
changed = (new_precision != saved_precision);
707+
if (changed) {
708+
globalContext().setFloat32Precision(
709+
target_backend, target_op, new_precision);
710+
}
711+
}
712+
Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete;
713+
Fp32PrecisonGuard(const Fp32PrecisonGuard&) = delete;
714+
Fp32PrecisonGuard& operator=(const Fp32PrecisonGuard&) = delete;
715+
Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete;
716+
~Fp32PrecisonGuard() {
717+
if (changed) {
718+
globalContext().setFloat32Precision(
719+
target_backend, target_op, saved_precision);
720+
}
721+
}
722+
723+
private:
724+
Float32Precision saved_precision;
725+
bool changed = false;
726+
};
727+
697728
struct TORCH_API ROCmBackwardPassGuard {
698729
ROCmBackwardPassGuard();
699730
ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
868868
? value.to(at::kFloat)
869869
: value;
870870
auto attn_mask = attn_mask_;
871+
const auto math_sdp_precision = at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATH_SDP);
872+
// Temporarily override matmul precision with value from cuda.math_sdp
873+
// IEEE should be used when use fp32+math backend as golden reference.
874+
at::Fp32PrecisonGuard<at::Float32Backend::CUDA, at::Float32Op::MATMUL> fp32guard(math_sdp_precision);
875+
871876
// Naive, composite implementation defined here.
872877

873878
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,6 +2054,7 @@
20542054
"PropModule",
20552055
# torch.backends.cuda
20562056
"cuBLASModule",
2057+
"MathSDPModule",
20572058
"cuFFTPlanCache",
20582059
"cuFFTPlanCacheAttrContextProp",
20592060
"cuFFTPlanCacheManager",

test/test_transformers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
5555
tf32_on_and_off,
5656
tf32_enabled,
57+
math_sdp_precision,
5758
)
5859

5960
if TEST_FAIRSEQ:
@@ -128,6 +129,12 @@ def _check_equal(
128129
_check_equal(gold, ref, tst, fudge_factor, tensor_name)
129130
return
130131

132+
if golden.is_cuda and golden.dtype == torch.float32:
133+
assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", (
134+
"Testing script error: FP32 golden tensor must be calculated with IEEE"
135+
" precision. Add @math_sdp_precision('ieee') to related tests to fix it."
136+
)
137+
131138
# Compute error between golden
132139
test_error = (golden - test).abs().max()
133140
ref_error = (golden - reference).abs().max()
@@ -3413,6 +3420,7 @@ def test_mem_eff_backwards_determinism(self, device):
34133420
)
34143421
@parametrize("scale", [None, "l1"])
34153422
@tf32_enabled()
3423+
@math_sdp_precision("ieee")
34163424
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
34173425
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
34183426
scale: str):
@@ -3528,6 +3536,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
35283536
)
35293537
@parametrize("scale", [None, "l1"])
35303538
@tf32_enabled()
3539+
@math_sdp_precision("ieee")
35313540
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
35323541
seq_len_k: int, head_dim: int, is_causal: bool,
35333542
dropout_p: float, dtype: torch.dtype,
@@ -3641,6 +3650,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
36413650
@parametrize("enable_gqa", [True, False])
36423651
@parametrize("n_heads", [[16, 8], [10, 2]])
36433652
@tf32_enabled()
3653+
@math_sdp_precision("ieee")
36443654
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
36453655
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
36463656
scale: str, enable_gqa: bool, n_heads: list[int]):
@@ -3786,6 +3796,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
37863796
@parametrize("scale", [None, "l1"])
37873797
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
37883798
@tf32_enabled()
3799+
@math_sdp_precision("ieee")
37893800
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
37903801
seq_len_q: int, seq_len_k: int,
37913802
head_dim: int,
@@ -4100,6 +4111,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device):
41004111
@parametrize("dtype", [torch.float16])
41014112
@parametrize("scale", [None, "l1"])
41024113
@parametrize("is_causal", [True, False])
4114+
@math_sdp_precision("ieee")
41034115
def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
41044116
head_dim: int, dropout_p: float, dtype: torch.dtype,
41054117
scale: str, is_causal: bool):

torch/backends/cuda/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"cuFFTPlanCache",
1313
"cuFFTPlanCacheManager",
1414
"cuBLASModule",
15+
"MathSDPModule",
1516
"preferred_linalg_library",
1617
"preferred_blas_library",
1718
"preferred_rocm_fa_library",
@@ -206,6 +207,18 @@ def __setattr__(self, name, value):
206207
raise AttributeError("Unknown attribute " + name)
207208

208209

210+
class MathSDPModule:
211+
def __getattr__(self, name):
212+
if name == "fp32_precision":
213+
return torch._C._get_fp32_precision_getter("cuda", "math_sdp")
214+
raise AttributeError("Unknown attribute " + name)
215+
216+
def __setattr__(self, name, value):
217+
if name == "fp32_precision":
218+
return torch._C._set_fp32_precision_setter("cuda", "math_sdp", value)
219+
raise AttributeError("Unknown attribute " + name)
220+
221+
209222
_LinalgBackends = {
210223
"default": torch._C._LinalgBackend.Default,
211224
"cusolver": torch._C._LinalgBackend.Cusolver,
@@ -591,3 +604,4 @@ def sdp_kernel(
591604

592605
cufft_plan_cache = cuFFTPlanCacheManager()
593606
matmul = cuBLASModule()
607+
math_sdp = MathSDPModule()

torch/testing/_internal/common_cuda.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,15 @@ def tf32_enabled():
229229
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
230230

231231

232+
@contextlib.contextmanager
233+
def math_sdp_precision(target_precision: str):
234+
saved_precision = torch.backends.cuda.math_sdp.fp32_precision
235+
try:
236+
torch.backends.cuda.math_sdp.fp32_precision = target_precision
237+
yield
238+
finally:
239+
torch.backends.cuda.math_sdp.fp32_precision = saved_precision
240+
232241
# This is a wrapper that wraps a test to run this test twice, one with
233242
# allow_tf32=True, another with allow_tf32=False. When running with
234243
# allow_tf32=True, it will use reduced precision as specified by the

0 commit comments

Comments
 (0)