diff --git a/pytorch_to_returnn/torch/nn/modules/module.py b/pytorch_to_returnn/torch/nn/modules/module.py index 1f9e920..312dc24 100644 --- a/pytorch_to_returnn/torch/nn/modules/module.py +++ b/pytorch_to_returnn/torch/nn/modules/module.py @@ -526,14 +526,24 @@ def direct_returnn_layer_call(cls) -> bool: """ if not cls.has_torch_forward(): return True - base = cls - while base is not object: + + queue = [cls] + visited = set() + while len(queue) > 0: + base = queue.pop(0) + if base in visited: + continue + visited.add(base) + + if cls is object: + return True + if not issubclass(base, Module): + continue if cls.create_returnn_layer_dict != base.create_returnn_layer_dict: return True elif cls.forward != base.forward: return False - assert len(base.__bases__) == 1, "Not implemented otherwise" - base = base.__bases__[0] + queue += [base for base in cls.__bases__ if issubclass(base, Module)] return True def check_returnn_layer(self, layer: LayerBase):