Skip to content

Commit fb59577

Browse files
Merge commit 'af0a9f2be4f7c0944c36873960fa2d0c9d3d9f80'
2 parents 72c6938 + af0a9f2 commit fb59577

File tree

9 files changed

+120
-56
lines changed

9 files changed

+120
-56
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
108108
let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)";
109109

110110
let hasVerifier = 1;
111+
112+
let hasFolder = 1;
111113
}
112114

113115
//

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,29 @@ LogicalResult ReshapeOp::verify() {
728728
}
729729

730730
//-- FpToFpOp --
731+
732+
// Fold FpToFpOp when the input operand is a constant zero.
733+
OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
734+
auto srcVal = getSrc();
735+
auto dstTy = getType();
736+
737+
const llvm::fltSemantics &semantic =
738+
llvm::cast<FloatType>(dstTy.getElementType()).getFloatSemantics();
739+
740+
if (matchPattern(srcVal, m_PosZeroFloat())) {
741+
llvm::APFloat posZero =
742+
llvm::APFloat::getZero(semantic, /*negative=*/false);
743+
return DenseFPElementsAttr::get(dstTy, posZero);
744+
}
745+
746+
if (matchPattern(srcVal, m_NegZeroFloat())) {
747+
llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true);
748+
return DenseFPElementsAttr::get(dstTy, negZero);
749+
}
750+
751+
return {};
752+
}
753+
731754
LogicalResult FpToFpOp::verify() {
732755
auto dstType = getType().getElementType();
733756
auto srcType = getSrc().getType().getElementType();

python/test/unit/language/test_compile_errors.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,7 @@
77
import triton.language as tl
88
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
99
import traceback
10-
11-
12-
def is_interpreter():
13-
return os.environ.get('TRITON_INTERPRET', '0') == '1'
14-
15-
16-
def is_cuda():
17-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda"
18-
19-
20-
def is_hip():
21-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip"
22-
23-
24-
def is_xpu():
25-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "xpu"
26-
27-
28-
def is_on_mi300():
29-
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')
10+
from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300, is_xpu
3011

3112

3213
def test_err_undefined_variable():
@@ -371,7 +352,7 @@ def test_fp8_support(dtype):
371352
if cc >= (8, 9):
372353
supported_dtypes.append(tl.float8e4nv)
373354
elif is_hip():
374-
if is_on_mi300():
355+
if is_hip_mi300():
375356
supported_dtypes += [tl.float8e4b8, tl.float8e5b16]
376357
elif is_xpu():
377358
supported_dtypes += [tl.float8e4b15, tl.float8e4nv]

python/test/unit/language/test_conversions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
11
# fmt: off
22

33

4-
import os
54
import numpy as np
65
import torch
76
import pytest
87
import triton
98
import triton.language as tl
109

11-
def is_interpreter():
12-
return os.environ.get('TRITON_INTERPRET', '0') == '1'
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300
1311

14-
def is_cuda():
15-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cuda"
16-
17-
def is_hip():
18-
return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "hip"
19-
20-
def is_on_mi300():
21-
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942')
2212

2313
def matching_int(dtype):
2414
if dtype.primitive_bitwidth == 8:
@@ -314,7 +304,7 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
314304
def test_typeconvert_upcast(src_dtype, dst_dtype, device):
315305
if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9))
316306
or (src_dtype in ('float8e4nv', 'float8e4b15') and is_hip())
317-
or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_on_mi300()))):
307+
or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_hip_mi300()))):
318308
# If the dtype should error out in the given device, we assert that and return
319309
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
320310
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
@@ -365,7 +355,7 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
365355
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or torch.cuda.is_available() and torch.cuda.get_device_capability(0) < (9, 0)):
366356
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
367357

