@@ -183,10 +183,11 @@ def setup():
183
183
print (" def make_layer_dict(self," , file = f )
184
184
if sig .has_source_param ():
185
185
print (f" { sig .get_module_call_source_param_code_str ()} ," , file = f )
186
- if sig .has_recurrent_state ():
187
- print (f" { sig .get_module_call_state_param_code_str ()} ," , file = f )
188
- if sig .has_module_call_args ():
186
+ if sig .has_module_call_args () or sig .has_recurrent_state ():
189
187
print (" *," , file = f )
188
+ if sig .has_recurrent_state ():
189
+ print (f" { sig .get_module_call_state_param_code_str ('state' )} ," , file = f )
190
+ print (f" { sig .get_module_call_state_param_code_str ('initial_state' )} ," , file = f )
190
191
for param in sig .get_module_call_args ():
191
192
print (f" { param .get_module_param_code_str ()} ," , file = f )
192
193
print (" ) -> LayerDictRaw:" , file = f )
@@ -209,7 +210,8 @@ def setup():
209
210
if sig .has_module_call_args () or sig .has_recurrent_state ():
210
211
print (" args = {" , file = f )
211
212
if sig .has_recurrent_state ():
212
- print (f" 'initial_state': state," , file = f )
213
+ print (f" 'state': state," , file = f )
214
+ print (f" 'initial_state': initial_state," , file = f )
213
215
for param in sig .get_module_call_args ():
214
216
print (f" '{ param .returnn_name } ': { param .get_module_param_name ()} ," , file = f )
215
217
print (" }" , file = f )
@@ -243,10 +245,11 @@ def setup():
243
245
if sig .has_source_param ():
244
246
print (f"{ prefix } { sig .get_module_call_source_param_code_str ()} ," , file = f )
245
247
args .append ("source" )
246
- if sig .has_recurrent_state ():
247
- print (f"{ prefix } { sig .get_module_call_state_param_code_str ()} ," , file = f )
248
- args .append ("state" )
249
248
print (f"{ prefix } *," , file = f )
249
+ if sig .has_recurrent_state ():
250
+ print (f"{ prefix } { sig .get_module_call_state_param_code_str ('state' )} ," , file = f )
251
+ print (f"{ prefix } { sig .get_module_call_state_param_code_str ('initial_state' )} ," , file = f )
252
+ args .extend (("state" , "initial_state" ))
250
253
mod_args = sig .get_all_derived_args ()
251
254
if mod_args :
252
255
for param in mod_args :
@@ -267,7 +270,8 @@ def setup():
267
270
if sig .has_source_param ():
268
271
print (f" { sig .get_module_call_source_docstring ()} " , file = f )
269
272
if sig .has_recurrent_state ():
270
- print (f" { sig .get_module_call_state_docstring ()} " , file = f )
273
+ print (f" { sig .get_module_call_state_docstring ('state' )} " , file = f )
274
+ print (f" { sig .get_module_call_state_docstring ('initial_state' )} " , file = f )
271
275
for param in mod_args :
272
276
print (param .get_module_param_docstring (indent = " " ), file = f )
273
277
print (" :param str|None name:" , file = f )
@@ -286,15 +290,16 @@ def setup():
286
290
module_call_args .append (param )
287
291
if sig .has_source_param () and not module_call_args :
288
292
if sig .has_recurrent_state ():
289
- print (f" return mod(source, state, name=name)" , file = f )
293
+ print (f" return mod(source, state=state, initial_state=initial_state , name=name)" , file = f )
290
294
else :
291
295
print (f" return mod(source, name=name)" , file = f )
292
296
else :
293
297
print (f" return mod(" , file = f )
294
298
if sig .has_source_param ():
295
299
print (" source," , file = f )
296
300
if sig .has_recurrent_state ():
297
- print (" state," , file = f )
301
+ print (" state=state," , file = f )
302
+ print (" initial_state=initial_state," , file = f )
298
303
for param in module_call_args :
299
304
print (f" { param .get_module_param_name ()} ={ param .get_module_param_name ()} ," , file = f )
300
305
print (" name=name)" , file = f )
@@ -391,19 +396,19 @@ def get_module_call_source_docstring(self):
391
396
s += " source:"
392
397
return s
393
398
394
- def get_module_call_state_param_code_str (self ):
399
+ def get_module_call_state_param_code_str (self , param_name : str ):
395
400
"""
396
401
Code for `state` param
397
402
"""
398
403
assert self .has_recurrent_state ()
399
- return "state : Optional[Union[LayerRef, List[LayerRef], Tuple[ LayerRef], NotSpecified]] = NotSpecified"
404
+ return f" { param_name } : Optional[Union[LayerRef, Dict[str, LayerRef], NotSpecified]] = NotSpecified"
400
405
401
- def get_module_call_state_docstring (self ):
406
+ def get_module_call_state_docstring (self , param_name : str ):
402
407
"""
403
408
Code for docstring of `source` param
404
409
"""
405
410
assert self .has_recurrent_state ()
406
- return ":param LayerRef|list[LayerRef]|tuple[LayerRef]|NotSpecified|None state :"
411
+ return f ":param LayerRef|list[LayerRef]|tuple[LayerRef]|NotSpecified|None { param_name } :"
407
412
408
413
def has_module_init_args (self ) -> bool :
409
414
"""
@@ -477,12 +482,12 @@ def has_recurrent_state(self) -> bool:
477
482
Inside a loop::
478
483
479
484
mod = Module(...)
480
- out, state = mod(in, prev_state)
485
+ out, state = mod(in, state= prev_state)
481
486
482
487
Outside a loop::
483
488
484
489
mod = Module(...)
485
- out, last_state = mod(in, [initial_state])
490
+ out, last_state = mod(in, [initial_state=initial_state ])
486
491
"""
487
492
if (
488
493
getattr (self .layer_class .get_rec_initial_extra_outputs , "__func__" )
@@ -500,7 +505,7 @@ def has_recurrent_state(self) -> bool:
500
505
"n_out" , "out_type" , "sources" , "target" , "loss" , "loss_" , "size_target" ,
501
506
"name_scope" , "reuse_params" ,
502
507
"rec_previous_layer" , "control_dependencies_on_output" ,
503
- "initial_state" , "initial_output" ,
508
+ "state" , " initial_state" , "initial_output" ,
504
509
"extra_deps" , "collocate_with" ,
505
510
"batch_norm" ,
506
511
"is_output_layer" , "register_as_extern_data" ,
@@ -862,7 +867,7 @@ def get_module_class_name_for_layer_class(sig: LayerSignature) -> str:
862
867
name = name [:- len ("Layer" )]
863
868
if name .startswith ("_" ):
864
869
return name
865
- if layer_class .layer_class in LayersHidden or sig .is_functional ():
870
+ if layer_class .layer_class in LayersHidden or sig .is_functional () or sig . has_recurrent_state () :
866
871
return "_" + name # we make a public function for it, but the module is hidden
867
872
return name
868
873
0 commit comments