diff --git a/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp index 37045630e4..40a2717d21 100644 --- a/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp +++ b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp @@ -15,7 +15,12 @@ // This example shows how to interface with an AOT compiled function in a C++ // bundle. to build and run the example, run the following command in project // root bash -// cutlass_ir/compiler/python/examples/cute/tvm_ffi/aot_use_in_cpp_bundle.sh +// ```bash +// # Generate the object file from the AOT export script +// python examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_export.py +// # Build the C++ executable +// bash examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.sh +// ``` #include #include diff --git a/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst b/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst index c82c74d8ff..abab9d071e 100644 --- a/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst +++ b/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.rst @@ -539,6 +539,167 @@ before launching the kernel based on that tensor's device index. For advanced scenarios that pass raw pointers instead of tensors, you should call ``cudaSetDevice`` explicitly through the CUDA Python API. +Call a compiled function from C++ via the TVM FFI registry +---------------------------------------------------------- + +The object returned by ``cute.compile(..., options="--enable-tvm-ffi")`` is itself a +``tvm_ffi.Function``: a native callable that follows the TVM FFI calling convention. +You can publish it in TVM FFI's process-global function registry under a string name with +``tvm_ffi.register_global_func``. Once registered, the same compiled kernel can be +looked up by name and invoked from any TVM FFI-supported language (e.g. C++) running +in the same process. This allows you to bypass the Python interpreter entirely and avoid +the CPU overhead by staying in the C++ environment only. +In such a case, Python is only used for expressiveness as a DSL to describe the kernel and register it in the shared registry, +and C++ is used for the actual execution with efficiency. + +The following is a minimal example of how to call a compiled CuTeDSL function in C++. + +This is the C++ code we will compile into a PyTorch extension. We name it ``extension.cpp`` here: + +.. code-block:: cpp + + #include // at::toDLPackNonOwning + #include // pybind11 + at::Tensor + #include // tvm::ffi::TensorView + #include // tvm::ffi::Function + + #include + + void apply_tvm_function(const std::string& name, at::Tensor &x, at::Tensor &y, at::Tensor &z) { + tvm::ffi::Function fn = tvm::ffi::Function::GetGlobalRequired(name); + DLTensor dl_x = {}; + DLTensor dl_y = {}; + DLTensor dl_z = {}; + at::toDLPackNonOwning(x, &dl_x); + at::toDLPackNonOwning(y, &dl_y); + at::toDLPackNonOwning(z, &dl_z); + fn(&dl_x, &dl_y, &dl_z); + } + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("apply_tvm_function", &apply_tvm_function, + "Look up a tvm-ffi global function by name and call it with three tensors."); + } + + +Then we need to compile and load this extension into PyTorch: + +.. code-block:: python + + import subprocess + import sys + + from torch.utils.cpp_extension import load + + def _tvm_ffi_config(flag: str) -> str: + """Ask the installed apache-tvm-ffi where its headers and lib live.""" + out = subprocess.check_output([sys.executable, "-m", "tvm_ffi.config", flag]) + return out.decode().strip() + + + def build_extension(): + include_dir = _tvm_ffi_config("--includedir") + dlpack_include_dir = _tvm_ffi_config("--dlpack-includedir") + lib_dir = _tvm_ffi_config("--libdir") + return load( + name="tvm_ffi_demo_ext", + sources=["extension.cpp"], + extra_include_paths=[include_dir, dlpack_include_dir], + extra_cflags=["-std=c++17"], + # -ltvm_ffi to link, and -rpath so the .so is found at runtime. This is + # the same libtvm_ffi.so that `import tvm_ffi` loads -> shared registry. + extra_ldflags=[f"-L{lib_dir}", "-ltvm_ffi", f"-Wl,-rpath,{lib_dir}"], + verbose=True, + ) + +With all the boilerplate code ready, now let's write a CuTeDSL kernel and use the C++ extension to call it. +In practice you might want to call the CuTeDSL function from C++ directly without going back to Python. We call from Python here just for demonstration purposes. + +.. code-block:: python + + import cutlass + import torch + import tvm_ffi + import cutlass.cute as cute + + @cute.jit + def add(x: cute.Tensor, y: cute.Tensor, z: cute.Tensor): + add_kernel(x, y, z).launch(grid=[1, 1, 1], block=[16, 1, 1]) + + @cute.kernel + def add_kernel(x: cute.Tensor, y: cute.Tensor, z: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + if tidx < 16: + z[tidx] = x[tidx] + y[tidx] + + def main() -> None: + ext = build_extension() + + fake_x = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) + fake_y = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) + fake_z = cute.runtime.make_fake_compact_tensor(cutlass.BFloat16, (4, 4), stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) + compiled = cute.compile(add, fake_x, fake_y, fake_z, options="--enable-tvm-ffi",) + tvm_ffi.register_global_func("CuTeDSL_add", compiled, override=True) + + x = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda") + y = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda") + z = torch.randn((4, 4), dtype=torch.bfloat16, device="cuda") + + # Launch the C++ function. This is only for demonstration because it's the easiest way to run our C++ function. + # In practice you might be working with some C++ heavy framework and you should call the C++ function from C++ directly without going through Python. + ext.apply_tvm_function("CuTeDSL_add", x, y, z) + assert torch.allclose(x + y, z, atol=1e-8, rtol=1e-8) + print("Successfully called CuTeDSL function from C++!") + + if __name__ == "__main__": + main() + +Calling convention of TVM-FFI in C++ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +To call a compiled CuTeDSL function from C++, we need to utilize the Ahead-of-Time (AOT) compilation to obtain the compiled function as a TVM FFI function object first, +and then register it in the TVM FFI global registry with a string name. This requires you to use ``cute.compile`` to compile the ``@cute.jit`` function with the ``--enable-tvm-ffi`` option, then ``cute.compile`` will return a TVM FFI function object. +Next, you need to register this function object as a `Global Function `_ with this API: ``tvm_ffi.register_global_func(func_name, f=None, override=False)``, +where ``func_name`` is the string name to identify the function in the global registry, ``f`` is the TVM FFI function object returned by ``cute.compile``, and ``override=True`` allows overwriting an existing function with the same name in the registry. +In this way, you can make this compiled ``@cute.jit`` TVM-FFI function accessible from other languages, including C++. + +Note this is not the only option to obtain a TVM-FFI function in C++. You can also export the compiled module to an object file and load it in C++ with TVM-FFI APIs. +See `Exporting Compiled Module `_ for more details. + +In C++, you can load the compiled function from the global registry with ``tvm::ffi::Function::GetGlobal`` (or ``tvm::ffi::Function::GetGlobalRequired``, which will throw if the function is not found). +The returned object will be of type ``tvm::ffi::Function``. + +The signature of the TVM FFI function loaded in C++ will use a unified ABI for all functions like this: + +.. code-block:: cpp + + void CallPacked(const AnyView* args, int32_t num_args, Any* result) const + +where ``const AnyView* args`` is the type-erased array of arguments, whose actual types are determined at runtime, +``int32_t num_args`` is the number of arguments, and ``Any* result`` is a pointer to the return value (if any). + +- The arguments are called "AnyView", meaning they are "non-owning" views of the underlying data and therefore the lifetime is determined by the actual data owner. +- The return value is called "Any", meaning it owns the data and is responsible for its lifetime. +- Both ``Any`` and ``AnyView`` are type-erased containers that can hold objects from different types. The actual type is decided by their ``type_index`` attribute at runtime, which is a `TVMFFITypeIndex `_ enum that represents a TVM-FFI type. + +However, you do not need to explicitly call TVM-FFI functions with this low-level packed format signature, because TVM-FFI has overridden the ``operator()`` method, which creates arguments of ``CallPacked`` for you +to allow you to call the function with the same signature as how you defined it. +So in our elementwise addition example, the C++ signature of the TVM FFI function will be something like (note our kernel does not return anything, so the return value will be a ``tvm::ffi::Any`` with ``type_index`` of ``kTVMFFINone``): + +.. code-block:: cpp + + tvm::ffi::Any add(tvm::ffi::AnyView x, tvm::ffi::AnyView y, tvm::ffi::AnyView z) + +When we want to call the TVM FFI function in C++, we need to construct our inputs in a form that can be converted to ``AnyView`` and recognized by TVM-FFI. In this case, the conversion path we would take is +``DLTensor`` -> ``tvm::ffi::TensorView`` -> ``tvm::ffi::AnyView``. The latter two conversions can be implicit (supported by TVM-FFI already), so we just need to convert our tensor type to ``DLTensor``. +For PyTorch tensors, they would be ``at::Tensor`` in C++ and we can use ``at::toDLPackNonOwning`` to get a ``DLTensor`` view. For custom tensor types, you might need to implement the conversion yourself. + +For other basic types, you can directly pass them and let TVM-FFI handle the conversion implicitly. You are unlikely to need to convert them manually, since they are general types that are widely recognized. + +See `layout `_ for more detail on how TVM-FFI's ``Any`` type works. + +See `tensor-classes `_ for more detail on how DLPack tensors and TVM-FFI tensors convert between each other. + + Exporting Compiled Module ------------------------- @@ -586,6 +747,39 @@ The exported object file exposes the function symbol ``__tvm_ffi_add_one`` that compatible with TVM FFI and can be used in various frameworks and programming languages. You can either build a shared library and load it back, or link the object file directly into your application and invoke the function via the ``InvokeExternC`` mechanism in TVM FFI. + +How it works is very similar to the C++ example above. The only difference is that instead of looking up the function from the TVM-FFI global registry, +the TVM-FFI function symbol is now exposed via a shared library: + +.. code-block:: cpp + + extern "C" int __tvm_ffi_add_one(void*, const TVMFFIAny*, int32_t, TVMFFIAny*); + + // If the tvm-ffi function symbol is already known at compile time and it's dynamically linked (or statically linked if you build a static library), + // then you can directly call the function via the exposed symbol via extern C. + void apply_tvm_function_via_extern_C(at::Tensor &a, at::Tensor &b){ + DLTensor dl_a = {}; + DLTensor dl_b = {}; + at::toDLPackNonOwning(a, &dl_a); + at::toDLPackNonOwning(b, &dl_b); + tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add_one, &dl_a, &dl_b); + } + + // If the tvm-ffi function symbol is not known until runtime, you can resolve it from its shared library at + // runtime by giving the library path and function name (dynamic loading). + void apply_tvm_function_via_dynamic_resolution(const std::string& lib_path, + const std::string& func_name, + at::Tensor &a, at::Tensor &b){ + tvm::ffi::Module mod = tvm::ffi::Module::LoadFromFile(lib_path); + tvm::ffi::Function fn = mod->GetFunction(func_name).value(); + DLTensor dl_a = {}; + DLTensor dl_b = {}; + at::toDLPackNonOwning(a, &dl_a); + at::toDLPackNonOwning(b, &dl_b); + fn(&dl_a, &dl_b); + } + + For more information, see the `quick start guide `_ in the official documentation.