-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[CuTeDSL] Add docs of how to call cute.jit functions in C++ via TVM-FFI #3289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
8a55a6d
089b9ee
fa29f85
5b6b415
b826e83
eb2f630
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -539,6 +539,157 @@ before launching the kernel based on that tensor's device index. | |
| For advanced scenarios that pass raw pointers instead of tensors, you should call | ||
| ``cudaSetDevice`` explicitly through the CUDA Python API. | ||
|
|
||
| Call a compiled function from C++ via the TVM FFI registry | ||
| ---------------------------------------------------------- | ||
|
|
||
| The object returned by ``cute.compile(..., options="--enable-tvm-ffi")`` is itself a | ||
| ``tvm_ffi.Function``: a native callable that follows the TVM FFI calling convention. | ||
| You can publish it in TVM FFI's process-global function registry under a string name with | ||
| ``tvm_ffi.register_global_func``. Once registered, the same compiled kernel can be | ||
| looked up by name and invoked from any TVM FFI-supported language (e.g. C++) running | ||
| in the same process. This allows you to bypass the Python interpreter entirely and avoid | ||
| the CPU overhead by staying in the C++ environment only. | ||
| In such a case, Python is only used for expressiveness as a DSL to describe the kernel and register it in the shared registry, | ||
| and C++ is used for the actual execution with efficiency. | ||
|
|
||
| The following is a minimal example of how to call a compiled CuTeDSL function in C++. | ||
|
|
||
| This is the C++ code we will compile into a PyTorch extension. We name it ``extension.cpp`` here: | ||
|
|
||
| .. code-block:: cpp | ||
|
|
||
| #include <ATen/DLConvertor.h> // at::toDLPackNonOwning | ||
| #include <torch/extension.h> // pybind11 + at::Tensor | ||
| #include <tvm/ffi/container/tensor.h> // tvm::ffi::TensorView | ||
| #include <tvm/ffi/function.h> // tvm::ffi::Function | ||
|
|
||
| #include <string> | ||
|
|
||
| void apply_tvm_function(const std::string& name, at::Tensor &x, at::Tensor &y, at::Tensor &z) { | ||
| tvm::ffi::Function fn = tvm::ffi::Function::GetGlobalRequired(name); | ||
| DLTensor dl_x = {}; | ||
| DLTensor dl_y = {}; | ||
| DLTensor dl_z = {}; | ||
| at::toDLPackNonOwning(x, &dl_x); | ||
| at::toDLPackNonOwning(y, &dl_y); | ||
| at::toDLPackNonOwning(z, &dl_z); | ||
| fn(&dl_x, &dl_y, &dl_z); | ||
| } | ||
|
|
||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
| m.def("apply_tvm_function", &apply_tvm_function, | ||
| "Look up a tvm-ffi global function by name and call it with three tensors."); | ||
| } | ||
|
|
||
|
|
||
| Then we need to compile and load this extension into PyTorch: | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import subprocess | ||
| import sys | ||
|
|
||
| from torch.utils.cpp_extension import load | ||
|
|
||
| def _tvm_ffi_config(flag: str) -> str: | ||
| """Ask the installed apache-tvm-ffi where its headers and lib live.""" | ||
| out = subprocess.check_output([sys.executable, "-m", "tvm_ffi.config", flag]) | ||
| return out.decode().strip() | ||
|
|
||
|
|
||
| def build_extension(): | ||
| include_dir = _tvm_ffi_config("--includedir") | ||
| lib_dir = _tvm_ffi_config("--libdir") | ||
| return load( | ||
| name="tvm_ffi_demo_ext", | ||
| sources=["extension.cpp"], | ||
| extra_include_paths=[include_dir], | ||
| extra_cflags=["-std=c++17"], | ||
| # -ltvm_ffi to link, and -rpath so the .so is found at runtime. This is | ||
| # the same libtvm_ffi.so that `import tvm_ffi` loads -> shared registry. | ||
| extra_ldflags=[f"-L{lib_dir}", "-ltvm_ffi", f"-Wl,-rpath,{lib_dir}"], | ||
| verbose=True, | ||
| ) | ||
|
|
||
| With all the boilerplate code ready, now let's write a CuTeDSL kernel and use the C++ extension to call it. | ||
| In practice you might want to call the CuTeDSL function from C++ directly without going back to Python. We call from Python here just for demonstration purposes. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import cutlass | ||
| import torch | ||
| import tvm_ffi | ||
| import cutlass.cute as cute | ||
|
|
||
| @cute.jit | ||
| def add(x: cute.Tensor, y: cute.Tensor, z: cute.Tensor): | ||
| add_kernel(x, y, z).launch(grid=[1, 1, 1], block=[16, 1, 1]) | ||
|
|
||
| @cute.kernel | ||
| def add_kernel(x: cute.Tensor, y: cute.Tensor, z: cute.Tensor): | ||
| tidx, _, _ = cute.arch.thread_idx() | ||
| if tidx < 16: | ||
| z[tidx] = x[tidx] + y[tidx] | ||
|
|
||
| def main() -> None: | ||
| ext = build_extension() | ||
|
|
||
| fake_x = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) | ||
| fake_y = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) | ||
| fake_z = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) | ||
| compiled = cute.compile(add, fake_x, fake_y, fake_z, options="--enable-tvm-ffi",) | ||
| tvm_ffi.register_global_func("CuTeDSL_add", compiled, override=True) | ||
|
|
||
| x = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda") | ||
| y = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda") | ||
| z = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda") | ||
| ext.apply_tvm_function("CuTeDSL_add", x, y, z) | ||
| assert torch.allclose(x + y, z, atol=1e-5, rtol=1e-5) | ||
| print("Successfully called CuTeDSL function from C++!") | ||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
|
|
||
| Calling convention of TVM-FFI in C++ | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
| When you compile a ``@cute.jit`` function with the ``--enable-tvm-ffi`` option, ``cute.compile`` will return a TVM FFI function object. | ||
| Then you can register it as a `Global Function <https://tvm.apache.org/ffi/guides/export_func_cls.html#global-functions>`_ to make it accessible from other languages. | ||
|
|
||
| In C++, you can load the compiled function from the global registry with ``tvm::ffi::Function::GetGlobal`` (or ``tvm::ffi::Function::GetGlobalRequired``, which will throw if the function is not found). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How can we dump the tvm ffi function from cutedsl, can we leverage AOT? Can you explain this part more clearly. Other parts look good to me, thanks
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I add some description on this but really it's just doing AOT compilation and then calling this TVM-FFI API as in my example.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thing that may be relevant surfacing is https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp see how it is not using the global function registry and relies on the symbol being available
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes. I also added some explanation of how to load symbols from shared libraries. Although I think using the global registry is more straightforward if I'm working with frameworks running in both C++ and python since in this way I don't need to handle the shared library manually. |
||
| The returned object will be of type ``tvm::ffi::Function``. | ||
|
|
||
| The signature of the TVM FFI function loaded in C++ will use a unified ABI for all functions like this: | ||
|
|
||
| .. code-block:: cpp | ||
|
|
||
| void CallPacked(const AnyView* args, int32_t num_args, Any* result) const | ||
|
|
||
| where ``const AnyView* args`` is the type-erased array of arguments, whose actual types are determined at runtime, | ||
| ``int32_t num_args`` is the number of arguments, and ``Any* result`` is a pointer to the return value (if any). | ||
|
|
||
| - The arguments are called "AnyView", meaning they are "non-owning" views of the underlying data and therefore the lifetime is determined by the actual data owner. | ||
| - The return value is called "Any", meaning it owns the data and is responsible for its lifetime. | ||
| - Both ``Any`` and ``AnyView`` are type-erased containers that can hold objects from different types. The actual type is decided by their ``type_index`` attribute at runtime, which is a `TVMFFITypeIndex <https://tvm.apache.org/ffi/reference/cpp/generated/enum_c__api_8h_1a1925bb5d568a3f5c92a6c28934c9bcc2.html#_CPPv4N15TVMFFITypeIndex11kTVMFFINoneE>`_ enum that represents a TVM-FFI type. | ||
|
|
||
| However, you do not need to explicitly call TVM-FFI functions with this low-level packed format signature, because TVM-FFI has overridden the ``operator()`` method, which creates arguments of ``CallPacked`` for you | ||
| to allow you to call the function with the same signature as how you defined it. | ||
| So in our elementwise addition example, the C++ signature of the TVM FFI function will be something like (note our kernel does not return anything, so the return value will be a ``tvm::ffi::Any`` with ``type_index`` of ``kTVMFFINone``): | ||
|
|
||
| .. code-block:: cpp | ||
|
|
||
| tvm::ffi::Any add(tvm::ffi::AnyView x, tvm::ffi::AnyView y, tvm::ffi::AnyView z) | ||
|
|
||
| When we want to call the TVM FFI function in C++, we need to construct our inputs in a form that can be converted to ``AnyView`` and recognized by TVM-FFI. In this case, the conversion path we would take is | ||
| ``DLTensor`` -> ``tvm::ffi::TensorView`` -> ``tvm::ffi::AnyView``. The latter two conversions can be implicit (supported by TVM-FFI already), so we just need to convert our tensor type to ``DLTensor``. | ||
| For PyTorch tensors, they would be ``at::Tensor`` in C++ and we can use ``at::toDLPackNonOwning`` to get a ``DLTensor`` view. For custom tensor types, you might need to implement the conversion yourself. | ||
|
|
||
| For other basic types, you can directly pass them and let TVM-FFI handle the conversion implicitly. You are unlikely to need to convert them manually, since they are general types that are widely recognized. | ||
|
|
||
| See `layout <https://tvm.apache.org/ffi/concepts/any.html#layout>`_ for more detail on how TVM-FFI's ``Any`` type works. | ||
|
|
||
| See `tensor-classes <https://tvm.apache.org/ffi/concepts/tensor.html#tensor-classes>`_ for more detail on how DLPack tensors and TVM-FFI tensors convert between each other. | ||
|
|
||
|
|
||
| Exporting Compiled Module | ||
| ------------------------- | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one possible option is to actually make use of tvm ffi's export mechanism just like flashinfer
https://github.com/flashinfer-ai/flashinfer/blob/c3c40a7b90b792fc59f90f8f55c9e2de9c1b6833/csrc/flashinfer_gemm_binding.cu#L35
This way exported function also can be used in various tvm-ffi compatible scenarios and callable from pytorch
In this case, you can also do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related https://tvm.apache.org/ffi/packaging/python_packaging.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @tqchen
I think the use case here is to create and export a TVM-FFI function in python, then call it in C++, which would be helpful if you are working with a C++ heavy framework and you just want to leverage CuTeDSL's flexibility and run a kernel without going to python region.
I feel like
TVM_FFI_DLL_EXPORT_TYPED_FUNCis more like exporting a C++ function to python, which seems to be the case in flashinfer and it's rather the opposite direction (we want to export in python and call it in C++).I include the pybind code in my example only because I need to run the C++ function to demonstrate the code is working (as I mentioned in the doc, in practice you probably want to call the C++ function in C++).