Skip to content

[Dynamic Shape] Reshape does not work with dynamic input + symint target size #4201

Open
@JackCaoG

Description

@JackCaoG

🐛 Bug

I would expect below code to emit some meaningful error message since we haven't support reshape.symint(in fact I don't know if such op existed).

import torch
import torch_xla
import torch_xla.core.xla_model as xm

import os
os.environ["XLA_EXPERIMENTAL"] = "nonzero:masked_select:masked_scatter"

dev = xm.xla_device()
size1 = 5
size2 = 2
t1 = torch.zeros([size1, size2], device=dev)
t1[3][0] = 1
# t2 has size [<=10, 2]
t2 = torch.nonzero(t1)
t3 = t2.reshape([t2.size(0) * t2.size(1)])

However I get

Traceback (most recent call last):
  File "test/test_dynamic_fail.py", line 16, in <module>
    t3 = t2.reshape([t2.size(0) * t2.size(1)])
RuntimeError: RuntimeError: NYI

which I am not sure what's the real error under the hood.

@miladm @vanbasten23 @Krovatkin @wconstab

Metadata

Metadata

Assignees

Labels

BLOCKEDdynamismDynamic Shape FeaturestriagedThis issue has been reviewed by the triage team and the appropriate priority assigned.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions