diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 8ed0f4b85b3d..4ce30f81fee1 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -4686,6 +4686,48 @@ TEST_F(AtenXlaTensorTest, TestExpandAs) { ExpectCounterChanged("xla::expand", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestExpandSymInt) { + torch::Tensor x = torch::rand({5}); + torch::Tensor y = torch::nonzero(x); + int64_t y0_size = y.sizes()[0]; + torch::Tensor a = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor b = a.expand({y0_size, 3, 4}, /*implicit=*/false); + + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_x = CopyToDevice(x, device); + torch::Tensor xla_y = torch::nonzero(xla_x); + std::cout << "*Ran nonzero" << std::endl; + c10::SymInt xla_y0_size = xla_y.sym_sizes()[0]; + std::cout << "*xla_y.sym_sizes()[0] " << xla_y.sym_sizes()[0] + << " is_symbolic: " << xla_y.sym_sizes()[0].is_symbolic() + << std::endl; + std::cout << "*xla_y.sym_sizes()[1] " << xla_y.sym_sizes()[1] + << " is_symbolic: " << xla_y.sym_sizes()[1].is_symbolic() + << std::endl; + torch::Tensor xla_a = CopyToDevice(a, device); + std::cout << "*xla_a" << std::endl; + torch::Tensor xla_b = xla_a.expand_symint( + c10::SymIntArrayRef({xla_y0_size, c10::SymInt(3), c10::SymInt(4)}), + /*implicit=*/false); + std::cout << "*Ran expand_symint" << std::endl; + std::cout << "*xla_b.sym_sizes()[0]: " << xla_b.sym_sizes()[0] + << " is_symbolic: " << xla_b.sym_sizes()[0].is_symbolic() + << std::endl; + std::cout << "*xla_b.sym_sizes()[1]: " << xla_b.sym_sizes()[1] + << " is_symbolic: " << xla_b.sym_sizes()[1].is_symbolic() + << std::endl; + std::cout << "*xla_b.sym_sizes()[2]: " << xla_b.sym_sizes()[2] + << " is_symbolic: " << xla_b.sym_sizes()[2].is_symbolic() + << std::endl; + AllClose(b, xla_b); + std::cout << "*Ran allclose" << std::endl; + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + std::cout << "*Ran ExpectCounterNotChanged" << std::endl; + ExpectCounterChanged("xla::expand_symint", cpp_test::GetIgnoredCounters()); + std::cout << "*Ran ExpectCounterChanged" << std::endl; + }); +} + TEST_F(AtenXlaTensorTest, TestEye) { int n = 5; ForEachDevice([&](const torch::Device& device) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index a1082fc7c24f..2506dc837522 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1228,6 +1228,32 @@ at::Tensor XLANativeFunctions::expand(const at::Tensor& self, bridge::GetXlaTensor(self), torch::lazy::ToVector(size))); } +at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self, + c10::SymIntArrayRef size, + bool implicit) { + XLA_FN_COUNTER("xla::"); + SymIntElements size_elements = SymIntElements(size); + // Replace -1 concrete int dim with the true shape value + std::vector _sizes = torch::lazy::ToVector(size); + int64_t num_new_dimensions = _sizes.size() - self.dim(); + std::vector padded_self(num_new_dimensions, 0); + padded_self.insert(padded_self.end(), self.sizes().begin(), + self.sizes().end()); + for (const auto idx : c10::irange(_sizes.size())) { + if (!_sizes[idx].is_symbolic() && _sizes[idx].expect_int() == -1) { + size_elements.SetUpperBound(idx, padded_self[idx]); + } + } + at::ScalarType size_type = self.scalar_type(); + torch::lazy::Shape shape_ = torch::lazy::Shape(size_type, {5, 3, 4}); + torch::lazy::Shape dynamic_shape_ = + shape_.with_symbolic_dims(std::vector{true, false, false}); + return bridge::AtenFromXlaTensor(XLATensor::expand_symint( + bridge::GetXlaTensor(self), size_elements.GetNodes(), + size_elements.GetUpperBounds(), size_elements.GetDynamicDims(), + dynamic_shape_)); +} + at::Tensor& XLANativeFunctions::exponential_( at::Tensor& self, double lambd, c10::optional generator) { XLA_FN_COUNTER("xla::"); @@ -2061,12 +2087,26 @@ std::tuple XLANativeFunctions::nll_loss_forward( at::Tensor XLANativeFunctions::nonzero(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - // Initially make XLA handled nonzero() handling experimental, and opt-in. - if (!DebugUtil::ExperimentEnabled("nonzero")) { - return at::native::call_fallback_fn<&xla_cpu_fallback, - ATEN_OP(nonzero)>::call(self); - } - return bridge::AtenFromXlaTensor(XLATensor::nonzero(self_tensor)); + + /* + * REMOVE THIS SECTION TO ENABLE CREATING DYNAMIC SHAPES FOR POC + * TODO: REMOVE THIS SECTION AFTER POC SUCCEEDS: + * https://github.com/pytorch/xla/pull/3558 + * + * // Initially make XLA handled nonzero() handling experimental, and opt-in. + * if (!DebugUtil::ExperimentEnabled("nonzero")) { + * return at::native::call_fallback_fn<&xla_cpu_fallback, + * ATEN_OP(nonzero)>::call(self); + * } + */ + at::ScalarType size_type = self.scalar_type(); + torch::lazy::Shape shape_ = torch::lazy::Shape( + size_type, {xla::ShapeUtil::ElementsIn(self_tensor->shape()), + self_tensor->shape().get().rank()}); + torch::lazy::Shape dynamic_shape_ = + shape_.with_symbolic_dims(std::vector{true, false}); + return bridge::AtenFromXlaTensor( + XLATensor::nonzero(self_tensor, dynamic_shape_)); } at::Tensor XLANativeFunctions::norm(const at::Tensor& self, diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index b5d0ef2457b8..0db62dfaa673 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -109,6 +109,20 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) { return XlaHelpers::DynamicReshape(input, output_sizes); } +xla::XlaOp SetDimensionSizes(xla::XlaOp input, + absl::Span output_sizes) { + std::cout << "In SetDimensionSizes" << std::endl; + for (int i = 0; i < output_sizes.size(); i++) { + std::cout << "In SetDimensionSizes loop " << i << std::endl; + xla::Shape dim_shape = XlaHelpers::ShapeOfXlaOp(output_sizes[i]); + if (dim_shape.is_dynamic()) { + std::cout << "xla::SetDimensionSize " << i << std::endl; + input = xla::SetDimensionSize(input, output_sizes[i], i); + } + } + return input; +} + xla::XlaOp BuildExpand(xla::XlaOp input, absl::Span output_sizes) { auto input_sizes = XlaHelpers::SizesOfXlaOp(input); diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index 07f773605fb2..21a575e07275 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -32,6 +32,10 @@ xla::XlaOp SqueezeTrivialDimension(xla::XlaOp input, int64_t dim); // Squeezes out the trivial (size 1) dimensions of the input. xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input); +// Update Output Dynamic Dimensions based on input size() +xla::XlaOp SetDimensionSizes(xla::XlaOp input, + absl::Span output_sizes); + // Creates a new tensor with the singleton dimensions expanded to the specified // output sizes. xla::XlaOp BuildExpand(xla::XlaOp input, diff --git a/torch_xla/csrc/ops/expand_dynamic.cpp b/torch_xla/csrc/ops/expand_dynamic.cpp new file mode 100644 index 000000000000..d41c57b05ff5 --- /dev/null +++ b/torch_xla/csrc/ops/expand_dynamic.cpp @@ -0,0 +1,63 @@ +#include "torch_xla/csrc/ops/expand_dynamic.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "torch/csrc/lazy/core/helpers.h" +#include "torch/csrc/lazy/core/util.h" +#include "torch_xla/csrc/data_ops.h" +#include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/infer_output_shape.h" + +namespace torch_xla { +namespace { +xla::Shape NodeOutputShape(const torch::lazy::Value& input, + const std::vector upper_bounds, + const std::vector dynamic_dims) { + return xla::ShapeUtil::MakeShape(GetXlaShape(input).element_type(), + {upper_bounds}, {dynamic_dims}); +} + +std::vector GetValues( + const torch::lazy::Value& input, + const std::vector dimensions) { + std::vector values = dimensions; + values.insert(values.begin(), input); + return values; +} + +} // namespace + +ExpandDynamic::ExpandDynamic(const torch::lazy::Value& input, + const std::vector& dimensions, + const std::vector upper_bounds, + const std::vector dynamic_dims, + const torch::lazy::Shape& dynamic_shapes) + : XlaNode( + torch::lazy::OpKind(at::aten::expand), GetValues(input, dimensions), + {dynamic_shapes}, + [&]() { return NodeOutputShape(input, upper_bounds, dynamic_dims); }, + /*num_outputs=*/1, torch::lazy::MHash(upper_bounds)), + upper_bounds_(std::move(upper_bounds)), + dynamic_dims_(std::move(dynamic_dims)) {} + +XlaOpVector ExpandDynamic::Lower(LoweringContext* loctx) const { + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + std::vector size_ops; + for (int i = 1; i < operands().size(); i++) { + size_ops.push_back(loctx->GetOutputOp(operand(i))); + } + xla::XlaOp output = + SetDimensionSizes(BuildExpand(input, upper_bounds_), size_ops); + return ReturnOp(output, loctx); +} + +std::string ExpandDynamic::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() << ", size=(" << absl::StrJoin(upper_bounds_, ", ") + << ")" + << ", dynamic_dims=(" << absl::StrJoin(dynamic_dims_, ", ") << ")"; + return ss.str(); +} + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/expand_dynamic.h b/torch_xla/csrc/ops/expand_dynamic.h new file mode 100644 index 000000000000..5292c712a88a --- /dev/null +++ b/torch_xla/csrc/ops/expand_dynamic.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "torch_xla/csrc/ir.h" + +namespace torch_xla { + +class ExpandDynamic : public XlaNode { + public: + ExpandDynamic(const torch::lazy::Value& input, + const std::vector& dimensions, + const std::vector upper_bounds, + const std::vector dynamic_dims, + const torch::lazy::Shape& dynamic_shapes); + + std::string ToString() const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + const std::vector& size() const { return upper_bounds_; }; + + const bool IsDynamic(int index) const { return dynamic_dims_[index]; }; + + private: + std::vector upper_bounds_; + std::vector dynamic_dims_; +}; + +} // namespace torch_xla diff --git a/torch_xla/csrc/ops/nonzero.cpp b/torch_xla/csrc/ops/nonzero.cpp index 72c2524ad617..d9f040d4a262 100644 --- a/torch_xla/csrc/ops/nonzero.cpp +++ b/torch_xla/csrc/ops/nonzero.cpp @@ -21,13 +21,15 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input) { } // namespace -NonZero::NonZero(const torch::lazy::Value& input) - : XlaNode(torch::lazy::OpKind(at::aten::nonzero), {input}, +NonZero::NonZero(const torch::lazy::Value& input, + const torch::lazy::Shape& dynamic_shape) + : XlaNode(torch::lazy::OpKind(at::aten::nonzero), {input}, dynamic_shape, NodeOutputShape(input), - /*num_outputs=*/2) {} + /*num_outputs=*/2), + dynamic_shape_(dynamic_shape) {} torch::lazy::NodePtr NonZero::Clone(torch::lazy::OpList operands) const { - return torch::lazy::MakeNode(operands.at(0)); + return torch::lazy::MakeNode(operands.at(0), dynamic_shape_); } XlaOpVector NonZero::Lower(LoweringContext* loctx) const { diff --git a/torch_xla/csrc/ops/nonzero.h b/torch_xla/csrc/ops/nonzero.h index ae1e3148833e..7198f9262ff2 100644 --- a/torch_xla/csrc/ops/nonzero.h +++ b/torch_xla/csrc/ops/nonzero.h @@ -9,11 +9,15 @@ namespace torch_xla { // it gets its own IR node class. class NonZero : public XlaNode { public: - NonZero(const torch::lazy::Value& input); + NonZero(const torch::lazy::Value& input, + const torch::lazy::Shape& dynamic_shape); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; XlaOpVector Lower(LoweringContext* loctx) const override; + + private: + torch::lazy::Shape dynamic_shape_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 990ed7d4a36b..7632d45b448d 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -21,6 +21,7 @@ #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_util.h" #include "torch_xla/csrc/lowering_context.h" +#include "torch_xla/csrc/ops/dynamic_ir.h" #include "torch_xla/csrc/view.h" #include "torch_xla/csrc/xla_sharding_util.h" @@ -567,6 +568,13 @@ class XLATensor : public c10::intrusive_ptr_target { static XLATensorPtr expand(const XLATensorPtr& input, std::vector size); + static XLATensorPtr expand_symint( + const XLATensorPtr& input, + const std::vector& size_nodes, + const std::vector upper_bounds, + const std::vector dynamic_dims, + const torch::lazy::Shape dynamic_shapes); + static void exponential_(XLATensorPtr& input, double lambd); // Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. @@ -881,7 +889,8 @@ class XLATensor : public c10::intrusive_ptr_target { const XLATensorPtr& score_threshold, const XLATensorPtr& iou_threshold, int64_t output_size); - static XLATensorPtr nonzero(const XLATensorPtr& input); + static XLATensorPtr nonzero(const XLATensorPtr& input, + const torch::lazy::Shape& dynamic_shape); static XLATensorPtr norm(const XLATensorPtr& input, const c10::optional& p, diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index b968a0569a32..a9e9b8747fed 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -6,11 +6,15 @@ #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "torch/csrc/lazy/backend/backend_interface.h" +#include "torch/csrc/lazy/core/tensor.h" #include "torch/csrc/lazy/core/tensor_util.h" #include "torch/csrc/lazy/core/util.h" #include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/ir_builder.h" #include "torch_xla/csrc/layout_manager.h" +#include "torch_xla/csrc/ops/dynamic_ir.h" #include "torch_xla/csrc/tensor_util.h" namespace torch_xla { @@ -115,9 +119,10 @@ at::IntArrayRef XLATensorImpl::sizes_custom() const { } c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const { - auto sizes = sizes_custom(); - return c10::SymIntArrayRef(reinterpret_cast(sizes.data()), - sizes.size()); + const_cast(this)->SetupSymSizeProperties(); + return c10::SymIntArrayRef( + reinterpret_cast(sym_sizes_.data()), + sym_sizes_.size()); } c10::SymInt XLATensorImpl::sym_numel_custom() const { @@ -178,6 +183,35 @@ void XLATensorImpl::SetupSizeProperties() { } } +void XLATensorImpl::SetupSymSizeProperties() { + size_t generation = tensor_->generation(); + if (generation != generation_) { + // Fill up the basic dimension data members which the base class + // implementation uses in its APIs. + auto shape = tensor_->shape(); + auto rank = tensor_->shape().get().rank(); + c10::SmallVector sym_sizes; + numel_ = 1; + XLAIrBuilder a = XLAIrBuilder(); + for (auto i : c10::irange(rank)) { + if (tensor_->shape().get().is_dynamic_dimension(i)) { + auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i); + auto symint_node = + c10::make_intrusive(dim_node); + auto sn = symint_node->toSymInt(); + sym_sizes_.push_back(sn); + /*TODO(miladm): verify numel_ calculation after adding a dynamic op + */ + numel_ *= dynamic_cast(dim_node.get())->getStaticValue(); + } else { + sym_sizes_.push_back(c10::SymInt(tensor_->shape().get().dimensions(i))); + numel_ *= tensor_->shape().get().dimensions(i); + } + } + generation_ = generation; + } +} + caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) { return c10::scalarTypeToTypeMeta(tensor.dtype()); } diff --git a/torch_xla/csrc/tensor_impl.h b/torch_xla/csrc/tensor_impl.h index ce8be9389594..41e84638c010 100644 --- a/torch_xla/csrc/tensor_impl.h +++ b/torch_xla/csrc/tensor_impl.h @@ -3,8 +3,13 @@ #include #include #include +#include +#include +#include +#include #include "torch_xla/csrc/tensor.h" +#include "torch_xla/csrc/xla_backend_impl.h" namespace torch_xla { @@ -52,10 +57,12 @@ class XLATensorImpl : public c10::TensorImpl { private: void SetupSizeProperties(); + void SetupSymSizeProperties(); static caffe2::TypeMeta GetTypeMeta(const XLATensor& tensor); XLATensorPtr tensor_; + std::vector sym_sizes_; size_t generation_ = 0; }; diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index d0a5f86a2d58..8df84b24c490 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -50,6 +50,7 @@ #include "torch_xla/csrc/ops/diagonal.h" #include "torch_xla/csrc/ops/discrete_uniform.h" #include "torch_xla/csrc/ops/expand.h" +#include "torch_xla/csrc/ops/expand_dynamic.h" #include "torch_xla/csrc/ops/exponential.h" #include "torch_xla/csrc/ops/flip.h" #include "torch_xla/csrc/ops/gather.h" @@ -1164,6 +1165,21 @@ XLATensorPtr XLATensor::expand(const XLATensorPtr& input, GetExpandDimensions(input_shape.get(), std::move(size)))); } +XLATensorPtr XLATensor::expand_symint( + const XLATensorPtr& input, + const std::vector& size_nodes, + const std::vector upper_bounds, + const std::vector dynamic_dims, + const torch::lazy::Shape dynamic_shapes) { + std::vector size_values; + for (auto& size_node : size_nodes) { + size_values.push_back(torch::lazy::Value(size_node, 0)); + } + return input->CreateFrom(torch::lazy::MakeNode( + input->GetIrValue(), size_values, std::move(upper_bounds), + std::move(dynamic_dims), dynamic_shapes)); +} + void XLATensor::exponential_(XLATensorPtr& input, double lambd) { auto input_shape = input->shape(); input->SetInPlaceIrValue(torch::lazy::MakeNode( @@ -1930,9 +1946,10 @@ std::pair XLATensor::nms( at::ScalarType::Int)); } -XLATensorPtr XLATensor::nonzero(const XLATensorPtr& input) { +XLATensorPtr XLATensor::nonzero(const XLATensorPtr& input, + const torch::lazy::Shape& dynamic_shape) { torch::lazy::NodePtr node = - torch::lazy::MakeNode(input->GetIrValue()); + torch::lazy::MakeNode(input->GetIrValue(), dynamic_shape); return input->CreateFrom(torch::lazy::Value(node, 0), at::ScalarType::Long); } diff --git a/torch_xla/csrc/torch_util.h b/torch_xla/csrc/torch_util.h index 21c1d5213544..acf408adc874 100644 --- a/torch_xla/csrc/torch_util.h +++ b/torch_xla/csrc/torch_util.h @@ -17,6 +17,7 @@ namespace torch_xla { struct SymIntElements { public: SymIntElements(c10::SymInt& size) { SetSymIntNodeElements(size); } + SymIntElements(c10::SymIntArrayRef& size) { std::vector _sizes = torch::lazy::ToVector(size); for (auto& _size : _sizes) { diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index f23d01ad6635..58ce986af7e2 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -162,6 +162,7 @@ supported: - empty.SymInt - empty_strided - expand + - expand.SymInt - exponential_ - eye.m_out - eye.out