Skip to content

[POC] Support of dynamic shapes for the fill_ op #3100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3da7f3e
initial implementation of fill lowering + dynamic shapes API calls - WIP
miladm Aug 25, 2021
ed37f47
cleanup of debug code
miladm Aug 25, 2021
46a2d40
fill lowering update
miladm Aug 25, 2021
8e0fdb5
debug code cleanup
miladm Aug 26, 2021
f199b17
debug code cleanup
miladm Aug 26, 2021
f4beeb4
adding a basic test command set - addded a test case
miladm Aug 26, 2021
a232267
test framework update
miladm Aug 26, 2021
4b680e4
improved test
miladm Aug 26, 2021
ebd35bb
moved BuildExpand above the Get/SetDimensionSize calls
miladm Aug 26, 2021
ddd5619
dynamic shapes are fully supported for fill_
miladm Aug 27, 2021
dcd45b1
Improved the initial implementation by replacing the Fill op support …
miladm Aug 28, 2021
60c7cc4
code cleaning, removal of Fill op lowering class
miladm Aug 30, 2021
06ba039
code cleanup
miladm Aug 30, 2021
64adec2
merge with master
miladm Aug 30, 2021
5bbe822
linter fix
miladm Aug 30, 2021
67681ed
fixed code review comments
miladm Aug 31, 2021
7448526
linter fix
miladm Aug 31, 2021
73ee915
adding the PyTorch overriding API as a temp independent script
miladm Sep 14, 2021
ceab958
adding the PyTorch overriding API as a temp independent script
miladm Sep 14, 2021
81fa967
cleanup of scrap code
miladm Sep 14, 2021
63090c0
added device type parameters, __new__ method
miladm Sep 16, 2021
d055591
integrated _get_xla_tensor_dimension_size API
miladm Sep 17, 2021
8affbca
integrated _get_xla_tensor_dimension_size API
miladm Sep 17, 2021
943bac0
added support for torch_xla._XLAC._xla_dynamic_expand
miladm Sep 17, 2021
4d4038c
dummy commit
miladm Sep 17, 2021
cb08879
added namedtuple, removed DynamicTensorXLASize, commented optional co…
miladm Sep 22, 2021
e305cac
support for passing dynamic size of self.t_ instead of the tensor itself
miladm Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions experimental/xla_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import logging
import functools
import time
from collections import namedtuple

debug = False
SIMPLE_TEST = True
time_torch_function = 0

class XLATensor(torch.Tensor):
def __new__(cls, data, **kwargs):
return torch.Tensor._make_subclass(cls, torch.as_tensor(data, dtype=torch.float32, **kwargs))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you probably don't need the t_ below since you are storing it here anyway? (given you are using _make_subclass and not _make_wrapper_subclass)


