@@ -526,6 +526,12 @@ def initial(self, value: Optional[Union[nn.Tensor, RawTensorTypes, nn.init.Varia
526526 else :
527527 self .layer_dict .pop ("init_by_layer" , None )
528528 self .layer_dict ["init" ] = value
529+ if nn .NameCtx .current_ctx ().root .debug_eager_mode :
530+ if isinstance (value , nn .Tensor ):
531+ assert value .data .placeholder is not None
532+ self .data .placeholder = value .data .placeholder
533+ else :
534+ self .data .placeholder = tf .broadcast_to (tf .convert_to_tensor (value ), self .data .batch_shape )
529535
530536 @property
531537 def weight_decay (self ) -> float :
@@ -638,11 +644,54 @@ def get_extern_data(data: Data) -> Tensor:
638644 assert scope .extern_data [data .name ] is data
639645 if data .have_batch_axis ():
640646 if not scope .global_batch :
641- scope .global_batch = data .batch if data .batch else nn .BatchInfo .make_global_batch_info (- 1 )
647+ if data .batch :
648+ scope .global_batch = data .batch
649+ elif scope .root .debug_eager_mode :
650+ scope .global_batch = nn .BatchInfo .make_global_batch_info (
651+ tf .constant (3 , name = "global_batch" )) # https://xkcd.com/221/, but prime
652+ else :
653+ scope .global_batch = nn .BatchInfo .make_global_batch_info (- 1 )
642654 if not data .batch :
643655 data .batch = scope .global_batch
644656 root_layer_name = f"data:{ data .name } "
645- return _get_raw_layer_by_name (root_layer_name , scope = scope , data = data )
657+ out = _get_raw_layer_by_name (root_layer_name , scope = scope , data = data )
658+ if scope .root .debug_eager_mode :
659+ out .data .placeholder = _make_random_tf_tensor_for_returnn_data (out .data )
660+ return out
661+
662+
663+ def _make_random_tf_tensor_for_returnn_data (data : Data ) -> tf .Tensor :
664+ shape = []
665+ for dim in data .dim_tags :
666+ if dim .is_batch_dim ():
667+ assert data .batch
668+ shape .append (data .batch .dim )
669+ elif dim .dimension is not None :
670+ shape .append (dim .dimension )
671+ else :
672+ dim .complete_dyn_size ()
673+ if dim .dyn_size_ext is None :
674+ assert data .batch
675+ dim .dyn_size_ext = Data (
676+ name = f"{ data .name } _dummy_dyn_size_ext" , dim_tags = [nn .batch_dim ], dtype = data .size_dtype , batch = data .batch )
677+ if dim .dyn_size_ext .placeholder is None :
678+ dim .dyn_size_ext .placeholder = _make_random_tf_tensor_for_returnn_data (dim .dyn_size_ext )
679+ shape .append (tf .reduce_max (dim .dyn_size_ext .placeholder ))
680+ dtype = tf .as_dtype (data .dtype )
681+ if dtype .is_integer :
682+ if data .sparse :
683+ return tf .random .uniform (shape = shape , dtype = dtype , minval = 0 , maxval = data .dim )
684+ else :
685+ c = abs (hash (data .name )) % 21 + 3
686+ shape = tf .convert_to_tensor (shape )
687+ c_tf = tf .constant (c , name = "dummy_random_const" , dtype = dtype )
688+ rnd = tf .broadcast_to (c_tf , shape )
689+ rnd_diff = tf .random .uniform (shape = shape , minval = 0 , maxval = 2 ** 31 - 1 , dtype = dtype )
690+ rnd_diff = rnd_diff % tf .reshape (tf .minimum (tf .range (0 , tf .size (rnd ), dtype = dtype ) + 1 , c_tf - 2 ), shape )
691+ rnd = tf .clip_by_value (rnd - rnd_diff , 1 , c_tf )
692+ return rnd
693+ assert dtype .is_floating # not implemented otherwise
694+ return tf .random .normal (shape = shape , dtype = dtype )
646695
647696
648697def _get_raw_layer_by_name (name : str , * , scope : Optional [nn .NameCtx ] = None , data : Data ) -> Tensor :
@@ -684,6 +733,7 @@ def _data_from_layer_dict(layer_dict: LayerDictRaw) -> Data:
684733 })
685734 BehaviorVersion .set (min_returnn_behavior_version )
686735 ctx = nn .NameCtx .top ()
736+ root_ctx = ctx .root
687737 inside_rec_time_dim = None
688738 control_flow_ctx = None
689739 while ctx :
@@ -755,4 +805,14 @@ def _map_layer_dict_elem(value):
755805 msg += ")"
756806 raise ReturnnConstructTemplateException (msg ) from exc
757807
808+ if root_ctx .debug_eager_mode :
809+ # See TFNetwork._create_layer.
810+ layer_desc ["output" ] = out_data
811+ out_data = layer_class .fixup_out_data (** layer_desc )
812+ out_data .sanity_check (ignore_placeholder = True )
813+ layer = layer_class (** layer_desc )
814+ layer .post_init (layer_desc )
815+ layer .output .sanity_check ()
816+ out_data = layer .output
817+
758818 return out_data
0 commit comments