Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ void LstmStepManager::UpdateBatch() {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::InputShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major || ((size_info_.batch_size > 1 && size_info_.time_steps == 1))) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.input_dimension};
Expand All @@ -485,7 +485,7 @@ RuntimeShape LstmStepManager::InputShape() const {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::StateShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major || (size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.state_dimension};
Expand Down
29 changes: 20 additions & 9 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
int input_dimension = step_info.input_dimension();
int state_dimension = step_info.state_dimension();

const auto& size_info = op_data.size_info;
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
num_batches = size_info.batch_size;
}

// Check offset validity to avoid memory overflow
TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
tflite::micro::GetTensorShape(input).FlatSize());
Expand Down Expand Up @@ -805,16 +810,22 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
}
} else {
// batch first, unable to size the input data. single batch inference
for (int b = 0; b < size_info.batch_size; b++) {
for (int t = 0; t < size_info.time_steps; t++) {
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
// prepare for the next time step
step_info.UpdateTime();
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
// Ramesh
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
} else {
for (int b = 0; b < size_info.batch_size; b++) {
for (int t = 0; t < size_info.time_steps; t++) {
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
// prepare for the next time step
step_info.UpdateTime();
}
// prepare for the next batch
step_info.UpdateBatch();
step_info.ResetTime();
}
// prepare for the next batch
step_info.UpdateBatch();
step_info.ResetTime();
}
}
return kTfLiteOk;
Expand Down