Skip to content

Commit 389e775

Browse files
suleshahidTFLM-bot
andauthored
Fix upstream sync (#2728)
* . * nl add * Sync from upstream TF. * fix include --------- Co-authored-by: TFLM-bot <[email protected]>
1 parent e440f0a commit 389e775

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

tensorflow/lite/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ cc_library(
88
srcs = ["array.cc"],
99
hdrs = ["array.h"],
1010
deps = [
11+
"//tensorflow/lite/c:common",
1112
"//tensorflow/lite/core/c:common",
1213
],
1314
)

tensorflow/lite/array.cc

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515

1616
#include "tensorflow/lite/array.h"
1717

18+
#include "tensorflow/lite/c/common.h"
19+
1820
namespace tflite {
1921
namespace array_internal {
2022

tensorflow/lite/kernels/internal/reference/batch_matmul.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
111111
const float* scaling_factors,
112112
const int32_t* input_offset, int32_t* row_sums,
113113
const RuntimeShape& output_shape, float* output_data,
114-
bool* compute_row_sums) {
114+
bool* compute_row_sums,
115+
const float* per_channel_scales) {
115116
const RuntimeShape extended_lhs_shape =
116117
RuntimeShape::ExtendedShape(5, lhs_shape);
117118
const RuntimeShape extended_rhs_shape =
@@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
188189
int32_t row_sum = woff_ptr2[i];
189190
total -= row_sum * batch_offset;
190191
int idx = lhs_rows * j + i;
191-
out_ptr[idx] += batch_scaling_factor * total;
192+
float scale = batch_scaling_factor;
193+
if (per_channel_scales) {
194+
scale *= per_channel_scales[i];
195+
}
196+
out_ptr[idx] += scale * total;
192197
}
193198
}
194199
}

0 commit comments

Comments
 (0)