Skip to content

Commit 259e223

Browse files
committed
RF PT lstm, use PackedSequence
#1120 (comment)
1 parent 82745f4 commit 259e223

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

returnn/tensor/_dim_extra.py

+13
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,19 @@ def get_uniq_collection(cls, tags, is_equal_opts=None):
16121612
res.append(tag)
16131613
return res
16141614

1615+
def get_size_tensor(self) -> _t.Tensor:
1616+
"""
1617+
:return: size tensor, or dyn_size_ext if defined
1618+
:rtype: _t.Tensor
1619+
"""
1620+
if self.dyn_size_ext:
1621+
return self.dyn_size_ext
1622+
1623+
import returnn.frontend as rf
1624+
1625+
assert self.size is not None
1626+
return rf.convert_to_tensor(self.size, name="%s:size" % self.description)
1627+
16151628
def get_dim_value(self) -> Union[int, _t.RawTensorType]:
16161629
"""
16171630
Infers the dim this axis should have if unbroadcasted.

returnn/torch/frontend/_backend.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -1191,18 +1191,43 @@ def lstm(
11911191
state_h_raw = torch.reshape(state_h_raw, [1, batch_dim, out_dim.get_dim_value()])
11921192
state_c_raw = torch.reshape(state_c_raw, [1, batch_dim, out_dim.get_dim_value()])
11931193

1194+
sizes = spatial_dim.get_size_tensor()
1195+
sizes = sizes.copy_compatible_to(
1196+
Tensor("batch_dims", batch_dims, dtype=sizes.dtype), unbroadcast=True, check_sparse=False
1197+
)
1198+
sizes_raw = torch.reshape(sizes.raw_tensor, [batch_dim])
1199+
1200+
# See the code of torch.nn.LSTM for sorting the batch dims.
1201+
# We need pack_padded_sequence because otherwise the LSTM would ignore the padding,
1202+
# and we would get incorrect final states.
1203+
source_packed = torch.nn.utils.rnn.pack_padded_sequence(source_raw, sizes_raw, enforce_sorted=False)
1204+
state_h_raw = state_h_raw.index_select(dim=1, index=source_packed.sorted_indices)
1205+
state_c_raw = state_c_raw.index_select(dim=1, index=source_packed.sorted_indices)
1206+
11941207
out_raw, new_state_h_raw, new_state_c_raw = torch.lstm(
1195-
source_raw,
1208+
source_packed.data,
1209+
source_packed.batch_sizes,
11961210
(state_h_raw, state_c_raw),
11971211
lstm_params,
11981212
has_biases=has_biases,
11991213
num_layers=1,
12001214
dropout=0.0,
12011215
train=rf.get_run_ctx().train_flag,
12021216
bidirectional=False,
1203-
batch_first=False,
12041217
)
12051218

1219+
# Unsort the batch dims.
1220+
new_state_h_raw = new_state_h_raw.index_select(dim=1, index=source_packed.unsorted_indices)
1221+
new_state_c_raw = new_state_c_raw.index_select(dim=1, index=source_packed.unsorted_indices)
1222+
# Unpack the sequence.
1223+
output_packed = torch.nn.utils.rnn.PackedSequence(
1224+
out_raw,
1225+
batch_sizes=source_packed.batch_sizes,
1226+
sorted_indices=source_packed.sorted_indices,
1227+
unsorted_indices=source_packed.unsorted_indices,
1228+
)
1229+
out_raw = torch.nn.utils.rnn.pad_packed_sequence(output_packed)[0]
1230+
12061231
if len(batch_dims) != 1:
12071232
out_raw = torch.reshape(
12081233
out_raw,

0 commit comments

Comments
 (0)