Skip to content

[POC] Support of dynamic shapes for the fill_ op #3100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from

Conversation

miladm
Copy link
Collaborator

@miladm miladm commented Aug 25, 2021

Op lowering of fill_ to study the requirements for supporting dynamic shapes in PT/XLA.

@miladm miladm added the DO_NOT_MERGE Not for merging. label Aug 25, 2021
@miladm miladm requested a review from JackCaoG August 25, 2021 03:50
@miladm miladm self-assigned this Aug 25, 2021
@miladm miladm force-pushed the fill_base_dynamic_shapes branch from f68e515 to 64adec2 Compare August 30, 2021 16:48
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should probably do a git pull --all in your xla dir and rebase this branch. You should also do a submodule sync and update(For this pr you might need to revert the tf side change, but if it is not intended to merge I am fine with it). It seem like local tf version you have is out of date and every-time you submit a pr it will try to also update the tf version.

@@ -1213,6 +1213,7 @@ void XLATensor::eye_out(XLATensor& out, xla::int64 lines, xla::int64 cols) {
void XLATensor::fill_(XLATensor& input, const at::Scalar& value) {
ir::Value constant =
GetIrValueForScalar(value, input.shape(), input.GetDevice());
constant = ir::ops::ExpandAsDynamicShapes(constant, input.GetIrValue());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you shouldn't pass input.shape() to GetIrValueForScalar in line above here. What happened is that GetIrValueForScalarwill a static expand to input.shape() which is not necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed

@@ -882,6 +882,27 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
input.shape(), std::move(lower_fn));
}

NodePtr ExpandAsDynamicShapes(const Value& static_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DynamicExpand might be a better name since we also have DynamicReshape in here. In the future we should just make expand support dynamic shape but that would require a frontend API change

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and it makes sense.

@miladm miladm requested a review from JackCaoG August 31, 2021 00:33
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, some naming nits

@@ -112,6 +112,23 @@ xla::XlaOp BuildExpand(xla::XlaOp input,
xla::util::Iota<xla::int64>(output_sizes.size()));
}

xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, BuildDynamicExpandAs since target is a value tensor, not a shape tenspr.

@@ -37,6 +37,10 @@ xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input);
xla::XlaOp BuildExpand(xla::XlaOp input,
absl::Span<const xla::int64> output_sizes);

// Dynamic Shape version of BuildExpand()
xla::XlaOp BuildDynamicExpand(xla::XlaOp static_input,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -882,6 +882,26 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
input.shape(), std::move(lower_fn));
}

NodePtr DynamicExpand(const Value& static_input, const Value& dynamic_target) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit DynamicExpandAs

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some ideas

debug = False
BYPASS_XLA = False #Performance profiling experimentation

