forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
🐛 Describe the bug
The issue was coming from lots of pytorch slow_tests.
%kernel {
T25_l[ ithreadIdx.x212{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS207{4}, iUS209{1}, iUR211{4} ]
= T4_g[ iS220{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS215{4}, iS217{1}, iS219{4} ];
T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ]
= T25_l[ ithreadIdx.x212{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS207{4}, iUS209{1}, iUR211{4} ];
T14_l[ ithreadIdx.x188{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS183{4}, iUS185{1}, iS187{4} ] ca_pos( 4 )
= T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ];
T23_l[ ithreadIdx.x100{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS95{4}, iUS97{1}, iUR99{4} ]
= T2_g[ iS108{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS103{4}, iS105{1}, iS107{4} ];
T22_l[ 0 ]
= T1_g[ 0 ];
T6_l[ bthreadIdx.x92{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS87{4}, bUS89{1}, bS91{4} ]
= broadcast( T22_l[ 0 ] )
T7_l[ ithreadIdx.x84{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS79{4}, iUS81{1}, iS83{4} ] ca_pos( 4 )
= where(T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ]
, T23_l[ ithreadIdx.x100{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS95{4}, iUS97{1}, iUR99{4} ]
, T6_l[ bthreadIdx.x92{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS87{4}, bUS89{1}, bS91{4} ]);
T27_l[ ithreadIdx.x75{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS70{4}rf, rUS72{1}rf, rS74{4}rf ] ca_pos( 1 ) produce_pos( 4 )
= reduction( T7_l[ ithreadIdx.x84{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS79{4}, iUS81{1}, iS83{4} ] ca_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T8_l[ rthreadIdx.x76{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 )
= reduction( T27_l[ ithreadIdx.x75{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS70{4}rf, rUS72{1}rf, rS74{4}rf ] ca_pos( 1 ) produce_pos( 4 ), op = add, initial value = double(0),
allreduce = false )
T10_l[ bthreadIdx.x156{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS151{4}, bUS153{1}, bS155{4} ]
= broadcast( T8_l[ rthreadIdx.x76{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 ) )
T24_l[ bthreadIdx.x140{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS135{4}, bUS137{1}, bUR139{4} ]
= T3_g[ bS148{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS143{4}, bS145{1}, bS147{4} ];
T9_l[ bthreadIdx.x132{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS127{4}, bUS129{1}, bS131{4} ]
= (float)(T24_l[ bthreadIdx.x140{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS135{4}, bUS137{1}, bUR139{4} ]);
T11_l[ bthreadIdx.x124{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS119{4}, bUS121{1}, bS123{4} ]
= T10_l[ bthreadIdx.x156{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS151{4}, bUS153{1}, bS155{4} ]
/ T9_l[ bthreadIdx.x132{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS127{4}, bUS129{1}, bS131{4} ];
T12_l[ ithreadIdx.x116{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS111{4}, iUS113{1}, iS115{4} ] ca_pos( 4 )
= T23_l[ ithreadIdx.x100{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS95{4}, iUS97{1}, iUR99{4} ]
- T11_l[ bthreadIdx.x124{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS119{4}, bUS121{1}, bS123{4} ];
T13_l[ ithreadIdx.x164{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS159{4}, iUS161{1}, iS163{4} ] ca_pos( 4 ) produce_pos( 4 )
= T12_l[ ithreadIdx.x116{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS111{4}, iUS113{1}, iS115{4} ] ca_pos( 4 )
* T12_l[ ithreadIdx.x116{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS111{4}, iUS113{1}, iS115{4} ] ca_pos( 4 );
T15_l[ bthreadIdx.x180{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS175{4}, bUS177{1}, bS179{4} ]
= broadcast( T22_l[ 0 ] )
T16_l[ ithreadIdx.x172{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS167{4}, iUS169{1}, iS171{4} ] ca_pos( 4 ) produce_pos( 4 )
= where(T14_l[ ithreadIdx.x188{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS183{4}, iUS185{1}, iS187{4} ] ca_pos( 4 )
, T13_l[ ithreadIdx.x164{( ceilDiv(( ceilDiv(( ceilDiv(( i0 * ( i1 * i2 ) ), 4) ), 1) ), 4) )}, iS159{4}, iUS161{1}, iS163{4} ] ca_pos( 4 ) produce_pos( 4 )
, T15_l[ bthreadIdx.x180{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 1) ), 4) )}, bS175{4}, bUS177{1}, bS179{4} ]);
T28_l[ ithreadIdx.x231{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS226{4}rf, rUS228{1}rf, rS230{4}rf ] ca_pos( 1 ) produce_pos( 4 )
= reduction( T16_l[ ithreadIdx.x172{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}, iS167{4}, iUS169{1}, iS171{4} ] ca_pos( 4 ) produce_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T17_l[ rthreadIdx.x232{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 )
= reduction( T28_l[ ithreadIdx.x231{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )}rf, rS226{4}rf, rUS228{1}rf, rS230{4}rf ] ca_pos( 1 ) produce_pos( 4 ), op = add, initial value = double(0), allreduce = false )
T21_l[ 0 ]
= T0_g[ 0 ];
T18_l[ 0 ]
= T17_l[ rthreadIdx.x232{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * i2 ) ), 4) ), 1) ), 4) )} ] produce_pos( 1 )
/ T21_l[ 0 ];
T19_l[ 0 ]
= T18_l[ 0 ];
T26_l[ 0 ]
= sqrtf(T19_l[ 0 ]);
T20_g[ 0 ]
= T26_l[ 0 ];
TransformPrinter :
T4_g[ iS220{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, iS215{4}, iS217{1}, iS219{4} ]
root domain : (iS6{i5},iS7{i6},bS8{1})
contiguity: 1 1
Merge: iS7{i6} and bS8{1} -> iS213{( i6 * 1 )}
Merge: iS6{i5} and iS213{( i6 * 1 )} -> iS214{( i5 * ( i6 * 1 ) )}
Outer split: iS214{( i5 * ( i6 * 1 ) )} by factor 4 -> iS215{4}, iS216{( ceilDiv(( i5 * ( i6 * 1 ) ), 4) )}, start offset: 0, stop offset: 0
Outer split: iS216{( ceilDiv(( i5 * ( i6 * 1 ) ), 4) )} by factor 1 -> iS217{1}, iS218{( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) )}, start offset: 0, stop offset: 0
Outer split: iS218{( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) )} by factor 4 -> iS219{4}, iS220{( ceilDiv(( ceilDiv(( ceilDiv(( i5 * ( i6 * 1 ) ), 4) ), 1) ), 4) )}, start offset: 0, stop offset: 0
RuntimeError: producer->getMemoryType() == MemoryType::Global || producer->getMemoryType() == MemoryType::Shared INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch-jit/third_party/nvfuser/csrc/lower_sync_information.cpp":772, please report a bug to PyTorch. Inconsistent parallelization found between TV5 (T5_l[ ithreadIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( T2.size[0] * ( T2.size[1] * 1 ) ), 4) ), 1) ), 4) )}, iS199{4}, iUS201{1}, iS203{4} ]) and TV14(T14_l[ ithreadIdx.x188{( ceilDiv(( ceilDiv(( ceilDiv(( T2.size[0] * ( T2.size[1] * 1 ) ), 4) ), 1) ), 4) )}, iS183{4}, iUS185{1}, iS187{4} ] ca_pos( 4 )). Producer is required to be in Global or Shared Memory based on parallelization strategy. RAW flags: (threadIdx.x)
Repro script with python API:
import torch
from nvfuser import FusionDefinition, DataType
inputs = [
torch.tensor(1.5, device="cuda"),
torch.tensor(2.5, device="cuda"),
torch.randn(5, 5, 5, device='cuda'),
torch.randint(0, 23, (1, 1, 1), device="cuda"),
#torch.randint(0, 1, (5, 5), device="cuda").bool().unsqueeze(-1).expand(5, 5, 5),
torch.randint(0, 1, (5, 5), device="cuda").bool().unsqueeze(-1),
]
print(inputs[4].shape)
print(inputs[4].stride())
def nvfuser_fusion(fd : FusionDefinition) -> None :
T0 = fd.from_pytorch(inputs[0])
T1 = fd.from_pytorch(inputs[1])
T2 = fd.from_pytorch(inputs[2])
T3 = fd.from_pytorch(inputs[3])
T4 = fd.from_pytorch(inputs[4])
#T4 = fd.define_tensor((-1, -1, 1), (True, True), dtype=DataType.Bool)
T5 = fd.ops.set(T4)
T7 = fd.ops.where(T5, T2, T1)
T8 = fd.ops.sum(T7)
T10 = fd.ops.cast(T3, DataType.Float)
T11 = fd.ops.div(T8, T10)
T12 = fd.ops.sub(T2, T11)
T13 = fd.ops.mul(T12, T12)
T14 = fd.ops.set(T5)
T16 = fd.ops.where(T14, T13, T1)
T17 = fd.ops.sum(T16)
T18 = fd.ops.div(T17, T0)
T19 = fd.ops.set(T18)
T20 = fd.ops.sqrt(T19)
fd.add_output(T20)
f = FusionDefinition()
with FusionDefinition() as fd:
nvfuser_fusion(fd)
for _ in range(5) :
o = fd.execute(inputs)
print(fd)
print(o[0])
Versions
repro on devel.
Metadata
Metadata
Assignees
Labels
No labels