Skip to content

Commit 8c4978d

Browse files
ebrevdotensorflower-gardener
authored andcommitted
[TFLite] Port Bucketize op from TF (CPU only).
PiperOrigin-RevId: 404402135 Change-Id: I9d33c42302d00223a3b29ed40c16a49fe01d64f1
1 parent f2ebf3d commit 8c4978d

19 files changed

+556
-10
lines changed

RELEASE.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* `tf.lite`:
1818
* Where operation support is added for these data types
1919
'int32/uint32/int8/uint8/int64'
20+
* Add builtin support for `Bucketize` op on CPU.
2021

2122
*<INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
2223
*<IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>

tensorflow/compiler/mlir/lite/ir/tfl_ops.td

+22
Original file line numberDiff line numberDiff line change
@@ -5125,6 +5125,28 @@ broadcasted shape. `s0`, `s1` and `r0` are all integer vectors.
51255125
);
51265126
}
51275127

5128+
def TFL_BucketizeOp
5129+
: TFL_Op<"bucketize", [NoSideEffect, SameOperandsAndResultShape]> {
5130+
let summary = "Bucketizes 'input' based on 'boundaries'.";
5131+
5132+
let description = [{
5133+
Example:
5134+
5135+
If the inputs are `boundaries = [0, 10, 100]` and
5136+
`input = [[-5, 10000][150, 10][5, 100]]`,
5137+
then the output will be `output = [[0, 3][3, 2][1, 3]]`.
5138+
}];
5139+
5140+
let arguments = (ins
5141+
TFL_TensorOf<[F32, F64, I32, I64]>:$input,
5142+
F32ArrayAttr:$boundaries
5143+
);
5144+
5145+
let results = (outs
5146+
TFL_TensorOf<[I32]>:$output
5147+
);
5148+
}
5149+
51285150
#endif // TFL_OPS
51295151

51305152
// LINT.ThenChange(//tensorflow/lite/tools/versioning/op_version.cc)

tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir

+7
Original file line numberDiff line numberDiff line change
@@ -2233,3 +2233,10 @@ func @select_v2_with_high_dims_dynamic_shape_both_sides(%arg0: tensor<8x7x6x5x?x
22332233
// CHECK: return %[[SELECT_V2]] : tensor<8x7x6x5x?x3x2x1xf32>
22342234
}
22352235

2236+
func @Bucketize(%arg0: tensor<3x2xf32>) -> tensor<3x2xi32> {
2237+
%0 = "tf.Bucketize"(%arg0) {boundaries = [1.0 : f32, 10.0 : f32, 100.0 : f32]} : (tensor<3x2xf32>) -> tensor<3x2xi32>
2238+
return %0: tensor<3x2xi32>
2239+
2240+
// CHECK-LABEL: Bucketize
2241+
// CHECK: "tfl.bucketize"(%arg0) {boundaries = [1.000000e+00 : f32, 1.000000e+01 : f32, 1.000000e+02 : f32]} : (tensor<3x2xf32>) -> tensor<3x2xi32>
2242+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s
2+
3+
func @main(tensor<3x2xf32>) -> tensor<3x2xi32> {
4+
^bb0(%arg0: tensor<3x2xf32>):
5+
// CHECK: {
6+
// CHECK-NEXT: version: 3,
7+
// CHECK-NEXT: operator_codes: [ {
8+
// CHECK-NEXT: deprecated_builtin_code: 127,
9+
// CHECK-NEXT: version: 1,
10+
// CHECK-NEXT: builtin_code: BUCKETIZE
11+
// CHECK-NEXT: } ],
12+
// CHECK-NEXT: subgraphs: [ {
13+
// CHECK-NEXT: tensors: [ {
14+
// CHECK-NEXT: shape: [ 3, 2 ],
15+
// CHECK-NEXT: buffer: 1,
16+
// CHECK-NEXT: name: "arg0",
17+
// CHECK-NEXT: quantization: {
18+
// CHECK-EMPTY:
19+
// CHECK-NEXT: }
20+
// CHECK-NEXT: }, {
21+
// CHECK-NEXT: shape: [ 3, 2 ],
22+
// CHECK-NEXT: buffer: 2,
23+
// CHECK-NEXT: name: "Const",
24+
// CHECK-NEXT: quantization: {
25+
// CHECK-EMPTY:
26+
// CHECK-NEXT: }
27+
// CHECK-NEXT: }, {
28+
// CHECK-NEXT: shape: [ 3, 2 ],
29+
// CHECK-NEXT: type: INT32,
30+
// CHECK-NEXT: buffer: 3,
31+
// CHECK-NEXT: name: "bucketize",
32+
// CHECK-NEXT: quantization: {
33+
// CHECK-EMPTY:
34+
// CHECK-NEXT: }
35+
// CHECK-NEXT: } ],
36+
// CHECK-NEXT: inputs: [ 0 ],
37+
// CHECK-NEXT: outputs: [ 2 ],
38+
// CHECK-NEXT: operators: [ {
39+
// CHECK-NEXT: inputs: [ 1 ],
40+
// CHECK-NEXT: outputs: [ 2 ]
41+
// CHECK-NEXT: } ],
42+
// CHECK-NEXT: name: "main"
43+
// CHECK-NEXT: } ],
44+
// CHECK-NEXT: description: "MLIR Converted.",
45+
// CHECK-NEXT: buffers: [ {
46+
// CHECK-EMPTY:
47+
// CHECK-NEXT: }, {
48+
// CHECK-EMPTY:
49+
// CHECK-NEXT: }, {
50+
// CHECK-NEXT: data: [ 0, 0, 160, 192, 0, 64, 28, 70, 0, 0, 22, 67, 0, 0, 32, 65, 0, 0, 160, 64, 0, 0, 200, 66 ]
51+
// CHECK-NEXT: }, {
52+
// CHECK-EMPTY:
53+
// CHECK-NEXT: }, {
54+
// CHECK-NEXT: data: [ 50, 46, 56, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
55+
// CHECK-NEXT: } ],
56+
// CHECK-NEXT: metadata: [ {
57+
// CHECK-NEXT: name: "min_runtime_version",
58+
// CHECK-NEXT: buffer: 4
59+
// CHECK-NEXT: } ],
60+
// CHECK-NEXT: signature_defs: [ ]
61+
// CHECK-NEXT: }
62+
// CHECK-EMPTY:
63+
64+
%0 = "tfl.pseudo_const" () {value = dense<[[-5.0, 10000.0], [150.0, 10.0], [5.0, 100.0]]> : tensor<3x2xf32>} : () -> tensor<3x2xf32> loc("Const")
65+
%1 = "tfl.bucketize"(%0) {boundaries = [0.0 : f32, 10.0 : f32, 100.0 : f32]} : (tensor<3x2xf32>) -> tensor<3x2xi32> loc("bucketize")
66+
return %1 : tensor<3x2xi32>
67+
}

tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td

+4
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,7 @@ def LegalizeComplexAbs : Pat<(TF_ComplexAbsOp $arg), (TFL_ComplexAbsOp $arg)>;
513513
def LegalizeReal : Pat<(TF_RealOp $arg), (TFL_RealOp $arg)>;
514514

515515
def LegalizeImag : Pat<(TF_ImagOp $arg), (TFL_ImagOp $arg)>;
516+
517+
def LegalizeBucketize : Pat<
518+
(TF_BucketizeOp $input, F32ArrayAttr:$boundaries),
519+
(TFL_BucketizeOp $input, $boundaries)>;

tensorflow/lite/builtin_ops.h

+1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ typedef enum {
174174
kTfLiteBuiltinAssignVariable = 144,
175175
kTfLiteBuiltinBroadcastArgs = 145,
176176
kTfLiteBuiltinRandomStandardNormal = 146,
177+
kTfLiteBuiltinBucketize = 147,
177178
} TfLiteBuiltinOperator;
178179

179180
#ifdef __cplusplus

tensorflow/lite/c/builtin_op_data.h

+7
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,13 @@ typedef struct {
507507
int seed2;
508508
} TfLiteRandomParams;
509509

510+
typedef struct {
511+
int num_boundaries;
512+
// This points to the memory stored in the model (flatbuffer),
513+
// and is not owned.
514+
const float* boundaries;
515+
} TfLiteBucketizeParams;
516+
510517
#ifdef __cplusplus
511518
} // extern "C"
512519
#endif // __cplusplus

tensorflow/lite/core/api/flatbuffer_conversions.cc

+13
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,19 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
797797
*builtin_data = params.release();
798798
return kTfLiteOk;
799799
}
800+
case BuiltinOperator_BUCKETIZE: {
801+
auto params = safe_allocator.Allocate<TfLiteBucketizeParams>();
802+
TF_LITE_ENSURE(error_reporter, params != nullptr);
803+
if (const auto* bucketize_params =
804+
op->builtin_options_as_BucketizeOptions()) {
805+
const flatbuffers::Vector<float>* boundaries =
806+
bucketize_params->boundaries();
807+
params->num_boundaries = boundaries->size();
808+
params->boundaries = boundaries->data();
809+
}
810+
*builtin_data = params.release();
811+
return kTfLiteOk;
812+
}
800813
// Below are the ops with no builtin_data structure.
801814
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
802815
// ok for now, since there is no call implementation either.

tensorflow/lite/core/shims/builtin_ops_list.inc

+1
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,4 @@ TFLITE_OP(Register_READ_VARIABLE)
159159
TFLITE_OP(Register_ASSIGN_VARIABLE)
160160
TFLITE_OP(Register_BROADCAST_ARGS)
161161
TFLITE_OP(Register_RANDOM_STANDARD_NORMAL)
162+
TFLITE_OP(Register_BUCKETIZE)

tensorflow/lite/kernels/BUILD

+15
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ BUILTIN_KERNEL_SRCS = [
556556
"bidirectional_sequence_rnn.cc",
557557
"broadcast_args.cc",
558558
"broadcast_to.cc",
559+
"bucketize.cc",
559560
"call_once.cc",
560561
"cast.cc",
561562
"ceil.cc",
@@ -1118,6 +1119,20 @@ cc_test(
11181119
],
11191120
)
11201121

