Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,157 @@ before launching the kernel based on that tensor's device index.
For advanced scenarios that pass raw pointers instead of tensors, you should call
``cudaSetDevice`` explicitly through the CUDA Python API.

Call a compiled function from C++ via the TVM FFI registry
----------------------------------------------------------

The object returned by ``cute.compile(..., options="--enable-tvm-ffi")`` is itself a
``tvm_ffi.Function``: a native callable that follows the TVM FFI calling convention.
You can publish it in TVM FFI's process-global function registry under a string name with
``tvm_ffi.register_global_func``. Once registered, the same compiled kernel can be
looked up by name and invoked from any TVM FFI-supported language (e.g. C++) running
in the same process. This allows you to bypass the Python interpreter entirely and avoid
the CPU overhead by staying in the C++ environment only.
In such a case, Python is only used for expressiveness as a DSL to describe the kernel and register it in the shared registry,
and C++ is used for the actual execution with efficiency.

The following is a minimal example of how to call a compiled CuTeDSL function in C++.

This is the C++ code we will compile into a PyTorch extension. We name it ``extension.cpp`` here:

.. code-block:: cpp

#include <ATen/DLConvertor.h> // at::toDLPackNonOwning
#include <torch/extension.h> // pybind11 + at::Tensor
#include <tvm/ffi/container/tensor.h> // tvm::ffi::TensorView
#include <tvm/ffi/function.h> // tvm::ffi::Function

#include <string>

void apply_tvm_function(const std::string& name, at::Tensor &x, at::Tensor &y, at::Tensor &z) {
tvm::ffi::Function fn = tvm::ffi::Function::GetGlobalRequired(name);
DLTensor dl_x = {};
DLTensor dl_y = {};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one possible option is to actually make use of tvm ffi's export mechanism just like flashinfer

https://github.com/flashinfer-ai/flashinfer/blob/c3c40a7b90b792fc59f90f8f55c9e2de9c1b6833/csrc/flashinfer_gemm_binding.cu#L35

This way exported function also can be used in various tvm-ffi compatible scenarios and callable from pytorch

In this case, you can also do

using ffi = tvm::ffi;

void ApplyTVMFFIFunc(ffi::Function f, ffi::TensorView x, ffi::TensorView y, ffi::TensorView z) {
    f(f, x, y, z);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(ApplyTVMFFIFunc);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tqchen

I think the use case here is to create and export a TVM-FFI function in python, then call it in C++, which would be helpful if you are working with a C++ heavy framework and you just want to leverage CuTeDSL's flexibility and run a kernel without going to python region.

I feel like TVM_FFI_DLL_EXPORT_TYPED_FUNC is more like exporting a C++ function to python, which seems to be the case in flashinfer and it's rather the opposite direction (we want to export in python and call it in C++).
I include the pybind code in my example only because I need to run the C++ function to demonstrate the code is working (as I mentioned in the doc, in practice you probably want to call the C++ function in C++).

DLTensor dl_z = {};
at::toDLPackNonOwning(x, &dl_x);
at::toDLPackNonOwning(y, &dl_y);
at::toDLPackNonOwning(z, &dl_z);
fn(&dl_x, &dl_y, &dl_z);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_tvm_function", &apply_tvm_function,
"Look up a tvm-ffi global function by name and call it with three tensors.");
}


Then we need to compile and load this extension into PyTorch:

.. code-block:: python

import subprocess
import sys

from torch.utils.cpp_extension import load

def _tvm_ffi_config(flag: str) -> str:
"""Ask the installed apache-tvm-ffi where its headers and lib live."""
out = subprocess.check_output([sys.executable, "-m", "tvm_ffi.config", flag])
return out.decode().strip()


def build_extension():
include_dir = _tvm_ffi_config("--includedir")
lib_dir = _tvm_ffi_config("--libdir")
return load(
name="tvm_ffi_demo_ext",
sources=["extension.cpp"],
extra_include_paths=[include_dir],
extra_cflags=["-std=c++17"],
# -ltvm_ffi to link, and -rpath so the .so is found at runtime. This is
# the same libtvm_ffi.so that `import tvm_ffi` loads -> shared registry.
extra_ldflags=[f"-L{lib_dir}", "-ltvm_ffi", f"-Wl,-rpath,{lib_dir}"],
verbose=True,
)

With all the boilerplate code ready, now let's write a CuTeDSL kernel and use the C++ extension to call it.
In practice you might want to call the CuTeDSL function from C++ directly without going back to Python. We call from Python here just for demonstration purposes.

.. code-block:: python

import cutlass
import torch
import tvm_ffi
import cutlass.cute as cute

@cute.jit
def add(x: cute.Tensor, y: cute.Tensor, z: cute.Tensor):
add_kernel(x, y, z).launch(grid=[1, 1, 1], block=[16, 1, 1])

