Skip to content

Commit fca5ed3

Browse files
committed
Sync from upstream TF.
1 parent e86d97b commit fca5ed3

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tensorflow/lite/kernels/internal/reference/batch_matmul.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
111111
const float* scaling_factors,
112112
const int32_t* input_offset, int32_t* row_sums,
113113
const RuntimeShape& output_shape, float* output_data,
114-
bool* compute_row_sums) {
114+
bool* compute_row_sums,
115+
const float* per_channel_scales) {
115116
const RuntimeShape extended_lhs_shape =
116117
RuntimeShape::ExtendedShape(5, lhs_shape);
117118
const RuntimeShape extended_rhs_shape =
@@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
188189
int32_t row_sum = woff_ptr2[i];
189190
total -= row_sum * batch_offset;
190191
int idx = lhs_rows * j + i;
191-
out_ptr[idx] += batch_scaling_factor * total;
192+
float ss = batch_scaling_factor;
193+
if (per_channel_scales) {
194+
ss *= per_channel_scales[j];
195+
}
196+
out_ptr[idx] += ss * total;
192197
}
193198
}
194199
}

0 commit comments

Comments
 (0)