Skip to content

Commit bcd0360

Browse files
committed
initial support for debug eager mode
#22
1 parent 4b76e4c commit bcd0360

File tree

4 files changed

+107
-2
lines changed

4 files changed

+107
-2
lines changed

.github/workflows/main.yml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
- TEST=nn_conformer
2828
- TEST=nn_container
2929
- TEST=nn_conv
30+
- TEST=nn_debug_eager_mode
3031
- TEST=nn_loop
3132
- TEST=nn_loss
3233
- TEST=nn_masked_computation

nn/base.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,12 @@ def initial(self, value: Optional[Union[nn.Tensor, RawTensorTypes, nn.init.Varia
526526
else:
527527
self.layer_dict.pop("init_by_layer", None)
528528
self.layer_dict["init"] = value
529+
if nn.NameCtx.current_ctx().root.debug_eager_mode:
530+
if isinstance(value, nn.Tensor):
531+
assert value.data.placeholder is not None
532+
self.data.placeholder = value.data.placeholder
533+
else:
534+
self.data.placeholder = tf.broadcast_to(tf.convert_to_tensor(value), self.data.batch_shape)
529535

530536
@property
531537
def weight_decay(self) -> float:
@@ -638,11 +644,54 @@ def get_extern_data(data: Data) -> Tensor:
638644
assert scope.extern_data[data.name] is data
639645
if data.have_batch_axis():
640646
if not scope.global_batch:
641-
scope.global_batch = data.batch if data.batch else nn.BatchInfo.make_global_batch_info(-1)
647+
if data.batch:
648+
scope.global_batch = data.batch
649+
elif scope.root.debug_eager_mode:
650+
scope.global_batch = nn.BatchInfo.make_global_batch_info(
651+
tf.constant(3, name="global_batch")) # https://xkcd.com/221/, but prime
652+
else:
653+
scope.global_batch = nn.BatchInfo.make_global_batch_info(-1)
642654
if not data.batch:
643655
data.batch = scope.global_batch
644656
root_layer_name = f"data:{data.name}"
645-
return _get_raw_layer_by_name(root_layer_name, scope=scope, data=data)
657+
out = _get_raw_layer_by_name(root_layer_name, scope=scope, data=data)
658+
if scope.root.debug_eager_mode:
659+
out.data.placeholder = _make_random_tf_tensor_for_returnn_data(out.data)
660+
return out
661+
662+
663+
def _make_random_tf_tensor_for_returnn_data(data: Data) -> tf.Tensor:
664+
shape = []
665+
for dim in data.dim_tags:
666+
if dim.is_batch_dim():
667+
assert data.batch
668+
shape.append(data.batch.dim)
669+
elif dim.dimension is not None:
670+
shape.append(dim.dimension)
671+
else:
672+
dim.complete_dyn_size()
673+
if dim.dyn_size_ext is None:
674+
assert data.batch
675+
dim.dyn_size_ext = Data(
676+
name=f"{data.name}_dummy_dyn_size_ext", dim_tags=[nn.batch_dim], dtype=data.size_dtype, batch=data.batch)
677+
if dim.dyn_size_ext.placeholder is None:
678+
dim.dyn_size_ext.placeholder = _make_random_tf_tensor_for_returnn_data(dim.dyn_size_ext)
679+
shape.append(tf.reduce_max(dim.dyn_size_ext.placeholder))
680+
dtype = tf.as_dtype(data.dtype)
681+
if dtype.is_integer:
682+
if data.sparse:
683+
return tf.random.uniform(shape=shape, dtype=dtype, minval=0, maxval=data.dim)
684+
else:
685+
c = abs(hash(data.name)) % 21 + 3
686+
shape = tf.convert_to_tensor(shape)
687+
c_tf = tf.constant(c, name="dummy_random_const", dtype=dtype)
688+
rnd = tf.broadcast_to(c_tf, shape)
689+
rnd_diff = tf.random.uniform(shape=shape, minval=0, maxval=2 ** 31 - 1, dtype=dtype)
690+
rnd_diff = rnd_diff % tf.reshape(tf.minimum(tf.range(0, tf.size(rnd), dtype=dtype) + 1, c_tf - 2), shape)
691+
rnd = tf.clip_by_value(rnd - rnd_diff, 1, c_tf)
692+
return rnd
693+
assert dtype.is_floating # not implemented otherwise
694+
return tf.random.normal(shape=shape, dtype=dtype)
646695

647696

648697
def _get_raw_layer_by_name(name: str, *, scope: Optional[nn.NameCtx] = None, data: Data) -> Tensor:
@@ -684,6 +733,7 @@ def _data_from_layer_dict(layer_dict: LayerDictRaw) -> Data:
684733
})
685734
BehaviorVersion.set(min_returnn_behavior_version)
686735
ctx = nn.NameCtx.top()
736+
root_ctx = ctx.root
687737
inside_rec_time_dim = None
688738
control_flow_ctx = None
689739
while ctx:
@@ -755,4 +805,14 @@ def _map_layer_dict_elem(value):
755805
msg += ")"
756806
raise ReturnnConstructTemplateException(msg) from exc
757807

808+
if root_ctx.debug_eager_mode:
809+
# See TFNetwork._create_layer.
810+
layer_desc["output"] = out_data
811+
out_data = layer_class.fixup_out_data(**layer_desc)
812+
out_data.sanity_check(ignore_placeholder=True)
813+
layer = layer_class(**layer_desc)
814+
layer.post_init(layer_desc)
815+
layer.output.sanity_check()
816+
out_data = layer.output
817+
758818
return out_data

nn/naming.py

+14
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(self, *,
197197
self.children = {} # type: Dict[str, NameCtx]
198198
self.extern_data = {} # type: Dict[str, nn.Data] # only for the root name ctx
199199
self.global_batch = None # type: Optional[nn.BatchInfo] # only for the root name ctx
200+
self.debug_eager_mode = False # only for the root name ctx
200201
self.marked_outputs = [] # type: List[nn.Tensor]
201202
self.marked_losses = [] # type: List[nn.Tensor]
202203
self.parent = parent if parent is not NotSpecified else self.current_ctx()
@@ -597,6 +598,19 @@ def _get_unique_name(self, suggested_name: Optional[str] = None) -> str:
597598
return name_
598599
i += 1
599600

601+
def enable_debug_eager_mode(self):
602+
"""
603+
For debugging.
604+
605+
Enables TF eager mode.
606+
Also, all layers will directly be created, and then due to TF eager mode directly evaluated.
607+
"""
608+
assert self.is_root
609+
import tensorflow as tf
610+
tf.compat.v1.enable_eager_execution()
611+
self.global_batch
612+
self.debug_eager_mode = True
613+
600614

601615
class ReturnnConfigSerializer:
602616
"""

tests/test_nn_debug_eager_mode.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
debug eager mode test
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from . import _setup_test_env # noqa
8+
import typing
9+
10+
if typing.TYPE_CHECKING:
11+
from .. import nn
12+
else:
13+
from returnn_common import nn # noqa
14+
15+
16+
# Enables it globally now.
17+
nn.NameCtx.current_ctx().root.enable_debug_eager_mode()
18+
19+
20+
def test_simple_linear():
21+
"""nn.Linear"""
22+
data = nn.get_extern_data(nn.Data("data", dim_tags=[nn.batch_dim, nn.SpatialDim("time"), nn.FeatureDim("in", 5)]))
23+
assert data.data.placeholder is not None
24+
lin = nn.Linear(nn.FeatureDim("lin", 10))
25+
out = lin(data)
26+
assert lin.weight.data.placeholder is not None
27+
assert lin.bias.data.placeholder is not None
28+
assert out.data.placeholder is not None
29+
assert out.data.placeholder.numpy().size > 0
30+
assert (out.data.placeholder.numpy() != 0).any()

0 commit comments

Comments
 (0)