Skip to content

Always enable flat net construction #1002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
12 changes: 9 additions & 3 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,12 +1222,15 @@ def _construct_template(self, parent_get_layer):
from collections import OrderedDict
from returnn.util.basic import StringIO, BehaviorVersion
from returnn.tf.network import NetworkConstructionDependencyLoopException, DataNotFound
from returnn.tf.network import _DelayedConstructionException
# The stack trace is not so interesting for these exceptions.
skip_stack_trace_exception_types = (
NetworkConstructionDependencyLoopException,)

# These Exceptions always indicate incorrect construction, so fail directly instead of collecting them
fail_directly_exception_types = (DataNotFound, LayerNotFound, BehaviorVersion.RequirementNotSatisfied)
fail_directly_exception_types = (
DataNotFound, LayerNotFound, BehaviorVersion.RequirementNotSatisfied,
_DelayedConstructionException)

# noinspection PyShadowingNames
def _parent_get_layer(layer_name):
Expand Down Expand Up @@ -1644,8 +1647,11 @@ def __call__(lself, name, is_prev_time_frame=False):
# And keep the remaining ones for potential later reports.
self._template_construction_exceptions = [s.text for s in ConstructCtx.collected_exceptions.values()]

except Exception:
print("%r: exception constructing template network (for deps and data shapes)" % self)
except _DelayedConstructionException:
raise

except Exception as exc:
print("%r: %s while constructing template network (for deps and data shapes)" % (self, type(exc).__name__))
from pprint import pprint
print("Most recent construction stack:")
if ConstructCtx.most_recent:
Expand Down
137 changes: 90 additions & 47 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,24 @@ class _NetworkConstructionStack:
Used to keep the recursive construction state of :function:`TFNetwork.construct_layer`.
"""

def __init__(self):
# This assumes that we do single-threaded net construction.
# For multi-threading (if this would ever be realistic for net construction),
# we would need this to be a thread local.
# We still have a stack here for flat_construction() because we need to be nested
# for things like CondLayer or RecLayer.
_flat_construction_stack = [] # type: typing.List[_NetworkConstructionStack]

def __init__(self, network):
"""
:param TFNetwork network:
"""
self.network = network
self.layers = [] # type: typing.List[str]
self.in_flat_construct_count = 0
self.flat_construct_stack = [] # type: typing.List[typing.Tuple[str, typing.Dict[str, typing.Any]]]

def __repr__(self):
return "<%s %r (cur stack size: %i)>" % (
self.__class__.__name__, self.network.name, len(self.flat_construct_stack))

def append(self, layer_name):
"""
Expand All @@ -340,23 +355,70 @@ def remove(self, layer_name):
"""
self.layers.remove(layer_name)

def flat_construct(self, initial):
def on_construct_layer_call(self, exc):
"""
:param _DelayedConstructionException initial:
This covers the whole flat construction logic.
If this returns None, it means that the normal construction should follow.
If a layer is returned, this can directly be returned.
Otherwise, this will not return but throw the exception which is handled outside.

:param _DelayedConstructionException exc:
:rtype: LayerBase|None
"""
self.in_flat_construct_count += 1
queue = [initial] # type: typing.List[_DelayedConstructionException]
cls = self.__class__
if not cls._flat_construction_stack or cls._flat_construction_stack[-1] is not self:
return self._flat_construct(exc)

assert exc.network is self.network
if self.flat_construct_stack:
if exc.layer_name == self.flat_construct_stack[-1][0]:
return None # continue with construction

existing_in_stack = [entry for entry in self.flat_construct_stack if exc.layer_name == entry[0]]
if existing_in_stack:
raise NetworkConstructionDependencyLoopException(
layer_name=exc.layer_name, constructing_layers=[entry[0] for entry in self.flat_construct_stack],
net_dict=existing_in_stack[0][1]["net_dict"], network=self.network)

raise exc

def _flat_construct(self, initial_exc):
"""
:param _DelayedConstructionException initial_exc:
:rtype: LayerBase
"""
cls = self.__class__
assert initial_exc.network is self.network
stack = self.flat_construct_stack
initial = (initial_exc.layer_name, initial_exc.other_kwargs)
stack.append(initial)
stack_init_idx = len(stack) - 1
cls._flat_construction_stack.append(self)
try:
while queue:
while stack:
try:
res = queue[-1].delayed_construction()
if queue[-1] is initial:
stack_top_idx = len(stack) - 1
top = stack[stack_top_idx]
layer_name, other_kwargs = top
res = self.network.construct_layer(name=layer_name, **other_kwargs)
assert stack_top_idx == len(stack) - 1
stack.pop(-1)
if top is initial:
return res
queue.pop(-1)
except _DelayedConstructionException as delayed_exc:
queue.append(delayed_exc)
# See on_construct_layer_call().
assert delayed_exc.network is self.network # we should be in another flat_construct() otherwise
stack.append((delayed_exc.layer_name, delayed_exc.other_kwargs))
except Exception as exc:
attr = "_RETURNN_layer_construction_stack"
if not hasattr(exc, attr):
setattr(exc, attr, [])
getattr(exc, attr).extend([(self.network, layer_name) for (layer_name, _) in stack])
raise
finally:
self.in_flat_construct_count -= 1
top_stack = cls._flat_construction_stack.pop(-1)
assert top_stack is self
del stack[stack_init_idx:]
assert False, "we should not get here"


Expand Down Expand Up @@ -458,7 +520,7 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
self.extra_nets = {} # type: typing.Dict[str,TFNetwork]
self.subnets = {} # type: typing.Dict[str,Subnetwork]
self._selected_train_layers = None
self._construction_stack = _NetworkConstructionStack()
self._construction_stack = _NetworkConstructionStack(self)
self.layers_desc = {} # type: typing.Dict[str,typing.Dict[str]]
self.layers = {} # type: typing.Dict[str,LayerBase]
self.losses_dict = {} # type: typing.Dict[str,LossHolder]
Expand Down Expand Up @@ -786,17 +848,6 @@ def get_layer(src_name):
self.used_data_keys.update(extra_net.used_data_keys)
return created_layers

def _flat_construction_enabled(self):
"""
:return: whether to use flat construction algorithm in :func:`construct_layer`.
Use this if you get stack overflow errors, such as:
``Fatal Python error: Cannot recover from stack overflow``
or
``RuntimeError: maximum recursion depth exceeded``.
:rtype: bool
"""
return self.get_config().bool("flat_net_construction", False)

def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_existing=True):
"""
This triggers the construction of the layer `name` if it is not constructed yet.
Expand Down Expand Up @@ -829,6 +880,9 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_
return self.get_layer(name)
except (LayerNotFound, DataNotFound):
pass # ok, we will try to construct it then
delayed_exc = _DelayedConstructionException(
network=self, layer_name=name, # make sure that we have all the original args
other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing))
if not get_layer:
get_layer = GetLayer(network=self, add_layer_func=add_layer)
full_name = name
Expand Down Expand Up @@ -918,15 +972,6 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_
layer_name=full_name, network=self)
return sub_layer

if self._flat_construction_enabled():
delayed_exc = _DelayedConstructionException(
network=self, layer_name=name,
other_kwargs=dict(net_dict=net_dict, get_layer=get_layer, add_layer=add_layer, check_existing=check_existing))
if not self._construction_stack.in_flat_construct_count:
return self._construction_stack.flat_construct(delayed_exc)
if self._construction_stack.layers:
raise delayed_exc

layer_desc = layer_desc.copy()
layer_desc.pop("class")
# Note about name:
Expand All @@ -936,10 +981,14 @@ def construct_layer(self, net_dict, name, get_layer=None, add_layer=None, check_
layer_desc["_network"] = net
layer_desc["_name"] = base_name
name_with_prefix = ("%s:%s" % (extra_prefix, name)) if extra_prefix else name
if name_with_prefix in self._construction_stack.layers:
raise NetworkConstructionDependencyLoopException(
layer_name=name_with_prefix, constructing_layers=self._construction_stack.layers,
net_dict=net_dict, network=self)

# Note: We don't want to raise this earlier here in this function
# because certain exceptions such as LayerNotFound should directly be raised
# because some other code tests for this
# (e.g. checking the loss checking for layer "classes" and then layer "data:classes").
_constructed_layer = self._construction_stack.on_construct_layer_call(delayed_exc)
if _constructed_layer:
return _constructed_layer
self._construction_stack.append(name_with_prefix)
try:
# This call would also resolve dependencies, and e.g. recursively then create them (via get_layer calls).
Expand Down Expand Up @@ -3102,6 +3151,9 @@ def add_templated_layer(name, layer_class, **layer_desc):
if layer.get("is_output_layer"):
get_templated_layer(layer_name)

except _DelayedConstructionException:
raise

except Exception as exc:
# Merge the exception message + further debug information all together into a single exception,
# which we will raise.
Expand Down Expand Up @@ -3624,16 +3676,7 @@ def __init__(self, network, layer_name, other_kwargs):
self.other_kwargs = other_kwargs

def __repr__(self):
return "%s(layer_name=%r)" % (self.__class__.__name__, self.layer_name)

def delayed_construction(self):
"""
Call :func:`TFNetwork.construct_layer` again now.

:rtype: LayerBase
"""
print("Delayed flat layer construction:", self.layer_name, file=log.v5)
return self.network.construct_layer(name=self.layer_name, **self.other_kwargs)
return "<%s %r/%r>" % (self.__class__.__name__, self.network.name, self.layer_name)


class LayerNotFound(NetworkLayerException):
Expand Down
1 change: 0 additions & 1 deletion tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5640,7 +5640,6 @@ def test_flat_net_construction():
"data": (n_in, 2),
"classes": (n_out, 1),
},
"flat_net_construction": True,
"debug_print_layer_output_template": True,
})
print("Creating network...")
Expand Down