@@ -1191,18 +1191,43 @@ def lstm(
1191
1191
state_h_raw = torch .reshape (state_h_raw , [1 , batch_dim , out_dim .get_dim_value ()])
1192
1192
state_c_raw = torch .reshape (state_c_raw , [1 , batch_dim , out_dim .get_dim_value ()])
1193
1193
1194
+ sizes = spatial_dim .get_size_tensor ()
1195
+ sizes = sizes .copy_compatible_to (
1196
+ Tensor ("batch_dims" , batch_dims , dtype = sizes .dtype ), unbroadcast = True , check_sparse = False
1197
+ )
1198
+ sizes_raw = torch .reshape (sizes .raw_tensor , [batch_dim ])
1199
+
1200
+ # See the code of torch.nn.LSTM for sorting the batch dims.
1201
+ # We need pack_padded_sequence because otherwise the LSTM would ignore the padding,
1202
+ # and we would get incorrect final states.
1203
+ source_packed = torch .nn .utils .rnn .pack_padded_sequence (source_raw , sizes_raw , enforce_sorted = False )
1204
+ state_h_raw = state_h_raw .index_select (dim = 1 , index = source_packed .sorted_indices )
1205
+ state_c_raw = state_c_raw .index_select (dim = 1 , index = source_packed .sorted_indices )
1206
+
1194
1207
out_raw , new_state_h_raw , new_state_c_raw = torch .lstm (
1195
- source_raw ,
1208
+ source_packed .data ,
1209
+ source_packed .batch_sizes ,
1196
1210
(state_h_raw , state_c_raw ),
1197
1211
lstm_params ,
1198
1212
has_biases = has_biases ,
1199
1213
num_layers = 1 ,
1200
1214
dropout = 0.0 ,
1201
1215
train = rf .get_run_ctx ().train_flag ,
1202
1216
bidirectional = False ,
1203
- batch_first = False ,
1204
1217
)
1205
1218
1219
+ # Unsort the batch dims.
1220
+ new_state_h_raw = new_state_h_raw .index_select (dim = 1 , index = source_packed .unsorted_indices )
1221
+ new_state_c_raw = new_state_c_raw .index_select (dim = 1 , index = source_packed .unsorted_indices )
1222
+ # Unpack the sequence.
1223
+ output_packed = torch .nn .utils .rnn .PackedSequence (
1224
+ out_raw ,
1225
+ batch_sizes = source_packed .batch_sizes ,
1226
+ sorted_indices = source_packed .sorted_indices ,
1227
+ unsorted_indices = source_packed .unsorted_indices ,
1228
+ )
1229
+ out_raw = torch .nn .utils .rnn .pad_packed_sequence (output_packed )[0 ]
1230
+
1206
1231
if len (batch_dims ) != 1 :
1207
1232
out_raw = torch .reshape (
1208
1233
out_raw ,
0 commit comments