Skip to content

codegen error: RuntimeError: producer->getMemoryType() == MemoryType::Global || producer->getMemoryType() == MemoryType::Shared INTERNAL ASSERT FAILED #2559

@jjsjann123

Description

@jjsjann123

🐛 Describe the bug

The issue was coming from lots of pytorch slow_tests.

%kernel {
T25_l[ ithreadIdx.x212{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS207{4}, iUS209{1}, iUR211{4} ]
   = T4_g[ iS220{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS215{4}, iS217{1}, iS219{4} ];
T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ]
   = T25_l[ ithreadIdx.x212{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS207{4}, iUS209{1}, iUR211{4} ];
T14_l[ ithreadIdx.x188{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS183{4}, iUS185{1}, iS187{4} ] ca_pos( 4 )
   = T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ];
T23_l[ ithreadIdx.x100{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS95{4}, iUS97{1}, iUR99{4} ]
   = T2_g[ iS108{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS103{4}, iS105{1}, iS107{4} ];
T22_l[ 0 ]
   = T1_g[ 0 ];
T6_l[ bthreadIdx.x92{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS87{4}, bUS89{1}, bS91{4} ]
   = broadcast( T22_l[ 0 ] )
T7_l[ ithreadIdx.x84{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS79{4}, iUS81{1}, iS83{4} ] ca_pos( 4 )
   = where(T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ]
  , T23_l[ ithreadIdx.x100{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS95{4}, iUS97{1}, iUR99{4} ]
  , T6_l[ bthreadIdx.x92{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS87{4}, bUS89{1}, bS91{4} ]);
T27_l[ ithreadIdx.x75{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS70{4}rf, rUS72{1}rf, rS74{4}rf ] ca_pos( 1 ) produce_pos( 4 )
   = reduction( T7_l[ ithreadIdx.x84{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS79{4}, iUS81{1}, iS83{4} ] ca_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T8_l[ rthreadIdx.x76{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 )
   = reduction( T27_l[ ithreadIdx.x75{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS70{4}rf, rUS72{1}rf, rS74{4}rf ] ca_pos( 1 ) produce_pos( 4 ), op = add, initial value = double(0), 
allreduce = false )
T10_l[ bthreadIdx.x156{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS151{4}, bUS153{1}, bS155{4} ]
   = broadcast( T8_l[ rthreadIdx.x76{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 ) )
T24_l[ bthreadIdx.x140{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS135{4}, bUS137{1}, bUR139{4} ]
   = T3_g[ bS148{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS143{4}, bS145{1}, bS147{4} ];
T9_l[ bthreadIdx.x132{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS127{4}, bUS129{1}, bS131{4} ]
   = (float)(T24_l[ bthreadIdx.x140{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS135{4}, bUS137{1}, bUR139{4} ]);
T11_l[ bthreadIdx.x124{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS119{4}, bUS121{1}, bS123{4} ]
   = T10_l[ bthreadIdx.x156{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS151{4}, bUS153{1}, bS155{4} ]
   / T9_l[ bthreadIdx.x132{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS127{4}, bUS129{1}, bS131{4} ];
T12_l[ ithreadIdx.x116{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS111{4}, iUS113{1}, iS115{4} ] ca_pos( 4 )
   = T23_l[ ithreadIdx.x100{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS95{4}, iUS97{1}, iUR99{4} ]
   - T11_l[ bthreadIdx.x124{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS119{4}, bUS121{1}, bS123{4} ];
T13_l[ ithreadIdx.x164{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS159{4}, iUS161{1}, iS163{4} ] ca_pos( 4 ) produce_pos( 4 )
   = T12_l[ ithreadIdx.x116{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS111{4}, iUS113{1}, iS115{4} ] ca_pos( 4 )
   * T12_l[ ithreadIdx.x116{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS111{4}, iUS113{1}, iS115{4} ] ca_pos( 4 );
T15_l[ bthreadIdx.x180{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS175{4}, bUS177{1}, bS179{4} ]
   = broadcast( T22_l[ 0 ] )
T16_l[ ithreadIdx.x172{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS167{4}, iUS169{1}, iS171{4} ] ca_pos( 4 ) produce_pos( 4 )
   = where(T14_l[ ithreadIdx.x188{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS183{4}, iUS185{1}, iS187{4} ] ca_pos( 4 )
  , T13_l[ ithreadIdx.x164{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS159{4}, iUS161{1}, iS163{4} ] ca_pos( 4 ) produce_pos( 4 )
  , T15_l[ bthreadIdx.x180{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS175{4}, bUS177{1}, bS179{4} ]);
T28_l[ ithreadIdx.x231{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS226{4}rf, rUS228{1}rf, rS230{4}rf ] ca_pos( 1 ) produce_pos( 4 )
   = reduction( T16_l[ ithreadIdx.x172{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS167{4}, iUS169{1}, iS171{4} ] ca_pos( 4 ) produce_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T17_l[ rthreadIdx.x232{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 )
   = reduction( T28_l[ ithreadIdx.x231{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS226{4}rf, rUS228{1}rf, rS230{4}rf ] ca_pos( 1 ) produce_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T21_l[ 0 ]
   = T0_g[ 0 ];
T18_l[ 0 ]
   = T17_l[ rthreadIdx.x232{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 )
   / T21_l[ 0 ];
T19_l[ 0 ]
   = T18_l[ 0 ];
T26_l[ 0 ]
   = sqrtf(T19_l[ 0 ]);
T20_g[ 0 ]
   = T26_l[ 0 ];

TransformPrinter :
T4_g[ iS220{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS215{4}, iS217{1}, iS219{4} ]
 root domain : (iS6{i5},iS7{i6},bS8{1})
 contiguity: 1 1
  Merge: iS7{i6} and bS8{1} -> iS213{( i6 * 1 )}
  Merge: iS6{i5} and iS213{( i6 * 1 )} -> iS214{( i5 * ( i6 * 1 ) )}
  Outer split: iS214{( i5 * ( i6 * 1 ) )} by factor 4 -> iS215{4}, iS216{( ceilDiv(( i5 * ( i6 * 1 ) ), 4) )}, start offset: 0, stop offset: 0
  Outer split: iS216{( ceilDiv(( i5 * ( i6 * 1 ) ), 4) )} by factor 1 -> iS217{1}, iS218{( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) )}, start offset: 0, stop offset: 0
  Outer split: iS218{( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) )} by factor 4 -> iS219{4}, iS220{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, start offset: 0, stop offset: 0

RuntimeError: producer->getMemoryType() == MemoryType::Global || producer->getMemoryType() == MemoryType::Shared INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch-jit/third_party/nvfuser/csrc/lower_sync_information.cpp":772, please report a bug to PyTorch. Inconsistent parallelization found between TV5 (T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( T2.size[0] * ( T2.size[1] * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ]) and TV14(T14_l[ ithreadIdx.x188{( ceilDiv(( ceilDiv(( ceilDiv(( T2.size[0] * ( T2.size[1] * 1 ) ), 4) ), 1) ), 4) )}, iS183{4}, iUS185{1}, iS187{4} ] ca_pos( 4 )). Producer is required to be in Global or Shared Memory based on parallelization strategy. RAW flags: (threadIdx.x)

Repro script with python API:

import torch
from nvfuser import FusionDefinition, DataType


inputs = [
    torch.tensor(1.5, device="cuda"),
    torch.tensor(2.5, device="cuda"),
    torch.randn(5, 5, 5, device='cuda'),
    torch.randint(0, 23, (1, 1, 1), device="cuda"),
    #torch.randint(0, 1, (5, 5), device="cuda").bool().unsqueeze(-1).expand(5, 5, 5),
    torch.randint(0, 1, (5, 5), device="cuda").bool().unsqueeze(-1),
]

print(inputs[4].shape)
print(inputs[4].stride())

def nvfuser_fusion(fd : FusionDefinition) -> None :
    T0 = fd.from_pytorch(inputs[0])
    T1 = fd.from_pytorch(inputs[1])
    T2 = fd.from_pytorch(inputs[2])
    T3 = fd.from_pytorch(inputs[3])
    T4 = fd.from_pytorch(inputs[4])
    #T4 = fd.define_tensor((-1, -1, 1), (True, True), dtype=DataType.Bool)
    T5 = fd.ops.set(T4)
    T7 = fd.ops.where(T5, T2, T1)
    T8 = fd.ops.sum(T7)
    T10 = fd.ops.cast(T3, DataType.Float)
    T11 = fd.ops.div(T8, T10)
    T12 = fd.ops.sub(T2, T11)
    T13 = fd.ops.mul(T12, T12)
    T14 = fd.ops.set(T5)
    T16 = fd.ops.where(T14, T13, T1)
    T17 = fd.ops.sum(T16)
    T18 = fd.ops.div(T17, T0)
    T19 = fd.ops.set(T18)
    T20 = fd.ops.sqrt(T19)
    fd.add_output(T20)

f = FusionDefinition()

with FusionDefinition() as fd:
    nvfuser_fusion(fd)

for _ in range(5) :
    o = fd.execute(inputs)

print(fd)
print(o[0])

Versions

repro on devel.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions