Skip to content

Enable dynamic shape for XLATensorImpl::sym_sizes_custom() #3829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5c74f05
adding dynamism to sym_sizes
miladm Aug 4, 2022
842a9f8
lcoal python tests pass on this version of the code - needs code cleanup
miladm Aug 7, 2022
ae04631
added SetupSymSizeProperties and removed debug code
miladm Aug 7, 2022
455e845
linter, removed incorrectly merged ode
miladm Aug 7, 2022
3af640f
fixed the custom call site
miladm Aug 11, 2022
54e5e51
adding dynamism to sym_sizes
miladm Aug 4, 2022
5bdd12a
added SetupSymSizeProperties and removed debug code
miladm Aug 7, 2022
cdbe67c
linter, removed incorrectly merged ode
miladm Aug 7, 2022
4f11f40
fixed the custom call site
miladm Aug 11, 2022
d1323ce
verify functionality without dynamism flag
miladm Aug 12, 2022
68311c0
linter
miladm Aug 12, 2022
875cf56
adding dynamism to sym_sizes
miladm Aug 4, 2022
d01ec69
lcoal python tests pass on this version of the code - needs code cleanup
miladm Aug 7, 2022
b4ca291
added SetupSymSizeProperties and removed debug code
miladm Aug 7, 2022
6e1f802
linter, removed incorrectly merged ode
miladm Aug 7, 2022
22cc04c
fixed the custom call site
miladm Aug 11, 2022
a2824d7
adding dynamism to sym_sizes
miladm Aug 4, 2022
3185ad7
lcoal python tests pass on this version of the code - needs code cleanup
miladm Aug 7, 2022
0cf3c6f
added SetupSymSizeProperties and removed debug code
miladm Aug 7, 2022
165a47a
linter, removed incorrectly merged ode
miladm Aug 7, 2022
76515e5
fixed the custom call site
miladm Aug 11, 2022
dfaba65
linter
miladm Aug 12, 2022
8c0957e
adding dynamism to sym_sizes
miladm Aug 4, 2022
de19ad5
lcoal python tests pass on this version of the code - needs code cleanup
miladm Aug 7, 2022
0976230
linter, removed incorrectly merged ode
miladm Aug 7, 2022
2546b08
fixed the custom call site
miladm Aug 11, 2022
d27fb73
adding dynamism to sym_sizes
miladm Aug 4, 2022
8382326
lcoal python tests pass on this version of the code - needs code cleanup
miladm Aug 7, 2022
f499465
added SetupSymSizeProperties and removed debug code
miladm Aug 7, 2022
8dbd82a
linter, removed incorrectly merged ode
miladm Aug 7, 2022
1054d43
fixed the custom call site
miladm Aug 11, 2022
601c192
adding dynamism to sym_sizes
miladm Aug 4, 2022
38d746a
lcoal python tests pass on this version of the code - needs code cleanup
miladm Aug 7, 2022
101bc84
added SetupSymSizeProperties and removed debug code
miladm Aug 7, 2022
122ba52
linter
miladm Aug 12, 2022
cad3ba9
corrections after rebase
miladm Aug 13, 2022
7be76b8
linter
miladm Aug 13, 2022
5037965
cleanup test
miladm Aug 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "torch/csrc/lazy/backend/backend_interface.h"
#include "torch/csrc/lazy/core/tensor.h"
#include "torch/csrc/lazy/core/tensor_util.h"
#include "torch/csrc/lazy/core/util.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir_builder.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/ops/dynamic_ir.h"
#include "torch_xla/csrc/tensor_util.h"

namespace torch_xla {
Expand Down Expand Up @@ -110,7 +114,11 @@ void XLATensorImpl::shallow_copy_from(
}

at::IntArrayRef XLATensorImpl::sizes_custom() const {
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
if (true) { /* TODO(@miladm): replace this with a flag */
const_cast<XLATensorImpl*>(this)->SetupSymSizeProperties();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Krovatkin Is the plan to always use symsize even for the static ints?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in his later PR, @miladm is using is_dynamic to decide whether we need to create symint nodes or static ints :

https://github.com/pytorch/xla/pull/3909/files#diff-c4e1dd39b63d78af7c207b2d48ac29553d74214b1c185ae34e084dd2f583879eR197

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, @miladm is forcing both sym_sizes_ and sizes_and_strides_ to be filled with upper bounds. There are no real symint nodes yet.

} else {
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
}
return sizes_default();
}

Expand Down Expand Up @@ -178,6 +186,40 @@ void XLATensorImpl::SetupSizeProperties() {
}
}

