Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,36 @@ runtime.cxx_library(
],
)

runtime.cxx_library(
name = "runtime_shims_slim",
srcs = [
"shims/memory_slim.cpp",
],
headers = [
"shims/memory_slim.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
visibility = ["@EXECUTORCH_CLIENTS"],
preprocessor_flags = ["-DCUDA_AVAILABLE=1"],
deps = [
"//executorch/backends/aoti/slim/core:slimtensor",
"//executorch/backends/aoti/slim/factory:empty",
"//executorch/backends/aoti/slim/factory:from_blob",
"//executorch/backends/aoti:common_shims",
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
],
nvcc_flags = get_nvcc_arch_args() + [
"-_NVCC_HOST_COMPILER_FLAG_",
"gcc",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
)

runtime.cxx_library(
name = "cuda_backend",
srcs = [
Expand Down
81 changes: 81 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/shims/memory_slim.h>

#include <executorch/backends/aoti/slim/factory/Empty.h>
#include <executorch/backends/aoti/slim/factory/FromBlob.h>
#include <executorch/backends/aoti/slim/util/ArrayRefUtil.h>
#include <executorch/runtime/platform/assert.h>

namespace executorch::backends::cuda {

namespace c10 = executorch::backends::aoti::slim::c10;
using c10::Device;
using c10::DeviceIndex;
using c10::DeviceType;
using c10::ScalarType;
using executorch::backends::aoti::slim::empty_strided;
using executorch::backends::aoti::slim::from_blob;
using executorch::backends::aoti::slim::IntArrayRef;

extern "C" {

AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor,
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size) {
// Unused parameters
(void)layout;
(void)opaque_metadata;
(void)opaque_metadata_size;

ET_CHECK_OR_RETURN_ERROR(
data != nullptr,
InvalidArgument,
"aoti_torch_create_tensor_from_blob_v2: data is null");

ET_CHECK_OR_RETURN_ERROR(
ret_new_tensor != nullptr,
InvalidArgument,
"aoti_torch_create_tensor_from_blob_v2: ret_new_tensor is null");

ET_CHECK_OR_RETURN_ERROR(
!(sizes_ptr == nullptr && ndim > 0),
InvalidArgument,
"aoti_torch_create_tensor_from_blob_v2: sizes_ptr is null but ndim > 0");

IntArrayRef sizes(sizes_ptr, static_cast<size_t>(ndim));
IntArrayRef strides(strides_ptr, static_cast<size_t>(ndim));

// Create the SlimTensor using from_blob (non-owning)
*ret_new_tensor = new Tensor(from_blob(
data,
sizes,
strides,
static_cast<ScalarType>(dtype),
Device(
static_cast<DeviceType>(device_type),
static_cast<DeviceIndex>(device_index)),
storage_offset));

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
62 changes: 62 additions & 0 deletions backends/cuda/runtime/shims/memory_slim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

#include <executorch/backends/aoti/export.h>
#include <executorch/backends/aoti/slim/core/SlimTensor.h>
#include <executorch/backends/aoti/slim/core/SlimTensorView-incl.h>
#include <executorch/runtime/core/error.h>

namespace executorch::backends::cuda {

using executorch::runtime::Error;
using AOTITorchError = Error;
using Tensor = executorch::backends::aoti::slim::SlimTensor;

extern "C" {

/**
* Creates a tensor object from an existing memory blob without copying the
* data. The tensor will wrap the provided memory and will not take ownership of
* it. When the tensor is deleted, the original memory will remain valid and
* must be freed by the caller.
*
* @param data Pointer to the memory blob to wrap (must not be null)
* @param ndim Number of dimensions in the tensor
* @param sizes_ptr Pointer to array of dimension sizes
* @param strides_ptr Pointer to array of strides for each dimension
* @param storage_offset Storage offset in number of elements
* @param dtype Data type identifier (matches PyTorch scalar types)
* @param device_type Device type (CPU=0, CUDA=1)
* @param device_index Device index
* @param ret_new_tensor Output parameter for the created tensor
* @param layout Tensor layout identifier (0=strided)
* @param opaque_metadata Optional metadata pointer (can be null)
* @param opaque_metadata_size Size of opaque metadata in bytes
* @return AOTITorchError error code (Error::Ok on success)
*/
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2(
void* data,
int64_t ndim,
const int64_t* sizes_ptr,
const int64_t* strides_ptr,
int64_t storage_offset,
int32_t dtype,
int32_t device_type,
int32_t device_index,
Tensor** ret_new_tensor,
int32_t layout,
const uint8_t* opaque_metadata,
int64_t opaque_metadata_size);

} // extern "C"

} // namespace executorch::backends::cuda
31 changes: 31 additions & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,40 @@ def cuda_shim_cpp_unittest(name):
),
)

def cuda_shim_slim_cpp_unittest(name):
"""Unittest for SlimTensor-based shim functions."""
cpp_unittest(
name = "test_" + name + "_slim",
srcs = [
"test_" + name + "_slim.cpp",
],
deps = [
"//executorch/backends/cuda/runtime:runtime_shims_slim",
"//executorch/backends/aoti:common_shims",
"//executorch/backends/aoti/slim/core:slimtensor",
"//executorch/backends/aoti/slim/factory:empty",
"//executorch/backends/aoti/slim/factory:from_blob",
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
],

external_deps = [
("cuda", None, "cuda-lazy"),
],
preprocessor_flags = ["-DCUDA_AVAILABLE=1"],
keep_gpu_sections = True,
remote_execution = re_test_utils.remote_execution(
platform = "gpu-remote-execution",
),
)

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""
# Original ETensor-based shim tests, will be removed after migration
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
Expand All @@ -41,3 +69,6 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_cpp_unittest("aoti_torch_item_bool")
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")

# SlimTensor-based shim tests
cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
Loading
Loading