Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make shape() and stride() return &[i64] to avoid data copy #749

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/custom-optimizer/sparse_adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl SparseAdam {
.unwrap()
.trainable_variables
.iter()
.map(|x| Buffer::new(&x.tensor.size()))
.map(|x| Buffer::new(x.tensor.size()))
.collect();

SparseAdam { lr, beta1, beta2, eps, force_sparse, vars, buffers }
Expand Down
4 changes: 2 additions & 2 deletions examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl CausalSelfAttention {
}

fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Tensor {
let mut dims = x.size();
let mut dims = x.size().to_vec();
let v = dims.pop().unwrap();
dims.push(v / 2);
dims.push(2);
Expand All @@ -177,7 +177,7 @@ impl CausalSelfAttention {
let im = &re_x * &im_f + &im_x * &re_f;
let rope = Tensor::cat(&[&re, &im], -1);
// TODO: Add the flatten op.
let mut dims = rope.size();
let mut dims = rope.size().to_vec();
let v1 = dims.pop().unwrap();
let v2 = dims.pop().unwrap();
dims.push(v1 * v2);
Expand Down
2 changes: 1 addition & 1 deletion examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ impl Module for AttentionBlock {

let xs = attention_probs.matmul(&value_states);
let xs = xs.permute([0, 2, 1, 3]).contiguous();
let mut new_xs_shape = xs.size();
let mut new_xs_shape = xs.size().to_vec();
new_xs_shape.pop();
new_xs_shape.pop();
new_xs_shape.push(self.channels);
Expand Down
4 changes: 2 additions & 2 deletions src/nn/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ impl Init {
let _ = tensor.uniform_(lo, up);
}
Init::Kaiming { dist, fan, non_linearity } => {
let fan = fan.for_weight_dims(&tensor.size());
let fan = fan.for_weight_dims(tensor.size());
let gain = non_linearity.gain();
let std = gain / (fan as f64).sqrt();
match dist {
Expand All @@ -197,7 +197,7 @@ impl Init {
tensor.copy_(&(tensor.randn_like() * stdev + mean));
}
Init::Orthogonal { gain } => {
let q = f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device()).unwrap();
let q = f_init(Init::Orthogonal { gain }, tensor.size(), tensor.device()).unwrap();
crate::no_grad(|| tensor.view_as(&q).copy_(&q));
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/nn/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ impl<'a> Path<'a> {
/// The variable uses a float tensor initialized by copying some
/// given tensor.
pub fn f_var_copy(&self, name: &str, t: &Tensor) -> Result<Tensor, TchError> {
let mut v = self.f_zeros(name, &t.size())?;
let mut v = self.f_zeros(name, t.size())?;
crate::no_grad(|| v.f_copy_(t))?;
Ok(v)
}
Expand Down Expand Up @@ -802,7 +802,7 @@ impl<'a> Entry<'a> {

/// Returns the existing entry if, otherwise create a new variable.
pub fn or_var_copy(self, tensor: &Tensor) -> Tensor {
let mut v = self.or_zeros(&tensor.size());
let mut v = self.or_zeros(tensor.size());
crate::no_grad(|| v.copy_(tensor));
v
}
Expand Down
4 changes: 2 additions & 2 deletions src/tensor/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl std::fmt::Debug for Tensor {
| Kind::ComplexFloat
| Kind::ComplexDouble => (false, false),
};
match (self.size().as_slice(), is_int, is_float) {
match (self.size(), is_int, is_float) {
([], true, false) => write!(f, "[{}]", i64::try_from(self).unwrap()),
([s], true, false) if *s < 10 => {
write!(f, "{:?}", Vec::<i64>::try_from(self).unwrap())
Expand Down Expand Up @@ -166,7 +166,7 @@ trait TensorFormatter {
let size = t.size();
let edge_items = po.edge_items as i64;
write!(f, "[")?;
match size.as_slice() {
match size {
[] => self.fmt(Self::value(t), max_w, f)?,
[v] if summarize && *v > 2 * edge_items => {
for v in Self::values(&t.slice(0, None, Some(edge_items), 1)).into_iter() {
Expand Down
2 changes: 1 addition & 1 deletion src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl Tensor {
/// [N1, ..., Nk, labels]. The returned tensor uses float values.
/// Elements of the input vector are expected to be between 0 and labels-1.
pub fn onehot(&self, labels: i64) -> Tensor {
Tensor::zeros([self.size(), vec![labels]].concat(), (Kind::Float, self.device()))
Tensor::zeros([self.size(), &[labels]].concat(), (Kind::Float, self.device()))
.scatter_value_(-1, &self.unsqueeze(-1).to_kind(Kind::Int64), 1.0)
}

Expand Down
2 changes: 1 addition & 1 deletion src/tensor/npy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl crate::Tensor {
f.write_all(NPY_MAGIC_STRING)?;
f.write_all(&[1u8, 0u8])?;
let kind = self.f_kind()?;
let header = Header { descr: kind, fortran_order: false, shape: self.size() };
let header = Header { descr: kind, fortran_order: false, shape: self.size().to_vec() };
let mut header = header.to_string()?;
let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
for _ in 0..pad % 16 {
Expand Down
2 changes: 1 addition & 1 deletion src/vision/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn load_from_memory(img_data: &[u8]) -> Result<Tensor, TchError> {
/// 0 to 255.
pub fn save<T: AsRef<Path>>(t: &Tensor, path: T) -> Result<(), TchError> {
let t = t.to_kind(crate::Kind::Uint8);
match t.size().as_slice() {
match t.size() {
[1, _, _, _] => save_hwc(&chw_to_hwc(&t.squeeze_dim(0)).to_device(Device::Cpu), path),
[_, _, _] => save_hwc(&chw_to_hwc(&t).to_device(Device::Cpu), path),
sz => Err(TchError::FileFormat(format!("unexpected size for image tensor {sz:?}"))),
Expand Down
2 changes: 1 addition & 1 deletion src/vision/imagenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ pub const CLASSES: [&str; 1000] = [

/// Returns the top k classes as well as the associated scores.
pub fn top(tensor: &Tensor, k: i64) -> Vec<(f64, String)> {
let tensor = match tensor.size().as_slice() {
let tensor = match tensor.size() {
[CLASS_COUNT] => tensor.shallow_clone(),
[1, CLASS_COUNT] => tensor.view((CLASS_COUNT,)),
[1, 1, CLASS_COUNT] => tensor.view((CLASS_COUNT,)),
Expand Down
38 changes: 18 additions & 20 deletions src/wrappers/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,112 +80,110 @@ impl Tensor {
}

/// Returns the shape of the input tensor.
pub fn size(&self) -> Vec<i64> {
pub fn size(&self) -> &[i64] {
let dim = unsafe_torch!(at_dim(self.c_tensor));
let mut sz = vec![0i64; dim];
unsafe_torch!(at_shape(self.c_tensor, sz.as_mut_ptr()));
sz
let ptr = unsafe_torch!(at_shape(self.c_tensor));
unsafe { std::slice::from_raw_parts(ptr, dim) }
}

/// Returns the tensor size for single dimension tensors.
pub fn size1(&self) -> Result<i64, TchError> {
match self.size().as_slice() {
match self.size() {
&[s0] => Ok(s0),
size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
}
}

/// Returns the tensor sizes for two dimension tensors.
pub fn size2(&self) -> Result<(i64, i64), TchError> {
match self.size().as_slice() {
match self.size() {
&[s0, s1] => Ok((s0, s1)),
size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
}
}

/// Returns the tensor sizes for three dimension tensors.
pub fn size3(&self) -> Result<(i64, i64, i64), TchError> {
match self.size().as_slice() {
match self.size() {
&[s0, s1, s2] => Ok((s0, s1, s2)),
size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
}
}

/// Returns the tensor sizes for four dimension tensors.
pub fn size4(&self) -> Result<(i64, i64, i64, i64), TchError> {
match self.size().as_slice() {
match self.size() {
&[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
}
}

/// Returns the tensor sizes for five dimension tensors.
pub fn size5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
match self.size().as_slice() {
match self.size() {
&[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
}
}

/// Returns the tensor sizes for six dimension tensors.
pub fn size6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
match self.size().as_slice() {
match self.size() {
&[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
}
}

/// Returns the stride of the input tensor.
pub fn stride(&self) -> Vec<i64> {
pub fn stride(&self) -> &[i64] {
let dim = unsafe_torch!(at_dim(self.c_tensor));
let mut sz = vec![0i64; dim];
unsafe_torch!(at_stride(self.c_tensor, sz.as_mut_ptr()));
sz
let ptr = unsafe_torch!(at_stride(self.c_tensor));
unsafe { std::slice::from_raw_parts(ptr, dim) }
}

/// Returns the tensor strides for single dimension tensors.
pub fn stride1(&self) -> Result<i64, TchError> {
match self.stride().as_slice() {
match self.stride() {
&[s0] => Ok(s0),
size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
}
}

/// Returns the tensor strides for two dimension tensors.
pub fn stride2(&self) -> Result<(i64, i64), TchError> {
match self.stride().as_slice() {
match self.stride() {
&[s0, s1] => Ok((s0, s1)),
size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
}
}

/// Returns the tensor strides for three dimension tensors.
pub fn stride3(&self) -> Result<(i64, i64, i64), TchError> {
match self.stride().as_slice() {
match self.stride() {
&[s0, s1, s2] => Ok((s0, s1, s2)),
size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
}
}

/// Returns the tensor strides for four dimension tensors.
pub fn stride4(&self) -> Result<(i64, i64, i64, i64), TchError> {
match self.stride().as_slice() {
match self.stride() {
&[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
}
}

/// Returns the tensor strides for five dimension tensors.
pub fn stride5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
match self.stride().as_slice() {
match self.stride() {
&[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
}
}

/// Returns the tensor strides for six dimension tensors.
pub fn stride6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
match self.stride().as_slice() {
match self.stride() {
&[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
}
Expand Down
18 changes: 8 additions & 10 deletions torch-sys/libtch/torch_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,16 @@ size_t at_dim(tensor t) {
return -1;
}

void at_shape(tensor t, int64_t *dims) {
PROTECT(
int i = 0;
for (int64_t dim : t->sizes()) dims[i++] = dim;
)
int64_t *at_shape(tensor t) {
// Follow https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp#L256
PROTECT(return const_cast<int64_t*>(t->sizes().data());)
return nullptr;
}

void at_stride(tensor t, int64_t *dims) {
PROTECT(
int i = 0;
for (int64_t dim: t->strides()) dims[i++] = dim;
)
int64_t * at_stride(tensor t) {
// Follow https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/DLConvertor.cpp#L256
PROTECT(return const_cast<int64_t*>(t->strides().data());)
return nullptr;
}

int at_scalar_type(tensor t) {
Expand Down
4 changes: 2 additions & 2 deletions torch-sys/libtch/torch_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ int at_is_sparse(tensor);
int at_is_contiguous(tensor);
int at_device(tensor);
size_t at_dim(tensor);
void at_shape(tensor, int64_t *);
void at_stride(tensor, int64_t *);
int64_t *at_shape(tensor);
int64_t *at_stride(tensor);
int at_scalar_type(tensor);

void at__amp_non_finite_check_and_unscale(tensor, tensor, tensor);
Expand Down
4 changes: 2 additions & 2 deletions torch-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ extern "C" {
pub fn at_dim(arg: *mut C_tensor) -> size_t;
pub fn at_get(arg: *mut C_tensor, index: c_int) -> *mut C_tensor;
pub fn at_requires_grad(arg: *mut C_tensor) -> c_int;
pub fn at_shape(arg: *mut C_tensor, sz: *mut i64);
pub fn at_stride(arg: *mut C_tensor, sz: *mut i64);
pub fn at_shape(arg: *mut C_tensor) -> *mut i64;
pub fn at_stride(arg: *mut C_tensor) -> *mut i64;
pub fn at_double_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> f64;
pub fn at_int64_value_at_indexes(arg: *mut C_tensor, idx: *const i64, idx_len: c_int) -> i64;
pub fn at_get_num_interop_threads() -> c_int;
Expand Down