Skip to content

Commit

Permalink
CMSIS-NN: Add int4 kernel support to conv2d, depthwise and fully conn…
Browse files Browse the repository at this point in the history
…ected (#2314)

BUG=CMSIS-NN has support for int4 packed weights, hence updating "glue" in TFLM

Co-authored-by: Adrian Lundell <[email protected]>, Ryan O'Shea <[email protected]> and Måns Nilsson <[email protected]>
  • Loading branch information
mansnils authored Nov 14, 2023
1 parent ecb3b32 commit 5beca7c
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 84 deletions.
160 changes: 141 additions & 19 deletions tensorflow/lite/micro/kernels/cmsis_nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_dims.w = output->dims->data[2];
output_dims.c = output_shape.Dims(3);

if (filter->type == kTfLiteInt4) {
int filter_size =
RuntimeShape(filter->dims->size,
reinterpret_cast<const int32_t*>(filter->dims->data))
.FlatSize();
context->RequestScratchBufferInArena(
context, filter_size, &data->reference_op_data.filter_buffer_index);
}

if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
const int num_channels = filter->dims->data[kConvQuantizedDimension];
data->reference_op_data.per_channel_output_multiplier =
Expand Down Expand Up @@ -168,6 +159,104 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus EvalQuantizedPerChannelInt4(
TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params,
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
cmsis_nn_conv_params conv_params;
conv_params.dilation.h = params.dilation_height_factor;
conv_params.dilation.w = params.dilation_width_factor;

// Initialize cmsis_nn convolution parameters
conv_params.input_offset = -data.reference_op_data.input_zero_point;
conv_params.output_offset = data.reference_op_data.output_zero_point;
conv_params.stride.h = params.stride_height;
conv_params.stride.w = params.stride_width;
conv_params.padding.h = data.reference_op_data.padding.height;
conv_params.padding.w = data.reference_op_data.padding.width;
conv_params.activation.min = data.reference_op_data.output_activation_min;
conv_params.activation.max = data.reference_op_data.output_activation_max;

// Initialize cmsis_nn per channel quantization parameters
cmsis_nn_per_channel_quant_params quant_params;
quant_params.multiplier = const_cast<int32_t*>(
data.reference_op_data.per_channel_output_multiplier);
quant_params.shift =
const_cast<int32_t*>(data.reference_op_data.per_channel_output_shift);

RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias);

// Consistency check.
TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const int batch_size = MatchingDim(input_shape, 0, output_shape, 0);
const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (tflite::micro::GetOptionalTensorData<int32_t>(bias)) {
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
}

// Initialize cmsis_nn dimensions
// Input
cmsis_nn_dims input_dims;
input_dims.n = batch_size;
input_dims.h = input_shape.Dims(1);
input_dims.w = input_shape.Dims(2);
input_dims.c = input_depth;

// Filter
cmsis_nn_dims filter_dims;
filter_dims.n = output_depth;
filter_dims.h = filter_shape.Dims(1);
filter_dims.w = filter_shape.Dims(2);
filter_dims.c = input_depth;

// Bias
cmsis_nn_dims bias_dims;
bias_dims.n = 1;
bias_dims.h = 1;
bias_dims.w = 1;
bias_dims.c = output_depth;

// Output
cmsis_nn_dims output_dims;
output_dims.n = batch_size;
output_dims.h = output_shape.Dims(1);
output_dims.w = output_shape.Dims(2);
output_dims.c = output_depth;

// Initialize cmsis_nn context
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;

if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
// Note: ctx.size is currently not used in cmsis_nn.
// The buffer should be allocated in the Prepare function through
// arm_convolve_wrapper_s8_get_buffer_size
}

// arm_convolve_wrapper_s4 dispatches the optimized kernel accordingly with
// the parameters passed for convolutions with 4 bit weights
TFLITE_DCHECK_EQ(
arm_convolve_wrapper_s4(
&ctx, &conv_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims,
tflite::micro::GetOptionalTensorData<int32_t>(bias), &output_dims,
tflite::micro::GetTensorData<int8_t>(output)),
ARM_CMSIS_NN_SUCCESS);

return kTfLiteOk;
}

TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
const TfLiteConvParams& params,
const OpData& data,
Expand Down Expand Up @@ -364,6 +453,28 @@ TfLiteStatus EvalQuantizedPerChannel16x8(
return kTfLiteOk;
}

TfLiteStatus EvalInt4(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
? tflite::micro::GetEvalInput(context, node, kConvBiasTensor)
: nullptr;
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kConvOutputTensor);

TFLITE_DCHECK(node->builtin_data != nullptr);
const auto& params =
*(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));

return EvalQuantizedPerChannelInt4(context, node, params, data, input, filter,
bias, output);
}

TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kConvInputTensor);
Expand All @@ -381,11 +492,9 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
*(reinterpret_cast<TfLiteConvParams*>(node->builtin_data));
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor(
context, data.reference_op_data.filter_buffer_index, filter);

return EvalQuantizedPerChannel(context, node, params, data, input,
&filter_int8, bias, output);
return EvalQuantizedPerChannel(context, node, params, data, input, filter,
bias, output);
}

TfLiteStatus EvalInt16x8(TfLiteContext* context, TfLiteNode* node) {
Expand Down Expand Up @@ -445,8 +554,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));

TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor(
context, data.reference_op_data.filter_buffer_index, filter);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(
context,
input->type == filter->type ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 && filter->type == kTfLiteInt4),
"Hybrid models are not supported on TFLite Micro.");

switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: {
Expand All @@ -463,20 +577,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(nullptr), nullptr);
break;
}
case kTfLiteInt8:
switch (filter_int8.type) {
case kTfLiteInt8: {
switch (filter->type) {
case kTfLiteInt4: {
return EvalQuantizedPerChannelInt4(context, node, params, data, input,
filter, bias, output);
}
case kTfLiteInt8: {
return EvalQuantizedPerChannel(context, node, params, data, input,
&filter_int8, bias, output);
filter, bias, output);
}
default: {
MicroPrintf("Filter type %s (%d) not supported.",
TfLiteTypeGetName(filter->type), filter->type);
return kTfLiteError;
}
}

break;
}
case kTfLiteInt16: {
if (bias == nullptr || bias->type == kTfLiteInt64) {
return EvalQuantizedPerChannel16x8(context, node, params, data, input,
Expand Down Expand Up @@ -516,6 +634,10 @@ TFLMRegistration Register_CONV_2D() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);
}

TFLMRegistration Register_CONV_2D_INT4() {
return tflite::micro::RegisterOp(Init, Prepare, EvalInt4);
}

TFLMRegistration Register_CONV_2D_INT8() {
return tflite::micro::RegisterOp(Init, Prepare, EvalInt8);
}
Expand Down
111 changes: 88 additions & 23 deletions tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -118,15 +118,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, num_channels * sizeof(int32_t)));
}

if (filter->type == kTfLiteInt4) {
int filter_size =
RuntimeShape(filter->dims->size,
reinterpret_cast<const int32_t*>(filter->dims->data))
.FlatSize();
context->RequestScratchBufferInArena(
context, filter_size, &data->reference_op_data.filter_buffer_index);
}

TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv(
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, data_type,
Expand Down Expand Up @@ -168,8 +159,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
dw_conv_params.dilation.h = params.dilation_height_factor;
dw_conv_params.dilation.w = params.dilation_width_factor;

const int32_t buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size(
&dw_conv_params, &input_dims, &filter_dims, &output_dims);
int32_t buf_size = 0;
if (filter->type == kTfLiteInt8) {
buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size(
&dw_conv_params, &input_dims, &filter_dims, &output_dims);
} else if (filter->type == kTfLiteInt4) {
buf_size = arm_depthwise_conv_wrapper_s4_get_buffer_size(
&dw_conv_params, &input_dims, &filter_dims, &output_dims);
} else {
MicroPrintf("Filter type %s (%d) not supported.",
TfLiteTypeGetName(filter->type), filter->type);
return kTfLiteError;
}

if (buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
Expand Down Expand Up @@ -285,6 +286,43 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
ARM_CMSIS_NN_SUCCESS);
}

void EvalQuantizedPerChannelInt4(TfLiteContext* context, TfLiteNode* node,
const TfLiteDepthwiseConvParams& params,
const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
cmsis_nn_dw_conv_params dw_conv_params;
cmsis_nn_per_channel_quant_params quant_params;
cmsis_nn_dims input_dims;
cmsis_nn_dims filter_dims;
cmsis_nn_dims bias_dims;
cmsis_nn_dims output_dims;

PopulateDwConvParams(&dw_conv_params, &quant_params, &input_dims,
&filter_dims, &bias_dims, &output_dims, params, data,
input, filter, bias, output);

cmsis_nn_context ctx;
ctx.buf = nullptr;
/* 'size' is unused */
ctx.size = 0;

if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
}

TFLITE_DCHECK_EQ(
arm_depthwise_conv_wrapper_s4(
&ctx, &dw_conv_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int8_t>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims,
tflite::micro::GetOptionalTensorData<int32_t>(bias), &output_dims,
tflite::micro::GetTensorData<int8_t>(output)),
ARM_CMSIS_NN_SUCCESS);
}

void EvalQuantizedPerChannel16x8(TfLiteContext* context, TfLiteNode* node,
const TfLiteDepthwiseConvParams& params,
const OpData& data,
Expand Down Expand Up @@ -337,9 +375,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
: nullptr;

TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor(
context, data.reference_op_data.filter_buffer_index, filter);

switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32: {
tflite::reference_ops::DepthwiseConv(
Expand All @@ -355,10 +390,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt8:
switch (filter_int8.type) {
switch (filter->type) {
case kTfLiteInt8: {
EvalQuantizedPerChannel(context, node, params, data, input,
&filter_int8, bias, output);
EvalQuantizedPerChannel(context, node, params, data, input, filter,
bias, output);
break;
}
case kTfLiteInt4: {
EvalQuantizedPerChannelInt4(context, node, params, data, input,
filter, bias, output);
break;
}
default: {
Expand Down Expand Up @@ -399,11 +439,8 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
: nullptr;

TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor(
context, data.reference_op_data.filter_buffer_index, filter);

EvalQuantizedPerChannel(context, node, params, data, input, &filter_int8,
bias, output);
EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
output);
return kTfLiteOk;
}

Expand Down Expand Up @@ -431,6 +468,30 @@ TfLiteStatus EvalInt16x8(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}

TfLiteStatus EvalInt4(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);

const auto& params =
*(reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data));
const OpData& data = *(static_cast<OpData*>(node->user_data));

TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kDepthwiseConvOutputTensor);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kDepthwiseConvInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 3)
? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor)
: nullptr;

EvalQuantizedPerChannelInt4(context, node, params, data, input, filter, bias,
output);
return kTfLiteOk;
}

} // namespace

TFLMRegistration Register_DEPTHWISE_CONV_2D() {
Expand All @@ -445,4 +506,8 @@ TFLMRegistration Register_DEPTHWISE_CONV_2D_INT16() {
return tflite::micro::RegisterOp(Init, Prepare, EvalInt16x8);
}

TFLMRegistration Register_DEPTHWISE_CONV_2D_INT4() {
return tflite::micro::RegisterOp(Init, Prepare, EvalInt4);
}

} // namespace tflite
Loading

0 comments on commit 5beca7c

Please sign in to comment.