-
Notifications
You must be signed in to change notification settings - Fork 545
Enable dynamic shape for XLATensorImpl::sym_sizes_custom()
#3829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
5c74f05
842a9f8
ae04631
455e845
3af640f
54e5e51
5bdd12a
cdbe67c
4f11f40
d1323ce
68311c0
875cf56
d01ec69
b4ca291
6e1f802
22cc04c
a2824d7
3185ad7
0cf3c6f
165a47a
76515e5
dfaba65
8c0957e
de19ad5
0976230
2546b08
d27fb73
8382326
f499465
8dbd82a
1054d43
601c192
38d746a
101bc84
122ba52
cad3ba9
7be76b8
5037965
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,15 @@ | |
|
||
#include "tensorflow/compiler/xla/xla_client/computation_client.h" | ||
#include "tensorflow/compiler/xla/xla_client/debug_macros.h" | ||
#include "torch/csrc/lazy/backend/backend_interface.h" | ||
#include "torch/csrc/lazy/core/tensor.h" | ||
#include "torch/csrc/lazy/core/tensor_util.h" | ||
#include "torch/csrc/lazy/core/util.h" | ||
#include "torch_xla/csrc/aten_xla_bridge.h" | ||
#include "torch_xla/csrc/device.h" | ||
#include "torch_xla/csrc/ir_builder.h" | ||
#include "torch_xla/csrc/layout_manager.h" | ||
#include "torch_xla/csrc/ops/dynamic_ir.h" | ||
#include "torch_xla/csrc/tensor_util.h" | ||
|
||
namespace torch_xla { | ||
|
@@ -110,7 +114,11 @@ void XLATensorImpl::shallow_copy_from( | |
} | ||
|
||
at::IntArrayRef XLATensorImpl::sizes_custom() const { | ||
const_cast<XLATensorImpl*>(this)->SetupSizeProperties(); | ||
if (true) { /* TODO(@miladm): replace this with a flag */ | ||
const_cast<XLATensorImpl*>(this)->SetupSymSizeProperties(); | ||
} else { | ||
const_cast<XLATensorImpl*>(this)->SetupSizeProperties(); | ||
} | ||
return sizes_default(); | ||
} | ||
|
||
|
@@ -178,6 +186,40 @@ void XLATensorImpl::SetupSizeProperties() { | |
} | ||
} | ||
|
||
void XLATensorImpl::SetupSymSizeProperties() { | ||
size_t generation = tensor_->generation(); | ||
if (generation != generation_) { | ||
// Fill up the basic dimension data members which the base class | ||
// implementation uses in its APIs. | ||
auto shape = tensor_->shape(); | ||
auto rank = tensor_->shape().get().rank(); | ||
c10::SmallVector<c10::SymInt, 5> sym_sizes; | ||
numel_ = 1; | ||
for (auto i : c10::irange(rank)) { | ||
// if (tensor_->shape().get().is_dynamic_dimension(i)) { | ||
// XLAIrBuilder a = XLAIrBuilder(); | ||
// auto dim_node = a.MakeSizeNode(tensor_->GetIrValue(), i); | ||
// auto* sn = | ||
// dynamic_cast<torch::lazy::SymIntNodeImpl*>(dim_node.get()); | ||
// sym_sizes.push_back(sn->toSymInt()); | ||
// /*TODO(miladm): verify numel_ calculation after adding a dynamic op | ||
// */ numel_ *= | ||
// dynamic_cast<SizeNode*>(dim_node.get())->getStaticValue(); | ||
// } else { | ||
sym_sizes.push_back(c10::SymInt(tensor_->shape().get().dimensions(i))); | ||
numel_ *= tensor_->shape().get().dimensions(i); | ||
// } | ||
} | ||
sizes_and_strides_.set_sizes(sym_sizes); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Krovatkin From what I can tell this pr switch to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Kind of. This PR is the first step in enabling real dynamic shapes. Note, we have two separate storages here:
Unfortunately, we can't just make C++ Now the doubly unfortunate part. We know that upstream can store real symintnodes in sizes_and_strides_ wrapped in SymInts. However, we can't take advantage of that, because if we store real SymIntNodes in sizes_and_strides_ and someone calls sizes() in C++ this would trigger conversions to ints on SymIntNodes which would trigger materialization. So we decided to have separate storage for This way when someone calls Now if a user uses one of those SymIntNodes in python it will trigger materialization of the SymIntNode (since we do want the exact result at least according to our discussions). Phew.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JackCaoG the part I'm not 100% sure is why we need to populate |
||
auto updated_strides = torch::lazy::ComputeArrayStrides( | ||
torch::lazy::ToVector<int64_t>(shape.get().dimensions())); | ||
for (int i = 0; i < updated_strides.size(); i++) { | ||
sizes_and_strides_.stride_at_unchecked(i) = updated_strides[i]; | ||
} | ||
generation_ = generation; | ||
} | ||
} | ||
|
||
caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) { | ||
return c10::scalarTypeToTypeMeta(tensor.dtype()); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Krovatkin Is the plan to always use symsize even for the static ints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in his later PR, @miladm is using
is_dynamic
to decide whether we need to create symint nodes or static ints :https://github.com/pytorch/xla/pull/3909/files#diff-c4e1dd39b63d78af7c207b2d48ac29553d74214b1c185ae34e084dd2f583879eR197
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this PR, @miladm is forcing both sym_sizes_ and sizes_and_strides_ to be filled with upper bounds. There are no real symint nodes yet.