1122+
cc_test(
1123+
name = "bucketize_test",
1124+
size = "small",
1125+
srcs = ["bucketize_test.cc"],
1126+
deps = [
1127+
":builtin_ops",
1128+
":test_main",
1129+
":test_util",
1130+
"//tensorflow/lite/schema:schema_fbs",
1131+
"//tensorflow/lite/testing:util",
1132+
"@com_google_googletest//:gtest",
1133+
],
1134+
)
1135+
11211136
cc_test(
11221137
name = "cast_test",
11231138
size = "small",

tensorflow/lite/kernels/bucketize.cc

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <stdint.h>
17+
18+
#include <algorithm>
19+
20+
#include "tensorflow/lite/c/builtin_op_data.h"
21+
#include "tensorflow/lite/c/common.h"
22+
#include "tensorflow/lite/kernels/internal/tensor.h"
23+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24+
#include "tensorflow/lite/kernels/kernel_util.h"
25+
26+
namespace tflite {
27+
namespace ops {
28+
namespace builtin {
29+
namespace bucketize {
30+
namespace {
31+
32+
constexpr int kInputTensor = 0;
33+
constexpr int kOutputTensor = 0;
34+
35+
struct OpData {
36+
// boundaries array is owned by the buffer housing TfLiteBucketizeParams.
37+
const float* boundaries;
38+
int num_boundaries;
39+
};
40+
41+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42+
auto* op_data = new OpData();
43+
const auto* params = reinterpret_cast<const TfLiteBucketizeParams*>(buffer);
44+
45+
op_data->boundaries = params->boundaries;
46+
op_data->num_boundaries = params->num_boundaries;
47+
return op_data;
48+
}
49+
50+
void Free(TfLiteContext* context, void* buffer) {
51+
delete reinterpret_cast<OpData*>(buffer);
52+
}
53+
54+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55+
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
56+
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
57+
OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
58+
if (!std::is_sorted(opdata->boundaries,
59+
opdata->boundaries + opdata->num_boundaries)) {
60+
TF_LITE_KERNEL_LOG(context, "Expected sorted boundaries");
61+
return kTfLiteError;
62+
}
63+
64+
const TfLiteTensor* input;
65+
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
66+
67+
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
68+
input->type != kTfLiteInt64 && input->type != kTfLiteFloat64) {
69+
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by bucketize.",
70+
TfLiteTypeGetName(input->type));
71+
return kTfLiteError;
72+
}
73+
74+
TfLiteTensor* output;
75+
TF_LITE_ENSURE_OK(context,
76+
GetOutputSafe(context, node, kOutputTensor, &output));
77+
output->type = kTfLiteInt32;
78+
79+
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
80+
return context->ResizeTensor(context, output, output_shape);
81+
}
82+
83+
template <typename T>
84+
inline void Bucketize(const RuntimeShape& input_shape, const T* input_data,
85+
const float* boundaries, int num_boundaries,
86+
const RuntimeShape& output_shape, int32_t* output_data) {
87+
const int flat_size = MatchingFlatSize(input_shape, output_shape);
88+
89+
for (int i = 0; i < flat_size; i++) {
90+
auto first_bigger_it = std::upper_bound(
91+
boundaries, boundaries + num_boundaries, input_data[i]);
92+
output_data[i] = first_bigger_it - boundaries;
93+
}
94+
}
95+
96+
template <typename T>
97+
TfLiteStatus BucketizeImpl(TfLiteContext* context, TfLiteNode* node) {
98+
const TfLiteTensor* input;
99+
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
100+
OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
101+
TfLiteTensor* output;
102+
TF_LITE_ENSURE_OK(context,
103+
GetOutputSafe(context, node, kOutputTensor, &output));
104+
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt32);
105+
106+
Bucketize<T>(GetTensorShape(input), GetTensorData<T>(input),
107+
opdata->boundaries, opdata->num_boundaries,
108+
GetTensorShape(output), GetTensorData<int32_t>(output));
109+
110+
return kTfLiteOk;
111+
}
112+
113+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
114+
const TfLiteTensor* input;
115+
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
116+
117+
switch (input->type) {
118+
case kTfLiteFloat32: {
119+
return BucketizeImpl<float>(context, node);
120+
}
121+
case kTfLiteFloat64: {
122+
return BucketizeImpl<double>(context, node);
123+
}
124+
case kTfLiteInt32: {
125+
return BucketizeImpl<int32_t>(context, node);
126+
}
127+
case kTfLiteInt64: {
128+
return BucketizeImpl<int64_t>(context, node);
129+
}
130+
default: {
131+
TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by bucketize.",
132+
TfLiteTypeGetName(input->type));
133+
return kTfLiteError;
134+
}
135+
}
136+
}
137+
138+
} // namespace
139+
} // namespace bucketize
140+
141+
TfLiteRegistration* Register_BUCKETIZE() {
142+
static TfLiteRegistration r = {bucketize::Init, bucketize::Free,
143+
bucketize::Prepare, bucketize::Eval};
144+
return &r;
145+
}
146+
147+
} // namespace builtin
148+
} // namespace ops
149+
} // namespace tflite

0 commit comments

Comments
 (0)