@cute.kernel
def add_kernel(x: cute.Tensor, y: cute.Tensor, z: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx < 16:
z[tidx] = x[tidx] + y[tidx]

def main() -> None:
ext = build_extension()

fake_x = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4)
fake_y = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4)
fake_z = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4)
compiled = cute.compile(add, fake_x, fake_y, fake_z, options="--enable-tvm-ffi",)
tvm_ffi.register_global_func("CuTeDSL_add", compiled, override=True)

x = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda")
y = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda")
z = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda")
ext.apply_tvm_function("CuTeDSL_add", x, y, z)
assert torch.allclose(x + y, z, atol=1e-5, rtol=1e-5)
print("Successfully called CuTeDSL function from C++!")

if __name__ == "__main__":
main()

Calling convention of TVM-FFI in C++
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When you compile a ``@cute.jit`` function with the ``--enable-tvm-ffi`` option, ``cute.compile`` will return a TVM FFI function object.
Then you can register it as a `Global Function <https://tvm.apache.org/ffi/guides/export_func_cls.html#global-functions>`_ to make it accessible from other languages.

In C++, you can load the compiled function from the global registry with ``tvm::ffi::Function::GetGlobal`` (or ``tvm::ffi::Function::GetGlobalRequired``, which will throw if the function is not found).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we dump the tvm ffi function from cutedsl, can we leverage AOT? Can you explain this part more clearly. Other parts look good to me, thanks

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add some description on this but really it's just doing AOT compilation and then calling this TVM-FFI API as in my example.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that may be relevant surfacing is https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp

see how it is not using the global function registry and relies on the symbol being available

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes. I also added some explanation of how to load symbols from shared libraries. Although I think using the global registry is more straightforward if I'm working with frameworks running in both C++ and python since in this way I don't need to handle the shared library manually.
Also the script path in aot_use_in_cpp_bundle.cpp is a bit outdated. I fixed that a bit as well

The returned object will be of type ``tvm::ffi::Function``.

The signature of the TVM FFI function loaded in C++ will use a unified ABI for all functions like this:

.. code-block:: cpp

void CallPacked(const AnyView* args, int32_t num_args, Any* result) const

where ``const AnyView* args`` is the type-erased array of arguments, whose actual types are determined at runtime,
``int32_t num_args`` is the number of arguments, and ``Any* result`` is a pointer to the return value (if any).

- The arguments are called "AnyView", meaning they are "non-owning" views of the underlying data and therefore the lifetime is determined by the actual data owner.
- The return value is called "Any", meaning it owns the data and is responsible for its lifetime.
- Both ``Any`` and ``AnyView`` are type-erased containers that can hold objects from different types. The actual type is decided by their ``type_index`` attribute at runtime, which is a `TVMFFITypeIndex <https://tvm.apache.org/ffi/reference/cpp/generated/enum_c__api_8h_1a1925bb5d568a3f5c92a6c28934c9bcc2.html#_CPPv4N15TVMFFITypeIndex11kTVMFFINoneE>`_ enum that represents a TVM-FFI type.

However, you do not need to explicitly call TVM-FFI functions with this low-level packed format signature, because TVM-FFI has overridden the ``operator()`` method, which creates arguments of ``CallPacked`` for you
to allow you to call the function with the same signature as how you defined it.
So in our elementwise addition example, the C++ signature of the TVM FFI function will be something like (note our kernel does not return anything, so the return value will be a ``tvm::ffi::Any`` with ``type_index`` of ``kTVMFFINone``):

.. code-block:: cpp

tvm::ffi::Any add(tvm::ffi::AnyView x, tvm::ffi::AnyView y, tvm::ffi::AnyView z)

When we want to call the TVM FFI function in C++, we need to construct our inputs in a form that can be converted to ``AnyView`` and recognized by TVM-FFI. In this case, the conversion path we would take is
``DLTensor`` -> ``tvm::ffi::TensorView`` -> ``tvm::ffi::AnyView``. The latter two conversions can be implicit (supported by TVM-FFI already), so we just need to convert our tensor type to ``DLTensor``.
For PyTorch tensors, they would be ``at::Tensor`` in C++ and we can use ``at::toDLPackNonOwning`` to get a ``DLTensor`` view. For custom tensor types, you might need to implement the conversion yourself.

For other basic types, you can directly pass them and let TVM-FFI handle the conversion implicitly. You are unlikely to need to convert them manually, since they are general types that are widely recognized.

See `layout <https://tvm.apache.org/ffi/concepts/any.html#layout>`_ for more detail on how TVM-FFI's ``Any`` type works.

See `tensor-classes <https://tvm.apache.org/ffi/concepts/tensor.html#tensor-classes>`_ for more detail on how DLPack tensors and TVM-FFI tensors convert between each other.


Exporting Compiled Module
-------------------------

Expand Down