Skip to content

Commit 5f590e2

Browse files
committed
distinction between state and initial_state
#31
1 parent 6277f86 commit 5f590e2

File tree

2 files changed

+89
-48
lines changed

2 files changed

+89
-48
lines changed

nn/_generate_layers.py

+23-18
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,11 @@ def setup():
183183
print(" def make_layer_dict(self,", file=f)
184184
if sig.has_source_param():
185185
print(f" {sig.get_module_call_source_param_code_str()},", file=f)
186-
if sig.has_recurrent_state():
187-
print(f" {sig.get_module_call_state_param_code_str()},", file=f)
188-
if sig.has_module_call_args():
186+
if sig.has_module_call_args() or sig.has_recurrent_state():
189187
print(" *,", file=f)
188+
if sig.has_recurrent_state():
189+
print(f" {sig.get_module_call_state_param_code_str('state')},", file=f)
190+
print(f" {sig.get_module_call_state_param_code_str('initial_state')},", file=f)
190191
for param in sig.get_module_call_args():
191192
print(f" {param.get_module_param_code_str()},", file=f)
192193
print(" ) -> LayerDictRaw:", file=f)
@@ -209,7 +210,8 @@ def setup():
209210
if sig.has_module_call_args() or sig.has_recurrent_state():
210211
print(" args = {", file=f)
211212
if sig.has_recurrent_state():
212-
print(f" 'initial_state': state,", file=f)
213+
print(f" 'state': state,", file=f)
214+
print(f" 'initial_state': initial_state,", file=f)
213215
for param in sig.get_module_call_args():
214216
print(f" '{param.returnn_name}': {param.get_module_param_name()},", file=f)
215217
print(" }", file=f)
@@ -243,10 +245,11 @@ def setup():
243245
if sig.has_source_param():
244246
print(f"{prefix}{sig.get_module_call_source_param_code_str()},", file=f)
245247
args.append("source")
246-
if sig.has_recurrent_state():
247-
print(f"{prefix}{sig.get_module_call_state_param_code_str()},", file=f)
248-
args.append("state")
249248
print(f"{prefix}*,", file=f)
249+
if sig.has_recurrent_state():
250+
print(f"{prefix}{sig.get_module_call_state_param_code_str('state')},", file=f)
251+
print(f"{prefix}{sig.get_module_call_state_param_code_str('initial_state')},", file=f)
252+
args.extend(("state", "initial_state"))
250253
mod_args = sig.get_all_derived_args()
251254
if mod_args:
252255
for param in mod_args:
@@ -267,7 +270,8 @@ def setup():
267270
if sig.has_source_param():
268271
print(f" {sig.get_module_call_source_docstring()}", file=f)
269272
if sig.has_recurrent_state():
270-
print(f" {sig.get_module_call_state_docstring()}", file=f)
273+
print(f" {sig.get_module_call_state_docstring('state')}", file=f)
274+
print(f" {sig.get_module_call_state_docstring('initial_state')}", file=f)
271275
for param in mod_args:
272276
print(param.get_module_param_docstring(indent=" "), file=f)
273277
print(" :param str|None name:", file=f)
@@ -286,15 +290,16 @@ def setup():
286290
module_call_args.append(param)
287291
if sig.has_source_param() and not module_call_args:
288292
if sig.has_recurrent_state():
289-
print(f" return mod(source, state, name=name)", file=f)
293+
print(f" return mod(source, state=state, initial_state=initial_state, name=name)", file=f)
290294
else:
291295
print(f" return mod(source, name=name)", file=f)
292296
else:
293297
print(f" return mod(", file=f)
294298
if sig.has_source_param():
295299
print(" source,", file=f)
296300
if sig.has_recurrent_state():
297-
print(" state,", file=f)
301+
print(" state=state,", file=f)
302+
print(" initial_state=initial_state,", file=f)
298303
for param in module_call_args:
299304
print(f" {param.get_module_param_name()}={param.get_module_param_name()},", file=f)
300305
print(" name=name)", file=f)
@@ -391,19 +396,19 @@ def get_module_call_source_docstring(self):
391396
s += " source:"
392397
return s
393398

394-
def get_module_call_state_param_code_str(self):
399+
def get_module_call_state_param_code_str(self, param_name: str):
395400
"""
396401
Code for `state` param
397402
"""
398403
assert self.has_recurrent_state()
399-
return "state: Optional[Union[LayerRef, List[LayerRef], Tuple[LayerRef], NotSpecified]] = NotSpecified"
404+
return f"{param_name}: Optional[Union[LayerRef, Dict[str, LayerRef], NotSpecified]] = NotSpecified"
400405

401-
def get_module_call_state_docstring(self):
406+
def get_module_call_state_docstring(self, param_name: str):
402407
"""
403408
Code for docstring of `source` param
404409
"""
405410
assert self.has_recurrent_state()
406-
return ":param LayerRef|list[LayerRef]|tuple[LayerRef]|NotSpecified|None state:"
411+
return f":param LayerRef|list[LayerRef]|tuple[LayerRef]|NotSpecified|None {param_name}:"
407412

408413
def has_module_init_args(self) -> bool:
409414
"""
@@ -477,12 +482,12 @@ def has_recurrent_state(self) -> bool:
477482
Inside a loop::
478483
479484
mod = Module(...)
480-
out, state = mod(in, prev_state)
485+
out, state = mod(in, state=prev_state)
481486
482487
Outside a loop::
483488
484489
mod = Module(...)
485-
out, last_state = mod(in, [initial_state])
490+
out, last_state = mod(in, [initial_state=initial_state])
486491
"""
487492
if (
488493
getattr(self.layer_class.get_rec_initial_extra_outputs, "__func__")
@@ -500,7 +505,7 @@ def has_recurrent_state(self) -> bool:
500505
"n_out", "out_type", "sources", "target", "loss", "loss_", "size_target",
501506
"name_scope", "reuse_params",
502507
"rec_previous_layer", "control_dependencies_on_output",
503-
"initial_state", "initial_output",
508+
"state", "initial_state", "initial_output",
504509
"extra_deps", "collocate_with",
505510
"batch_norm",
506511
"is_output_layer", "register_as_extern_data",
@@ -862,7 +867,7 @@ def get_module_class_name_for_layer_class(sig: LayerSignature) -> str:
862867
name = name[:-len("Layer")]
863868
if name.startswith("_"):
864869
return name
865-
if layer_class.layer_class in LayersHidden or sig.is_functional():
870+
if layer_class.layer_class in LayersHidden or sig.is_functional() or sig.has_recurrent_state():
866871
return "_" + name # we make a public function for it, but the module is hidden
867872
return name
868873

0 commit comments

Comments
 (0)