diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 353d5695f7..06ce3366cd 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -1222,12 +1222,15 @@ def _construct_template(self, parent_get_layer): from collections import OrderedDict from returnn.util.basic import StringIO, BehaviorVersion from returnn.tf.network import NetworkConstructionDependencyLoopException, DataNotFound + from returnn.tf.network import _DelayedConstructionException # The stack trace is not so interesting for these exceptions. skip_stack_trace_exception_types = ( NetworkConstructionDependencyLoopException,) # These Exceptions always indicate incorrect construction, so fail directly instead of collecting them - fail_directly_exception_types = (DataNotFound, LayerNotFound, BehaviorVersion.RequirementNotSatisfied) + fail_directly_exception_types = ( + DataNotFound, LayerNotFound, BehaviorVersion.RequirementNotSatisfied, + _DelayedConstructionException) # noinspection PyShadowingNames def _parent_get_layer(layer_name): @@ -1644,8 +1647,11 @@ def __call__(lself, name, is_prev_time_frame=False): # And keep the remaining ones for potential later reports. self._template_construction_exceptions = [s.text for s in ConstructCtx.collected_exceptions.values()] - except Exception: - print("%r: exception constructing template network (for deps and data shapes)" % self) + except _DelayedConstructionException: + raise + + except Exception as exc: + print("%r: %s while constructing template network (for deps and data shapes)" % (self, type(exc).__name__)) from pprint import pprint print("Most recent construction stack:") if ConstructCtx.most_recent: diff --git a/returnn/tf/network.py b/returnn/tf/network.py index 68f578ab55..a41362732d 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -323,9 +323,24 @@ class _NetworkConstructionStack: Used to keep the recursive construction state of :function:`TFNetwork.construct_layer`. """ - def __init__(self): + # This assumes that we do single-threaded net construction. + # For multi-threading (if this would ever be realistic for net construction), + # we would need this to be a thread local. + # We still have a stack here for flat_construction() because we need to be nested + # for things like CondLayer or RecLayer. + _flat_construction_stack = [] # type: typing.List[_NetworkConstructionStack] + + def __init__(self, network): + """ + :param TFNetwork network: + """ + self.network = network self.layers = [] # type: typing.List[str] - self.in_flat_construct_count = 0 + self.flat_construct_stack = [] # type: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any]]] + + def __repr__(self): + return "<%s %r (cur stack size: %i)>" % ( + self.__class__.__name__, self.network.name, len(self.flat_construct_stack)) def append(self, layer_name): """ @@ -340,23 +355,70 @@ def remove(self, layer_name): """ self.layers.remove(layer_name) - def flat_construct(self, initial): + def on_construct_layer_call(self, exc): """ - :param _DelayedConstructionException initial: + This covers the whole flat construction logic. + If this returns None, it means that the normal construction should follow. + If a layer is returned, this can directly be returned. + Otherwise, this will not return but throw the exception which is handled outside. + + :param _DelayedConstructionException exc: + :rtype: LayerBase|None """ - self.in_flat_construct_count += 1 - queue = [initial] # type: typing.List[_DelayedConstructionException] + cls = self.__class__ + if not cls._flat_construction_stack or cls._flat_construction_stack[-1] is not self: + return self._flat_construct(exc) + + assert exc.network is self.network + if self.flat_construct_stack: + if exc.layer_name == self.flat_construct_stack[-1][0]: + return None # continue with construction + + existing_in_stack = [entry for entry in self.flat_construct_stack if exc.layer_name == entry[0]] + if existing_in_stack: + raise NetworkConstructionDependencyLoopException( + layer_name=exc.layer_name, constructing_layers=[entry[0] for entry in self.flat_construct_stack], + net_dict=existing_in_stack[0][1]["net_dict"], network=self.network) + + raise exc + + def _flat_construct(self, initial_exc): + """ + :param _DelayedConstructionException initial_exc: + :rtype: LayerBase + """ + cls = self.__class__ + assert initial_exc.network is self.network + stack = self.flat_construct_stack + initial = (initial_exc.layer_name, initial_exc.other_kwargs) + stack.append(initial) + stack_init_idx = len(stack) - 1 + cls._flat_construction_stack.append(self) try: - while queue: + while stack: try: - res = queue[-1].delayed_construction() - if queue[-1] is initial: + stack_top_idx = len(stack) - 1 + top = stack[stack_top_idx] + layer_name, other_kwargs = top + res = self.network.construct_layer(name=layer_name, **other_kwargs) + assert stack_top_idx == len(stack) - 1 + stack.pop(-1) + if top is initial: return res - queue.pop(-1) except _DelayedConstructionException as delayed_exc: - queue.append(delayed_exc) + # See on_construct_layer_call(). + assert delayed_exc.network is self.network # we should be in another flat_construct() otherwise + stack.append((delayed_exc.layer_name, delayed_exc.other_kwargs)) + except Exception as exc: + attr = "_RETURNN_layer_construction_stack" + if not hasattr(exc, attr): + setattr(exc, attr, []) + getattr(exc, attr).extend([(self.network, layer_name) for (layer_name, _) in stack]) + raise finally: - self.in_flat_construct_count -= 1 + top_stack = cls._flat_construction_stack.pop(-1) + assert top_stack is self + del stack[stack_init_idx:] assert False, "we should not get here" @@ -458,7 +520,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None, self.extra_nets = {} # type: typing.Dict[str,TFNetwork] self.subnets = {} # type: typing.Dict[str,Subnetwork] self._selected_train_layers = None - self._construction_stack = _NetworkConstructionStack() + self._construction_stack = _NetworkConstructionStack(self) self.layers_desc = {} # type: typing.Dict[str,typing.Dict[str]] self.layers = {} # type: typing.Dict[str,LayerBase] self.losses_dict = {} # type: typing.Dict[str,LossHolder] @@ -786,17 +848,6 @@ def get_layer(src_name): self.used_data_keys.update(extra_net.used_data_keys) return created_layers - def _flat_construction_enabled(self): - """ - :return: whether to use flat construction algorithm in :func:`construct_layer`. - Use this if you get stack overflow errors, such as: - ``Fatal Python error: Cannot recover from stack overflow`` - or - ``RuntimeError: maximum recursion depth exceeded``. - :rtype: bool - """ - return self.get_config().bool("flat_net_construction", False) - def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_existing=True): """ This triggers the construction of the layer `name` if it is not constructed yet. @@ -829,6 +880,9 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ return self.get_layer(name) except (LayerNotFound, DataNotFound): pass # ok, we will try to construct it then + delayed_exc = _DelayedConstructionException( + network=self, layer_name=name, # make sure that we have all the original args + other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing)) if not get_layer: get_layer = GetLayer(network=self, add_layer_func=add_layer) full_name = name @@ -918,15 +972,6 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ layer_name=full_name, network=self) return sub_layer - if self._flat_construction_enabled(): - delayed_exc = _DelayedConstructionException( - network=self, layer_name=name, - other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing)) - if not self._construction_stack.in_flat_construct_count: - return self._construction_stack.flat_construct(delayed_exc) - if self._construction_stack.layers: - raise delayed_exc - layer_desc = layer_desc.copy() layer_desc.pop("class") # Note about name: @@ -936,10 +981,14 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_ layer_desc["_network"] = net layer_desc["_name"] = base_name name_with_prefix = ("%s:%s" % (extra_prefix, name)) if extra_prefix else name - if name_with_prefix in self._construction_stack.layers: - raise NetworkConstructionDependencyLoopException( - layer_name=name_with_prefix, constructing_layers=self._construction_stack.layers, - net_dict=net_dict, network=self) + + # Note: We don't want to raise this earlier here in this function + # because certain exceptions such as LayerNotFound should directly be raised + # because some other code tests for this + # (e.g. checking the loss checking for layer "classes" and then layer "data:classes"). + _constructed_layer = self._construction_stack.on_construct_layer_call(delayed_exc) + if _constructed_layer: + return _constructed_layer self._construction_stack.append(name_with_prefix) try: # This call would also resolve dependencies, and e.g. recursively then create them (via get_layer calls). @@ -3102,6 +3151,9 @@ def add_templated_layer(name, layer_class, **layer_desc): if layer.get("is_output_layer"): get_templated_layer(layer_name) + except _DelayedConstructionException: + raise + except Exception as exc: # Merge the exception message + further debug information all together into a single exception, # which we will raise. @@ -3624,16 +3676,7 @@ def __init__(self, network, layer_name, other_kwargs): self.other_kwargs = other_kwargs def __repr__(self): - return "%s(layer_name=%r)" % (self.__class__.__name__, self.layer_name) - - def delayed_construction(self): - """ - Call :func:`TFNetwork.construct_layer` again now. - - :rtype: LayerBase - """ - print("Delayed flat layer construction:", self.layer_name, file=log.v5) - return self.network.construct_layer(name=self.layer_name, **self.other_kwargs) + return "<%s %r/%r>" % (self.__class__.__name__, self.network.name, self.layer_name) class LayerNotFound(NetworkLayerException): diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index dc724adaa2..b02f992448 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -5640,7 +5640,6 @@ def test_flat_net_construction(): "data": (n_in, 2), "classes": (n_out, 1), }, - "flat_net_construction": True, "debug_print_layer_output_template": True, }) print("Creating network...")