diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a5ac7dcb6ba6..282b43e0dd45 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2186,7 +2186,10 @@ at::Tensor XLANativeFunctions::nonzero(const at::Tensor& self) { return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP(nonzero)>::call(self); } - return bridge::AtenFromXlaTensor(XLATensor::nonzero(self_tensor)); + std::vector dynamic_shapes_ = + torch::lazy::compute_shape_nonzero(self); + return bridge::AtenFromXlaTensor( + XLATensor::nonzero(self_tensor, dynamic_shapes_[0])); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, diff --git a/torch_xla/csrc/ops/nonzero.cpp b/torch_xla/csrc/ops/nonzero.cpp index 72c2524ad617..d9f040d4a262 100644 --- a/torch_xla/csrc/ops/nonzero.cpp +++ b/torch_xla/csrc/ops/nonzero.cpp @@ -21,13 +21,15 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input) { } // namespace -NonZero::NonZero(const torch::lazy::Value& input) - : XlaNode(torch::lazy::OpKind(at::aten::nonzero), {input}, +NonZero::NonZero(const torch::lazy::Value& input, + const torch::lazy::Shape& dynamic_shape) + : XlaNode(torch::lazy::OpKind(at::aten::nonzero), {input}, dynamic_shape, NodeOutputShape(input), - /*num_outputs=*/2) {} + /*num_outputs=*/2), + dynamic_shape_(dynamic_shape) {} torch::lazy::NodePtr NonZero::Clone(torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0)); + return torch::lazy::MakeNode(operands.at(0), dynamic_shape_); } XlaOpVector NonZero::Lower(LoweringContext* loctx) const { diff --git a/torch_xla/csrc/ops/nonzero.h b/torch_xla/csrc/ops/nonzero.h index ae1e3148833e..7198f9262ff2 100644 --- a/torch_xla/csrc/ops/nonzero.h +++ b/torch_xla/csrc/ops/nonzero.h @@ -9,11 +9,15 @@ namespace torch_xla { // it gets its own IR node class. class NonZero : public XlaNode { public: - NonZero(const torch::lazy::Value& input); + NonZero(const torch::lazy::Value& input, + const torch::lazy::Shape& dynamic_shape); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + torch::lazy::Shape dynamic_shape_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 4f6085c2caf4..9981cdbb448f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -890,7 +890,8 @@ class XLATensor : public c10::intrusive_ptr_target { const XLATensorPtr& score_threshold, const XLATensorPtr& iou_threshold, int64_t output_size); - static XLATensorPtr nonzero(const XLATensorPtr& input); + static XLATensorPtr nonzero(const XLATensorPtr& input, + const torch::lazy::Shape& dynamic_shape); static XLATensorPtr norm(const XLATensorPtr& input, const c10::optional& p, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7679f27228c6..8be25bb5456d 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2003,9 +2003,10 @@ std::pair XLATensor::nms( at::ScalarType::Int)); } -XLATensorPtr XLATensor::nonzero(const XLATensorPtr& input) { +XLATensorPtr XLATensor::nonzero(const XLATensorPtr& input, + const torch::lazy::Shape& dynamic_shape) { torch::lazy::NodePtr node = - torch::lazy::MakeNode(input->GetIrValue()); + torch::lazy::MakeNode(input->GetIrValue(), dynamic_shape); return input->CreateFrom(torch::lazy::Value(node, 0), at::ScalarType::Long); }