def __init__(self, data, **kwargs):
super().__init__()
# if debug: print('[__init__]')
self.t_ = torch.as_tensor(data, dtype=torch.float32, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be removed once we pass the actual dynamic shape as tensor right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current XLATensor inheritance configuration, we need to keep self.t_ since it's the only variable holding the data for this object.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch tensor with XLA device type will use pt/xla's tensor_impl, which carrys a c++ XLATensor. C++ XLATensor will a Data object which carry one of

  1. ComputationData --> a handle refer to an allocation on the XRT server side
  2. IR --> pending computation
  3. at::Tenspr --> cpu tensor that we need to upload to XRT server in a future time
  4. View

Python XLATensor should do the same thing, it should not always carry a python data. We need a way to hook in the tensor_impl like we do when we set device=xm.xla_device().

Copy link
Collaborator Author

@miladm miladm Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the input! I have a couple of follow ups that I will ask you offline.

QQ: should we consider broadening the scope of the python XLATensor (using a tensor_impl-like module) now or after we have additional clarity on whether this methodology is viable for our goal?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's get it working first before doing any benchmark. You can work with Ed and pytoch team to figure out how to make

torch.tensor([123], device='xla:0')

to be a python XLATensor while still use pt/xla's tensor_impl

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the code as it is today, self should already be a valid device='xla' tensor. I wonder if you were trying to work around some other problem with t_

self.xla_tensor_sizes_ = namedtuple("XLATensorSizes", "static_size dynamic_size")

def size(self):
# if debug: print('[size] ')
static_size = self.t_.size()
dynamic_size = []
for i,_ in enumerate(self.t_.shape):
dynamic_size.append(torch_xla._XLAC._get_xla_tensor_dimension_size(self.t_, i))
return self.xla_tensor_sizes_(static_size, dynamic_size[0])

def expand(self, size):
# if debug: print('[expand]')
return torch_xla._XLAC._xla_dynamic_expand(self.t_, size.dynamic_size, size.dynamic_size.int(), size.static_size)

def nonzero(self):
return torch.nonzero(self.t_)

def __repr__(self):
# if debug: print ('[__repr__]')
return "XLATensor:\n{}\n".format(self.t_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLATensor should always be a tensor with XLA device type. This means that it will be a Lazy tensor should not carry a cpu storage. I think for a regular Pytorch tensor with xla device type, we will call a to_cpu() to get the scalar value of the tensor.


@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't doing anything nontrivial, so we can get rid of it. (If we do end up needing to override a torch namespace function, that will be a global tax everywhere, see pytorch/pytorch#62888 cc @anjali411 but nb the issue is about __torch_dispatch__ but the need here is for __torch_function__)

Copy link

@anjali411 anjali411 Oct 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like XLATensor might not need to be a tensor subclass at all. I am not sure if the size() overload is expected to be in C++. If yes, then it still needs to be a tensor subclass and we should use _make_wrapper_subclass if the size needs to be used in C++ as defined above. (cc. @albanD)

Also, @ezyang the overhead here won't be as bad because we don't call from C++ to python for ever single function, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few prototypes floating around, but in the version where torch.empty(device='xla') transparently returns an XLATensor, it needs to be a tensor subclass

if kwargs is None:
kwargs = {}
args = [a.t_ if hasattr(a, 't_') else a for a in args]
result = func(*args, **kwargs)
return result

if SIMPLE_TEST == True:
a = XLATensor([[0, 0, 0], [1, 1, 0]], device=xm.xla_device())
b = XLATensor([[1], [2]], device=xm.xla_device())

idxs = a.nonzero()
idxs_t = XLATensor([[1]], device=xm.xla_device())
idxs_t.t_ = idxs
print('a.nonzero(): ', idxs)

y = b.expand(idxs_t.size())
print (f"b.expand(): {y}")
y_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(y, 0)
print("[RUNNING] _get_xla_tensor_dimension_size(y, 0)\n", y_dim0_shape)
print("[RUNNING] _get_xla_tensors_hlo([y])\n", torch_xla._XLAC._get_xla_tensors_hlo([y]))
print("[RUNNING] _get_xla_tensors_text([y])\n", torch_xla._XLAC._get_xla_tensors_text([y]))
else:
NUM = 500

time_size = 0
time_expand = 0
for i in range(NUM):
t = torch.tensor([[1], [2], [3]], device=xm.xla_device())
o = torch.tensor([[1,1,1,1],[1,1,1,1],[1,1,1,1]], device=xm.xla_device())

tic = time.perf_counter()
b = o.size()
toc = time.perf_counter()
time_size += toc-tic

tic = time.perf_counter()
tt = t.expand(b)
toc = time.perf_counter()
time_expand += toc-tic
print(f"size() time {time_size:0.4f} seconds")
print(f"expand() time {time_expand:0.4f} seconds")

time_size = 0
time_expand = 0
for i in range(NUM):
t = XLATensor([[1], [2], [3]], device=xm.xla_device())
o = XLATensor([[1,1,1,1],[1,1,1,1],[1,1,1,1]], device=xm.xla_device())

tic = time.perf_counter()
b = o.size()
toc = time.perf_counter()
time_size += toc-tic

tic = time.perf_counter()
tt = t.expand(b)
toc = time.perf_counter()
time_expand += toc-tic
print(f"size() time {time_size:0.4f} seconds")
print(f"expand() time {time_expand:0.4f} seconds")
print(f"torch_function() time {time_torch_function:0.4f} seconds")
50 changes: 50 additions & 0 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Run:
# XLA_EXPERIMENTAL="nonzero:masked_select" python3 <test_dynamic_shapes.py>
#

import torch
import torch_xla


class TestDynamicShapes(object):

def __init__(self):
self.device = 'xla:0'

def runTest(self):
t1 = torch.tensor([1, 0, 2, 0, 0, 1, 3], device=self.device)

t2 = torch.nonzero(t1)
t2_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t2, 0)
print("[RUNNING] _get_xla_tensor_dimension_size(t2, 0)\n", t2_dim0_shape)
print("[RUNNING] _get_xla_tensors_text([t2])\n",
torch_xla._XLAC._get_xla_tensors_text([t2]))
print("[RUNNING] _get_xla_tensors_hlo([t2])\n",
torch_xla._XLAC._get_xla_tensors_hlo([t2]))
assert t2_dim0_shape.item() == 4

t3 = torch.fill_(t2, 10)
print(f"size of {t3.size()}, size of t2 {t2.size()}")
t2_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t2, 0)
t3_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t3, 0)
print("[RUNNING] _get_xla_tensor_dimension_size(t2, 0)\n", t2_dim0_shape)
print("[RUNNING] _get_xla_tensor_dimension_size(t3, 0)\n", t3_dim0_shape)
print("[RUNNING] _get_xla_tensors_text([t2])\n",
torch_xla._XLAC._get_xla_tensors_text([t2]))
print("[RUNNING] _get_xla_tensors_text([t3])\n",
torch_xla._XLAC._get_xla_tensors_text([t3]))
print("[RUNNING] _get_xla_tensors_hlo([t2])\n",
torch_xla._XLAC._get_xla_tensors_hlo([t2]))
print("[RUNNING] _get_xla_tensors_hlo([t3])\n",
torch_xla._XLAC._get_xla_tensors_hlo([t3]))
assert t2_dim0_shape.item() == 4
assert t3_dim0_shape.item() == 4

