From 2268ddb6f2a8db8da7eb64888ed2af2228bdba08 Mon Sep 17 00:00:00 2001
From: ddavis-2015 <ddavis@bdti.com>
Date: Thu, 12 Dec 2024 10:45:41 -0600
Subject: [PATCH] feat(compression): implement tensor decompression in op
 depthwise conv

Implement tensor decompression in op depthwise conv. Extend tests
to validate operation on compressed tensors.

BUG=part of #2636
---
 .../lite/micro/kernels/depthwise_conv.cc      |  41 +-
 .../micro/kernels/depthwise_conv_common.cc    |  23 +-
 .../lite/micro/kernels/depthwise_conv_test.cc | 388 +++++++++++++++++-
 .../micro/kernels/xtensa/depthwise_conv.cc    |  34 +-
 .../kernels/xtensa/depthwise_conv_hifi.cc     |  37 +-
 .../kernels/xtensa/depthwise_conv_vision.cc   |  73 +++-
 tensorflow/lite/micro/micro_utils.h           |  13 +-
 7 files changed, 581 insertions(+), 28 deletions(-)

diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.cc b/tensorflow/lite/micro/kernels/depthwise_conv.cc
index fa55a705606..489e83f94f2 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/depthwise_conv.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -52,6 +52,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
           ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
           : nullptr;
 
+#ifdef USE_TFLM_COMPRESSION
+
+  MicroContext* micro_context = GetMicroContext(context);
+
+  const CompressionTensorData* filter_comp_td =
+      micro_context->GetTensorCompressionData(node,
+                                              kDepthwiseConvWeightsTensor);
+  const CompressionTensorData* bias_comp_td =
+      micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor);
+
+#endif  // USE_TFLM_COMPRESSION
+
   switch (input->type) {  // Already know in/out types are same.
     case kTfLiteFloat32: {
       tflite::reference_ops::DepthwiseConv(
@@ -59,9 +71,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
           tflite::micro::GetTensorShape(input),
           tflite::micro::GetTensorData<float>(input),
           tflite::micro::GetTensorShape(filter),
+#ifdef USE_TFLM_COMPRESSION
+          tflite::micro::GetTensorData<float>(micro_context, filter,
+                                              filter_comp_td,
+                                              data.weights_scratch_index),
+          tflite::micro::GetTensorShape(bias),
+          tflite::micro::GetOptionalTensorData<float>(
+              micro_context, bias, bias_comp_td, data.bias_scratch_index),
+#else   // USE_TFLM_COMPRESSION
           tflite::micro::GetTensorData<float>(filter),
           tflite::micro::GetTensorShape(bias),
           tflite::micro::GetOptionalTensorData<float>(bias),
+#endif  // USE_TFLM_COMPRESSION
           tflite::micro::GetTensorShape(output),
           tflite::micro::GetTensorData<float>(output));
       break;
@@ -94,9 +115,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
               tflite::micro::GetTensorShape(input),
               tflite::micro::GetTensorData<int8_t>(input),
               tflite::micro::GetTensorShape(filter),
+#ifdef USE_TFLM_COMPRESSION
+              tflite::micro::GetTensorData<int8_t>(micro_context, filter,
+                                                   filter_comp_td,
+                                                   data.weights_scratch_index),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<int32_t>(
+                  micro_context, bias, bias_comp_td, data.bias_scratch_index),
+#else   // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorData<int8_t>(filter),
               tflite::micro::GetTensorShape(bias),
               tflite::micro::GetOptionalTensorData<int32_t>(bias),
+#endif  // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorShape(output),
               tflite::micro::GetTensorData<int8_t>(output));
           break;
@@ -118,9 +148,18 @@ TfLiteStatus DepthwiseConvEval(TfLiteContext* context, TfLiteNode* node) {
               tflite::micro::GetTensorShape(input),
               tflite::micro::GetTensorData<int16_t>(input),
               tflite::micro::GetTensorShape(filter),
+#ifdef USE_TFLM_COMPRESSION
+              tflite::micro::GetTensorData<int8_t>(micro_context, filter,
+                                                   filter_comp_td,
+                                                   data.weights_scratch_index),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<int64_t>(
+                  micro_context, bias, bias_comp_td, data.bias_scratch_index),
+#else   // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorData<int8_t>(filter),
               tflite::micro::GetTensorShape(bias),
               tflite::micro::GetOptionalTensorData<int64_t>(bias),
+#endif  // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorShape(output),
               tflite::micro::GetTensorData<int16_t>(output));
           break;
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
index 52804de3315..0813d2b028e 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
+++ b/tensorflow/lite/micro/kernels/depthwise_conv_common.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -127,7 +127,9 @@ TfLiteStatus CalculateOpDataDepthwiseConv(
 
   micro_context->DeallocateTempTfLiteTensor(input);
   micro_context->DeallocateTempTfLiteTensor(filter);
-  micro_context->DeallocateTempTfLiteTensor(bias);
+  if (has_bias) {
+    micro_context->DeallocateTempTfLiteTensor(bias);
+  }
   micro_context->DeallocateTempTfLiteTensor(output);
 
   return kTfLiteOk;
@@ -209,6 +211,23 @@ TfLiteStatus DepthwiseConvPrepare(TfLiteContext* context, TfLiteNode* node) {
       context, node, params, input_width, input_height, filter_width,
       filter_height, output_width, output_height, input->type, data));
 
+#ifdef USE_TFLM_COMPRESSION
+
+  // Compression scratch buffers.
+  // These will only be allocated if the tensor is compressed.
+  if (micro_context->IsTensorCompressed(node, kDepthwiseConvWeightsTensor) &&
+      filter->type == kTfLiteInt4) {
+    MicroPrintf("Compression not supported with INT4 tensors");
+    return kTfLiteError;
+  }
+  data->weights_scratch_index =
+      micro_context->AllocateDecompressionScratchBuffer(
+          node, kDepthwiseConvWeightsTensor);
+  data->bias_scratch_index = micro_context->AllocateDecompressionScratchBuffer(
+      node, kDepthwiseConvBiasTensor);
+
+#endif  // USE_TFLM_COMPRESSION
+
   micro_context->DeallocateTempTfLiteTensor(output);
   micro_context->DeallocateTempTfLiteTensor(input);
   micro_context->DeallocateTempTfLiteTensor(filter);
diff --git a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
index b50b40ae6d6..adedcaeb04e 100644
--- a/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
+++ b/tensorflow/lite/micro/kernels/depthwise_conv_test.cc
@@ -1,5 +1,5 @@
 
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include <type_traits>
+
 #include "tensorflow/lite/c/builtin_op_data.h"
 #include "tensorflow/lite/c/common.h"
 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
@@ -32,17 +34,99 @@ constexpr int kOutputTensorIndex = 3;
 constexpr int kMaxFilterChannels = 64;
 constexpr int kMaxBiasChannels = 64;
 
+#ifdef USE_TFLM_COMPRESSION
+
+constexpr size_t kDepthwiseConvMaxTensors = 4;
+constexpr size_t kDepthwiseConvMaxInputTensors = 3;
+
+// Common inputs and outputs (quantized multi channel).
+// data from TfLite test:
+// PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift
+static int kInputShapeQ1[] = {4, 1, 2, 3, 2};
+static constexpr float kInputDataQ1[] = {
+    // [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
+    3,  2,   // batch = 0, y = 0, x = 0
+    1,  -1,  // batch = 0, y = 0, x = 1
+    -2, -3,  // batch = 0, y = 0, x = 2
+    4,  3,   // batch = 0, y = 1, x = 0
+    2,  -2,  // batch = 0, y = 1, x = 1
+    -3, -4,  // batch = 0, y = 1, x = 2
+};
+constexpr size_t kInputElementsQ1 = std::extent<decltype(kInputDataQ1)>::value;
+
+constexpr int kNumChannelsQ1 = 4;
+static int kFilterShapeQ1[] = {4, 1, 2, 2, 4};
+static constexpr float kFilterDataQ1[] = {
+    // This is a compact value table.  Original data is:
+    // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel]
+    // depth multiplier = 2
+    // 1, 2, 3, 4,   y = 0, x = 0
+    // 3, 4, 5, 6,   y = 0, x = 1
+    // 7, 8, 5, 6,   y = 1, x = 0
+    // 3, 4, 1, 2,   y = 1, x = 1
+    1, 3, 7, 8, 2, 4, 1, 3, 5, 2, 4, 6,
+};
+constexpr size_t kFilterElementsQ1 =
+    std::extent<decltype(kFilterDataQ1)>::value;
+
+static int kBiasShapeQ1[] = {1, 4};
+static constexpr float kBiasDataQ1[] = {3, -2, 4, 6};
+constexpr size_t kBiasElementsQ1 = std::extent<decltype(kBiasDataQ1)>::value;
+
+static int kOutputShapeQ1[] = {4, 1, 1, 2, 4};
+static constexpr float kGoldenDataQ1[] = {43, 48, 21, 22, 3, -4, -30, -36};
+constexpr int kOutputElementsQ1 = std::extent<decltype(kGoldenDataQ1)>::value;
+
+// compressed filter data for kBinQuant scheme, matches kFilterDataQ1
+// Align the tensor data the same as a Buffer in the schema
+alignas(16) constexpr uint8_t kBinQuantFilterDataQ1[] = {0x15, 0x6A, 0x8A,
+                                                         0x60};
+constexpr int kBinQuantFilterBitWidthQ1 = 2;
+// compressed bias data for kBinQuant scheme, matches kBiasDataQ1
+// Align the tensor data the same as a Buffer in the schema
+alignas(16) constexpr uint8_t kBinQuantBiasDataQ1[] = {0x00};
+constexpr int kBinQuantBiasBitWidthQ1 = 1;
+
+#endif  // USE_TFLM_COMPRESSION
+
 // Creates a DepthwiseConv opeerator, calls it with the provided input tensors
 // and some defaults parameters, and compares the output with
 // expected_output_data.
 //
 // The tensors parameter contains both the input tensors as well as a
 // preallocated output tensor into which the output is stored.
-template <typename T>
+template <typename T, typename TF = void, typename TB = void>
 TfLiteStatus ValidateDepthwiseConvGoldens(
     const T* expected_output_data, int output_length,
     TfLiteDepthwiseConvParams* conv_params, float tolerance, int tensors_size,
-    TfLiteTensor* tensors) {
+    TfLiteTensor* tensors
+#ifdef USE_TFLM_COMPRESSION
+    ,
+    const TestCompressionInfo<TF>* filter_comp_info = nullptr,
+    const TestCompressionInfo<TB>* bias_comp_info = nullptr
+#endif  // USE_TFLM_COMPRESSION
+) {
+#ifdef USE_TFLM_COMPRESSION
+
+  TestCompressedList<kDepthwiseConvMaxInputTensors> tcl;
+  if (filter_comp_info != nullptr) {
+    TF_LITE_MICRO_EXPECT_EQ(
+        tcl.AddInput(*filter_comp_info, tensors[kDepthwiseConvWeightsTensor],
+                     kDepthwiseConvWeightsTensor),
+        kTfLiteOk);
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+  if (bias_comp_info != nullptr) {
+    TF_LITE_MICRO_EXPECT_EQ(
+        tcl.AddInput(*bias_comp_info, tensors[kDepthwiseConvBiasTensor],
+                     kDepthwiseConvBiasTensor),
+        kTfLiteOk);
+    TF_LITE_MICRO_CHECK_FAIL();
+  }
+  const CompressedTensorList* comp_list_p = tcl.GetCompressedTensorList();
+
+#endif  // USE_TFLM_COMPRESSION
+
   int inputs_array_data[] = {3, 0, 1, 2};
   TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
   int outputs_array_data[] = {1, 3};
@@ -50,8 +134,12 @@ TfLiteStatus ValidateDepthwiseConvGoldens(
 
   const TFLMRegistration registration = Register_DEPTHWISE_CONV_2D();
   micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
-                             outputs_array,
-                             reinterpret_cast<void*>(conv_params));
+                             outputs_array, reinterpret_cast<void*>(conv_params)
+#ifdef USE_TFLM_COMPRESSION
+                                                ,
+                             nullptr, comp_list_p
+#endif  // USE_TFLM_COMPRESSION
+  );
 
   int input_depth = tensors[0].dims->data[3];
   int output_depth = tensors[1].dims->data[3];
@@ -183,18 +271,93 @@ void TestDepthwiseConvQuantizedPerChannel(
       output_scale, output_zero_point, conv_params, filter_packed_type);
 }
 
+#ifdef USE_TFLM_COMPRESSION
+
+template <typename TIO, typename TBIAS>
+TfLiteStatus TestDepthwiseConvQuantizedCompressed(
+    int* input_dims_data, const float* input_data, TIO* input_quantized,
+    float input_scale, int input_zero_point, int* output_dims_data,
+    const float* expected_output_data, TIO* expected_output_quantized,
+    TIO* output_quantized, float output_scale, int output_zero_point,
+    TfLiteDepthwiseConvParams* conv_params, const unsigned int tolerance,
+    const TestCompressionQuantizedInfo<int8_t>* filter_comp_info,
+    const TestCompressionQuantizedInfo<TBIAS>* bias_comp_info) {
+  // TODO(b/360169306): account for optional bias tensor
+  // bool null_bias = comp_info->bias_data == nullptr ? true : false;
+
+  TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
+  TfLiteIntArray* filter_dims = IntArrayFromInts(filter_comp_info->dims_data);
+  TfLiteIntArray* bias_dims = IntArrayFromInts(bias_comp_info->dims_data);
+  TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
+
+  TfLiteFloatArray* filter_scales =
+      FloatArrayFromFloats(filter_comp_info->scales);
+  TfLiteIntArray* filter_zero_points =
+      IntArrayFromInts(filter_comp_info->zero_points);
+  TfLiteFloatArray* bias_scales = FloatArrayFromFloats(bias_comp_info->scales);
+  TfLiteIntArray* bias_zero_points =
+      IntArrayFromInts(bias_comp_info->zero_points);
+
+  TfLiteAffineQuantization filter_quant = {};
+  TfLiteTensor filter_tensor = CreatePerChannelQuantizedTensor(
+      filter_comp_info->compressed, filter_dims, filter_scales,
+      filter_zero_points, &filter_quant, kDepthwiseConvQuantizedDimension,
+      false /* is_variable */, kTfLiteInt8);
+  // Value tables are always in channel order, therefore do not use the
+  // quantized dimension.
+  SymmetricPerChannelQuantize(
+      filter_comp_info->data, filter_comp_info->value_table,
+      filter_scales->size * filter_comp_info->value_table_stride,
+      filter_scales->size, filter_scales->data, 0 /* see comment above */);
+
+  TfLiteAffineQuantization bias_quant = {};
+  TfLiteTensor bias_tensor = CreatePerChannelQuantizedBiasTensor(
+      bias_comp_info->compressed, bias_dims, input_scale, filter_scales,
+      bias_scales, bias_zero_points, &bias_quant,
+      0 /* quantized dimension for bias tensor */, false /* is_variable */,
+      typeToTfLiteType<TBIAS>());
+  SymmetricPerChannelQuantize(
+      bias_comp_info->data, bias_comp_info->value_table,
+      bias_scales->size * bias_comp_info->value_table_stride, bias_scales->size,
+      bias_scales->data);
+
+  constexpr int tensors_size = kDepthwiseConvMaxTensors;
+  TfLiteTensor tensors[tensors_size] = {
+      CreateQuantizedTensor(input_data, input_quantized, input_dims,
+                            input_scale, input_zero_point),
+      filter_tensor,
+      bias_tensor,
+      CreateQuantizedTensor(output_quantized, output_dims, output_scale,
+                            output_zero_point),
+  };
+
+  const int output_dims_count = ElementCount(*output_dims);
+  Quantize(expected_output_data, expected_output_quantized, output_dims_count,
+           output_scale, output_zero_point);
+  return ValidateDepthwiseConvGoldens(
+      expected_output_quantized, output_dims_count, conv_params, tolerance,
+      tensors_size, tensors, filter_comp_info, bias_comp_info);
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
+// TODO(ddavis-2015): is this still valid?
 // Xtensa kernels do not support float activations., and the corresponding tests
 // are disabled. As a result, helper functions that are only needed for float
 // kernel tests also need to be ifdef'd out to avoid build errors due to unused
 // functions.
 #if !defined(XTENSA)
-void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data,
-                            int* filter_dims_data, const float* filter_data,
-                            int* bias_dims_data, const float* bias_data,
-                            const float* expected_output_data,
-                            int* output_dims_data,
-                            TfLiteDepthwiseConvParams* conv_params,
-                            float* output_data) {
+void TestDepthwiseConvFloat(
+    int* input_dims_data, const float* input_data, int* filter_dims_data,
+    const float* filter_data, int* bias_dims_data, const float* bias_data,
+    const float* expected_output_data, int* output_dims_data,
+    TfLiteDepthwiseConvParams* conv_params, float* output_data
+#ifdef USE_TFLM_COMPRESSION
+    ,
+    const TestCompressionInfo<const float>* filter_comp_info = nullptr,
+    const TestCompressionInfo<const float>* bias_comp_info = nullptr
+#endif  // USE_TFLM_COMPRESSION
+) {
   TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
   TfLiteIntArray* filter_dims = IntArrayFromInts(filter_dims_data);
   TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
@@ -212,7 +375,12 @@ void TestDepthwiseConvFloat(int* input_dims_data, const float* input_data,
   };
 
   ValidateDepthwiseConvGoldens(expected_output_data, output_dims_count,
-                               conv_params, 1e-5, tensors_size, tensors);
+                               conv_params, 1e-5, tensors_size, tensors
+#ifdef USE_TFLM_COMPRESSION
+                               ,
+                               filter_comp_info, bias_comp_info
+#endif  // USE_TFLM_COMPRESSION
+  );
 }
 
 #endif  // !defined(XTENSA)
@@ -253,6 +421,60 @@ TF_LITE_MICRO_TEST(SimpleTest) {
       bias_values, golden, output_shape, &conv_params, output_data);
 }
 
