Skip to content

Commit cf34004

Browse files
bertmaherxinyazhangantiagainstzhanglx13
authored
AMD requested cherry-picks for release/3.1.x (#4794)
Per @jataylo : * libamdhip64.so discovery: * triton-lang/triton@13edc45 * triton-lang/triton@e0613c6 * wmma: triton-lang/triton@4a1ea8e (already cherry-picked) * mfma: triton-lang/triton@0a66c1b --------- Co-authored-by: Xinya Zhang <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Lixun Zhang <[email protected]>
1 parent 566b63c commit cf34004

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,18 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
8787
if (!encoding)
8888
// encoding not available
8989
return resultVals;
90-
if (!dyn_cast<BlockedEncodingAttr>(encoding) &&
91-
!dyn_cast<SliceEncodingAttr>(encoding)) {
92-
// TODO: constraining the ecndoing type here is necessary for avoiding
93-
// crashes in the getElemsPerThread call below happening in the
90+
Attribute baseEncoding = encoding;
91+
if (isa<AMDMfmaEncodingAttr>(baseEncoding))
92+
// TODO: this logic seems incorrect for mfma layout. Skip for now.
93+
// We saw mismatches for some flash-attention tests on AMD backend.
94+
// Note that this logic works for sliced layout whose parent is
95+
// mfma layout. Therefore, this is not combined with the following check.
96+
return resultVals;
97+
while (auto sliced = dyn_cast<SliceEncodingAttr>(baseEncoding))
98+
baseEncoding = sliced.getParent();
99+
if (isa<NvidiaMmaEncodingAttr, DotOperandEncodingAttr>(baseEncoding)) {
100+
// TODO: this logic seems incorrect for mma layout. Skip for now.
101+
// The following test crashes and some other miscompile:
94102
// test_core::test_fp8_dot_acc
95103
return resultVals;
96104
}

third_party/amd/backend/driver.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,54 @@
1313
include_dir = [os.path.join(dirname, "include")]
1414

1515

16+
def _find_already_mmapped_dylib_on_linux(lib_name):
17+
import platform
18+
if platform.system() != 'Linux':
19+
return None
20+
21+
# Use dl_iterate_phdr to walk through the list of shared libraries at runtime.
22+
# See https://www.man7.org/linux/man-pages/man3/dl_iterate_phdr.3.html for details.
23+
24+
import ctypes
25+
from ctypes import c_char, c_int, c_size_t, c_void_p, c_char_p, POINTER
26+
27+
class DlPhdrInfo(ctypes.Structure):
28+
_fields_ = [
29+
('dlpi_addr', c_void_p),
30+
('dlpi_name', c_char_p),
31+
# We don't care about the remaining fields.
32+
]
33+
34+
# callback_t must use POINTER(c_char) to avoid copying.
35+
callback_t = ctypes.CFUNCTYPE(c_int, POINTER(DlPhdrInfo), POINTER(c_size_t), POINTER(c_char))
36+
37+
# Load libc and get the dl_iterate_phdr symbol.
38+
try:
39+
dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr
40+
except:
41+
return None
42+
# argtypes must use c_char_p to accept create_string_buffer.
43+
dl_iterate_phdr.argtypes = [callback_t, c_char_p]
44+
dl_iterate_phdr.restype = c_int
45+
46+
max_path_length = 4096
47+
path = ctypes.create_string_buffer(max_path_length + 1)
48+
49+
# Define callback to get the loaded dylib path.
50+
def callback(info, size, data):
51+
dlpi_name = info.contents.dlpi_name
52+
p = Path(os.fsdecode(dlpi_name))
53+
if lib_name in p.name:
54+
# Found the dylib; get its path.
55+
ctypes.memmove(data, dlpi_name, min(max_path_length, len(dlpi_name)))
56+
return 1
57+
return 0
58+
59+
if dl_iterate_phdr(callback_t(callback), path):
60+
return os.fsdecode(ctypes.string_at(path))
61+
return None
62+
63+
1664
@functools.lru_cache()
1765
def _get_path_to_hip_runtime_dylib():
1866
lib_name = "libamdhip64.so"
@@ -24,13 +72,24 @@ def _get_path_to_hip_runtime_dylib():
2472
return env_libhip_path
2573
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
2674

75+
# If the shared object is already mmapped to address space, use it.
76+
mmapped_path = _find_already_mmapped_dylib_on_linux(lib_name)
77+
if mmapped_path:
78+
if os.path.exists(mmapped_path):
79+
return mmapped_path
80+
raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
81+
2782
paths = []
2883

2984
import site
3085
# First search the HIP runtime dynamic library packaged with PyTorch. It's very likely
3186
# that we run Triton together with PyTorch. This makes sure we use the same dynamic
3287
# library to avoid version mismatch.
33-
for path in site.getsitepackages():
88+
site_packages = site.getsitepackages()
89+
user_site = site.getusersitepackages()
90+
if site.ENABLE_USER_SITE: # ENABLE_USER_SITE is initialized in getusersitepackages()
91+
site_packages = [user_site] + site_packages
92+
for path in site_packages:
3493
path = os.path.join(path, "torch", "lib", lib_name)
3594
if os.path.exists(path):
3695
return path

0 commit comments

Comments
 (0)