diff --git a/Cargo.toml b/Cargo.toml index 9ef07b3e..bf431ba1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ image = { version = "0.24.5", optional = true } clap = { version = "4.2.4", features = ["derive"], optional = true } serde_json = { version = "1.0.96", optional = true } memmap2 = { version = "0.6.1", optional = true } +dlpark = { version = "0.2.2", default-features = false } [dev-dependencies] anyhow = "1" diff --git a/src/wrappers/tensor.rs b/src/wrappers/tensor.rs index 3d5cf7fc..748da451 100644 --- a/src/wrappers/tensor.rs +++ b/src/wrappers/tensor.rs @@ -6,10 +6,12 @@ use super::{ kind::Kind, }; use crate::TchError; +use dlpark::ffi::DLManagedTensor; use libc::{c_char, c_int, c_void}; use std::borrow::Borrow; use std::io::{Read, Seek, Write}; use std::path::Path; +use std::ptr::NonNull; use torch_sys::io::ReadStream; use torch_sys::*; @@ -338,6 +340,28 @@ impl Tensor { Tensor::f_run_backward(tensors, inputs, keep_graph, create_graph).unwrap() } + pub fn f_to_dlpack(self) -> Result, TchError> { + let ptr = unsafe_torch_err!(at_to_dlpack(self.c_tensor)); + Ok(unsafe { NonNull::new_unchecked(ptr) }) + } + + /// Convert `tch::Tensor` to dlpack. + /// If you want to access original tensor, please use `shallow_clone` to make a shared view first. + pub fn to_dlpack(self) -> NonNull { + self.f_to_dlpack().unwrap() + } + + pub fn f_from_dlpack(src: NonNull) -> Result { + let ptr = unsafe_torch_err!(at_from_dlpack(src.as_ptr().cast())); + Ok(unsafe { Self::from_ptr(ptr) }) + } + + /// Convert dlpack to `tch::Tensor`. + pub fn from_dlpack(src: NonNull) -> Self { + // Using ownership to prevent accessing the pointer twice. + Self::f_from_dlpack(src).unwrap() + } + /// Copies `numel` elements from `self` to `dst`. pub fn f_copy_data_u8(&self, dst: &mut [u8], numel: usize) -> Result<(), TchError> { let elt_size_in_bytes = self.f_kind()?.elt_size_in_bytes(); diff --git a/tests/tensor_tests.rs b/tests/tensor_tests.rs index 031c83db..9ca0d7b7 100644 --- a/tests/tensor_tests.rs +++ b/tests/tensor_tests.rs @@ -1,7 +1,9 @@ use anyhow::Result; +use dlpark::ffi::DLManagedTensor; use half::f16; use std::convert::{TryFrom, TryInto}; use std::f32; +use std::ptr::NonNull; use tch::{Device, TchError, Tensor}; mod test_utils; @@ -487,3 +489,25 @@ fn convert_ndarray() { let array_3d: ndarray::ArrayD = t_3d.as_ref().try_into().unwrap(); assert_eq!(array_3d.as_slice(), ndarray::array![[[0, 1], [2, 3]], [[4, 5], [6, 7]]].as_slice()); } + +#[test] +fn convert_dlpack() { + let t1: Tensor = Tensor::from_slice(&[0, 1, 2, 3]); + let dlpack = t1.shallow_clone().to_dlpack(); + let t2 = Tensor::from_dlpack(dlpack); + assert!(t1.allclose(&t2, 1e-5, 1e-5, false)); + assert_eq!(t1.data_ptr(), t2.data_ptr()); +} + +#[test] +fn from_vec_as_dlpack() { + let v: Vec = vec![0, 1, 2, 3, 4]; + let v_ptr = v.as_ptr(); + // TODO: upgrade dlpark version + let dlpack: DLManagedTensor = dlpark::tensor::ManagerCtx::from(v).into(); + let t1 = Tensor::from_dlpack(NonNull::from(&dlpack)); + let t2 = Tensor::arange(5, tch::kind::INT64_CPU); + assert!(t1.allclose(&t2, 1e-5, 1e-5, false)); + // Check if zero copy + assert_eq!(t1.data_ptr(), v_ptr as *const std::ffi::c_void as *mut std::ffi::c_void); +} diff --git a/torch-sys/Cargo.toml b/torch-sys/Cargo.toml index 3015668c..11b92e1c 100644 --- a/torch-sys/Cargo.toml +++ b/torch-sys/Cargo.toml @@ -13,6 +13,7 @@ categories = ["external-ffi-bindings", "science"] license = "MIT/Apache-2.0" [dependencies] +dlpark = { version = "0.2.2", default-features = false } libc = "0.2.0" [build-dependencies] diff --git a/torch-sys/libtch/torch_api.cpp b/torch-sys/libtch/torch_api.cpp index be7cdb6a..28a01620 100644 --- a/torch-sys/libtch/torch_api.cpp +++ b/torch-sys/libtch/torch_api.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -667,6 +668,14 @@ void at_run_backward(tensor *tensors, ) } +DLManagedTensor* at_to_dlpack(tensor src) { + return at::toDLPack(*src); +} + +tensor at_from_dlpack(DLManagedTensor* src) { + return new torch::Tensor(at::fromDLPack(src)); +} + optimizer ato_adam(double learning_rate, double beta1, double beta2, diff --git a/torch-sys/libtch/torch_api.h b/torch-sys/libtch/torch_api.h index e203bd21..0ed71fd3 100644 --- a/torch-sys/libtch/torch_api.h +++ b/torch-sys/libtch/torch_api.h @@ -4,6 +4,7 @@ #ifdef __cplusplus #include +#include #include using namespace std; extern thread_local char *torch_last_err; @@ -114,6 +115,9 @@ void at_run_backward(tensor *tensors, int keep_graph, int create_graph); +DLManagedTensor* at_to_dlpack(tensor src); +tensor at_from_dlpack(DLManagedTensor* src); + optimizer ato_adam(double learning_rate, double beta1, double beta2, diff --git a/torch-sys/src/lib.rs b/torch-sys/src/lib.rs index d7229afd..063e9daf 100644 --- a/torch-sys/src/lib.rs +++ b/torch-sys/src/lib.rs @@ -4,6 +4,7 @@ pub mod io; pub mod python; mod traits; +use dlpark::ffi::DLManagedTensor; use libc::{c_char, c_int, c_uchar, c_void, size_t}; pub use traits::{DoubleList, IntList, IntListOption}; @@ -60,6 +61,8 @@ extern "C" { keep_graph: c_int, create_graph: c_int, ); + pub fn at_to_dlpack(src: *mut C_tensor) -> *mut DLManagedTensor; + pub fn at_from_dlpack(src: *mut DLManagedTensor) -> *mut C_tensor; pub fn at_copy_data( arg: *mut C_tensor, vs: *const c_void,