class DynamicTensorXLASize(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I am not sure right now is whether we need DynamicTensorXLASize as a class which I think is more expensive than a namedtuple(which is immutable).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size() doesn't accept a tuple return type. I will try namedtuple.

Copy link
Collaborator Author

@miladm miladm Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change. It turns out the this optimization leads to a small improvement in speedup.

def __init__(self, data, **kwargs):
super().__init__()
if debug: print('[__init__]')
self.t_ = torch.as_tensor(data, dtype=torch.float32, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be removed once we pass the actual dynamic shape as tensor right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current XLATensor inheritance configuration, we need to keep self.t_ since it's the only variable holding the data for this object.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch tensor with XLA device type will use pt/xla's tensor_impl, which carrys a c++ XLATensor. C++ XLATensor will a Data object which carry one of

  1. ComputationData --> a handle refer to an allocation on the XRT server side
  2. IR --> pending computation
  3. at::Tenspr --> cpu tensor that we need to upload to XRT server in a future time
  4. View

Python XLATensor should do the same thing, it should not always carry a python data. We need a way to hook in the tensor_impl like we do when we set device=xm.xla_device().

Copy link
Collaborator Author

@miladm miladm Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the input! I have a couple of follow ups that I will ask you offline.

QQ: should we consider broadening the scope of the python XLATensor (using a tensor_impl-like module) now or after we have additional clarity on whether this methodology is viable for our goal?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's get it working first before doing any benchmark. You can work with Ed and pytoch team to figure out how to make

torch.tensor([123], device='xla:0')

to be a python XLATensor while still use pt/xla's tensor_impl

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the code as it is today, self should already be a valid device='xla' tensor. I wonder if you were trying to work around some other problem with t_

def size(self):
if debug: print('[size] ')
static_size = super(XLATensor, self).size()
if self.device == xm.xla_device():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we figured out how to pass the device to XLATensor, we should check for whether dynamic shape is included to decide which function to call(instead of checking for device).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a fair point. Can we check for the dynamic shape condition without calling the xla::GetDimensionSize() first?


def __repr__(self):
if debug: print ('[__repr__]')
return "XLATensor:\n{}\n".format(self.t_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLATensor should always be a tensor with XLA device type. This means that it will be a Lazy tensor should not carry a cpu storage. I think for a regular Pytorch tensor with xla device type, we will call a to_cpu() to get the scalar value of the tensor.

return func(*args, **kwargs)

if debug: print ('[main] make b')
t = XLATensor([[1], [2], [3]], device=xm.xla_device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ed also mentioned this in your issues, ideally we should just do torch.tensor and auto cast all tensor to XLATensor when device is a xla device in the backend.

Copy link
Collaborator Author

@miladm miladm Sep 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Though, AFAIU, such an auto-cast requires additional support from PyTorch that they kept as an open topic to investigate on their side. I'd like to seek some clarity from them next time we meet. Do you have a different understanding @JackCaoG?

…de (e.g. debug) to improve speedup, this commit assumes all input tensors have dynamic shape, added performance profiling code
return "XLATensor:\n{}\n".format(self.t_)

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't doing anything nontrivial, so we can get rid of it. (If we do end up needing to override a torch namespace function, that will be a global tax everywhere, see pytorch/pytorch#62888 cc @anjali411 but nb the issue is about __torch_dispatch__ but the need here is for __torch_function__)

Copy link

@anjali411 anjali411 Oct 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like XLATensor might not need to be a tensor subclass at all. I am not sure if the size() overload is expected to be in C++. If yes, then it still needs to be a tensor subclass and we should use _make_wrapper_subclass if the size needs to be used in C++ as defined above. (cc. @albanD)

Also, @ezyang the overhead here won't be as bad because we don't call from C++ to python for ever single function, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few prototypes floating around, but in the version where torch.empty(device='xla') transparently returns an XLATensor, it needs to be a tensor subclass


class XLATensor(torch.Tensor):
def __new__(cls, data, **kwargs):
return torch.Tensor._make_subclass(cls, torch.as_tensor(data, dtype=torch.float32, **kwargs))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you probably don't need the t_ below since you are storing it here anyway? (given you are using _make_subclass and not _make_wrapper_subclass)

@@ -882,26 +883,50 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
input.shape(), std::move(lower_fn));
}

NodePtr DynamicExpand(const Value& static_input, const Value& dynamic_target) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
NodePtr DynamicExpand(const Value& input, const std::vector<xla::int64> static_size, const xla::int64 dynamic_size, const Value& dynamic_target) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably

, const std::vector<xla::int64> static_size

was meant to be passed by ref? Also seems passed by value into shape_fn?

@miladm miladm added this to the Dynamic Shape milestone Feb 10, 2022
@miladm miladm changed the title Support of dynamic shapes for the fill_ op [POC] Support of dynamic shapes for the fill_ op Feb 25, 2022
@miladm miladm marked this pull request as draft May 15, 2022 22:14
@miladm
Copy link
Collaborator Author

miladm commented May 21, 2022

An up to date version of this PR is at #3558.
Closing this PR.

@miladm miladm closed this May 21, 2022
@miladm miladm added the dynamism Dynamic Shape Features label May 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
DO_NOT_MERGE Not for merging. dynamism Dynamic Shape Features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants