diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index 57cce631fa0..dbe2cb22af8 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -8,6 +8,7 @@ cc_library(
     srcs = ["array.cc"],
     hdrs = ["array.h"],
     deps = [
+        "//tensorflow/lite/c:common",
         "//tensorflow/lite/core/c:common",
     ],
 )
diff --git a/tensorflow/lite/array.cc b/tensorflow/lite/array.cc
index 1b1ff2e4557..21d704a76c4 100644
--- a/tensorflow/lite/array.cc
+++ b/tensorflow/lite/array.cc
@@ -15,6 +15,8 @@ limitations under the License.
 
 #include "tensorflow/lite/array.h"
 
+#include "tensorflow/lite/c/common.h"
+
 namespace tflite {
 namespace array_internal {
 
diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h
index 767ad6ab0af..d83696219c2 100644
--- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h
+++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h
@@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
                         const float* scaling_factors,
                         const int32_t* input_offset, int32_t* row_sums,
                         const RuntimeShape& output_shape, float* output_data,
-                        bool* compute_row_sums) {
+                        bool* compute_row_sums,
+                        const float* per_channel_scales) {
   const RuntimeShape extended_lhs_shape =
       RuntimeShape::ExtendedShape(5, lhs_shape);
   const RuntimeShape extended_rhs_shape =
@@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
             int32_t row_sum = woff_ptr2[i];
             total -= row_sum * batch_offset;
             int idx = lhs_rows * j + i;
-            out_ptr[idx] += batch_scaling_factor * total;
+            float scale = batch_scaling_factor;
+            if (per_channel_scales) {
+              scale *= per_channel_scales[i];
+            }
+            out_ptr[idx] += scale * total;
           }
         }
       }