+#ifdef USE_TFLM_COMPRESSION
+
+TF_LITE_MICRO_TEST(SimpleTestCompressed) {
+  int input_shape[] = {4, 1, 3, 2, 2};
+  const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12};
+  int filter_shape[] = {4, 1, 2, 2, 4};
+  // Filter values:
+  // {1, 2, 3, 4, -9, 10,  -11, 12, 5, 6, 7, 8, 13, -14, 15,  -16}
+  // Align the tensor data the same as a Buffer in the schema
+  alignas(16) const uint8_t kBinQuantFilterData[] = {0x01, 0x23, 0xF8, 0xE9,
+                                                     0x45, 0x67, 0xAD, 0xBC};
+  const float kBinQuantFilterValueTable[] = {1,  2,  3,  4,  5,   6,   7,   8,
+                                             10, 12, 13, 15, -16, -14, -11, -9};
+  int bias_shape[] = {4, 1, 1, 1, 4};
+  const float bias_values[] = {1, 2, 3, 4};
+  // Align the tensor data the same as a Buffer in the schema
+  alignas(16) const uint8_t kBinQuantBiasData[] = {0x1B};
+  const float golden[] = {
+      71, -34, 99, -20, 91, -26, 127, -4,
+  };
+  int output_shape[] = {4, 1, 2, 1, 4};
+  const int output_dims_count = std::extent<decltype(golden)>::value;
+  float output_data[output_dims_count];
+
+  tflite::testing::TestCompressionInfo<const float> filter_comp_info = {};
+  tflite::testing::TestCompressionInfo<const float> bias_comp_info = {};
+
+  filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
+  filter_comp_info.value_table = kBinQuantFilterValueTable;
+  filter_comp_info.value_table_stride =
+      std::extent<decltype(kBinQuantFilterValueTable)>::value;
+  filter_comp_info.bit_width = 4;
+
+  bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
+  bias_comp_info.value_table = bias_values;
+  bias_comp_info.value_table_stride = std::extent<decltype(bias_values)>::value;
+  bias_comp_info.bit_width = 2;
+
+  TfLiteDepthwiseConvParams conv_params;
+  conv_params.activation = kTfLiteActNone;
+  conv_params.dilation_width_factor = 1;
+  conv_params.dilation_height_factor = 1;
+  conv_params.stride_height = 1;
+  conv_params.stride_width = 1;
+
+  tflite::testing::TestDepthwiseConvFloat(
+      input_shape, input_values, filter_shape,
+      reinterpret_cast<const float*>(kBinQuantFilterData), bias_shape,
+      reinterpret_cast<const float*>(kBinQuantBiasData), golden, output_shape,
+      &conv_params, output_data, &filter_comp_info, &bias_comp_info);
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
 TF_LITE_MICRO_TEST(SimpleTestRelu) {
   int input_shape[] = {4, 1, 3, 2, 2};
   const float input_values[] = {1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12};
@@ -1068,4 +1290,144 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16InputInt8Filter) {
       bias_quantized, output_shape, golden, golden_quantized, output_data,
       output_scale, output_zero_point, &conv_params);
 }
+
+#ifdef USE_TFLM_COMPRESSION
+
+TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt8Compressed) {
+  // data from TfLite test:
+  // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift
+  const float input_scale = 0.5f;
+  const float output_scale = 0.5f;
+  const int input_zero_point = -1;
+  const int output_zero_point = -1;
+  constexpr float filter_scales[] = {
+      tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f,
+  };
+  constexpr int filter_zero_points[] = {
+      tflite::testing::kNumChannelsQ1, 0, 0, 0, 0,
+  };
+  // bias scales and zero points will be computed
+  float bias_scales[std::extent<decltype(filter_scales)>::value] = {};
+  int bias_zero_points[std::extent<decltype(filter_scales)>::value] = {};
+
+  int8_t input_quantized[tflite::testing::kInputElementsQ1];
+  int8_t filter_quantized[tflite::testing::kFilterElementsQ1];
+  int32_t bias_quantized[tflite::testing::kBiasElementsQ1];
+  int8_t golden_quantized[tflite::testing::kOutputElementsQ1];
+  int8_t output_quantized[tflite::testing::kOutputElementsQ1];
+
+  tflite::testing::TestCompressionQuantizedInfo<int8_t> filter_comp_info = {};
+  tflite::testing::TestCompressionQuantizedInfo<int32_t> bias_comp_info = {};
+
+  filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
+  filter_comp_info.value_table = filter_quantized;
+  filter_comp_info.value_table_stride =
+      tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1;
+  filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
+  filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1;
+  filter_comp_info.data = tflite::testing::kFilterDataQ1;
+  filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1;
+  filter_comp_info.scales = filter_scales;
+  filter_comp_info.zero_points = filter_zero_points;
+
+  bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
+  bias_comp_info.value_table = bias_quantized;
+  bias_comp_info.value_table_stride =
+      tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1;
+  bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
+  bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1;
+  bias_comp_info.data = tflite::testing::kBiasDataQ1;
+  bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1;
+  bias_comp_info.scales = bias_scales;
+  bias_comp_info.zero_points = bias_zero_points;
+
+  TfLiteDepthwiseConvParams conv_params = {};
+  conv_params.activation = kTfLiteActNone;
+  conv_params.dilation_width_factor = 1;
+  conv_params.dilation_height_factor = 1;
+  conv_params.stride_height = 1;
+  conv_params.stride_width = 1;
+
+  // tolerance of 3 is approx. 2.0f
+  // TODO(ddavis-2015): why does the tolerance differ from TfLite test???
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk,
+      tflite::testing::TestDepthwiseConvQuantizedCompressed(
+          tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1,
+          input_quantized, input_scale, input_zero_point,
+          tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1,
+          golden_quantized, output_quantized, output_scale, output_zero_point,
+          &conv_params, 3, &filter_comp_info, &bias_comp_info));
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantizedPerChannelInt16Compressed) {
+  // data from TfLite test:
+  // PerChannelQuantizedDepthwiseConvolutionOpTest SimpleTestMixedOutputShift
+  const float input_scale =
+      tflite::testing::SymmetricScaleFromMinMax<int16_t>(-4.0f, 4.0f);
+  const float output_scale =
+      tflite::testing::SymmetricScaleFromMinMax<int16_t>(-63.5f, 64.0f);
+  const int input_zero_point = 0;
+  const int output_zero_point = 0;
+  constexpr float filter_scales[] = {
+      tflite::testing::kNumChannelsQ1, 0.1f, 2.0f, 3.0f, 0.4f,
+  };
+  constexpr int filter_zero_points[] = {
+      tflite::testing::kNumChannelsQ1, 0, 0, 0, 0,
+  };
+  // bias scales and zero points will be computed
+  float bias_scales[std::extent<decltype(filter_scales)>::value] = {};
+  int bias_zero_points[std::extent<decltype(filter_scales)>::value] = {};
+
+  int16_t input_quantized[tflite::testing::kInputElementsQ1];
+  int8_t filter_quantized[tflite::testing::kFilterElementsQ1];
+  int64_t bias_quantized[tflite::testing::kBiasElementsQ1];
+  int16_t golden_quantized[tflite::testing::kOutputElementsQ1];
+  int16_t output_quantized[tflite::testing::kOutputElementsQ1];
+
+  tflite::testing::TestCompressionQuantizedInfo<int8_t> filter_comp_info = {};
+  tflite::testing::TestCompressionQuantizedInfo<int64_t> bias_comp_info = {};
+
+  filter_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
+  filter_comp_info.value_table = filter_quantized;
+  filter_comp_info.value_table_stride =
+      tflite::testing::kFilterElementsQ1 / tflite::testing::kNumChannelsQ1;
+  filter_comp_info.bit_width = tflite::testing::kBinQuantFilterBitWidthQ1;
+  filter_comp_info.compressed = tflite::testing::kBinQuantFilterDataQ1;
+  filter_comp_info.data = tflite::testing::kFilterDataQ1;
+  filter_comp_info.dims_data = tflite::testing::kFilterShapeQ1;
+  filter_comp_info.scales = filter_scales;
+  filter_comp_info.zero_points = filter_zero_points;
+
+  bias_comp_info.scheme = tflite::CompressionScheme::kBinQuant;
+  bias_comp_info.value_table = bias_quantized;
+  bias_comp_info.value_table_stride =
+      tflite::testing::kBiasElementsQ1 / tflite::testing::kNumChannelsQ1;
+  bias_comp_info.bit_width = tflite::testing::kBinQuantBiasBitWidthQ1;
+  bias_comp_info.compressed = tflite::testing::kBinQuantBiasDataQ1;
+  bias_comp_info.data = tflite::testing::kBiasDataQ1;
+  bias_comp_info.dims_data = tflite::testing::kBiasShapeQ1;
+  bias_comp_info.scales = bias_scales;
+  bias_comp_info.zero_points = bias_zero_points;
+
+  TfLiteDepthwiseConvParams conv_params = {};
+  conv_params.activation = kTfLiteActNone;
+  conv_params.dilation_width_factor = 1;
+  conv_params.dilation_height_factor = 1;
+  conv_params.stride_height = 1;
+  conv_params.stride_width = 1;
+
+  // tolerance of 512 is approx. 1.0f
+  TF_LITE_MICRO_EXPECT_EQ(
+      kTfLiteOk,
+      tflite::testing::TestDepthwiseConvQuantizedCompressed(
+          tflite::testing::kInputShapeQ1, tflite::testing::kInputDataQ1,
+          input_quantized, input_scale, input_zero_point,
+          tflite::testing::kOutputShapeQ1, tflite::testing::kGoldenDataQ1,
+          golden_quantized, output_quantized, output_scale, output_zero_point,
+          &conv_params, 512, &filter_comp_info, &bias_comp_info));
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
 TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
index 8536ff79507..838fdc0944e 100644
--- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv.cc
@@ -1,5 +1,5 @@
 
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -93,6 +93,18 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor(
       context, op_data.reference_op_data.filter_buffer_index, filter);
 
+#ifdef USE_TFLM_COMPRESSION
+
+  MicroContext* micro_context = GetMicroContext(context);
+
+  const CompressionTensorData* filter_comp_td =
+      micro_context->GetTensorCompressionData(node,
+                                              kDepthwiseConvWeightsTensor);
+  const CompressionTensorData* bias_comp_td =
+      micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor);
+
+#endif  // USE_TFLM_COMPRESSION
+
   switch (input->type) {  // Already know in/out types are same.
     case kTfLiteInt8: {
       switch (filter_int8.type) {
@@ -111,9 +123,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
               tflite::micro::GetTensorShape(input),
               tflite::micro::GetTensorData<int8_t>(input),
               tflite::micro::GetTensorShape(filter),
+#ifdef USE_TFLM_COMPRESSION
+              tflite::micro::GetTensorData<int8_t>(
+                  micro_context, &filter_int8, filter_comp_td,
+                  op_data.reference_op_data.weights_scratch_index),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<int32_t>(
+                  micro_context, bias, bias_comp_td,
+                  op_data.reference_op_data.bias_scratch_index),
+#else   // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorData<int8_t>(&filter_int8),
               tflite::micro::GetTensorShape(bias),
               tflite::micro::GetOptionalTensorData<int32_t>(bias),
+#endif  // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorShape(output),
               tflite::micro::GetTensorData<int8_t>(output));
 #endif  // defined(HIFI3) || defined(HIFI4) || defined(HIFI5)
@@ -136,9 +158,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
               tflite::micro::GetTensorShape(input),
               tflite::micro::GetTensorData<int16_t>(input),
               tflite::micro::GetTensorShape(filter),
+#ifdef USE_TFLM_COMPRESSION
+              tflite::micro::GetTensorData<int8_t>(
+                  micro_context, &filter_int8, filter_comp_td,
+                  op_data.reference_op_data.weights_scratch_index),
+              tflite::micro::GetTensorShape(bias),
+              tflite::micro::GetOptionalTensorData<int64_t>(
+                  micro_context, bias, bias_comp_td,
+                  op_data.reference_op_data.bias_scratch_index),
+#else   // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorData<int8_t>(&filter_int8),
               tflite::micro::GetTensorShape(bias),
               tflite::micro::GetOptionalTensorData<int64_t>(bias),
+#endif  // USE_TFLM_COMPRESSION
               tflite::micro::GetTensorShape(output),
               tflite::micro::GetTensorData<int16_t>(output));
           break;
diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc
index 8c2052b23e7..09e84dee936 100644
--- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_hifi.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -97,10 +97,22 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node,
                                    const TfLiteEvalTensor* filter,
                                    const TfLiteEvalTensor* bias,
                                    TfLiteEvalTensor* output) {
+#ifdef USE_TFLM_COMPRESSION
+
+  MicroContext* micro_context = GetMicroContext(context);
+
+  const CompressionTensorData* filter_comp_td =
+      micro_context->GetTensorCompressionData(node,
+                                              kDepthwiseConvWeightsTensor);
+  const CompressionTensorData* bias_comp_td =
+      micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor);
+
+#endif  // USE_TFLM_COMPRESSION
+
   // If dilation is not required use the optimized NN Library kernel.
   // Otherwise call the reference implementation.
   if ((params.dilation_width_factor == 1) &&
-      (params.dilation_height_factor == 1)) {
+      (params.dilation_height_factor == 1) && bias != nullptr) {
     const int stride_width = params.stride_width;
     const int stride_height = params.stride_height;
     const int pad_width = data.reference_op_data.padding.width;
@@ -133,8 +145,17 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node,
     TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
 
     const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);
+#ifdef USE_TFLM_COMPRESSION
+    const int8_t* filter_data = tflite::micro::GetTensorData<int8_t>(
+        micro_context, filter, filter_comp_td,
+        data.reference_op_data.weights_scratch_index);
+    const int32_t* bias_data = tflite::micro::GetTensorData<int32_t>(
+        micro_context, bias, bias_comp_td,
+        data.reference_op_data.bias_scratch_index);
+#else   // USE_TFLM_COMPRESSION
     const int8_t* filter_data = tflite::micro::GetTensorData<int8_t>(filter);
     const int32_t* bias_data = tflite::micro::GetTensorData<int32_t>(bias);
+#endif  // USE_TFLM_COMPRESSION
     int8_t* output_data = tflite::micro::GetTensorData<int8_t>(output);
 
     int32_t input_data_format = 0;
@@ -178,9 +199,19 @@ TfLiteStatus DepthwiseConvEvalHifi(TfLiteContext* context, TfLiteNode* node,
       tflite::micro::GetTensorShape(input),
       tflite::micro::GetTensorData<int8_t>(input),
       tflite::micro::GetTensorShape(filter),
+#ifdef USE_TFLM_COMPRESSION
+      tflite::micro::GetTensorData<int8_t>(
+          micro_context, filter, filter_comp_td,
+          data.reference_op_data.weights_scratch_index),
+      tflite::micro::GetTensorShape(bias),
+      tflite::micro::GetOptionalTensorData<int32_t>(
+          micro_context, bias, bias_comp_td,
+          data.reference_op_data.bias_scratch_index),
+#else   // USE_TFLM_COMPRESSION
       tflite::micro::GetTensorData<int8_t>(filter),
       tflite::micro::GetTensorShape(bias),
-      tflite::micro::GetTensorData<int32_t>(bias),
+      tflite::micro::GetOptionalTensorData<int32_t>(bias),
+#endif  // USE_TFLM_COMPRESSION
       tflite::micro::GetTensorShape(output),
       tflite::micro::GetTensorData<int8_t>(output));
 
diff --git a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc
index 35fa8cf1c1a..23e18dc8342 100644
--- a/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc
+++ b/tensorflow/lite/micro/kernels/xtensa/depthwise_conv_vision.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -53,7 +53,7 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context,
   TF_LITE_ENSURE(context, filter != nullptr);
   TfLiteTensor* bias =
       micro_context->AllocateTempInputTensor(node, kDepthwiseConvBiasTensor);
-  TF_LITE_ENSURE(context, filter != nullptr);
+  TF_LITE_ENSURE(context, bias != nullptr);
 
   // Dynamically allocate per-channel quantization parameters.
   const int num_channels =
@@ -135,18 +135,81 @@ TfLiteStatus DepthwiseConvPrepareVision(TfLiteContext* context,
     filter_int8 = *filter;
   }
 
+#ifdef USE_TFLM_COMPRESSION
+
+  uint8_t* filter_data = nullptr;
+  int32_t* bias_data = nullptr;
+
+  const CompressionTensorData* filter_comp_td =
+      micro_context->GetTensorCompressionData(node,
+                                              kDepthwiseConvWeightsTensor);
+  if (filter_comp_td != nullptr) {
+    const size_t filter_data_size =
+        NumElements(&filter_int8) * TfLiteTypeGetSize(kTfLiteInt8);
+    filter_data =
+        micro_context->AllocateTempBuffer(filter_data_size, sizeof(int8_t));
+    if (filter_data == nullptr) {
+      return kTfLiteError;
+    }
+    const TfLiteEvalTensor* filter_eval =
+        tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor);
+    filter_data = static_cast<uint8_t*>(micro_context->DecompressTensorToBuffer(
+        *filter_eval, *filter_comp_td, filter_data));
+  } else {
+    filter_data = GetTensorData<uint8_t>(&filter_int8);
+  }
+
+  const CompressionTensorData* bias_comp_td =
+      micro_context->GetTensorCompressionData(node, kDepthwiseConvBiasTensor);
+  if (bias_comp_td != nullptr) {
+    const size_t bias_data_size =
+        NumElements(bias) * TfLiteTypeGetSize(kTfLiteInt32);
+    bias_data = reinterpret_cast<int32_t*>(
+        micro_context->AllocateTempBuffer(bias_data_size, sizeof(int32_t)));
+    if (bias_data == nullptr) {
+      return kTfLiteError;
+    }
+    const TfLiteEvalTensor* bias_eval =
+        tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor);
+    bias_data = static_cast<int32_t*>(micro_context->DecompressTensorToBuffer(
+        *bias_eval, *bias_comp_td, bias_data));
+  } else {
+    bias_data = GetTensorData<int32_t>(bias);
+  }
+
+  if (filter_data == nullptr || bias_data == nullptr) {
+    return kTfLiteError;
+  }
+
+#else  // USE_TFLM_COMPRESSION
+
+  uint8_t* filter_data = GetTensorData<uint8_t>(&filter_int8);
+  int32_t* bias_data = GetTensorData<int32_t>(bias);
+
+#endif  // USE_TFLM_COMPRESSION
+
   status = xiDepthwiseConvDoCoeffReorder(
       data->p_context, data->context_size,
       reinterpret_cast<uint8_t*>(data->reorder_coefficient_bias),
-      data->reorder_coefficient_bias_size,
-      const_cast<uint8_t*>(GetTensorData<uint8_t>(&filter_int8)),
-      const_cast<int32_t*>(GetTensorData<int32_t>(bias)));
+      data->reorder_coefficient_bias_size, filter_data, bias_data);
   if (status) {
     return kTfLiteError;
   }
   if (filter->type == kTfLiteInt4) {
     micro_context->DeallocateTempBuffer(GetTensorData<uint8_t>(&filter_int8));
   }
