Skip to content

Commit 5eeb7c7

Browse files
Merge pull request tensorflow#46275 from benbarsdell:gpu-SparseReshape-cpu-refactor
PiperOrigin-RevId: 357024042 Change-Id: I63ec2724c86e1def68962a40e375c152dda8fcaa
2 parents 2e73f72 + b1a0dbd commit 5eeb7c7

File tree

4 files changed

+98
-42
lines changed

4 files changed

+98
-42
lines changed

tensorflow/core/kernels/deserialize_sparse_string_op.cc

+5-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ limitations under the License.
3535

3636
namespace tensorflow {
3737

38+
using CPUDevice = Eigen::ThreadPoolDevice;
39+
3840
namespace {
3941

4042
using sparse::SparseTensor;
@@ -204,9 +206,9 @@ class DeserializeSparseOp : public OpKernel {
204206
target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
205207
}
206208

207-
ReshapeSparseTensor(context, output.indices(), input_shape, target_shape,
208-
0 /* output indices index */,
209-
2 /* output shape index */);
209+
ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape,
210+
target_shape, 0 /* output indices index */,
211+
2 /* output shape index */);
210212
context->set_output(1, output.values());
211213
}
212214

tensorflow/core/kernels/reshape_util.cc

+67-35
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,53 @@ limitations under the License.
3131

3232
namespace tensorflow {
3333

34+
using CPUDevice = Eigen::ThreadPoolDevice;
35+
36+
namespace functor {
37+
38+
template <>
39+
struct ReshapeSparseTensorFunctor<CPUDevice> {
40+
Status operator()(const TensorShape &input_shape,
41+
const TensorShape &output_shape,
42+
typename TTypes<int64>::ConstMatrix input_indices,
43+
typename TTypes<int64>::Matrix output_indices) const {
44+
const int64 input_rank = input_shape.dims();
45+
const int64 output_rank = output_shape.dims();
46+
const int64 nnz = input_indices.dimension(0);
47+
gtl::InlinedVector<int64, 8> input_strides(input_rank);
48+
if (input_rank > 0) {
49+
input_strides[input_rank - 1] = 1;
50+
for (int d = input_rank - 2; d >= 0; --d) {
51+
input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
52+
}
53+
}
54+
55+
gtl::InlinedVector<int64, 8> output_strides(output_rank);
56+
if (output_rank > 0) {
57+
output_strides[output_rank - 1] = 1;
58+
for (int d = output_rank - 2; d >= 0; --d) {
59+
output_strides[d] =
60+
output_strides[d + 1] * output_shape.dim_size(d + 1);
61+
}
62+
}
63+
64+
for (int i = 0; i < nnz; ++i) {
65+
int64 id = 0;
66+
for (int j = 0; j < input_rank; ++j) {
67+
id += input_indices(i, j) * input_strides[j];
68+
}
69+
for (int j = 0; j < output_rank; ++j) {
70+
output_indices(i, j) = id / output_strides[j];
71+
id %= output_strides[j];
72+
}
73+
}
74+
return Status::OK();
75+
}
76+
};
77+
78+
} // namespace functor
79+
80+
template <typename Device>
3481
void ReshapeSparseTensor(OpKernelContext *context,
3582
const Tensor &input_indices_in,
3683
const Tensor &input_shape_in,
@@ -49,7 +96,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
4996
"Target shape should be a vector but received shape ",
5097
target_shape_in.shape().DebugString()));
5198

52-
const int64 input_rank = input_shape_in.NumElements();
5399
const int64 output_rank = target_shape_in.NumElements();
54100
const TensorShape input_shape(input_shape_in.vec<int64>());
55101
const int64 dense_size = input_shape.num_elements();
@@ -111,40 +157,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
111157
return;
112158
}
113159

