Skip to content

Commit b056ca5

Browse files
authored
Call Tensor::alias() when returning an input Tensor (#1433)
* Call `Tensor::alias()` when concatenating a list of count 1 * multi_dot also * broadcast_tensors * Add release notes
1 parent 3701fc5 commit b056ca5

File tree

4 files changed

+4
-3
lines changed

4 files changed

+4
-3
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ __API Changes__:
1212

1313
#1374 Add accumulate to index_put_<br/>
1414
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.<br/>
15+
Returning an input tensor has been corrected, is now `alias()`.<br/>
1516

1617
# NuGet Version 0.105.0
1718

src/TorchSharp/LinearAlgebra.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ public static Tensor multi_dot(IList<Tensor> tensors)
440440
throw new ArgumentException(nameof(tensors));
441441
}
442442
if (tensors.Count == 1) {
443-
return tensors[0];
443+
return tensors[0].alias();
444444
}
445445

446446
using (var parray = new PinnedArray<IntPtr>()) {

src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public static Tensor cat(IList<Tensor> tensors, long dim = 0)
4040
switch (tensors.Count)
4141
{
4242
case <=0: throw new ArgumentException(nameof(tensors));
43-
case 1: return tensors[0];
43+
case 1: return tensors[0].alias();
4444
}
4545

4646
using var parray = new PinnedArray<IntPtr>();

src/TorchSharp/Tensor/torch.OtherOperations.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static IList<Tensor> broadcast_tensors(params Tensor[] tensors)
6363
throw new ArgumentException(nameof(tensors));
6464
}
6565
if (tensors.Length == 1) {
66-
return tensors;
66+
return new Tensor[] { tensors[0].alias() };
6767
}
6868

6969
IntPtr[] ptrArray;

0 commit comments

Comments
 (0)