Skip to content

Enable dynamic shape for XLATensorImpl::sym_sizes_custom() #3829

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: master
Choose a base branch
from

Conversation

miladm
Copy link
Collaborator

@miladm miladm commented Aug 4, 2022

Enable dynamic shape for XLATensorImpl::sym_sizes_custom()

@miladm
Copy link
Collaborator Author

miladm commented Aug 4, 2022

Here is the reference to is_dynamic_dimension.

@miladm
Copy link
Collaborator Author

miladm commented Aug 7, 2022

Local tests pass for python ../test/test_view_ops.py -v TestViewOpsXLA.test_view_copy_xla as shown below, though they fail on CI.

2022-08-07 23:30:32.889533: W 1463191 tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
2022-08-07 23:30:32.889591: W 1463191 tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
test_view_copy_xla (__main__.TestViewOpsXLA) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.233s

OK

@@ -110,7 +114,11 @@ void XLATensorImpl::shallow_copy_from(
}

at::IntArrayRef XLATensorImpl::sizes_custom() const {
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
if (true) { /* TODO(@miladm): replace this with a flag */
const_cast<XLATensorImpl*>(this)->SetupSymSizeProperties();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Krovatkin Is the plan to always use symsize even for the static ints?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in his later PR, @miladm is using is_dynamic to decide whether we need to create symint nodes or static ints :

https://github.com/pytorch/xla/pull/3909/files#diff-c4e1dd39b63d78af7c207b2d48ac29553d74214b1c185ae34e084dd2f583879eR197

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, @miladm is forcing both sym_sizes_ and sizes_and_strides_ to be filled with upper bounds. There are no real symint nodes yet.

numel_ *= tensor_->shape().get().dimensions(i);
// }
}
sizes_and_strides_.set_sizes(sym_sizes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Krovatkin From what I can tell this pr switch to set size to set sym_sizes for tensorImpl's sizes_and_strides_. I am guessing upstream already support only setting sym_size?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell this pr switch to set size to set sym_sizes for tensorImpl's sizes_and_strides_.

Kind of. This PR is the first step in enabling real dynamic shapes. Note, we have two separate storages here:

  • sizes_and_strides_
  • sym_sizes_

Unfortunately, we can't just make C++ sizes() to start throwing a "NotImplemented" exception, otherwise we would break a lot of code. We want sizes() to return upper bounds.
This is why @miladm is populating set_sizes_and_strides with concrete integers wrapped in SymInts. You could think of this logic as just saving upper bounds; we just need to wrap them in SymInts since upstream is now using the unified type: SymInt.

Now the doubly unfortunate part. We know that upstream can store real symintnodes in sizes_and_strides_ wrapped in SymInts. However, we can't take advantage of that, because if we store real SymIntNodes in sizes_and_strides_ and someone calls sizes() in C++ this would trigger conversions to ints on SymIntNodes which would trigger materialization.

So we decided to have separate storage for sym_sizes() namely sym_sizes_.

This way when someone calls size() in python, we would call XLATensor::sym_sizes_custom which would return sym_sizes_ which may contain both concrete ints and symintnodes.

Now if a user uses one of those SymIntNodes in python it will trigger materialization of the SymIntNode (since we do want the exact result at least according to our discussions).

Phew....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JackCaoG the part I'm not 100% sure is why we need to populate sizes_and_strides_ in addition to sym_sizes_ when sym_sizes as called. I'd think when someone calls sizes() that would update sizes_and_strides_ when someone calls sym_sizes that would set sym_sizes_.
Presumably, @miladm ran into some issues so we do need to be setting both when `sym_sizes is called.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment