-
Notifications
You must be signed in to change notification settings - Fork 545
[POC] Support of dynamic shapes for the fill_
op
#3100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3da7f3e
ed37f47
46a2d40
8e0fdb5
f199b17
f4beeb4
a232267
4b680e4
ebd35bb
ddd5619
dcd45b1
60c7cc4
06ba039
64adec2
5bbe822
67681ed
7448526
73ee915
ceab958
81fa967
63090c0
d055591
8affbca
943bac0
4d4038c
cb08879
e305cac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import logging | ||
import functools | ||
import time | ||
from collections import namedtuple | ||
|
||
debug = False | ||
SIMPLE_TEST = True | ||
time_torch_function = 0 | ||
|
||
class XLATensor(torch.Tensor): | ||
def __new__(cls, data, **kwargs): | ||
return torch.Tensor._make_subclass(cls, torch.as_tensor(data, dtype=torch.float32, **kwargs)) | ||
|
||
def __init__(self, data, **kwargs): | ||
super().__init__() | ||
# if debug: print('[__init__]') | ||
self.t_ = torch.as_tensor(data, dtype=torch.float32, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can be removed once we pass the actual dynamic shape as tensor right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the current There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PyTorch tensor with XLA device type will use pt/xla's tensor_impl, which carrys a c++ XLATensor. C++
Python XLATensor should do the same thing, it should not always carry a python data. We need a way to hook in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the input! I have a couple of follow ups that I will ask you offline. QQ: should we consider broadening the scope of the python There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's get it working first before doing any benchmark. You can work with Ed and pytoch team to figure out how to make
to be a python There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the code as it is today, |
||
self.xla_tensor_sizes_ = namedtuple("XLATensorSizes", "static_size dynamic_size") | ||
|
||
def size(self): | ||
# if debug: print('[size] ') | ||
static_size = self.t_.size() | ||
dynamic_size = [] | ||
for i,_ in enumerate(self.t_.shape): | ||
dynamic_size.append(torch_xla._XLAC._get_xla_tensor_dimension_size(self.t_, i)) | ||
return self.xla_tensor_sizes_(static_size, dynamic_size[0]) | ||
|
||
def expand(self, size): | ||
# if debug: print('[expand]') | ||
return torch_xla._XLAC._xla_dynamic_expand(self.t_, size.dynamic_size, size.dynamic_size.int(), size.static_size) | ||
|
||
def nonzero(self): | ||
return torch.nonzero(self.t_) | ||
|
||
def __repr__(self): | ||
# if debug: print ('[__repr__]') | ||
return "XLATensor:\n{}\n".format(self.t_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. XLATensor should always be a tensor with XLA device type. This means that it will be a Lazy tensor should not carry a cpu storage. I think for a regular Pytorch tensor with xla device type, we will call a |
||
|
||
@classmethod | ||
def __torch_function__(cls, func, types, args=(), kwargs=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't doing anything nontrivial, so we can get rid of it. (If we do end up needing to override a torch namespace function, that will be a global tax everywhere, see pytorch/pytorch#62888 cc @anjali411 but nb the issue is about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like Also, @ezyang the overhead here won't be as bad because we don't call from C++ to python for ever single function, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a few prototypes floating around, but in the version where torch.empty(device='xla') transparently returns an XLATensor, it needs to be a tensor subclass |
||
if kwargs is None: | ||
kwargs = {} | ||
args = [a.t_ if hasattr(a, 't_') else a for a in args] | ||
result = func(*args, **kwargs) | ||
return result | ||
|
||
if SIMPLE_TEST == True: | ||
a = XLATensor([[0, 0, 0], [1, 1, 0]], device=xm.xla_device()) | ||
b = XLATensor([[1], [2]], device=xm.xla_device()) | ||
|
||
idxs = a.nonzero() | ||
idxs_t = XLATensor([[1]], device=xm.xla_device()) | ||
idxs_t.t_ = idxs | ||
print('a.nonzero(): ', idxs) | ||
|
||
y = b.expand(idxs_t.size()) | ||
print (f"b.expand(): {y}") | ||
y_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(y, 0) | ||
print("[RUNNING] _get_xla_tensor_dimension_size(y, 0)\n", y_dim0_shape) | ||
print("[RUNNING] _get_xla_tensors_hlo([y])\n", torch_xla._XLAC._get_xla_tensors_hlo([y])) | ||
print("[RUNNING] _get_xla_tensors_text([y])\n", torch_xla._XLAC._get_xla_tensors_text([y])) | ||
else: | ||
NUM = 500 | ||
|
||
time_size = 0 | ||
time_expand = 0 | ||
for i in range(NUM): | ||
t = torch.tensor([[1], [2], [3]], device=xm.xla_device()) | ||
o = torch.tensor([[1,1,1,1],[1,1,1,1],[1,1,1,1]], device=xm.xla_device()) | ||
|
||
tic = time.perf_counter() | ||
b = o.size() | ||
toc = time.perf_counter() | ||
time_size += toc-tic | ||
|
||
tic = time.perf_counter() | ||
tt = t.expand(b) | ||
toc = time.perf_counter() | ||
time_expand += toc-tic | ||
print(f"size() time {time_size:0.4f} seconds") | ||
print(f"expand() time {time_expand:0.4f} seconds") | ||
|
||
time_size = 0 | ||
time_expand = 0 | ||
for i in range(NUM): | ||
t = XLATensor([[1], [2], [3]], device=xm.xla_device()) | ||
o = XLATensor([[1,1,1,1],[1,1,1,1],[1,1,1,1]], device=xm.xla_device()) | ||
|
||
tic = time.perf_counter() | ||
b = o.size() | ||
toc = time.perf_counter() | ||
time_size += toc-tic | ||
|
||
tic = time.perf_counter() | ||
tt = t.expand(b) | ||
toc = time.perf_counter() | ||
time_expand += toc-tic | ||
print(f"size() time {time_size:0.4f} seconds") | ||
print(f"expand() time {time_expand:0.4f} seconds") | ||
print(f"torch_function() time {time_torch_function:0.4f} seconds") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Run: | ||
# XLA_EXPERIMENTAL="nonzero:masked_select" python3 <test_dynamic_shapes.py> | ||
# | ||
|
||
import torch | ||
import torch_xla | ||
|
||
|
||
class TestDynamicShapes(object): | ||
|
||
def __init__(self): | ||
self.device = 'xla:0' | ||
|
||
def runTest(self): | ||
t1 = torch.tensor([1, 0, 2, 0, 0, 1, 3], device=self.device) | ||
|
||
t2 = torch.nonzero(t1) | ||
t2_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t2, 0) | ||
print("[RUNNING] _get_xla_tensor_dimension_size(t2, 0)\n", t2_dim0_shape) | ||
print("[RUNNING] _get_xla_tensors_text([t2])\n", | ||
torch_xla._XLAC._get_xla_tensors_text([t2])) | ||
print("[RUNNING] _get_xla_tensors_hlo([t2])\n", | ||
torch_xla._XLAC._get_xla_tensors_hlo([t2])) | ||
assert t2_dim0_shape.item() == 4 | ||
|
||
t3 = torch.fill_(t2, 10) | ||
print(f"size of {t3.size()}, size of t2 {t2.size()}") | ||
t2_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t2, 0) | ||
t3_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(t3, 0) | ||
print("[RUNNING] _get_xla_tensor_dimension_size(t2, 0)\n", t2_dim0_shape) | ||
print("[RUNNING] _get_xla_tensor_dimension_size(t3, 0)\n", t3_dim0_shape) | ||
print("[RUNNING] _get_xla_tensors_text([t2])\n", | ||
torch_xla._XLAC._get_xla_tensors_text([t2])) | ||
print("[RUNNING] _get_xla_tensors_text([t3])\n", | ||
torch_xla._XLAC._get_xla_tensors_text([t3])) | ||
print("[RUNNING] _get_xla_tensors_hlo([t2])\n", | ||
torch_xla._XLAC._get_xla_tensors_hlo([t2])) | ||
print("[RUNNING] _get_xla_tensors_hlo([t3])\n", | ||
torch_xla._XLAC._get_xla_tensors_hlo([t3])) | ||
assert t2_dim0_shape.item() == 4 | ||
assert t3_dim0_shape.item() == 4 | ||
|
||
print('t1: ', t1) | ||
print('t2: ', t2) | ||
print('t3: ', t3) | ||
|
||
|
||
if __name__ == "__main__": | ||
test = TestDynamicShapes() | ||
test.runTest() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,11 +12,16 @@ | |
#include "tensorflow/compiler/xla/xla_client/debug_macros.h" | ||
#include "tensorflow/compiler/xla/xla_client/sys_util.h" | ||
#include "tensorflow/compiler/xla/xla_client/util.h" | ||
#include "tensorflow/compiler/xla/xla_client/xla_util.h" | ||
#include "tensorflow/compiler/xla/client/value_inference.h" | ||
#include "torch_xla/csrc/convert_ops.h" | ||
#include "torch_xla/csrc/helpers.h" | ||
#include "torch_xla/csrc/reduction.h" | ||
#include "torch_xla/csrc/tensor_util.h" | ||
|
||
#include <chrono> | ||
using namespace std::chrono; | ||
|
||
namespace torch_xla { | ||
namespace { | ||
|
||
|
@@ -112,6 +117,27 @@ xla::XlaOp BuildExpand(xla::XlaOp input, | |
xla::util::Iota<xla::int64>(output_sizes.size())); | ||
} | ||
|
||
xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit, |
||
xla::XlaOp dynamic_target, | ||
xla::Shape dynamic_shapes) { | ||
// auto start = high_resolution_clock::now(); | ||
xla::XlaOp output = BuildExpand(static_input, dynamic_shapes.dimensions()); | ||
bool seen_dynamic = false; // Limit support to one dynamic dimension | ||
for (int i = 0; i < dynamic_shapes.rank(); ++i) { | ||
if (dynamic_shapes.is_dynamic_dimension(i)) { | ||
XLA_CHECK(seen_dynamic == false); | ||
seen_dynamic = true; | ||
output = xla::SetDimensionSize(output, MaybeConvertTo(dynamic_target, | ||
xla::PrimitiveType::S32) , i); | ||
} | ||
} | ||
// auto stop = high_resolution_clock::now(); | ||
// auto duration = duration_cast<microseconds>(stop - start); | ||
// std::cout << "Time taken dynamic shape (BuildDynamicShape): " | ||
// << duration.count() << " microseconds" << std::endl; | ||
return output; | ||
} | ||
|
||
std::vector<xla::int64> BuildSqueezedDimensions( | ||
absl::Span<const xla::int64> dimensions, xla::int64 squeeze_dim) { | ||
std::vector<xla::int64> output_dimensions; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,11 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input); | |
xla::XlaOp BuildExpand(xla::XlaOp input, | ||
absl::Span<const xla::int64> output_sizes); | ||
|
||
// Dynamic Shape version of BuildExpand() | ||
xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
xla::XlaOp dynamic_target, | ||
xla::Shape dynamic_shapes); | ||
|
||
std::vector<xla::int64> BuildSqueezedDimensions( | ||
absl::Span<const xla::int64> dimensions, xla::int64 squeeze_dim); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
#include "tensorflow/compiler/xla/client/lib/logdet.h" | ||
#include "tensorflow/compiler/xla/client/lib/math.h" | ||
#include "tensorflow/compiler/xla/client/lib/matrix.h" | ||
#include "tensorflow/compiler/xla/client/value_inference.h" | ||
#include "tensorflow/compiler/xla/shape_util.h" | ||
#include "tensorflow/compiler/xla/xla_client/debug_macros.h" | ||
#include "tensorflow/compiler/xla/xla_client/util.h" | ||
|
@@ -882,6 +883,50 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, | |
input.shape(), std::move(lower_fn)); | ||
} | ||
|
||
NodePtr DynamicExpand(const Value& input, const std::vector<xla::int64> static_size, const xla::int64 dynamic_size, const Value& dynamic_target) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably , const std::vector<xla::int64> static_size was meant to be passed by ref? Also seems passed by value into shape_fn? |
||
xla::Shape input_shape = input.shape(); | ||
auto lower_fn = [input_shape, static_size, dynamic_size](const Node& node, LoweringContext* loctx) -> XlaOpVector { | ||
xla::XlaOp static_input = loctx->GetOutputOp(node.operand(0)); | ||
xla::XlaOp dynamic_target = loctx->GetOutputOp(node.operand(1)); | ||
std::vector<bool> dynamic_dims; | ||
std::vector<xla::int64> dims; | ||
for (int i = 0; i < static_size.size(); i++) { | ||
if (i == 0) { | ||
dims.push_back(dynamic_size); | ||
dynamic_dims.push_back(dynamic_size < static_size[i] ? true : false); | ||
} else { | ||
dims.push_back(static_size[i]); | ||
dynamic_dims.push_back(false); | ||
} | ||
} | ||
xla::Shape target_shape = xla::ShapeUtil::MakeShape(xla::S32, dims, dynamic_dims); | ||
xla::XlaOp dynamic_output = BuildDynamicExpand(static_input, dynamic_target, target_shape); | ||
return node.ReturnOp(dynamic_output, loctx); | ||
}; | ||
auto shape_fn = [input_shape, static_size, dynamic_size](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp { | ||
std::vector<bool> dynamic_dims; | ||
std::vector<xla::int64> dims; | ||
for (int i = 0; i < static_size.size(); i++) { | ||
if (i == 0) { | ||
dims.push_back(dynamic_size); | ||
dynamic_dims.push_back(dynamic_size < static_size[i] ? true : false); | ||
} else { | ||
dims.push_back(static_size[i]); | ||
dynamic_dims.push_back(false); | ||
} | ||
} | ||
xla::Shape target_shape = xla::ShapeUtil::MakeShape(xla::S32, dims, dynamic_dims); | ||
return BuildDynamicExpand(operands[0], operands[1], target_shape); | ||
}; | ||
return GenericOp(OpKind(at::aten::expand), {input, dynamic_target}, | ||
[&]() { | ||
return InferOutputShape( | ||
{input.shape(), dynamic_target.shape()}, | ||
shape_fn); | ||
}, | ||
std::move(lower_fn)); | ||
} | ||
|
||
} // namespace ops | ||
} // namespace ir | ||
} // namespace torch_xla | ||
} // namespace torch_xla |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like you probably don't need the
t_
below since you are storing it here anyway? (given you are using_make_subclass
and not_make_wrapper_subclass
)