114-
gtl::InlinedVector<int64, 8> input_strides(input_rank);
115-
if (input_rank > 0) {
116-
input_strides[input_rank - 1] = 1;
117-
for (int d = input_rank - 2; d >= 0; --d) {
118-
input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
119-
}
120-
}
121-
122-
gtl::InlinedVector<int64, 8> output_strides(output_rank);
123-
if (output_rank > 0) {
124-
output_strides[output_rank - 1] = 1;
125-
for (int d = output_rank - 2; d >= 0; --d) {
126-
output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
127-
}
128-
}
129-
130-
Tensor *result_indices = nullptr;
131-
OP_REQUIRES_OK(context,
132-
context->allocate_output(output_indices_idx,
133-
TensorShape({nnz, output_rank}),
134-
&result_indices));
135-
auto input_ind = input_indices_in.matrix<int64>();
136-
auto output_ind = result_indices->matrix<int64>();
137-
for (int i = 0; i < nnz; ++i) {
138-
int64 id = 0;
139-
for (int j = 0; j < input_rank; ++j) {
140-
id += input_ind(i, j) * input_strides[j];
141-
}
142-
for (int j = 0; j < output_rank; ++j) {
143-
output_ind(i, j) = id / output_strides[j];
144-
id %= output_strides[j];
145-
}
146-
}
147-
148160
Tensor *result_shape = nullptr;
149161
OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
150162
TensorShape({output_rank}),
@@ -153,6 +165,26 @@ void ReshapeSparseTensor(OpKernelContext *context,
153165
for (int j = 0; j < output_shape.dims(); ++j) {
154166
output_shape_vec(j) = output_shape.dim_size(j);
155167
}
168+
169+
Tensor *result_indices = nullptr;
170+
OP_REQUIRES_OK(context,
171+
context->allocate_output(output_indices_idx,
172+
TensorShape({nnz, output_rank}),
173+
&result_indices));
174+
if (nnz > 0) {
175+
OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
176+
input_shape, output_shape,
177+
input_indices_in.matrix<int64>(),
178+
result_indices->matrix<int64>()));
179+
}
156180
}
157181

182+
#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
183+
template void ReshapeSparseTensor<Device>( \
184+
OpKernelContext * context, const Tensor &input_indices_in, \
185+
const Tensor &input_shape_in, const Tensor &target_shape_in, \
186+
int output_indices_idx, int output_shape_idx)
187+
EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
188+
#undef EXPLICITLY_INSTANTIATE_FUNCTION
189+
158190
} // namespace tensorflow

tensorflow/core/kernels/reshape_util.h

+18
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,36 @@ limitations under the License.
1616
#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
1717
#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
1818

19+
#include "tensorflow/core/framework/tensor_shape.h"
20+
#include "tensorflow/core/framework/tensor_types.h"
21+
#include "tensorflow/core/lib/core/status.h"
22+
1923
namespace tensorflow {
2024

2125
class OpKernelContext;
2226
class Tensor;
2327

2428
// Reshapes the input indices and input shape to the target shape.
29+
// Note: This template is explicitly instantiated for CPU device only.
30+
template <typename Device>
2531
void ReshapeSparseTensor(OpKernelContext *context,
2632
const Tensor &input_indices_in,
2733
const Tensor &input_shape_in,
2834
const Tensor &target_shape_in, int output_indices_idx,
2935
int output_shape_idx);
3036

37+
namespace functor {
38+
39+
template <typename Device>
40+
struct ReshapeSparseTensorFunctor {
41+
Status operator()(const TensorShape &input_shape,
42+
const TensorShape &output_shape,
43+
typename TTypes<int64>::ConstMatrix input_indices,
44+
typename TTypes<int64>::Matrix output_indices) const;
45+
};
46+
47+
} // namespace functor
48+
3149
} // namespace tensorflow
3250

3351
#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_

tensorflow/core/kernels/sparse_reshape_op.cc

+8-4
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,21 @@ limitations under the License.
2929

3030
namespace tensorflow {
3131

32+
using CPUDevice = Eigen::ThreadPoolDevice;
33+
34+
template <typename Device>
3235
class SparseReshapeOp : public OpKernel {
3336
public:
3437
explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
3538

3639
void Compute(OpKernelContext* context) override {
37-
ReshapeSparseTensor(context, context->input(0), context->input(1),
38-
context->input(2), 0 /* output indices index */,
39-
1 /* output shape index */);
40+
ReshapeSparseTensor<Device>(context, context->input(0), context->input(1),
41+
context->input(2), 0 /* output indices index */,
42+
1 /* output shape index */);
4043
}
4144
};
4245

4346
REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU),
44-
SparseReshapeOp)
47+
SparseReshapeOp<CPUDevice>)
48+
4549
} // namespace tensorflow

0 commit comments

Comments
 (0)