void XLATensorImpl::SetupSymSizeProperties() {
size_t generation = tensor_->generation();
if (generation != generation_) {
// Fill up the basic dimension data members which the base class
// implementation uses in its APIs.
auto shape = tensor_->shape();
auto rank = tensor_->shape().get().rank();
c10::SmallVector<c10::SymInt, 5> sym_sizes;
numel_ = 1;
for (auto i : c10::irange(rank)) {
// if (tensor_->shape().get().is_dynamic_dimension(i)) {
// XLAIrBuilder a = XLAIrBuilder();
// auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i);
// auto* sn =
// dynamic_cast<torch::lazy::SymIntNodeImpl*>(dim_node.get());
// sym_sizes.push_back(sn->toSymInt());
// /*TODO(miladm): verify numel_ calculation after adding a dynamic op
// */ numel_ *=
// dynamic_cast<SizeNode*>(dim_node.get())->getStaticValue();
// } else {
sym_sizes.push_back(c10::SymInt(tensor_->shape().get().dimensions(i)));
numel_ *= tensor_->shape().get().dimensions(i);
// }
}
sizes_and_strides_.set_sizes(sym_sizes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Krovatkin From what I can tell this pr switch to set size to set sym_sizes for tensorImpl's sizes_and_strides_. I am guessing upstream already support only setting sym_size?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell this pr switch to set size to set sym_sizes for tensorImpl's sizes_and_strides_.

Kind of. This PR is the first step in enabling real dynamic shapes. Note, we have two separate storages here:

  • sizes_and_strides_
  • sym_sizes_

Unfortunately, we can't just make C++ sizes() to start throwing a "NotImplemented" exception, otherwise we would break a lot of code. We want sizes() to return upper bounds.
This is why @miladm is populating set_sizes_and_strides with concrete integers wrapped in SymInts. You could think of this logic as just saving upper bounds; we just need to wrap them in SymInts since upstream is now using the unified type: SymInt.

Now the doubly unfortunate part. We know that upstream can store real symintnodes in sizes_and_strides_ wrapped in SymInts. However, we can't take advantage of that, because if we store real SymIntNodes in sizes_and_strides_ and someone calls sizes() in C++ this would trigger conversions to ints on SymIntNodes which would trigger materialization.

So we decided to have separate storage for sym_sizes() namely sym_sizes_.

This way when someone calls size() in python, we would call XLATensor::sym_sizes_custom which would return sym_sizes_ which may contain both concrete ints and symintnodes.

Now if a user uses one of those SymIntNodes in python it will trigger materialization of the SymIntNode (since we do want the exact result at least according to our discussions).

Phew....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JackCaoG the part I'm not 100% sure is why we need to populate sizes_and_strides_ in addition to sym_sizes_ when sym_sizes as called. I'd think when someone calls sizes() that would update sizes_and_strides_ when someone calls sym_sizes that would set sym_sizes_.
Presumably, @miladm ran into some issues so we do need to be setting both when `sym_sizes is called.

auto updated_strides = torch::lazy::ComputeArrayStrides(
torch::lazy::ToVector<int64_t>(shape.get().dimensions()));
for (int i = 0; i < updated_strides.size(); i++) {
sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i];
}
generation_ = generation;
}
}

caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) {
return c10::scalarTypeToTypeMeta(tensor.dtype());
}
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
#include <ATen/Tensor.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorImpl.h>
#include <torch/csrc/lazy/backend/backend_interface.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/trie.h>

#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/xla_backend_impl.h"

namespace torch_xla {

Expand Down Expand Up @@ -52,6 +57,7 @@ class XLATensorImpl : public c10::TensorImpl {

private:
void SetupSizeProperties();
void SetupSymSizeProperties();

static caffe2::TypeMeta GetTypeMeta(const XLATensor& tensor);

Expand Down