print('t1: ', t1)
print('t2: ', t2)
print('t3: ', t3)


if __name__ == "__main__":
test = TestDynamicShapes()
test.runTest()
2 changes: 1 addition & 1 deletion third_party/tensorflow
Submodule tensorflow updated 4671 files
26 changes: 26 additions & 0 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "tensorflow/compiler/xla/xla_client/xla_util.h"
#include "tensorflow/compiler/xla/client/value_inference.h"
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/tensor_util.h"

#include <chrono>
using namespace std::chrono;

namespace torch_xla {
namespace {

Expand Down Expand Up @@ -112,6 +117,27 @@ xla::XlaOp BuildExpand(xla::XlaOp input,
xla::util::Iota<xla::int64>(output_sizes.size()));
}

xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, BuildDynamicExpandAs since target is a value tensor, not a shape tenspr.

xla::XlaOp dynamic_target,
xla::Shape dynamic_shapes) {
// auto start = high_resolution_clock::now();
xla::XlaOp output = BuildExpand(static_input, dynamic_shapes.dimensions());
bool seen_dynamic = false; // Limit support to one dynamic dimension
for (int i = 0; i < dynamic_shapes.rank(); ++i) {
if (dynamic_shapes.is_dynamic_dimension(i)) {
XLA_CHECK(seen_dynamic == false);
seen_dynamic = true;
output = xla::SetDimensionSize(output, MaybeConvertTo(dynamic_target,
xla::PrimitiveType::S32) , i);
}
}
// auto stop = high_resolution_clock::now();
// auto duration = duration_cast<microseconds>(stop - start);
// std::cout << "Time taken dynamic shape (BuildDynamicShape): "
// << duration.count() << " microseconds" << std::endl;
return output;
}

std::vector<xla::int64> BuildSqueezedDimensions(
absl::Span<const xla::int64> dimensions, xla::int64 squeeze_dim) {
std::vector<xla::int64> output_dimensions;
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input);
xla::XlaOp BuildExpand(xla::XlaOp input,
absl::Span<const xla::int64> output_sizes);

// Dynamic Shape version of BuildExpand()
xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

xla::XlaOp dynamic_target,
xla::Shape dynamic_shapes);

std::vector<xla::int64> BuildSqueezedDimensions(
absl::Span<const xla::int64> dimensions, xla::int64 squeeze_dim);

Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,17 @@ absl::flat_hash_map<std::string, absl::variant<int>> ConvertDictToMap(
return map;
}

at::Tensor GetXlaDynamicExpand(const at::Tensor& tensor,
const at::Tensor& dynamic_size_tensors,
xla::int64 dynamic_size_values,
std::vector<xla::int64> static_size) {
XLATensor xtensor = bridge::GetXlaTensor(tensor);
XLATensor xdynamicsizetensors = bridge::GetXlaTensor(dynamic_size_tensors);
at::Tensor dynamic_expand_tensor = bridge::AtenFromXlaTensor(
XLATensor::dynamic_expand(xtensor, xdynamicsizetensors, dynamic_size_values, static_size));
return dynamic_expand_tensor;
}

void BuildProfilerSubmodule(py::module* m) {
py::module profiler = m->def_submodule("profiler", "Profiler integration");
py::class_<xla::profiler::ProfilerServer,
Expand Down Expand Up @@ -1094,6 +1105,10 @@ void InitXlaModuleBindings(py::module m) {
m.def("_run_xrt_local_service", [](xla::uint64 service_port) {
xla::ComputationClient::RunLocalService(service_port);
});
m.def("_xla_dynamic_expand",
[](const at::Tensor& tensor, const at::Tensor& dynamic_size_tensors, xla::int64 dynamic_size_values, std::vector<xla::int64> static_size) {
return GetXlaDynamicExpand(tensor, dynamic_size_tensors, dynamic_size_values, static_size);
});

BuildProfilerSubmodule(&m);
}
Expand Down
47 changes: 46 additions & 1 deletion torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "tensorflow/compiler/xla/client/lib/logdet.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
Expand Down Expand Up @@ -882,6 +883,50 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
input.shape(), std::move(lower_fn));
}

NodePtr DynamicExpand(const Value& input, const std::vector<xla::int64> static_size, const xla::int64 dynamic_size, const Value& dynamic_target) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably

, const std::vector<xla::int64> static_size

was meant to be passed by ref? Also seems passed by value into shape_fn?

xla::Shape input_shape = input.shape();
auto lower_fn = [input_shape, static_size, dynamic_size](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp static_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp dynamic_target = loctx->GetOutputOp(node.operand(1));
std::vector<bool> dynamic_dims;
std::vector<xla::int64> dims;
for (int i = 0; i < static_size.size(); i++) {
if (i == 0) {
dims.push_back(dynamic_size);
dynamic_dims.push_back(dynamic_size < static_size[i] ? true : false);
} else {
dims.push_back(static_size[i]);
dynamic_dims.push_back(false);
}
}
xla::Shape target_shape = xla::ShapeUtil::MakeShape(xla::S32, dims, dynamic_dims);
xla::XlaOp dynamic_output = BuildDynamicExpand(static_input, dynamic_target, target_shape);
return node.ReturnOp(dynamic_output, loctx);
};
auto shape_fn = [input_shape, static_size, dynamic_size](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
std::vector<bool> dynamic_dims;
std::vector<xla::int64> dims;
for (int i = 0; i < static_size.size(); i++) {
if (i == 0) {
dims.push_back(dynamic_size);
dynamic_dims.push_back(dynamic_size < static_size[i] ? true : false);
} else {
dims.push_back(static_size[i]);
dynamic_dims.push_back(false);
}
}
xla::Shape target_shape = xla::ShapeUtil::MakeShape(xla::S32, dims, dynamic_dims);
return BuildDynamicExpand(operands[0], operands[1], target_shape);
};
return GenericOp(OpKind(at::aten::expand), {input, dynamic_target},
[&]() {
return InferOutputShape(
{input.shape(), dynamic_target.shape()},
shape_fn);
},
std::move(lower_fn));
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ NodePtr LogicalOr(const Value& input, const Value& other);

NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
const Value& neginf);
NodePtr DynamicExpand(const Value& static_input,
const std::vector<xla::int64> static_size_values,
const xla::int64 dynamic_size_values,
const Value& dynamic_target);

} // namespace ops
} // namespace ir
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,8 @@ class XLATensor {

static XLATensor expand(const XLATensor& input, std::vector<xla::int64> size);

static XLATensor dynamic_expand(const XLATensor& input, const XLATensor& dynamic_size_tensors, const xla::int64 dynamic_size, std::vector<xla::int64> static_size);

static XLATensor expm1(const XLATensor& input);

static void exponential_(XLATensor& input, double lambd);
Expand Down
19 changes: 17 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,17 @@ XLATensor XLATensor::expand(const XLATensor& input,
GetExpandDimensions(input_shape.get(), std::move(size))));
}

XLATensor XLATensor::dynamic_expand(const XLATensor& input,
const XLATensor& dynamic_size_tensors,
const xla::int64 dynamic_size,
std::vector<xla::int64> static_size) {
return input.CreateFrom(ir::ops::DynamicExpand(
input.GetIrValue(),
std::move(static_size),
std::move(dynamic_size),
dynamic_size_tensors.GetIrValue()));
}

XLATensor XLATensor::expm1(const XLATensor& input) {
return input.CreateFrom(ir::ops::Expm1(input.GetIrValue()));
}
Expand Down Expand Up @@ -1211,8 +1222,12 @@ void XLATensor::eye_out(XLATensor& out, xla::int64 lines, xla::int64 cols) {
}

void XLATensor::fill_(XLATensor& input, const at::Scalar& value) {
ir::Value constant =
GetIrValueForScalar(value, input.shape(), input.GetDevice());
ir::Value constant = GetIrValueForScalar(value, input.GetDevice());
constant = ir::ops::DynamicExpand(
constant,
{-1}, //TODO: fix this line
std::move(-1), //TODO: fix this line
input.GetIrValue());
input.SetInPlaceIrValue(std::move(constant));
}

Expand Down