Skip to content

Commit 5681853

Browse files
committed
updated expand.SymInt API in PT/XLA
1 parent d9e19d0 commit 5681853

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4655,7 +4655,7 @@ TEST_F(AtenXlaTensorTest, TestExpandSymInt) {
46554655
torch::Tensor xla_y = torch::nonzero(xla_x);
46564656
c10::SymInt xla_y0_size = xla_y.sym_sizes()[0];
46574657
torch::Tensor xla_a = CopyToDevice(a, device);
4658-
torch::Tensor xla_b = xla_a.expand(
4658+
torch::Tensor xla_b = xla_a.expand_symint(
46594659
c10::SymIntArrayRef({xla_y0_size, c10::SymInt(3), c10::SymInt(4)}),
46604660
/*implicit=*/false);
46614661
AllClose(b, xla_b);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,7 @@ at::Tensor XLANativeFunctions::expand(const at::Tensor& self,
13291329
bridge::GetXlaTensor(self), torch::lazy::ToVector<int64_t>(size)));
13301330
}
13311331

1332-
at::Tensor XLANativeFunctions::expand(const at::Tensor& self,
1332+
at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self,
13331333
c10::SymIntArrayRef size, bool implicit) {
13341334
XLA_FN_COUNTER("xla::");
13351335
SymIntElements size_elements = SymIntElements(size);

0 commit comments

Comments
 (0)