Skip to content

Missed optimization: a[mask] = b -> a = torch.where(mask, b, a) #4248

Open
@pscollins

Description

@pscollins

🐛 Bug

I'm trying to get an existing model running under pytorch/XLA that uses the construct a[mask] = b frequently, which seems to be a bottleneck. I'm guessing that this is (as far as I can imagine, unnecessarily) becomes something like "create a slice with a dynamic shape, then write to it" and triggers recompilation. I'm running against my CPU currently.

To Reproduce

I tried to demonstrate the issue with a microbenchmark:

import timeit

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


DEST_SHAPE = (100000,)
THRESHOLD = .5

ITERS = 1000

def test_store_getitem(device):
    dest = torch.rand(DEST_SHAPE, device=device)
    mask = torch.rand(DEST_SHAPE, device=device) > THRESHOLD

    dest[mask] = 1.0
    if device != 'cpu':
        xm.mark_step()

def test_store_where(device):
    dest = torch.rand(DEST_SHAPE, device=device)
    mask = torch.rand(DEST_SHAPE, device=device) > THRESHOLD

    dest = torch.where(mask, 1.0, dest)
    if device != 'cpu':
        xm.mark_step()


def main():
    xla_device = xm.xla_device()
    cpu_device = 'cpu'

    scope = dict(globals(), **locals())
    print('CPU: getitem')
    print(timeit.timeit('test_store_getitem(cpu_device)', number=ITERS, globals=scope))
    print('CPU: where')
    print(timeit.timeit('test_store_where(cpu_device)', number=ITERS, globals=scope))

    print('XLA/CPU: getitem')
    print(timeit.timeit('test_store_getitem(xla_device)', number=ITERS, globals=scope))
    print(met.metrics_report())
    print('XLA/CPU: where')
    print(timeit.timeit('test_store_where(xla_device)', number=ITERS, globals=scope))
    print(met.metrics_report())


if __name__ == '__main__':
    main()

Output (times only):

CPU: getitem
1.335198831744492
CPU: where
1.4740506513044238
XLA/CPU: getitem                                                                                                                                                                                                   
84.45449290703982 
XLA/CPU: where                                                                                                                                                                                                     
1.3696353798732162 

Grabbing some snippets from the two metrics output, the first call reports:

Counter: CachedCompile
  Value: 505                                                                                                                                                                                                       
Counter: CreateXlaTensor                                                                                                                                                                                           
  Value: 11000                                                                                                                                                                                                     
Counter: DestroyXlaTensor
  Value: 11000               
Counter: DeviceDataCacheMiss
  Value: 1         
Counter: OpByOpCompileCacheMiss                                                                          
  Value: 12                                                                                                                                                                                                        
Counter: UncachedCompile                                                                                                                                                                                           
  Value: 495                                                                                                                                                                                                       
Metric: CompileTime                                                                                                                                                                                                
  TotalSamples: 496                                                                                                                                                                                                
  Accumulator: 01m11s847ms893.248us                                                                                                                                                                                
  ValueRate: 840ms091.405us / second
  Rate: 5.88149 / second                                                                                 
  Percentiles: 1%=136ms016.889us; 5%=138ms843.705us; 10%=139ms704.928us; 20%=140ms990.267us; 50%=142ms292.030us; 80%=145ms245.633us; 90%=149ms023.811us; 95%=151ms380.062us; 99%=156ms496.857us

and the seconds reports:

Counter: CachedCompile
  Value: 1504                                                                                                                                                                                                      
Counter: CreateXlaTensor                                                                                                                                                                                           
  Value: 19000                                                                                                                                                                                                     
Counter: DestroyXlaTensor
  Value: 19000               
Counter: DeviceDataCacheMiss
  Value: 1         
Counter: OpByOpCompileCacheMiss                                                                          
  Value: 12                                                                                                                                                                                                        
Counter: UncachedCompile                                                                                                                                                                                           
  Value: 496                                                                                                                                                                                                       
Metric: CompileTime                                                                                                                                                                                                
  TotalSamples: 497                                                                                                                                                                                                
  Accumulator: 01m11s980ms339.237us                                                                                                                                                                                
  ValueRate: 840ms942.688us / second
  Rate: 5.88123 / second                                                                                 
  Percentiles: 1%=136ms856.889us; 5%=138ms782.989us; 10%=139ms595.450us; 20%=140ms977.095us; 50%=142ms271.090us; 80%=145ms245.633us; 90%=149ms023.811us; 95%=151ms380.062us; 99%=156ms496.857us

Expected behavior

Ideally I think that the getitem and where variants should have roughly equal performance.

Environment

  • Reproducible on XLA backend [CPU/TPU]: CPU
  • torch_xla version: torch_xla/version.py shows:
# Autogenerated file, do not edit!
__version__ = '1.14'
__xla_gitrev__ = 'f790bc8ac411a8e6903b89adf7610b812996537b'
__torch_gitrev__ = '7b0d577c226fae78f377b26feab4122c4203ad59'

Metadata

Metadata

Assignees

Labels

dynamismDynamic Shape FeaturesnostaleDo not consider for staleness

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions