@@ -526,6 +526,12 @@ def initial(self, value: Optional[Union[nn.Tensor, RawTensorTypes, nn.init.Varia
526
526
else :
527
527
self .layer_dict .pop ("init_by_layer" , None )
528
528
self .layer_dict ["init" ] = value
529
+ if nn .NameCtx .current_ctx ().root .debug_eager_mode :
530
+ if isinstance (value , nn .Tensor ):
531
+ assert value .data .placeholder is not None
532
+ self .data .placeholder = value .data .placeholder
533
+ else :
534
+ self .data .placeholder = tf .broadcast_to (tf .convert_to_tensor (value ), self .data .batch_shape )
529
535
530
536
@property
531
537
def weight_decay (self ) -> float :
@@ -638,11 +644,54 @@ def get_extern_data(data: Data) -> Tensor:
638
644
assert scope .extern_data [data .name ] is data
639
645
if data .have_batch_axis ():
640
646
if not scope .global_batch :
641
- scope .global_batch = data .batch if data .batch else nn .BatchInfo .make_global_batch_info (- 1 )
647
+ if data .batch :
648
+ scope .global_batch = data .batch
649
+ elif scope .root .debug_eager_mode :
650
+ scope .global_batch = nn .BatchInfo .make_global_batch_info (
651
+ tf .constant (3 , name = "global_batch" )) # https://xkcd.com/221/, but prime
652
+ else :
653
+ scope .global_batch = nn .BatchInfo .make_global_batch_info (- 1 )
642
654
if not data .batch :
643
655
data .batch = scope .global_batch
644
656
root_layer_name = f"data:{ data .name } "
645
- return _get_raw_layer_by_name (root_layer_name , scope = scope , data = data )
657
+ out = _get_raw_layer_by_name (root_layer_name , scope = scope , data = data )
658
+ if scope .root .debug_eager_mode :
659
+ out .data .placeholder = _make_random_tf_tensor_for_returnn_data (out .data )
660
+ return out
661
+
662
+
663
+ def _make_random_tf_tensor_for_returnn_data (data : Data ) -> tf .Tensor :
664
+ shape = []
665
+ for dim in data .dim_tags :
666
+ if dim .is_batch_dim ():
667
+ assert data .batch
668
+ shape .append (data .batch .dim )
669
+ elif dim .dimension is not None :
670
+ shape .append (dim .dimension )
671
+ else :
672
+ dim .complete_dyn_size ()
673
+ if dim .dyn_size_ext is None :
674
+ assert data .batch
675
+ dim .dyn_size_ext = Data (
676
+ name = f"{ data .name } _dummy_dyn_size_ext" , dim_tags = [nn .batch_dim ], dtype = data .size_dtype , batch = data .batch )
677
+ if dim .dyn_size_ext .placeholder is None :
678
+ dim .dyn_size_ext .placeholder = _make_random_tf_tensor_for_returnn_data (dim .dyn_size_ext )
679
+ shape .append (tf .reduce_max (dim .dyn_size_ext .placeholder ))
680
+ dtype = tf .as_dtype (data .dtype )
681
+ if dtype .is_integer :
682
+ if data .sparse :
683
+ return tf .random .uniform (shape = shape , dtype = dtype , minval = 0 , maxval = data .dim )
684
+ else :
685
+ c = abs (hash (data .name )) % 21 + 3
686
+ shape = tf .convert_to_tensor (shape )
687
+ c_tf = tf .constant (c , name = "dummy_random_const" , dtype = dtype )
688
+ rnd = tf .broadcast_to (c_tf , shape )
689
+ rnd_diff = tf .random .uniform (shape = shape , minval = 0 , maxval = 2 ** 31 - 1 , dtype = dtype )
690
+ rnd_diff = rnd_diff % tf .reshape (tf .minimum (tf .range (0 , tf .size (rnd ), dtype = dtype ) + 1 , c_tf - 2 ), shape )
691
+ rnd = tf .clip_by_value (rnd - rnd_diff , 1 , c_tf )
692
+ return rnd
693
+ assert dtype .is_floating # not implemented otherwise
694
+ return tf .random .normal (shape = shape , dtype = dtype )
646
695
647
696
648
697
def _get_raw_layer_by_name (name : str , * , scope : Optional [nn .NameCtx ] = None , data : Data ) -> Tensor :
@@ -684,6 +733,7 @@ def _data_from_layer_dict(layer_dict: LayerDictRaw) -> Data:
684
733
})
685
734
BehaviorVersion .set (min_returnn_behavior_version )
686
735
ctx = nn .NameCtx .top ()
736
+ root_ctx = ctx .root
687
737
inside_rec_time_dim = None
688
738
control_flow_ctx = None
689
739
while ctx :
@@ -755,4 +805,14 @@ def _map_layer_dict_elem(value):
755
805
msg += ")"
756
806
raise ReturnnConstructTemplateException (msg ) from exc
757
807
808
+ if root_ctx .debug_eager_mode :
809
+ # See TFNetwork._create_layer.
810
+ layer_desc ["output" ] = out_data
811
+ out_data = layer_class .fixup_out_data (** layer_desc )
812
+ out_data .sanity_check (ignore_placeholder = True )
813
+ layer = layer_class (** layer_desc )
814
+ layer .post_init (layer_desc )
815
+ layer .output .sanity_check ()
816
+ out_data = layer .output
817
+
758
818
return out_data
0 commit comments