Skip to content

[Bug] RecursionError: maximum recursion depth exceeded in ResNet #504

@Qi-Zhan

Description

@Qi-Zhan

Describe the bug
RecursionError: maximum recursion depth exceeded in ResNet

To Reproduce
Steps to reproduce the behavior. A small and reproducible script would be very helpful.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        """
        :param in_channels: Number of input channels
        :param out_channels: Number of output channels
        :param stride: Stride for the first convolutional layer
        :param downsample: Downsample layer for the shortcut connection
        """
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        block = Bottleneck

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        """
        :param x: Input tensor, shape (batch_size, 3, height, width)
        :return: Output tensor, shape (batch_size, num_classes)
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

# Test code
batch_size = 10
height = 224
width = 224
layers = [3, 4, 23, 3]
num_classes = 1000

def get_inputs():
    return [torch.rand(batch_size, 3, height, width, device=torch.device("cuda"))]

def get_init_inputs():
    return [layers, num_classes]


import hidet
hidet.torch.dynamo_config.search_space(2)  # tune each tunable operator
model = Model(*get_init_inputs())
model.eval()
model.to(torch.device("cuda"))
model = torch.compile(model, backend="hidet")
model(*get_inputs())
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
    return self.visit(obj)
           ^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 35, in visit
    self.visit_Tensor(obj)
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 54, in visit_Tensor
    self(tensor.trace[0])
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
    return self.visit(obj)
           ^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 33, in visit
    self.visit_Operator(obj)
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 158, in visit_Operator
    GraphVisitor.visit_Operator(self, op)
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 49, in visit_Operator
    self(inp)
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
    return self.visit(obj)
           ^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 35, in visit
    self.visit_Tensor(obj)
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 54, in visit_Tensor
    self(tensor.trace[0])
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 24, in __call__
    return self.visit(obj)
           ^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 33, in visit
    self.visit_Operator(obj)
  File "/root/miniconda3/lib/python3.12/site-packages/hidet/graph/graph_utils/functors.py", line 157, in visit_Operator
    self.usage[inp].append((op, idx))
    ~~~~~~~~~~^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='hidet' raised:
RecursionError: maximum recursion depth exceeded

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"\

Enviroment

  • OS: [Ubuntu 22.04]
  • GPU: [e.g. A800]
  • Others: [e.g. NVIDIA GPU Driver 580.82.07]

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions