Skip to content

RETURNN layers with hidden state should make it explicit #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
albertz opened this issue Oct 13, 2021 · 12 comments
Closed

RETURNN layers with hidden state should make it explicit #31

albertz opened this issue Oct 13, 2021 · 12 comments
Milestone

Comments

@albertz
Copy link
Member

albertz commented Oct 13, 2021

As it was discussed in #16, RETURNN layers with (hidden) state (e.g. RecLayer with unit="lstm") should make the state explicit in the API. E.g. the Rec module should get two arguments input and prev_state and return output and state. So the usage would look like this in a loop:

lstm = Lstm(...)
with Loop() as loop:
  ...
  out, loop.state.lstm = lstm(x, loop.state.lstm)

Or like this outside a loop (using default initial state, ignoring last state):

lstm = Lstm(...)
out, _ = lstm(x)

This applies for all RETURNN layers with rec hidden state, and further modules like Lstm.
See RETURNN layers with rec hidden state.

Relevant modules here:

  • _Rec based, e.g. Lstm (only one so far)
  • window
  • cumsum
  • ken_lm_state
  • edit_distance_table
  • unmask
  • _TwoDLSTM
  • cum_concat
@albertz
Copy link
Member Author

albertz commented Oct 20, 2021

Implemented partly now. All recurrent layers do return a tuple (output, state), and accept a state (initial or previous state) argument.

The returned state is via GetLastHiddenStateLayer. This is incomplete:

  • It needs to have a specific n_out.
  • For cases where the state is more structured (e.g. LSTM with (h,c)), this should return the same structure.

This might need to have special implementations for each case. Not sure.

@albertz
Copy link
Member Author

albertz commented Oct 20, 2021

Another issue: Some layers like WindowLayer often would not use the state and then become unnatural to use. I'm not really sure about the best way to solve this.

Maybe we can make two versions of the module, one which is not recurrent and one which is recurrent.

Also unclear whether this should be automatically generated, or whether we explicitly handle these cases.

@albertz
Copy link
Member Author

albertz commented Oct 27, 2021

We should be a bit more specific, and list the layers or modules which need custom hidden state. E.g. Lstm, or Rec in general. Or basically every rec layer, but we should specify the hidden state.

We also should list the layers where we might consider non-recurrent variants.

It should be collected by editing the main post.

@albertz
Copy link
Member Author

albertz commented Oct 27, 2021

It's not totally obvious.

CumsumLayer and CumConcatLayer have state = output, so one single return is enough, not a tuple.
Same for RecLayer with many units, such as GRU, RNN, etc.

RecLayer with LSTM will return (h, c), where h = output.

state = output or h = output, that is the case only inside the loop.
Outside the loop, state or h is just the last state, where output is the whole seq. But the whole seq includes h or state then.

The initial_state argument however needs to be consistent with the returned state.

So, you could argue, RecLayer inside the loop could just return (output, c) instead of (output, (h, c)), but the initial_state argument is (h, c) anyway.
Note that in PyTorch, LSTM (operating on the seq) returns (output, (h, c)).
In PyTorch, LSTMCell (operating on the frame) just returns (h, c). Although its prev state input argument is (h, c).
So it's a bit inconsistent, as you cannot generalize for any PyTorch module with state, what you need to keep from the returned output to feed back as prev state input in the next frame. There is no generic API as we intent here.

albertz added a commit that referenced this issue Oct 27, 2021
@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

I introduced the explicit LayerState (which derives from dict currently, although that might maybe change), the code can explicitly test and see whether some layer returned a state.

By convention, I suggest that the return in such case should be a tuple, and the last item should always be this LayerState instance.

Note that LayerState also avoids that the user accidentally does automatic concat, such as in this example:

x = rnn(...)
y = linear(x)