+
+#ifdef USE_TFLM_COMPRESSION
+
+  if (filter_comp_td) {
+    micro_context->DeallocateTempBuffer(filter_data);
+  }
+  if (bias_comp_td) {
+    micro_context->DeallocateTempBuffer(reinterpret_cast<uint8_t*>(bias_data));
+  }
+
+#endif  // USE_TFLM_COMPRESSION
+
   micro_context->DeallocateTempTfLiteTensor(output);
   micro_context->DeallocateTempTfLiteTensor(input);
   micro_context->DeallocateTempTfLiteTensor(filter);
diff --git a/tensorflow/lite/micro/micro_utils.h b/tensorflow/lite/micro/micro_utils.h
index 98ef81dc8ed..b362d3402bb 100644
--- a/tensorflow/lite/micro/micro_utils.h
+++ b/tensorflow/lite/micro/micro_utils.h
@@ -90,12 +90,19 @@ void SymmetricQuantize(const float* input, T* output, int num_elements,
 template <typename T>
 void SymmetricPerChannelQuantize(const float* input, T* output,
                                  int num_elements, int num_channels,
-                                 float* scales) {
+                                 float* scales,
+                                 size_t quantized_dimension = 0) {
   int elements_per_channel = num_elements / num_channels;
   for (int i = 0; i < num_channels; i++) {
     for (int j = 0; j < elements_per_channel; j++) {
-      output[i * elements_per_channel + j] = FloatToSymmetricQuantizedType<T>(
-          input[i * elements_per_channel + j], scales[i]);
+      size_t offset;
+      if (quantized_dimension == 0) {
+        offset = i * elements_per_channel + j;
+      } else {
+        offset = i + elements_per_channel * j;
+      }
+      output[offset] =
+          FloatToSymmetricQuantizedType<T>(input[offset], scales[i]);
     }
   }
 }