diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc index b387c6d7a3e..d3d15521c2c 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc @@ -99,15 +99,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_dims.w = output->dims->data[2]; output_dims.c = output_shape.Dims(3); - if (filter->type == kTfLiteInt4) { - int filter_size = - RuntimeShape(filter->dims->size, - reinterpret_cast(filter->dims->data)) - .FlatSize(); - context->RequestScratchBufferInArena( - context, filter_size, &data->reference_op_data.filter_buffer_index); - } - if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) { const int num_channels = filter->dims->data[kConvQuantizedDimension]; data->reference_op_data.per_channel_output_multiplier = @@ -168,6 +159,104 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalQuantizedPerChannelInt4( + TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params, + const OpData& data, const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output) { + cmsis_nn_conv_params conv_params; + conv_params.dilation.h = params.dilation_height_factor; + conv_params.dilation.w = params.dilation_width_factor; + + // Initialize cmsis_nn convolution parameters + conv_params.input_offset = -data.reference_op_data.input_zero_point; + conv_params.output_offset = data.reference_op_data.output_zero_point; + conv_params.stride.h = params.stride_height; + conv_params.stride.w = params.stride_width; + conv_params.padding.h = data.reference_op_data.padding.height; + conv_params.padding.w = data.reference_op_data.padding.width; + conv_params.activation.min = data.reference_op_data.output_activation_min; + conv_params.activation.max = data.reference_op_data.output_activation_max; + + // Initialize cmsis_nn per channel quantization parameters + cmsis_nn_per_channel_quant_params quant_params; + quant_params.multiplier = const_cast( + data.reference_op_data.per_channel_output_multiplier); + quant_params.shift = + const_cast(data.reference_op_data.per_channel_output_shift); + + RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter); + RuntimeShape input_shape = tflite::micro::GetTensorShape(input); + RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + RuntimeShape bias_shape = tflite::micro::GetTensorShape(bias); + + // Consistency check. + TFLITE_DCHECK_LE(conv_params.activation.min, conv_params.activation.max); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batch_size = MatchingDim(input_shape, 0, output_shape, 0); + const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3); + const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3); + if (tflite::micro::GetOptionalTensorData(bias)) { + TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); + } + + // Initialize cmsis_nn dimensions + // Input + cmsis_nn_dims input_dims; + input_dims.n = batch_size; + input_dims.h = input_shape.Dims(1); + input_dims.w = input_shape.Dims(2); + input_dims.c = input_depth; + + // Filter + cmsis_nn_dims filter_dims; + filter_dims.n = output_depth; + filter_dims.h = filter_shape.Dims(1); + filter_dims.w = filter_shape.Dims(2); + filter_dims.c = input_depth; + + // Bias + cmsis_nn_dims bias_dims; + bias_dims.n = 1; + bias_dims.h = 1; + bias_dims.w = 1; + bias_dims.c = output_depth; + + // Output + cmsis_nn_dims output_dims; + output_dims.n = batch_size; + output_dims.h = output_shape.Dims(1); + output_dims.w = output_shape.Dims(2); + output_dims.c = output_depth; + + // Initialize cmsis_nn context + cmsis_nn_context ctx; + ctx.buf = nullptr; + ctx.size = 0; + + if (data.buffer_idx > -1) { + ctx.buf = context->GetScratchBuffer(context, data.buffer_idx); + // Note: ctx.size is currently not used in cmsis_nn. + // The buffer should be allocated in the Prepare function through + // arm_convolve_wrapper_s8_get_buffer_size + } + + // arm_convolve_wrapper_s4 dispatches the optimized kernel accordingly with + // the parameters passed for convolutions with 4 bit weights + TFLITE_DCHECK_EQ( + arm_convolve_wrapper_s4( + &ctx, &conv_params, &quant_params, &input_dims, + tflite::micro::GetTensorData(input), &filter_dims, + tflite::micro::GetTensorData(filter), &bias_dims, + tflite::micro::GetOptionalTensorData(bias), &output_dims, + tflite::micro::GetTensorData(output)), + ARM_CMSIS_NN_SUCCESS); + + return kTfLiteOk; +} + TfLiteStatus EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params, const OpData& data, @@ -364,6 +453,28 @@ TfLiteStatus EvalQuantizedPerChannel16x8( return kTfLiteOk; } +TfLiteStatus EvalInt4(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kConvInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) + : nullptr; + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kConvOutputTensor); + + TFLITE_DCHECK(node->builtin_data != nullptr); + const auto& params = + *(reinterpret_cast(node->builtin_data)); + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + return EvalQuantizedPerChannelInt4(context, node, params, data, input, filter, + bias, output); +} + TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, kConvInputTensor); @@ -381,11 +492,9 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { *(reinterpret_cast(node->builtin_data)); TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); - TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( - context, data.reference_op_data.filter_buffer_index, filter); - return EvalQuantizedPerChannel(context, node, params, data, input, - &filter_int8, bias, output); + return EvalQuantizedPerChannel(context, node, params, data, input, filter, + bias, output); } TfLiteStatus EvalInt16x8(TfLiteContext* context, TfLiteNode* node) { @@ -445,8 +554,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); - TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( - context, data.reference_op_data.filter_buffer_index, filter); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + TF_LITE_ENSURE_MSG( + context, + input->type == filter->type || + (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) || + (input->type == kTfLiteInt8 && filter->type == kTfLiteInt4), + "Hybrid models are not supported on TFLite Micro."); switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { @@ -463,11 +577,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorShape(nullptr), nullptr); break; } - case kTfLiteInt8: - switch (filter_int8.type) { + case kTfLiteInt8: { + switch (filter->type) { + case kTfLiteInt4: { + return EvalQuantizedPerChannelInt4(context, node, params, data, input, + filter, bias, output); + } case kTfLiteInt8: { return EvalQuantizedPerChannel(context, node, params, data, input, - &filter_int8, bias, output); + filter, bias, output); } default: { MicroPrintf("Filter type %s (%d) not supported.", @@ -475,8 +593,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } } - break; + } case kTfLiteInt16: { if (bias == nullptr || bias->type == kTfLiteInt64) { return EvalQuantizedPerChannel16x8(context, node, params, data, input, @@ -516,6 +634,10 @@ TFLMRegistration Register_CONV_2D() { return tflite::micro::RegisterOp(Init, Prepare, Eval); } +TFLMRegistration Register_CONV_2D_INT4() { + return tflite::micro::RegisterOp(Init, Prepare, EvalInt4); +} + TFLMRegistration Register_CONV_2D_INT8() { return tflite::micro::RegisterOp(Init, Prepare, EvalInt8); } diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc index 7b733b76afd..f30a9520831 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/depthwise_conv.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -118,15 +118,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, num_channels * sizeof(int32_t))); } - if (filter->type == kTfLiteInt4) { - int filter_size = - RuntimeShape(filter->dims->size, - reinterpret_cast(filter->dims->data)) - .FlatSize(); - context->RequestScratchBufferInArena( - context, filter_size, &data->reference_op_data.filter_buffer_index); - } - TF_LITE_ENSURE_STATUS(CalculateOpDataDepthwiseConv( context, node, params, input_width, input_height, filter_width, filter_height, output_width, output_height, data_type, @@ -168,8 +159,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { dw_conv_params.dilation.h = params.dilation_height_factor; dw_conv_params.dilation.w = params.dilation_width_factor; - const int32_t buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size( - &dw_conv_params, &input_dims, &filter_dims, &output_dims); + int32_t buf_size = 0; + if (filter->type == kTfLiteInt8) { + buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims); + } else if (filter->type == kTfLiteInt4) { + buf_size = arm_depthwise_conv_wrapper_s4_get_buffer_size( + &dw_conv_params, &input_dims, &filter_dims, &output_dims); + } else { + MicroPrintf("Filter type %s (%d) not supported.", + TfLiteTypeGetName(filter->type), filter->type); + return kTfLiteError; + } if (buf_size > 0) { TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena( @@ -285,6 +286,43 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, ARM_CMSIS_NN_SUCCESS); } +void EvalQuantizedPerChannelInt4(TfLiteContext* context, TfLiteNode* node, + const TfLiteDepthwiseConvParams& params, + const OpData& data, + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output) { + cmsis_nn_dw_conv_params dw_conv_params; + cmsis_nn_per_channel_quant_params quant_params; + cmsis_nn_dims input_dims; + cmsis_nn_dims filter_dims; + cmsis_nn_dims bias_dims; + cmsis_nn_dims output_dims; + + PopulateDwConvParams(&dw_conv_params, &quant_params, &input_dims, + &filter_dims, &bias_dims, &output_dims, params, data, + input, filter, bias, output); + + cmsis_nn_context ctx; + ctx.buf = nullptr; + /* 'size' is unused */ + ctx.size = 0; + + if (data.buffer_idx > -1) { + ctx.buf = context->GetScratchBuffer(context, data.buffer_idx); + } + + TFLITE_DCHECK_EQ( + arm_depthwise_conv_wrapper_s4( + &ctx, &dw_conv_params, &quant_params, &input_dims, + tflite::micro::GetTensorData(input), &filter_dims, + tflite::micro::GetTensorData(filter), &bias_dims, + tflite::micro::GetOptionalTensorData(bias), &output_dims, + tflite::micro::GetTensorData(output)), + ARM_CMSIS_NN_SUCCESS); +} + void EvalQuantizedPerChannel16x8(TfLiteContext* context, TfLiteNode* node, const TfLiteDepthwiseConvParams& params, const OpData& data, @@ -337,9 +375,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) : nullptr; - TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( - context, data.reference_op_data.filter_buffer_index, filter); - switch (input->type) { // Already know in/out types are same. case kTfLiteFloat32: { tflite::reference_ops::DepthwiseConv( @@ -355,10 +390,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } case kTfLiteInt8: - switch (filter_int8.type) { + switch (filter->type) { case kTfLiteInt8: { - EvalQuantizedPerChannel(context, node, params, data, input, - &filter_int8, bias, output); + EvalQuantizedPerChannel(context, node, params, data, input, filter, + bias, output); + break; + } + case kTfLiteInt4: { + EvalQuantizedPerChannelInt4(context, node, params, data, input, + filter, bias, output); break; } default: { @@ -399,11 +439,8 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) : nullptr; - TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( - context, data.reference_op_data.filter_buffer_index, filter); - - EvalQuantizedPerChannel(context, node, params, data, input, &filter_int8, - bias, output); + EvalQuantizedPerChannel(context, node, params, data, input, filter, bias, + output); return kTfLiteOk; } @@ -431,6 +468,30 @@ TfLiteStatus EvalInt16x8(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalInt4(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + + const auto& params = + *(reinterpret_cast(node->builtin_data)); + const OpData& data = *(static_cast(node->user_data)); + + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kDepthwiseConvOutputTensor); + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kDepthwiseConvWeightsTensor); + const TfLiteEvalTensor* bias = + (NumInputs(node) == 3) + ? tflite::micro::GetEvalInput(context, node, kDepthwiseConvBiasTensor) + : nullptr; + + EvalQuantizedPerChannelInt4(context, node, params, data, input, filter, bias, + output); + return kTfLiteOk; +} + } // namespace TFLMRegistration Register_DEPTHWISE_CONV_2D() { @@ -445,4 +506,8 @@ TFLMRegistration Register_DEPTHWISE_CONV_2D_INT16() { return tflite::micro::RegisterOp(Init, Prepare, EvalInt16x8); } +TFLMRegistration Register_DEPTHWISE_CONV_2D_INT4() { + return tflite::micro::RegisterOp(Init, Prepare, EvalInt4); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc index dbba5b27f58..2066ad6ed70 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc @@ -105,7 +105,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims); } else if (input->type == kTfLiteInt8 && - data->reference_op_data.filter_zero_point == 0) { + data->reference_op_data.filter_zero_point == 0 && + filter->type != kTfLiteInt4) { const RuntimeShape input_shape = GetTensorShape(input); TFLITE_DCHECK_GE(output_dim_count, 2); @@ -134,41 +135,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->AllocatePersistentBuffer(context, buf_size)); int8_t* filter_data = GetTensorData(filter); - - if (filter->type == kTfLiteInt4) { - size_t filter_size = GetTensorShape(filter).FlatSize(); - int8_t* unpacked_filter_buf = - reinterpret_cast(micro_context->AllocateTempBuffer( - filter_size, tflite::MicroArenaBufferAlignment())); - - tflite::tensor_utils::UnpackDenseInt4IntoInt8( - filter_data, filter_size, unpacked_filter_buf); - filter_data = unpacked_filter_buf; - } - arm_vector_sum_s8(data->kernel_sums, filter_dims.n, data->output_depth, filter_data); - if (filter->type == kTfLiteInt4) { - micro_context->DeallocateTempBuffer( - reinterpret_cast(filter_data)); - } - // Do not request a scratch buffer since using persistent memory buf_size = 0; } } } - if (filter->type == kTfLiteInt4) { - int filter_size = - RuntimeShape(filter->dims->size, - reinterpret_cast(filter->dims->data)) - .FlatSize(); - context->RequestScratchBufferInArena( - context, filter_size, &data->reference_op_data.filter_buffer_index); - } - if (buf_size > 0) { TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena( context, buf_size, &data->buffer_idx)); @@ -221,6 +196,49 @@ void PopulateCommonParams(TfLiteContext* context, } } +TfLiteStatus EvalQuantizedInt4(TfLiteContext* context, TfLiteNode* node, + const OpData& data, + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* filter, + const TfLiteEvalTensor* bias, + TfLiteEvalTensor* output) { + const RuntimeShape output_shape = tflite::micro::GetTensorShape(output); + const int output_dim_count = output_shape.DimensionsCount(); + TFLITE_DCHECK_GE(output_dim_count, 2); + TFLITE_DCHECK_LE(output_dim_count, 4); + + cmsis_nn_per_tensor_quant_params quant_params; + cmsis_nn_dims input_dims; + cmsis_nn_dims filter_dims; + cmsis_nn_dims bias_dims; + cmsis_nn_dims output_dims; + cmsis_nn_context ctx; + + PopulateCommonParams(context, &quant_params, &input_dims, &filter_dims, + &bias_dims, &output_dims, &ctx, data); + + const int32_t* bias_data = + tflite::micro::GetOptionalTensorData(bias); + + cmsis_nn_fc_params fc_params; + fc_params.input_offset = -data.reference_op_data.input_zero_point; + fc_params.output_offset = data.reference_op_data.output_zero_point; + fc_params.filter_offset = 0; + fc_params.activation.min = data.reference_op_data.output_activation_min; + fc_params.activation.max = data.reference_op_data.output_activation_max; + + TF_LITE_ENSURE_EQ( + context, + arm_fully_connected_s4( + &ctx, &fc_params, &quant_params, &input_dims, + tflite::micro::GetTensorData(input), &filter_dims, + tflite::micro::GetTensorData(filter), &bias_dims, bias_data, + &output_dims, tflite::micro::GetTensorData(output)), + ARM_CMSIS_NN_SUCCESS); + + return kTfLiteOk; +} + TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node, const OpData& data, const TfLiteEvalTensor* input, @@ -353,9 +371,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); - TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( - context, data.reference_op_data.filter_buffer_index, filter); - // Checks in Prepare ensure input, output and filter types are all the same. switch (input->type) { case kTfLiteFloat32: { @@ -373,11 +388,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } case kTfLiteInt8: { - switch (filter_int8.type) { + switch (filter->type) { + case kTfLiteInt4: + return EvalQuantizedInt4(context, node, data, input, filter, bias, + output); case kTfLiteInt8: if (data.reference_op_data.filter_zero_point == 0) { - return EvalQuantizedInt8(context, node, data, input, &filter_int8, - bias, output); + return EvalQuantizedInt8(context, node, data, input, filter, bias, + output); } else { tflite::reference_integer_ops::FullyConnected( FullyConnectedParamsQuantized(data.reference_op_data), @@ -411,6 +429,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalInt4(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor); + const TfLiteEvalTensor* filter = + tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor); + const TfLiteEvalTensor* bias = + tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor); + + TFLITE_DCHECK(node->user_data != nullptr); + const OpData& data = *(static_cast(node->user_data)); + + // Checks in Prepare ensure input, output and filter types are all the same. + if (input->type != kTfLiteInt8 && filter->type != kTfLiteInt4) { + MicroPrintf("Type %s (%d) not supported.", TfLiteTypeGetName(input->type), + input->type); + return kTfLiteError; + } + + return EvalQuantizedInt4(context, node, data, input, filter, bias, output); +} + // Note that the current function names are not ideal at all (this EvalInt8 // function internally calls EvalQuantizedInt8, and there is similar name // aliasing in the Eval function too). We will be attempting to have a more @@ -437,11 +478,7 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } - TfLiteEvalTensor filter_int8 = tflite::micro::MakeUnpackedInt4Tensor( - context, data.reference_op_data.filter_buffer_index, filter); - - return EvalQuantizedInt8(context, node, data, input, &filter_int8, bias, - output); + return EvalQuantizedInt8(context, node, data, input, filter, bias, output); } TfLiteStatus EvalInt16(TfLiteContext* context, TfLiteNode* node) { @@ -473,6 +510,10 @@ TFLMRegistration Register_FULLY_CONNECTED() { return tflite::micro::RegisterOp(Init, Prepare, Eval); } +TFLMRegistration Register_FULLY_CONNECTED_INT4() { + return tflite::micro::RegisterOp(Init, Prepare, EvalInt4); +} + TFLMRegistration Register_FULLY_CONNECTED_INT8() { return tflite::micro::RegisterOp(Init, Prepare, EvalInt8); } diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h index b8a034b8ac1..0c8073f48f0 100644 --- a/tensorflow/lite/micro/kernels/conv.h +++ b/tensorflow/lite/micro/kernels/conv.h @@ -95,6 +95,15 @@ inline TFLMRegistration Register_CONV_2D_INT8REF() { } #endif // defined(XTENSA) +#if defined(CMSIS_NN) +// Returns a TFLMRegistration struct for kernel variant that only supports +// int8 activations and int4 weights and uses the latency optimized +// implementations. +TFLMRegistration Register_CONV_2D_INT4(); +#else +inline TFLMRegistration Register_CONV_2D_INT4() { return Register_CONV_2D(); } +#endif // defined(CMSIS_NN) + #if defined(CMSIS_NN) || defined(XTENSA) // Returns a TFLMRegistration struct for kernel variant that only supports // int8 activations and int8 weights and uses the latency optimized diff --git a/tensorflow/lite/micro/kernels/depthwise_conv.h b/tensorflow/lite/micro/kernels/depthwise_conv.h index d8cc78db6ab..5f2d87efe28 100644 --- a/tensorflow/lite/micro/kernels/depthwise_conv.h +++ b/tensorflow/lite/micro/kernels/depthwise_conv.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -65,6 +65,11 @@ TFLMRegistration Register_DEPTHWISE_CONV_2D_INT8(); // implementations. TFLMRegistration Register_DEPTHWISE_CONV_2D_INT16(); +// Returns a TFLMRegistration struct for kernel variant that only supports +// int8 activations and int4 weights and uses the latency optimized +// implementations. +TFLMRegistration Register_DEPTHWISE_CONV_2D_INT4(); + #else inline TFLMRegistration Register_DEPTHWISE_CONV_2D_INT8() { return Register_DEPTHWISE_CONV_2D(); @@ -73,6 +78,11 @@ inline TFLMRegistration Register_DEPTHWISE_CONV_2D_INT8() { inline TFLMRegistration Register_DEPTHWISE_CONV_2D_INT16() { return Register_DEPTHWISE_CONV_2D(); } + +inline TFLMRegistration Register_DEPTHWISE_CONV_2D_INT4() { + return Register_DEPTHWISE_CONV_2D(); +} + #endif } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/fully_connected.h b/tensorflow/lite/micro/kernels/fully_connected.h index 3fa6060c74a..8308838ec6d 100644 --- a/tensorflow/lite/micro/kernels/fully_connected.h +++ b/tensorflow/lite/micro/kernels/fully_connected.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -95,6 +95,10 @@ inline TFLMRegistration Register_FULLY_CONNECTED_INT8() { // int16. TFLMRegistration Register_FULLY_CONNECTED_INT16(); +// Returns a TFLMRegistration struct for kernel variant that only supports +// int8 and int4 packed kernels. +TFLMRegistration Register_FULLY_CONNECTED_INT4(); + #else // Note that while this block gets used for both reference and optimized kernels // that do not have any specialized implementations, the only goal here is to @@ -105,6 +109,10 @@ inline TFLMRegistration Register_FULLY_CONNECTED_INT16() { return Register_FULLY_CONNECTED(); } +inline TFLMRegistration Register_FULLY_CONNECTED_INT4() { + return Register_FULLY_CONNECTED(); +} + #endif } // namespace tflite diff --git a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh index 895ac59162e..aeaeb8ed0cd 100755 --- a/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh +++ b/tensorflow/lite/micro/tools/make/ext_libs/cmsis_nn_download.sh @@ -47,9 +47,9 @@ if [ -d ${DOWNLOADED_CMSIS_NN_PATH} ]; then echo >&2 "${DOWNLOADED_CMSIS_NN_PATH} already exists, skipping the download." else - ZIP_PREFIX_NN="ca476254fecf8492021428162381adb76d1cad6e" + ZIP_PREFIX_NN="bfc54edb61e873039ec0857cacc40df36b1d644e" CMSIS_NN_URL="http://github.com/ARM-software/CMSIS-NN/archive/${ZIP_PREFIX_NN}.zip" - CMSIS_NN_MD5="272ef45ad69d8a35acc5d2fcba693cd6" + CMSIS_NN_MD5="944eb9c0060bb7f5eccb8841f1f62f2a" # wget is much faster than git clone of the entire repo. So we wget a specific # version and can then apply a patch, as needed.