-
Notifications
You must be signed in to change notification settings - Fork 544
[POC] Support of dynamic shapes for the fill_
op
#3100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…with ExpandAsDynamicShape support
f68e515
to
64adec2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should probably do a git pull --all
in your xla
dir and rebase this branch. You should also do a submodule sync and update(For this pr you might need to revert the tf side change, but if it is not intended to merge I am fine with it). It seem like local tf version you have is out of date and every-time you submit a pr it will try to also update the tf version.
torch_xla/csrc/tensor_methods.cpp
Outdated
@@ -1213,6 +1213,7 @@ void XLATensor::eye_out(XLATensor& out, xla::int64 lines, xla::int64 cols) { | |||
void XLATensor::fill_(XLATensor& input, const at::Scalar& value) { | |||
ir::Value constant = | |||
GetIrValueForScalar(value, input.shape(), input.GetDevice()); | |||
constant = ir::ops::ExpandAsDynamicShapes(constant, input.GetIrValue()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you shouldn't pass input.shape()
to GetIrValueForScalar
in line above here. What happened is that GetIrValueForScalar
will a static expand to input.shape()
which is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
torch_xla/csrc/ops/ops.cpp
Outdated
@@ -882,6 +882,27 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, | |||
input.shape(), std::move(lower_fn)); | |||
} | |||
|
|||
NodePtr ExpandAsDynamicShapes(const Value& static_input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DynamicExpand
might be a better name since we also have DynamicReshape
in here. In the future we should just make expand
support dynamic shape but that would require a frontend API change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree and it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM, some naming nits
@@ -112,6 +112,23 @@ xla::XlaOp BuildExpand(xla::XlaOp input, | |||
xla::util::Iota<xla::int64>(output_sizes.size())); | |||
} | |||
|
|||
xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, BuildDynamicExpandAs
since target is a value tensor, not a shape tenspr.
@@ -37,6 +37,10 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input); | |||
xla::XlaOp BuildExpand(xla::XlaOp input, | |||
absl::Span<const xla::int64> output_sizes); | |||
|
|||
// Dynamic Shape version of BuildExpand() | |||
xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
torch_xla/csrc/ops/ops.cpp
Outdated
@@ -882,6 +882,26 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, | |||
input.shape(), std::move(lower_fn)); | |||
} | |||
|
|||
NodePtr DynamicExpand(const Value& static_input, const Value& dynamic_target) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit DynamicExpandAs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some ideas
experimental/xla_tensor.py
Outdated
debug = False | ||
BYPASS_XLA = False #Performance profiling experimentation | ||
|
||
class DynamicTensorXLASize(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I am not sure right now is whether we need DynamicTensorXLASize
as a class which I think is more expensive than a namedtuple(which is immutable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size()
doesn't accept a tuple
return type. I will try namedtuple
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this change. It turns out the this optimization leads to a small improvement in speedup.
def __init__(self, data, **kwargs): | ||
super().__init__() | ||
if debug: print('[__init__]') | ||
self.t_ = torch.as_tensor(data, dtype=torch.float32, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be removed once we pass the actual dynamic shape as tensor right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the current XLATensor
inheritance configuration, we need to keep self.t_
since it's the only variable holding the data for this object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch tensor with XLA device type will use pt/xla's tensor_impl, which carrys a c++ XLATensor. C++ XLATensor
will a Data object which carry one of
- ComputationData --> a handle refer to an allocation on the XRT server side
- IR --> pending computation
- at::Tenspr --> cpu tensor that we need to upload to XRT server in a future time
- View
Python XLATensor should do the same thing, it should not always carry a python data. We need a way to hook in the tensor_impl
like we do when we set device=xm.xla_device()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the input! I have a couple of follow ups that I will ask you offline.
QQ: should we consider broadening the scope of the python XLATensor
(using a tensor_impl
-like module) now or after we have additional clarity on whether this methodology is viable for our goal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's get it working first before doing any benchmark. You can work with Ed and pytoch team to figure out how to make
torch.tensor([123], device='xla:0')
to be a python XLATensor
while still use pt/xla's tensor_impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the code as it is today, self
should already be a valid device='xla' tensor. I wonder if you were trying to work around some other problem with t_
experimental/xla_tensor.py
Outdated
def size(self): | ||
if debug: print('[size] ') | ||
static_size = super(XLATensor, self).size() | ||
if self.device == xm.xla_device(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After we figured out how to pass the device to XLATensor, we should check for whether dynamic shape is included to decide which function to call(instead of checking for device).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair point. Can we check for the dynamic shape condition without calling the xla::GetDimensionSize()
first?
|
||
def __repr__(self): | ||
if debug: print ('[__repr__]') | ||
return "XLATensor:\n{}\n".format(self.t_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLATensor should always be a tensor with XLA device type. This means that it will be a Lazy tensor should not carry a cpu storage. I think for a regular Pytorch tensor with xla device type, we will call a to_cpu()
to get the scalar value of the tensor.
experimental/xla_tensor.py
Outdated
return func(*args, **kwargs) | ||
|
||
if debug: print ('[main] make b') | ||
t = XLATensor([[1], [2], [3]], device=xm.xla_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ed also mentioned this in your issues, ideally we should just do torch.tensor
and auto cast all tensor to XLATensor
when device is a xla device in the backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Though, AFAIU, such an auto-cast requires additional support from PyTorch that they kept as an open topic to investigate on their side. I'd like to seek some clarity from them next time we meet. Do you have a different understanding @JackCaoG?
…de (e.g. debug) to improve speedup, this commit assumes all input tensors have dynamic shape, added performance profiling code
return "XLATensor:\n{}\n".format(self.t_) | ||
|
||
@classmethod | ||
def __torch_function__(cls, func, types, args=(), kwargs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't doing anything nontrivial, so we can get rid of it. (If we do end up needing to override a torch namespace function, that will be a global tax everywhere, see pytorch/pytorch#62888 cc @anjali411 but nb the issue is about __torch_dispatch__
but the need here is for __torch_function__
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like XLATensor
might not need to be a tensor subclass at all. I am not sure if the size()
overload is expected to be in C++. If yes, then it still needs to be a tensor subclass and we should use _make_wrapper_subclass
if the size needs to be used in C++ as defined above. (cc. @albanD)
Also, @ezyang the overhead here won't be as bad because we don't call from C++ to python for ever single function, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a few prototypes floating around, but in the version where torch.empty(device='xla') transparently returns an XLATensor, it needs to be a tensor subclass
|
||
class XLATensor(torch.Tensor): | ||
def __new__(cls, data, **kwargs): | ||
return torch.Tensor._make_subclass(cls, torch.as_tensor(data, dtype=torch.float32, **kwargs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like you probably don't need the t_
below since you are storing it here anyway? (given you are using _make_subclass
and not _make_wrapper_subclass
)
@@ -882,26 +883,50 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, | |||
input.shape(), std::move(lower_fn)); | |||
} | |||
|
|||
NodePtr DynamicExpand(const Value& static_input, const Value& dynamic_target) { | |||
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector { | |||
NodePtr DynamicExpand(const Value& input, const std::vector<xla::int64> static_size, const xla::int64 dynamic_size, const Value& dynamic_target) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably
, const std::vector<xla::int64> static_size
was meant to be passed by ref? Also seems passed by value into shape_fn?
fill_
opfill_
op
An up to date version of this PR is at #3558. |
Op lowering of
fill_
to study the requirements for supporting dynamic shapes in PT/XLA.