368-
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_on_mi300()):
358+
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or not is_hip_mi300()):
369359
pytest.xfail(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300")
370360

371361
# dtype : (exponent_bits, mantissa_bits, exponent_bias)

python/test/unit/language/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
is_cuda,
3030
is_interpreter,
3131
is_hip,
32+
is_hip_cdna,
3233
is_hip_mi200,
3334
is_xpu,
3435
get_arch,
@@ -3381,13 +3382,12 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
33813382
if cc < (8, 9):
33823383
pytest.skip("float8e4nv not supported on CUDA < 8.9")
33833384
if is_hip():
3385+
if not is_hip_cdna():
3386+
pytest.skip("scaled_dot only implemented for HIP CDNA")
33843387
if (type_a not in ["e2m1", "e5m2"]) or (type_b not in ["e2m1", "e5m2", "bf16"]):
33853388
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
33863389
if mma == 16 and K == 64:
33873390
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
3388-
arch = triton.runtime.driver.active.get_current_target().arch
3389-
if "gfx11" in arch or "gfx12" in arch:
3390-
pytest.skip("scaled_dot not yet implemented for gfx11 and gfx12")
33913391
if is_xpu():
33923392
pytest.skip("scaled_dot isn't supported on XPU")
33933393

python/test/unit/language/test_pipeliner.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,7 @@
66
import triton.language as tl
77
import triton.tools.experimental_descriptor
88

9-
10-
def is_cuda():
11-
return triton.runtime.driver.active.get_current_target().backend == "cuda"
12-
13-
14-
def is_hopper():
15-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
16-
17-
18-
def is_hip():
19-
return triton.runtime.driver.active.get_current_target().backend == "hip"
20-
21-
22-
def is_hip_mi200():
23-
target = triton.runtime.driver.active.get_current_target()
24-
return target.backend == 'hip' and target.arch == 'gfx90a'
9+
from triton._internal_testing import is_cuda, is_hopper, is_hip_cdna, is_hip_mi200
2510

2611

2712
def check_capabilities():
@@ -229,8 +214,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
229214
@pytest.mark.parametrize("scale", [True, False])
230215
def test_pipeline_matmul(scale, device):
231216
check_capabilities()
232-
if scale and not is_cuda():
233-
pytest.skip("NYI: scale_dot just implemented in CUDA")
217+
if scale and not (is_cuda() or is_hip_cdna()):
218+
pytest.skip("NYI: scale_dot just implemented in CUDA/HIP")
234219
M, N, K = 512, 512, 128
235220
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32
236221
NUM_STAGES = 4

python/triton/_internal_testing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def is_cuda():
3636
return False if target is None else target.backend == "cuda"
3737

3838

39+
def is_hopper():
40+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
41+
42+
3943
def is_hip():
4044
target = get_current_target()
4145
return False if target is None else target.backend == "hip"
@@ -46,6 +50,15 @@ def is_hip_mi200():
4650
return target.backend == 'hip' and target.arch == 'gfx90a'
4751

4852

53+
def is_hip_mi300():
54+
target = get_current_target()
55+
return target.backend == 'hip' and target.arch in ('gfx940', 'gfx941', 'gfx942')
56+
57+
58+
def is_hip_cdna():
59+
return is_hip_mi200() or is_hip_mi300()
60+
61+
4962
def is_xpu():
5063
target = get_current_target()
5164
return False if target is None else target.backend == "xpu"

test/Triton/canonicalize.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,74 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){
5050
tt.return %b : tensor<32x1xf32, #blocked0>
5151
}
5252
} // end module
53+
54+
// -----
55+
56+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
57+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
58+
tt.func @fp_to_fp_pos_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> {
59+
// CHECK-LABEL: fp_to_fp_pos_zero_fold
60+
// CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked>
61+
// CHECK-NEXT: tt.return %[[cst_folded]]
62+
%cst = arith.constant dense<0.00e+00> : tensor<32x128xf32, #blocked>
63+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
64+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
65+
}
66+
} // end module
67+
68+
// -----
69+
70+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
71+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
72+
tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FN, #blocked> {
73+
// CHECK-LABEL: fp_to_fp_neg_zero_fold
74+
// CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<-0.000000e+00> : tensor<32x128xf8E4M3FN, #blocked>
75+
// CHECK-NEXT: tt.return %[[cst_folded]]
76+
%cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked>
77+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FN, #blocked>
78+
tt.return %cst_converted : tensor<32x128xf8E4M3FN, #blocked>
79+
}
80+
} // end module
81+
82+
// -----
83+
84+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
85+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
86+
tt.func @fp_to_fp_neg_zero_fold() -> tensor<32x128xf8E4M3FNUZ, #blocked> {
87+
// CHECK-LABEL: fp_to_fp_neg_zero_fold
88+
// We fold to the positive zero here given by definition f8E4M3FNUZ does not have negative zero encoding.
89+
// CHECK-NEXT: %[[cst_folded:.+]] = arith.constant dense<0.000000e+00> : tensor<32x128xf8E4M3FNUZ, #blocked>
90+
// CHECK-NEXT: tt.return %[[cst_folded]]
91+
%cst = arith.constant dense<-0.00e+00> : tensor<32x128xf32, #blocked>
92+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
93+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
94+
}
95+
} // end module
96+
97+
// -----
98+
99+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
100+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
101+
tt.func @fold_fp_to_fp_non_zero_nofold() -> tensor<32x128xf8E4M3FNUZ, #blocked> {
102+
// CHECK-LABEL: fold_fp_to_fp_non_zero_nofold
103+
// CHECK-NEXT: %[[cst:.+]] = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked>
104+
// CHECK-NEXT: %[[cst_cvt:.+]] = tt.fp_to_fp %[[cst]]
105+
// CHECK-NEXT: tt.return %[[cst_cvt]]
106+
%cst = arith.constant dense<0xFF800000> : tensor<32x128xf32, #blocked>
107+
%cst_converted = tt.fp_to_fp %cst, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
108+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
109+
}
110+
} // end module
111+
112+
// -----
113+
114+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
115+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
116+
tt.func @fold_fp_to_fp_non_constant_nofold(%arg0: tensor<32x128xf32, #blocked>) -> tensor<32x128xf8E4M3FNUZ, #blocked> {
117+
// CHECK-LABEL: fold_fp_to_fp_non_constant_nofold
118+
// CHECK-NEXT: %[[arg_cvt:.+]] = tt.fp_to_fp %arg0
119+
// CHECK-NEXT: tt.return %[[arg_cvt]]
120+
%cst_converted = tt.fp_to_fp %arg0, rounding = rtne : tensor<32x128xf32, #blocked> -> tensor<32x128xf8E4M3FNUZ, #blocked>
121+
tt.return %cst_converted : tensor<32x128xf8E4M3FNUZ, #blocked>
122+
}
123+
} // end module

third_party/nvidia/backend/driver.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include <stdbool.h>
44
#define PY_SSIZE_T_CLEAN
55
#include <Python.h>
6-
#include <stdatomic.h>
76

87
// Raises a Python exception and returns false if code is not CUDA_SUCCESS.
98
static bool gpuAssert(CUresult code, const char *file, int line) {

0 commit comments

Comments
 (0)