Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions backends/aoti/slim/c10/cuda/Exception.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef CUDA_AVAILABLE

#include <cuda.h>
#include <cuda_runtime.h>

#include <executorch/backends/aoti/slim/c10/macros/Macros.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>

/// Checks a CUDA expression and aborts on error.
/// @param EXPR The CUDA expression to check.
#define ET_CUDA_CHECK(EXPR) \
do { \
const cudaError_t __err = EXPR; \
ET_CHECK_MSG( \
__err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \
} while (0)

/// Checks a CUDA expression and logs a warning on error (non-fatal).
/// @param EXPR The CUDA expression to check.
#define ET_CUDA_LOG_WARN(EXPR) \
do { \
const cudaError_t __err = EXPR; \
if (SLIMTENSOR_UNLIKELY(__err != cudaSuccess)) { \
[[maybe_unused]] auto error_unused = cudaGetLastError(); \
ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \
} \
} while (0)

#endif // CUDA_AVAILABLE
6 changes: 6 additions & 0 deletions backends/aoti/slim/c10/cuda/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
16 changes: 16 additions & 0 deletions backends/aoti/slim/c10/cuda/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Define targets for SlimTensor CUDA exception handling module."""

runtime.cxx_library(
name = "exception",
exported_headers = [
"Exception.h",
],
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
"//executorch/backends/aoti/slim/c10/macros:macros",
"//executorch/runtime/platform:platform",
],
)
141 changes: 133 additions & 8 deletions backends/aoti/slim/core/Storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@

#include <cstring>

#ifdef CUDA_AVAILABLE
#include <executorch/backends/aoti/slim/c10/cuda/Exception.h>
#include <executorch/backends/cuda/runtime/guard.h>
#endif

#include <executorch/backends/aoti/slim/c10/core/Device.h>
#include <executorch/backends/aoti/slim/c10/core/ScalarType.h>
#include <executorch/backends/aoti/slim/util/ArrayRefUtil.h>
#include <executorch/backends/aoti/slim/util/SharedPtr.h>
#include <executorch/backends/aoti/slim/util/SizeUtil.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>

namespace executorch::backends::aoti::slim {

Expand All @@ -30,6 +36,10 @@ inline void noop(void*) {}
/// Default CPU device constant.
inline const c10::Device CPU_DEVICE = c10::Device(c10::DeviceType::CPU, 0);

/// Default CUDA device constant.
inline const c10::Device DEFAULT_CUDA_DEVICE =
c10::Device(c10::DeviceType::CUDA, 0);

/// DeviceTraits template for device-specific operations.
/// Device-specific implementations provide allocate(), free(), and memcpy().
template <c10::DeviceType D>
Expand Down Expand Up @@ -74,6 +84,119 @@ struct DeviceTraits<c10::DeviceType::CPU> {
}
};

#ifdef CUDA_AVAILABLE
/// CUDA specialization of DeviceTraits.
/// Provides CUDA memory allocation and copy operations using
/// cudaMallocAsync/cudaFreeAsync with proper stream handling.
///
/// IMPORTANT: Callers are expected to set the correct CUDA device and stream
/// using CUDAStreamGuard before calling these methods. This is consistent
/// with PyTorch's CUDACachingAllocator design pattern where the allocator
/// assumes the caller has already set the correct device context.
template <>
struct DeviceTraits<c10::DeviceType::CUDA> {
/// Allocates CUDA device memory on the current stream.
/// Uses cudaMallocAsync for asynchronous allocation on the stream
/// that is currently set via CUDAStreamGuard, similar to how
/// PyTorch's CUDACachingAllocator works.
///
/// NOTE: Caller must ensure the correct device is already set via
/// CUDAStreamGuard. This function does NOT create a device guard internally.
///
/// @param nbytes Number of bytes to allocate.
/// @param device The target CUDA device (used to get the stream).
/// @return Pointer to allocated device memory.
static void* allocate(size_t nbytes, const c10::Device& device) {
// Get the current stream for this device (set by CUDAStreamGuard if any)
// This follows PyTorch's pattern where the allocator assumes the caller
// has already set the correct device via CUDAStreamGuard.
auto stream_result =
executorch::backends::cuda::getCurrentCUDAStream(device.index());
ET_CHECK_MSG(
stream_result.ok(),
"Failed to get current CUDA stream for device %d",
static_cast<int>(device.index()));

cudaStream_t stream = stream_result.get();
void* data = nullptr;
ET_CUDA_CHECK(cudaMallocAsync(&data, nbytes, stream));
return data;
}

/// Frees CUDA device memory on the current stream.
/// @param ptr Pointer to device memory to free.
static void free(void* ptr) {
// Get the current stream for the current device
auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1);
if (stream_result.ok()) {
ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get()));
} else {
// Fallback to synchronous free if we can't get the stream
ET_CUDA_LOG_WARN(cudaFree(ptr));
}
}

/// Copies memory between CPU and CUDA or CUDA and CUDA.
/// @param dst Destination pointer.
/// @param src Source pointer.
/// @param nbytes Number of bytes to copy.
/// @param dst_device Destination device.
/// @param src_device Source device.
static void memcpy(
void* dst,
const void* src,
size_t nbytes,
const c10::Device& dst_device,
const c10::Device& src_device) {
cudaMemcpyKind direction = cudaMemcpyDeviceToDevice;

if (src_device.is_cpu()) {
direction = cudaMemcpyHostToDevice;
} else if (dst_device.is_cpu()) {
direction = cudaMemcpyDeviceToHost;
} else {
ET_CHECK_MSG(
src_device.index() == dst_device.index(),
"CUDA memcpy across different device indices not supported: %d != %d",
static_cast<int>(src_device.index()),
static_cast<int>(dst_device.index()));
}

ET_CUDA_CHECK(cudaMemcpy(dst, src, nbytes, direction));
}
};
#else
/// CUDA stub when CUDA_AVAILABLE is not defined.
/// All operations abort with an error message.
template <>
struct DeviceTraits<c10::DeviceType::CUDA> {
static void* allocate(size_t nbytes, const c10::Device& device) {
(void)nbytes;
(void)device;
ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support");
}

static void free(void* ptr) {
(void)ptr;
ET_LOG(Error, "Build with CUDA_AVAILABLE=1 to enable CUDA support");
}

static void memcpy(
void* dst,
const void* src,
size_t nbytes,
const c10::Device& dst_device,
const c10::Device& src_device) {
(void)dst;
(void)src;
(void)nbytes;
(void)dst_device;
(void)src_device;
ET_CHECK_MSG(false, "Build with CUDA_AVAILABLE=1 to enable CUDA support");
}
};
#endif // CUDA_AVAILABLE

/**
* MaybeOwningStorage - A storage class that manages tensor data memory.
*
Expand All @@ -93,17 +216,19 @@ struct DeviceTraits<c10::DeviceType::CPU> {
class MaybeOwningStorage {
public:
/// Constructs owning storage with allocated memory.
/// @param device The device for storage (must be CPU).
/// @param device The device for storage (CPU or CUDA).
/// @param nbytes Number of bytes to allocate.
MaybeOwningStorage(const c10::Device& device, size_t nbytes)
: device_(device), capacity_(nbytes), is_owning_(true) {
ET_CHECK_MSG(
device.is_cpu(),
"Only CPU device is currently supported, got: %s",
device.str().c_str());

data_ = DeviceTraits<c10::DeviceType::CPU>::allocate(nbytes, device);
deleter_ = DeviceTraits<c10::DeviceType::CPU>::free;
if (device.is_cpu()) {
data_ = DeviceTraits<c10::DeviceType::CPU>::allocate(nbytes, device);
deleter_ = DeviceTraits<c10::DeviceType::CPU>::free;
} else if (device.is_cuda()) {
data_ = DeviceTraits<c10::DeviceType::CUDA>::allocate(nbytes, device);
deleter_ = DeviceTraits<c10::DeviceType::CUDA>::free;
} else {
ET_CHECK_MSG(false, "Unsupported device type: %s", device.str().c_str());
}
}

/// Default constructor is deleted - storage must have a device.
Expand Down
4 changes: 3 additions & 1 deletion backends/aoti/slim/core/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ def define_common_targets():
"//executorch/backends/aoti/slim/util:shared_ptr",
"//executorch/backends/aoti/slim/util:size_util",
"//executorch/runtime/platform:platform",
"//executorch/backends/aoti/slim/c10/cuda:exception",
"//executorch/backends/cuda/runtime:guard",
],
)

# Header-only library for SlimTensor
# Header-only library for SlimTensor (CPU-only for now)
runtime.cxx_library(
name = "slimtensor",
headers = [
Expand Down
37 changes: 28 additions & 9 deletions backends/aoti/slim/core/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def get_backend_mode():
"""Get the supported backend mode of slimtensor."""
return ["cuda", "cpu"]

def define_common_targets():
"""Define test targets for SlimTensor core module."""

runtime.cxx_test(
name = "test_storage",
srcs = [
"test_storage.cpp",
],
deps = [
"//executorch/backends/aoti/slim/core:storage",
],
)
# GPU storage test with CUDA support
for backend_mode in get_backend_mode():
backend_suffix = "_" + backend_mode if backend_mode == "cuda" else ""

backend_kwargs = {
"external_deps": [("cuda", None, "cuda-lazy")],
"preprocessor_flags": ["-DCUDA_AVAILABLE=1"],
"keep_gpu_sections": True,
"remote_execution": re_test_utils.remote_execution(
platform = "gpu-remote-execution",
),
} if backend_mode == "cuda" else {}

runtime.cxx_test(
name = "test_storage" + backend_suffix,
srcs = [
"test_storage.cpp",
],
deps = [
"//executorch/backends/aoti/slim/core:storage",
],
**backend_kwargs
)

runtime.cxx_test(
name = "test_slimtensor_basic",
Expand Down
Loading
Loading