Open
Description
🐛 Bug
I'm trying to get an existing model running under pytorch/XLA that uses the construct a[mask] = b
frequently, which seems to be a bottleneck. I'm guessing that this is (as far as I can imagine, unnecessarily) becomes something like "create a slice with a dynamic shape, then write to it" and triggers recompilation. I'm running against my CPU currently.
To Reproduce
I tried to demonstrate the issue with a microbenchmark:
import timeit
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
DEST_SHAPE = (100000,)
THRESHOLD = .5
ITERS = 1000
def test_store_getitem(device):
dest = torch.rand(DEST_SHAPE, device=device)
mask = torch.rand(DEST_SHAPE, device=device) > THRESHOLD
dest[mask] = 1.0
if device != 'cpu':
xm.mark_step()
def test_store_where(device):
dest = torch.rand(DEST_SHAPE, device=device)
mask = torch.rand(DEST_SHAPE, device=device) > THRESHOLD
dest = torch.where(mask, 1.0, dest)
if device != 'cpu':
xm.mark_step()
def main():
xla_device = xm.xla_device()
cpu_device = 'cpu'
scope = dict(globals(), **locals())
print('CPU: getitem')
print(timeit.timeit('test_store_getitem(cpu_device)', number=ITERS, globals=scope))
print('CPU: where')
print(timeit.timeit('test_store_where(cpu_device)', number=ITERS, globals=scope))
print('XLA/CPU: getitem')
print(timeit.timeit('test_store_getitem(xla_device)', number=ITERS, globals=scope))
print(met.metrics_report())
print('XLA/CPU: where')
print(timeit.timeit('test_store_where(xla_device)', number=ITERS, globals=scope))
print(met.metrics_report())
if __name__ == '__main__':
main()
Output (times only):
CPU: getitem
1.335198831744492
CPU: where
1.4740506513044238
XLA/CPU: getitem
84.45449290703982
XLA/CPU: where
1.3696353798732162
Grabbing some snippets from the two metrics output, the first call reports:
Counter: CachedCompile
Value: 505
Counter: CreateXlaTensor
Value: 11000
Counter: DestroyXlaTensor
Value: 11000
Counter: DeviceDataCacheMiss
Value: 1
Counter: OpByOpCompileCacheMiss
Value: 12
Counter: UncachedCompile
Value: 495
Metric: CompileTime
TotalSamples: 496
Accumulator: 01m11s847ms893.248us
ValueRate: 840ms091.405us / second
Rate: 5.88149 / second
Percentiles: 1%=136ms016.889us; 5%=138ms843.705us; 10%=139ms704.928us; 20%=140ms990.267us; 50%=142ms292.030us; 80%=145ms245.633us; 90%=149ms023.811us; 95%=151ms380.062us; 99%=156ms496.857us
and the seconds reports:
Counter: CachedCompile
Value: 1504
Counter: CreateXlaTensor
Value: 19000
Counter: DestroyXlaTensor
Value: 19000
Counter: DeviceDataCacheMiss
Value: 1
Counter: OpByOpCompileCacheMiss
Value: 12
Counter: UncachedCompile
Value: 496
Metric: CompileTime
TotalSamples: 497
Accumulator: 01m11s980ms339.237us
ValueRate: 840ms942.688us / second
Rate: 5.88123 / second
Percentiles: 1%=136ms856.889us; 5%=138ms782.989us; 10%=139ms595.450us; 20%=140ms977.095us; 50%=142ms271.090us; 80%=145ms245.633us; 90%=149ms023.811us; 95%=151ms380.062us; 99%=156ms496.857us
Expected behavior
Ideally I think that the getitem
and where
variants should have roughly equal performance.
Environment
- Reproducible on XLA backend [CPU/TPU]: CPU
- torch_xla version:
torch_xla/version.py
shows:
# Autogenerated file, do not edit!
__version__ = '1.14'
__xla_gitrev__ = 'f790bc8ac411a8e6903b89adf7610b812996537b'
__torch_gitrev__ = '7b0d577c226fae78f377b26feab4122c4203ad59'