@@ -31,6 +31,53 @@ limitations under the License.
31
31
32
32
namespace tensorflow {
33
33
34
+ using CPUDevice = Eigen::ThreadPoolDevice;
35
+
36
+ namespace functor {
37
+
38
+ template <>
39
+ struct ReshapeSparseTensorFunctor <CPUDevice> {
40
+ Status operator ()(const TensorShape &input_shape,
41
+ const TensorShape &output_shape,
42
+ typename TTypes<int64>::ConstMatrix input_indices,
43
+ typename TTypes<int64>::Matrix output_indices) const {
44
+ const int64 input_rank = input_shape.dims ();
45
+ const int64 output_rank = output_shape.dims ();
46
+ const int64 nnz = input_indices.dimension (0 );
47
+ gtl::InlinedVector<int64, 8 > input_strides (input_rank);
48
+ if (input_rank > 0 ) {
49
+ input_strides[input_rank - 1 ] = 1 ;
50
+ for (int d = input_rank - 2 ; d >= 0 ; --d) {
51
+ input_strides[d] = input_strides[d + 1 ] * input_shape.dim_size (d + 1 );
52
+ }
53
+ }
54
+
55
+ gtl::InlinedVector<int64, 8 > output_strides (output_rank);
56
+ if (output_rank > 0 ) {
57
+ output_strides[output_rank - 1 ] = 1 ;
58
+ for (int d = output_rank - 2 ; d >= 0 ; --d) {
59
+ output_strides[d] =
60
+ output_strides[d + 1 ] * output_shape.dim_size (d + 1 );
61
+ }
62
+ }
63
+
64
+ for (int i = 0 ; i < nnz; ++i) {
65
+ int64 id = 0 ;
66
+ for (int j = 0 ; j < input_rank; ++j) {
67
+ id += input_indices (i, j) * input_strides[j];
68
+ }
69
+ for (int j = 0 ; j < output_rank; ++j) {
70
+ output_indices (i, j) = id / output_strides[j];
71
+ id %= output_strides[j];
72
+ }
73
+ }
74
+ return Status::OK ();
75
+ }
76
+ };
77
+
78
+ } // namespace functor
79
+
80
+ template <typename Device>
34
81
void ReshapeSparseTensor (OpKernelContext *context,
35
82
const Tensor &input_indices_in,
36
83
const Tensor &input_shape_in,
@@ -49,7 +96,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
49
96
" Target shape should be a vector but received shape " ,
50
97
target_shape_in.shape ().DebugString ()));
51
98
52
- const int64 input_rank = input_shape_in.NumElements ();
53
99
const int64 output_rank = target_shape_in.NumElements ();
54
100
const TensorShape input_shape (input_shape_in.vec <int64>());
55
101
const int64 dense_size = input_shape.num_elements ();
@@ -111,40 +157,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
111
157
return ;
112
158
}
113
159
114
- gtl::InlinedVector<int64, 8 > input_strides (input_rank);
115
- if (input_rank > 0 ) {
116
- input_strides[input_rank - 1 ] = 1 ;
117
- for (int d = input_rank - 2 ; d >= 0 ; --d) {
118
- input_strides[d] = input_strides[d + 1 ] * input_shape.dim_size (d + 1 );
119
- }
120
- }
121
-
122
- gtl::InlinedVector<int64, 8 > output_strides (output_rank);
123
- if (output_rank > 0 ) {
124
- output_strides[output_rank - 1 ] = 1 ;
125
- for (int d = output_rank - 2 ; d >= 0 ; --d) {
126
- output_strides[d] = output_strides[d + 1 ] * output_shape.dim_size (d + 1 );
127
- }
128
- }
129
-
130
- Tensor *result_indices = nullptr ;
131
- OP_REQUIRES_OK (context,
132
- context->allocate_output (output_indices_idx,
133
- TensorShape ({nnz, output_rank}),
134
- &result_indices));
135
- auto input_ind = input_indices_in.matrix <int64>();
136
- auto output_ind = result_indices->matrix <int64>();
137
- for (int i = 0 ; i < nnz; ++i) {
138
- int64 id = 0 ;
139
- for (int j = 0 ; j < input_rank; ++j) {
140
- id += input_ind (i, j) * input_strides[j];
141
- }
142
- for (int j = 0 ; j < output_rank; ++j) {
143
- output_ind (i, j) = id / output_strides[j];
144
- id %= output_strides[j];
145
- }
146
- }
147
-
148
160
Tensor *result_shape = nullptr ;
149
161
OP_REQUIRES_OK (context, context->allocate_output (output_shape_idx,
150
162
TensorShape ({output_rank}),
@@ -153,6 +165,26 @@ void ReshapeSparseTensor(OpKernelContext *context,
153
165
for (int j = 0 ; j < output_shape.dims (); ++j) {
154
166
output_shape_vec (j) = output_shape.dim_size (j);
155
167
}
168
+
169
+ Tensor *result_indices = nullptr ;
170
+ OP_REQUIRES_OK (context,
171
+ context->allocate_output (output_indices_idx,
172
+ TensorShape ({nnz, output_rank}),
173
+ &result_indices));
174
+ if (nnz > 0 ) {
175
+ OP_REQUIRES_OK (context, functor::ReshapeSparseTensorFunctor<Device>()(
176
+ input_shape, output_shape,
177
+ input_indices_in.matrix <int64>(),
178
+ result_indices->matrix <int64>()));
179
+ }
156
180
}
157
181
182
+ #define EXPLICITLY_INSTANTIATE_FUNCTION (Device ) \
183
+ template void ReshapeSparseTensor<Device>( \
184
+ OpKernelContext * context, const Tensor &input_indices_in, \
185
+ const Tensor &input_shape_in, const Tensor &target_shape_in, \
186
+ int output_indices_idx, int output_shape_idx)
187
+ EXPLICITLY_INSTANTIATE_FUNCTION (CPUDevice);
188
+ #undef EXPLICITLY_INSTANTIATE_FUNCTION
189
+
158
190
} // namespace tensorflow
0 commit comments