When x is a tuple (out, hidden), where hidden is just a layer-ref, this might do an automatic concat (#40, #41).
When hidden is a LayerState, such automatic concat cannot happen here.

@albertz
Copy link
Member Author

albertz commented Nov 2, 2021

Passing state to a RETURNN layer is a bit tricky. It is not really the same as initial_state. initial_state defines the initial state only (both when it is optimized out or inside the loop), not the state per frame. See: rwth-i6/returnn#732

So we need a new way on RETURNN side to allow passing the state per frame on such layers, via a new state option (rwth-i6/returnn#732).

But we also want to be able to make use of RETURNN rec layer automatic optimization, which is tricky as well. We need to figure out if the layer uses its own prev state as input state. If that is the case, it would set initial_state on the layer, pointing to the initial argument from the State object. Otherwise it would use the new state layer option.

However, at the time we call the module (layer maker), it can not know this. Only when it returns and when State.assign happens, it can figure out that the state argument passed to the module was the same.

So this means we need to delay the layer dict creation, or post edit it to apply the optimization and remove the state argument in this case, and replace it with initial_state.

This is not only for the Lstm module but any hierarchical module where state might be nested, which makes this more complicated because it might not be just a single layer dict.

We can recursively go through a State.assign value and check each referred layers whether they are Layer with layer dict with class: get_last_hidden_state and then resolve the name and find the layer and its layer dict and then check the state argument whether it points to the same layer or layers.

albertz added a commit that referenced this issue Nov 2, 2021
albertz added a commit that referenced this issue Nov 3, 2021
Fix test_rec_inner_lstm.
Also see #31.
@albertz
Copy link
Member Author

albertz commented Nov 3, 2021

Note that commit 15d67a9 implements the mentioned optimization now more or less just as outlined.

We still can not pass state to a RETURNN layer though. For that, we still need to clarify rwth-i6/returnn#732.

@albertz
Copy link
Member Author

albertz commented Nov 4, 2021

Another problem: The behavior of such modules (e.g. Lstm) is maybe not really well defined. We want that:

So, how do we figure out which case we have, single step or sequence? We could simply check whether there is an outer loop context. However, that means that we never could use a Lstm inside a loop applied on a sequence.

Note that RETURNN determines this based on the input. If the input has a time axis, it will operate on the sequence, otherwise it will do a single step. However, to be able to replicate such a logic here, we need the shape information (#47).

Or we make it more explicit by having two variants here (just like PyTorch), e.g. Lstm and LstmCell.
But we would need that then for all other wrapped recurrent layers as well, e.g. like WindowLayer and so on?

@albertz
Copy link
Member Author

albertz commented Nov 4, 2021

I tend to the solution of just having separate modules Lstm and LstmCell, so there is never an ambiguity about that. LstmCell would set the layer state option (rwth-i6/returnn#732), Lstm would set the layer initial_state option.

The question is then also about other recurrent layers.

@albertz
Copy link
Member Author

albertz commented Nov 4, 2021

Note that 5f590e2 implemented now a distinction between the options state and initial_state. This already partly solves it because state can only be used when it operates on a single frame, while initial_state can only be used when it operates on the sequence.

@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

If we go with two variants for rec modules, one for the per-step operation (inside loop), another one for operating on a sequence (outside loop, although you could also use it inside if there is another separate time axis maybe), then here are suggestions:

  • LSTM (on seq) and LSTMStep (one step)
  • window (on seq, without state) and window_step (one step)
  • cumsum (on seq, without state), no per-step version, as it is so trivial
  • ken_lm_state_step (one step), no seq-version so far (KenLmStateLayer does not work on sequence (e.g. when optimized) returnn#736)
  • edit_distance_table_step (one step), no seq-version so far
  • unmask (on seq), not sure about the recurrent per-step variant, which is a no-op... but you anyway might want it or need it with masked computation... Also see Masked computation wrapper #23. Skip the per-step variant for now.
  • (will be done later: TwoDLSTM (on seq), TwoDLSTMStep (one step))
  • cum_concat_step (one step), no seq-version as this is not allowed (for setting a new dim tag, there is reinterpret_data)

Notes:

  • Changed Lstm to LSTM, to keep consistent to PyTorch. It also makes more sense.
  • Use "step" instead of "cell" (LSTMStep instead of LSTMCell). I think it is easier to understand. Keep this convention of "step" for all recurrent variants which operate one step.

albertz added a commit that referenced this issue Nov 5, 2021
albertz added a commit that referenced this issue Nov 5, 2021
albertz added a commit that referenced this issue Nov 5, 2021
@albertz
Copy link
Member Author

albertz commented Nov 5, 2021

Ok this is mostly done now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant