Skip to content

Extend nonzero to provide dynamism to its torch::lazy::shape member #3715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4686,6 +4686,48 @@ TEST_F(AtenXlaTensorTest, TestExpandAs) {
ExpectCounterChanged("xla::expand", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestExpandSymInt) {
torch::Tensor x = torch::rand({5});
torch::Tensor y = torch::nonzero(x);
int64_t y0_size = y.sizes()[0];
torch::Tensor a = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = a.expand({y0_size, 3, 4}, /*implicit=*/false);

ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_x = CopyToDevice(x, device);
torch::Tensor xla_y = torch::nonzero(xla_x);
std::cout << "*Ran nonzero" << std::endl;
c10::SymInt xla_y0_size = xla_y.sym_sizes()[0];
std::cout << "*xla_y.sym_sizes()[0] " << xla_y.sym_sizes()[0]
<< " is_symbolic: " << xla_y.sym_sizes()[0].is_symbolic()
<< std::endl;
std::cout << "*xla_y.sym_sizes()[1] " << xla_y.sym_sizes()[1]
<< " is_symbolic: " << xla_y.sym_sizes()[1].is_symbolic()
<< std::endl;
torch::Tensor xla_a = CopyToDevice(a, device);
std::cout << "*xla_a" << std::endl;
torch::Tensor xla_b = xla_a.expand_symint(
c10::SymIntArrayRef({xla_y0_size, c10::SymInt(3), c10::SymInt(4)}),
/*implicit=*/false);
std::cout << "*Ran expand_symint" << std::endl;
std::cout << "*xla_b.sym_sizes()[0]: " << xla_b.sym_sizes()[0]
<< " is_symbolic: " << xla_b.sym_sizes()[0].is_symbolic()
<< std::endl;
std::cout << "*xla_b.sym_sizes()[1]: " << xla_b.sym_sizes()[1]
<< " is_symbolic: " << xla_b.sym_sizes()[1].is_symbolic()
<< std::endl;
std::cout << "*xla_b.sym_sizes()[2]: " << xla_b.sym_sizes()[2]
<< " is_symbolic: " << xla_b.sym_sizes()[2].is_symbolic()
<< std::endl;
AllClose(b, xla_b);
std::cout << "*Ran allclose" << std::endl;
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
std::cout << "*Ran ExpectCounterNotChanged" << std::endl;
ExpectCounterChanged("xla::expand_symint", cpp_test::GetIgnoredCounters());
std::cout << "*Ran ExpectCounterChanged" << std::endl;
});
}

TEST_F(AtenXlaTensorTest, TestEye) {
int n = 5;
ForEachDevice([&](const torch::Device& device) {
Expand Down
52 changes: 46 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,32 @@ at::Tensor XLANativeFunctions::expand(const at::Tensor& self,
bridge::GetXlaTensor(self), torch::lazy::ToVector<int64_t>(size)));
}

at::Tensor XLANativeFunctions::expand_symint(const at::Tensor& self,
c10::SymIntArrayRef size,
bool implicit) {
XLA_FN_COUNTER("xla::");
SymIntElements size_elements = SymIntElements(size);
// Replace -1 concrete int dim with the true shape value
std::vector<c10::SymInt> _sizes = torch::lazy::ToVector<c10::SymInt>(size);
int64_t num_new_dimensions = _sizes.size() - self.dim();
std::vector<int64_t> padded_self(num_new_dimensions, 0);
padded_self.insert(padded_self.end(), self.sizes().begin(),
self.sizes().end());
for (const auto idx : c10::irange(_sizes.size())) {
if (!_sizes[idx].is_symbolic() && _sizes[idx].expect_int() == -1) {
size_elements.SetUpperBound(idx, padded_self[idx]);
}
}
at::ScalarType size_type = self.scalar_type();
torch::lazy::Shape shape_ = torch::lazy::Shape(size_type, {5, 3, 4});
torch::lazy::Shape dynamic_shape_ =
shape_.with_symbolic_dims(std::vector<bool>{true, false, false});
return bridge::AtenFromXlaTensor(XLATensor::expand_symint(
bridge::GetXlaTensor(self), size_elements.GetNodes(),
size_elements.GetUpperBounds(), size_elements.GetDynamicDims(),
dynamic_shape_));
}

at::Tensor& XLANativeFunctions::exponential_(
at::Tensor& self, double lambd, c10::optional<at::Generator> generator) {
XLA_FN_COUNTER("xla::");
Expand Down Expand Up @@ -2061,12 +2087,26 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::nll_loss_forward(
at::Tensor XLANativeFunctions::nonzero(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
// Initially make XLA handled nonzero() handling experimental, and opt-in.
if (!DebugUtil::ExperimentEnabled("nonzero")) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
ATEN_OP(nonzero)>::call(self);
}
return bridge::AtenFromXlaTensor(XLATensor::nonzero(self_tensor));

/*
* REMOVE THIS SECTION TO ENABLE CREATING DYNAMIC SHAPES FOR POC
* TODO: REMOVE THIS SECTION AFTER POC SUCCEEDS:
* https://github.com/pytorch/xla/pull/3558
*
* // Initially make XLA handled nonzero() handling experimental, and opt-in.
* if (!DebugUtil::ExperimentEnabled("nonzero")) {
* return at::native::call_fallback_fn<&xla_cpu_fallback,
* ATEN_OP(nonzero)>::call(self);
* }
*/
at::ScalarType size_type = self.scalar_type();
torch::lazy::Shape shape_ = torch::lazy::Shape(
size_type, {xla::ShapeUtil::ElementsIn(self_tensor->shape()),
self_tensor->shape().get().rank()});
torch::lazy::Shape dynamic_shape_ =
shape_.with_symbolic_dims(std::vector<bool>{true, false});
return bridge::AtenFromXlaTensor(
XLATensor::nonzero(self_tensor, dynamic_shape_));
}

at::Tensor XLANativeFunctions::norm(const at::Tensor& self,
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) {
return XlaHelpers::DynamicReshape(input, output_sizes);
}

xla::XlaOp SetDimensionSizes(xla::XlaOp input,
absl::Span<const xla::XlaOp> output_sizes) {
std::cout << "In SetDimensionSizes" << std::endl;
for (int i = 0; i < output_sizes.size(); i++) {
std::cout << "In SetDimensionSizes loop " << i << std::endl;
xla::Shape dim_shape = XlaHelpers::ShapeOfXlaOp(output_sizes[i]);
if (dim_shape.is_dynamic()) {
std::cout << "xla::SetDimensionSize " << i << std::endl;
input = xla::SetDimensionSize(input, output_sizes[i], i);
}
}
return input;
}

xla::XlaOp BuildExpand(xla::XlaOp input,
absl::Span<const int64_t> output_sizes) {
auto input_sizes = XlaHelpers::SizesOfXlaOp(input);
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ xla::XlaOp SqueezeTrivialDimension(xla::XlaOp input, int64_t dim);
// Squeezes out the trivial (size 1) dimensions of the input.
xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input);

// Update Output Dynamic Dimensions based on input size()
xla::XlaOp SetDimensionSizes(xla::XlaOp input,
absl::Span<const xla::XlaOp> output_sizes);

// Creates a new tensor with the singleton dimensions expanded to the specified
// output sizes.
xla::XlaOp BuildExpand(xla::XlaOp input,
Expand Down
63 changes: 63 additions & 0 deletions torch_xla/csrc/ops/expand_dynamic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "torch_xla/csrc/ops/expand_dynamic.h"

#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "torch/csrc/lazy/core/helpers.h"
#include "torch/csrc/lazy/core/util.h"
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"

namespace torch_xla {
namespace {
xla::Shape NodeOutputShape(const torch::lazy::Value& input,
const std::vector<int64_t> upper_bounds,
const std::vector<bool> dynamic_dims) {
return xla::ShapeUtil::MakeShape(GetXlaShape(input).element_type(),
{upper_bounds}, {dynamic_dims});
}

std::vector<torch::lazy::Value> GetValues(
const torch::lazy::Value& input,
const std::vector<torch::lazy::Value> dimensions) {
std::vector<torch::lazy::Value> values = dimensions;
values.insert(values.begin(), input);
return values;
}

} // namespace

ExpandDynamic::ExpandDynamic(const torch::lazy::Value& input,
const std::vector<torch::lazy::Value>& dimensions,
const std::vector<int64_t> upper_bounds,
const std::vector<bool> dynamic_dims,
const torch::lazy::Shape& dynamic_shapes)
: XlaNode(
torch::lazy::OpKind(at::aten::expand), GetValues(input, dimensions),
{dynamic_shapes},
[&]() { return NodeOutputShape(input, upper_bounds, dynamic_dims); },
/*num_outputs=*/1, torch::lazy::MHash(upper_bounds)),
upper_bounds_(std::move(upper_bounds)),
dynamic_dims_(std::move(dynamic_dims)) {}

XlaOpVector ExpandDynamic::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
std::vector<xla::XlaOp> size_ops;
for (int i = 1; i < operands().size(); i++) {
size_ops.push_back(loctx->GetOutputOp(operand(i)));
}
xla::XlaOp output =
SetDimensionSizes(BuildExpand(input, upper_bounds_), size_ops);
return ReturnOp(output, loctx);
}

std::string ExpandDynamic::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", size=(" << absl::StrJoin(upper_bounds_, ", ")
<< ")"
<< ", dynamic_dims=(" << absl::StrJoin(dynamic_dims_, ", ") << ")";
return ss.str();
}

} // namespace torch_xla
30 changes: 30 additions & 0 deletions torch_xla/csrc/ops/expand_dynamic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <vector>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class ExpandDynamic : public XlaNode {
public:
ExpandDynamic(const torch::lazy::Value& input,
const std::vector<torch::lazy::Value>& dimensions,
const std::vector<int64_t> upper_bounds,
const std::vector<bool> dynamic_dims,
const torch::lazy::Shape& dynamic_shapes);

std::string ToString() const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

const std::vector<int64_t>& size() const { return upper_bounds_; };

const bool IsDynamic(int index) const { return dynamic_dims_[index]; };

private:
std::vector<int64_t> upper_bounds_;
std::vector<bool> dynamic_dims_;
};

} // namespace torch_xla
10 changes: 6 additions & 4 deletions torch_xla/csrc/ops/nonzero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& input) {

} // namespace

NonZero::NonZero(const torch::lazy::Value& input)
: XlaNode(torch::lazy::OpKind(at::aten::nonzero), {input},
NonZero::NonZero(const torch::lazy::Value& input,
const torch::lazy::Shape& dynamic_shape)
: XlaNode(torch::lazy::OpKind(at::aten::nonzero), {input}, dynamic_shape,
NodeOutputShape(input),
/*num_outputs=*/2) {}
/*num_outputs=*/2),
dynamic_shape_(dynamic_shape) {}

torch::lazy::NodePtr NonZero::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<NonZero>(operands.at(0));
return torch::lazy::MakeNode<NonZero>(operands.at(0), dynamic_shape_);
}

XlaOpVector NonZero::Lower(LoweringContext* loctx) const {
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/nonzero.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@ namespace torch_xla {
// it gets its own IR node class.
class NonZero : public XlaNode {
public:
NonZero(const torch::lazy::Value& input);
NonZero(const torch::lazy::Value& input,
const torch::lazy::Shape& dynamic_shape);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

private:
torch::lazy::Shape dynamic_shape_;
};

} // namespace torch_xla
11 changes: 10 additions & 1 deletion torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/view.h"
#include "torch_xla/csrc/xla_sharding_util.h"

Expand Down Expand Up @@ -567,6 +568,13 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensorPtr expand(const XLATensorPtr& input,
std::vector<int64_t> size);

static XLATensorPtr expand_symint(
const XLATensorPtr& input,
const std::vector<torch::lazy::NodePtr>& size_nodes,
const std::vector<int64_t> upper_bounds,
const std::vector<bool> dynamic_dims,
const torch::lazy::Shape dynamic_shapes);

static void exponential_(XLATensorPtr& input, double lambd);

// Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.
Expand Down Expand Up @@ -881,7 +889,8 @@ class XLATensor : public c10::intrusive_ptr_target {
const XLATensorPtr& score_threshold, const XLATensorPtr& iou_threshold,
int64_t output_size);

static XLATensorPtr nonzero(const XLATensorPtr& input);
static XLATensorPtr nonzero(const XLATensorPtr& input,
const torch::lazy::Shape& dynamic_shape);

static XLATensorPtr norm(const XLATensorPtr& input,
const c10::optional<at::Scalar>& p,
Expand Down
40 changes: 37 additions & 3 deletions torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "torch/csrc/lazy/backend/backend_interface.h"
#include "torch/csrc/lazy/core/tensor.h"
#include "torch/csrc/lazy/core/tensor_util.h"
#include "torch/csrc/lazy/core/util.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir_builder.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/tensor_util.h"

namespace torch_xla {
Expand Down Expand Up @@ -115,9 +119,10 @@ at::IntArrayRef XLATensorImpl::sizes_custom() const {
}

c10::SymIntArrayRef XLATensorImpl::sym_sizes_custom() const {
auto sizes = sizes_custom();
return c10::SymIntArrayRef(reinterpret_cast<const c10::SymInt*>(sizes.data()),
sizes.size());
const_cast<XLATensorImpl*>(this)->SetupSymSizeProperties();
return c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(sym_sizes_.data()),
sym_sizes_.size());
}

c10::SymInt XLATensorImpl::sym_numel_custom() const {
Expand Down Expand Up @@ -178,6 +183,35 @@ void XLATensorImpl::SetupSizeProperties() {
}
}

void XLATensorImpl::SetupSymSizeProperties() {
size_t generation = tensor_->generation();
if (generation != generation_) {
// Fill up the basic dimension data members which the base class
// implementation uses in its APIs.
auto shape = tensor_->shape();
auto rank = tensor_->shape().get().rank();
c10::SmallVector<c10::SymInt, 5> sym_sizes;
numel_ = 1;
XLAIrBuilder a = XLAIrBuilder();
for (auto i : c10::irange(rank)) {
if (tensor_->shape().get().is_dynamic_dimension(i)) {
auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i);
auto symint_node =
c10::make_intrusive<torch::lazy::SymIntNodeImpl>(dim_node);
auto sn = symint_node->toSymInt();
sym_sizes_.push_back(sn);
/*TODO(miladm): verify numel_ calculation after adding a dynamic op
*/
numel_ *= dynamic_cast<SizeNode*>(dim_node.get())->getStaticValue();
} else {
sym_sizes_.push_back(c10::SymInt(tensor_->shape().get().dimensions(i)));
numel_ *= tensor_->shape().get().dimensions(i);
}
}
generation_ = generation;
}
}

caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) {
return c10::scalarTypeToTypeMeta(tensor.dtype());
}
Expand Down
Loading