|
| 1 | +diff --git a/third_party/nvidia/backend/cuda_utils.cc b/third_party/nvidia/backend/cuda_utils.cc |
| 2 | +--- a/third_party/nvidia/backend/cuda_utils.cc |
| 3 | ++++ b/third_party/nvidia/backend/cuda_utils.cc |
| 4 | +@@ -270,51 +270,16 @@ bool extractPointer(PyObject* obj, void* |
| 5 | + return true; |
| 6 | + } |
| 7 | + |
| 8 | ++CUtensorMap* getTmaDesc(PyObject* obj); |
| 9 | ++ |
| 10 | + // Extract a CUtensorMap descriptor from a python object, and store it to the |
| 11 | + // memory location pointed by ptr. |
| 12 | + bool extractTmaDesc(PyObject* obj, void* ptr) { |
| 13 | +- if (sizeof(CUtensorMap*) != 8) { |
| 14 | +- PyErr_SetString(PyExc_SystemError, |
| 15 | +- "extractTmaDesc() requires 64-bit compilation"); |
| 16 | +- return false; |
| 17 | +- } |
| 18 | +- |
| 19 | +- UniquePyObjectPtr method_ret( |
| 20 | +- PyObject_CallMethod(obj, "tma_desc_cpu_ptr", nullptr)); |
| 21 | +- // Checking the error retains context if tma_desc_cpu_ptr raises an exception. |
| 22 | +- if (PyErr_Occurred()) { |
| 23 | +- return false; |
| 24 | +- } |
| 25 | +- |
| 26 | +- if (!method_ret) { |
| 27 | +- PyErr_SetString(PyExc_SystemError, "Call to tma_desc_cpu_ptr() failed"); |
| 28 | ++ CUtensorMap* tensor_map = getTmaDesc(obj); |
| 29 | ++ if (tensor_map == nullptr) { |
| 30 | + return false; |
| 31 | + } |
| 32 | +- |
| 33 | +- if (!PyLong_Check(method_ret.get())) { |
| 34 | +- PyErr_SetString(PyExc_TypeError, |
| 35 | +- "tma_desc_cpu_ptr() must return 64-bit int"); |
| 36 | +- return false; |
| 37 | +- } |
| 38 | +- |
| 39 | +- uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret.get()); |
| 40 | +- if (PyErr_Occurred()) { |
| 41 | +- return false; |
| 42 | +- } |
| 43 | +- |
| 44 | +- if (!ptr_as_uint) { |
| 45 | +- PyErr_SetString(PyExc_ValueError, |
| 46 | +- "received NULL ptr from tma_desc_cpu_ptr()"); |
| 47 | +- return false; |
| 48 | +- } |
| 49 | +- if (ptr_as_uint % 64 != 0) { |
| 50 | +- PyErr_SetString(PyExc_ValueError, |
| 51 | +- "tma_desc_cpu_ptr() must be 64-byte aligned"); |
| 52 | +- return false; |
| 53 | +- } |
| 54 | +- |
| 55 | +- *static_cast<CUtensorMap*>(ptr) = |
| 56 | +- *reinterpret_cast<CUtensorMap*>(ptr_as_uint); |
| 57 | ++ *static_cast<CUtensorMap*>(ptr) = *tensor_map; |
| 58 | + return true; |
| 59 | + } |
| 60 | + |
| 61 | +@@ -392,6 +357,7 @@ struct ExtractionInfo { |
| 62 | + // Prefixes of types reprs supported by the extractor. |
| 63 | + llvm::SmallVector<llvm::StringRef> supported_type_repr_prefixes; |
| 64 | + std::size_t size; // Size required by the extracted value. |
| 65 | ++ std::size_t alignment; // Alignment requirement for the extracted value. |
| 66 | + ExtractorType extractor; // Function to call to extract the value. |
| 67 | + |
| 68 | + // Builds an ExtractionInfo for a given type T and a list of type reprs that |
| 69 | +@@ -400,7 +366,7 @@ struct ExtractionInfo { |
| 70 | + static ExtractionInfo build( |
| 71 | + std::initializer_list<llvm::StringRef> supported_type_reprs, |
| 72 | + ExtractorType extractor = extractValue<T>) { |
| 73 | +- return {supported_type_reprs, sizeof(T), extractor}; |
| 74 | ++ return {supported_type_reprs, sizeof(T), alignof(T), extractor}; |
| 75 | + } |
| 76 | + |
| 77 | + // Checks if the extractor supports extracting a given type repr. |
| 78 | +@@ -428,7 +394,7 @@ const ExtractionInfo kExtractionInfos[]{ |
| 79 | + // Note: types are e.g. '*fp32', so no closing quote is intentional. |
| 80 | + ExtractionInfo::build<void*>({"'*"}, extractPointer), |
| 81 | + ExtractionInfo{ |
| 82 | +- {"None", "'none'"}, 0, nullptr}, // Represent constexprs as None |
| 83 | ++ {"None", "'none'"}, 0, 0, nullptr}, // Represent constexprs as None |
| 84 | + ExtractionInfo::build<CUtensorMap>({"'nvTmaDesc'"}, extractTmaDesc), |
| 85 | + }; |
| 86 | + |
| 87 | +@@ -628,7 +594,19 @@ PyObject* launch(PyObject* self, PyObjec |
| 88 | + if (extraction_info.size == 0) { |
| 89 | + continue; // skip adding constexpr parameters |
| 90 | + } |
| 91 | +- config.params[params_idx] = alloca(extraction_info.size); |
| 92 | ++ size_t alignment = std::max(1UL, extraction_info.alignment); |
| 93 | ++ |
| 94 | ++ // Allocate enough space on the stack to guarantee an aligned block. |
| 95 | ++ size_t size_with_alignment = extraction_info.size + alignment - 1; |
| 96 | ++ void *param_storage_ptr = alloca(size_with_alignment); |
| 97 | ++ |
| 98 | ++ void *aligned_ptr = std::align(alignment, extraction_info.size, |
| 99 | ++ param_storage_ptr, size_with_alignment); |
| 100 | ++ if (aligned_ptr == nullptr) { |
| 101 | ++ PyErr_SetString(PyExc_MemoryError, "Failed to align parameter storage"); |
| 102 | ++ return nullptr; |
| 103 | ++ } |
| 104 | ++ config.params[params_idx] = aligned_ptr; |
| 105 | + if (!extraction_info.extractor(arg, config.params[params_idx])) { |
| 106 | + return nullptr; |
| 107 | + } |
| 108 | +@@ -940,6 +918,36 @@ static PyTypeObject PyCUtensorMapType = |
| 109 | + }; |
| 110 | + // clang-format on |
| 111 | + |
| 112 | ++namespace { |
| 113 | ++ |
| 114 | ++// Extracts a pointer to `CUtensorMap` from a `PyCUtensorMapObject`. |
| 115 | ++CUtensorMap* getTmaDesc(PyObject* obj) { |
| 116 | ++ if (sizeof(CUtensorMap*) != 8) { |
| 117 | ++ PyErr_SetString(PyExc_SystemError, |
| 118 | ++ "getTmaDesc() requires 64-bit compilation"); |
| 119 | ++ return nullptr; |
| 120 | ++ } |
| 121 | ++ if (Py_TYPE(obj) != static_cast<PyTypeObject*>(&PyCUtensorMapType)) { |
| 122 | ++ PyErr_Format(PyExc_TypeError, |
| 123 | ++ "object must be of type PyCUtensorMap, got %s", |
| 124 | ++ Py_TYPE(obj)->tp_name); |
| 125 | ++ return nullptr; |
| 126 | ++ } |
| 127 | ++ CUtensorMap* map = &((PyCUtensorMapObject*)obj)->tensorMap; |
| 128 | ++ // PyCUtensorMapObject aligns tensorMap to 128. |
| 129 | ++ uintptr_t align_128 = (uintptr_t)map & (128 - 1); |
| 130 | ++ if (align_128 != 0) { |
| 131 | ++ PyErr_Format( |
| 132 | ++ PyExc_ValueError, |
| 133 | ++ "CUtensorMap must be aligned to 128B, but got (&map) mod 128 = %ld", |
| 134 | ++ align_128); |
| 135 | ++ return nullptr; |
| 136 | ++ } |
| 137 | ++ return map; |
| 138 | ++} |
| 139 | ++ |
| 140 | ++} // namespace |
| 141 | ++ |
| 142 | + static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) { |
| 143 | + unsigned long long global_address; |
| 144 | + int swizzle; |
0 commit comments