Skip to content

Commit 0d40280

Browse files
nputikhinGoogle-ML-Automation
authored andcommitted
Handle PyCUtensorMapObject in extractTmaDesc in the launcher
Reenables failing tests PiperOrigin-RevId: 825528658
1 parent 1d5f010 commit 0d40280

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc
2+
--- a/third_party/nvidia/backend/cuda_utils.cc
3+
+++ b/third_party/nvidia/backend/cuda_utils.cc
4+
@@ -270,51 +270,16 @@ bool extractPointer(PyObject* obj, void*
5+
return true;
6+
}
7+
8+
+CUtensorMap* getTmaDesc(PyObject* obj);
9+
+
10+
// Extract a CUtensorMap descriptor from a python object, and store it to the
11+
// memory location pointed by ptr.
12+
bool extractTmaDesc(PyObject* obj, void* ptr) {
13+
- if (sizeof(CUtensorMap*) != 8) {
14+
- PyErr_SetString(PyExc_SystemError,
15+
- "extractTmaDesc() requires 64-bit compilation");
16+
- return false;
17+
- }
18+
-
19+
- UniquePyObjectPtr method_ret(
20+
- PyObject_CallMethod(obj, "tma_desc_cpu_ptr", nullptr));
21+
- // Checking the error retains context if tma_desc_cpu_ptr raises an exception.
22+
- if (PyErr_Occurred()) {
23+
- return false;
24+
- }
25+
-
26+
- if (!method_ret) {
27+
- PyErr_SetString(PyExc_SystemError, "Call to tma_desc_cpu_ptr() failed");
28+
+ CUtensorMap* tensor_map = getTmaDesc(obj);
29+
+ if (tensor_map == nullptr) {
30+
return false;
31+
}
32+
-
33+
- if (!PyLong_Check(method_ret.get())) {
34+
- PyErr_SetString(PyExc_TypeError,
35+
- "tma_desc_cpu_ptr() must return 64-bit int");
36+
- return false;
37+
- }
38+
-
39+
- uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret.get());
40+
- if (PyErr_Occurred()) {
41+
- return false;
42+
- }
43+
-
44+
- if (!ptr_as_uint) {
45+
- PyErr_SetString(PyExc_ValueError,
46+
- "received NULL ptr from tma_desc_cpu_ptr()");
47+
- return false;
48+
- }
49+
- if (ptr_as_uint % 64 != 0) {
50+
- PyErr_SetString(PyExc_ValueError,
51+
- "tma_desc_cpu_ptr() must be 64-byte aligned");
52+
- return false;
53+
- }
54+
-
55+
- *static_cast<CUtensorMap*>(ptr) =
56+
- *reinterpret_cast<CUtensorMap*>(ptr_as_uint);
57+
+ *static_cast<CUtensorMap*>(ptr) = *tensor_map;
58+
return true;
59+
}
60+
61+
@@ -392,6 +357,7 @@ struct ExtractionInfo {
62+
// Prefixes of types reprs supported by the extractor.
63+
llvm::SmallVector<llvm::StringRef> supported_type_repr_prefixes;
64+
std::size_t size; // Size required by the extracted value.
65+
+ std::size_t alignment; // Alignment requirement for the extracted value.
66+
ExtractorType extractor; // Function to call to extract the value.
67+
68+
// Builds an ExtractionInfo for a given type T and a list of type reprs that
69+
@@ -400,7 +366,7 @@ struct ExtractionInfo {
70+
static ExtractionInfo build(
71+
std::initializer_list<llvm::StringRef> supported_type_reprs,
72+
ExtractorType extractor = extractValue<T>) {
73+
- return {supported_type_reprs, sizeof(T), extractor};
74+
+ return {supported_type_reprs, sizeof(T), alignof(T), extractor};
75+
}
76+
77+
// Checks if the extractor supports extracting a given type repr.
78+
@@ -428,7 +394,7 @@ const ExtractionInfo kExtractionInfos[]{
79+
// Note: types are e.g. '*fp32', so no closing quote is intentional.
80+
ExtractionInfo::build<void*>({"'*"}, extractPointer),
81+
ExtractionInfo{
82+
- {"None", "'none'"}, 0, nullptr}, // Represent constexprs as None
83+
+ {"None", "'none'"}, 0, 0, nullptr}, // Represent constexprs as None
84+
ExtractionInfo::build<CUtensorMap>({"'nvTmaDesc'"}, extractTmaDesc),
85+
};
86+
87+
@@ -628,7 +594,19 @@ PyObject* launch(PyObject* self, PyObjec
88+
if (extraction_info.size == 0) {
89+
continue; // skip adding constexpr parameters
90+
}
91+
- config.params[params_idx] = alloca(extraction_info.size);
92+
+ size_t alignment = std::max(1UL, extraction_info.alignment);
93+
+
94+
+ // Allocate enough space on the stack to guarantee an aligned block.
95+
+ size_t size_with_alignment = extraction_info.size + alignment - 1;
96+
+ void *param_storage_ptr = alloca(size_with_alignment);
97+
+
98+
+ void *aligned_ptr = std::align(alignment, extraction_info.size,
99+
+ param_storage_ptr, size_with_alignment);
100+
+ if (aligned_ptr == nullptr) {
101+
+ PyErr_SetString(PyExc_MemoryError, "Failed to align parameter storage");
102+
+ return nullptr;
103+
+ }
104+
+ config.params[params_idx] = aligned_ptr;
105+
if (!extraction_info.extractor(arg, config.params[params_idx])) {
106+
return nullptr;
107+
}
108+
@@ -940,6 +918,36 @@ static PyTypeObject PyCUtensorMapType =
109+
};
110+
// clang-format on
111+
112+
+namespace {
113+
+
114+
+// Extracts a pointer to `CUtensorMap` from a `PyCUtensorMapObject`.
115+
+CUtensorMap* getTmaDesc(PyObject* obj) {
116+
+ if (sizeof(CUtensorMap*) != 8) {
117+
+ PyErr_SetString(PyExc_SystemError,
118+
+ "getTmaDesc() requires 64-bit compilation");
119+
+ return nullptr;
120+
+ }
121+
+ if (Py_TYPE(obj) != static_cast<PyTypeObject*>(&PyCUtensorMapType)) {
122+
+ PyErr_Format(PyExc_TypeError,
123+
+ "object must be of type PyCUtensorMap, got %s",
124+
+ Py_TYPE(obj)->tp_name);
125+
+ return nullptr;
126+
+ }
127+
+ CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap;
128+
+ // PyCUtensorMapObject aligns tensorMap to 128.
129+
+ uintptr_t align_128 = (uintptr_t)map & (128 - 1);
130+
+ if (align_128 != 0) {
131+
+ PyErr_Format(
132+
+ PyExc_ValueError,
133+
+ "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld",
134+
+ align_128);
135+
+ return nullptr;
136+
+ }
137+
+ return map;
138+
+}
139+
+
140+
+} // namespace
141+
+
142+
static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
143+
unsigned long long global_address;
144+
int swizzle;

third_party/triton/temporary/series.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ those to this list.
1515

1616
temporary_patch_list = [
1717
"//third_party/triton:temporary/utility-fix.patch",
18+
"//third_party/triton:temporary/launcher_tma_desc_fix.patch",
1819
# Add new patches just above this line
1920
]

0 commit comments

Comments
 (0)