-
Notifications
You must be signed in to change notification settings - Fork 68
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
RecursionError: maximum recursion depth exceeded in ResNet
To Reproduce
Steps to reproduce the behavior. A small and reproducible script would be very helpful.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param stride: Stride for the first convolutional layer
:param downsample: Downsample layer for the shortcut connection
"""
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
block = Bottleneck
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * block.expansion),
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
"""
:param x: Input tensor, shape (batch_size, 3, height, width)
:return: Output tensor, shape (batch_size, num_classes)
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Test code
batch_size = 10
height = 224
width = 224
layers = [3, 4, 23, 3]
num_classes = 1000
def get_inputs():
return [torch.rand(batch_size, 3, height, width, device=torch.device("cuda"))]
def get_init_inputs():
return [layers, num_classes]
import hidet
hidet.torch.dynamo_config.search_space(2) # tune each tunable operator
model = Model(*get_init_inputs())
model.eval()
model.to(torch.device("cuda"))
model = torch.compile(model, backend="hidet")
model(*get_inputs()) File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
return self.visit(obj)
^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 35, in visit
self.visit_Tensor(obj)
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 54, in visit_Tensor
self(tensor.trace[0])
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
return self.visit(obj)
^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 33, in visit
self.visit_Operator(obj)
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 158, in visit_Operator
GraphVisitor.visit_Operator(self, op)
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 49, in visit_Operator
self(inp)
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
return self.visit(obj)
^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 35, in visit
self.visit_Tensor(obj)
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 54, in visit_Tensor
self(tensor.trace[0])
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
return self.visit(obj)
^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 33, in visit
self.visit_Operator(obj)
File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 157, in visit_Operator
self.usage[inp].append((op, idx))
~~~~~~~~~~^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='hidet' raised:
RecursionError: maximum recursion depth exceeded
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"\
Enviroment
- OS: [Ubuntu 22.04]
- GPU: [e.g. A800]
- Others: [e.g. NVIDIA GPU Driver 580.82.07]
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working