Skip to content

Commit a94611c

Browse files
authored
Adds reinterpret_const with a test. (#86)
1 parent c65e05f commit a94611c

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

include/array/array.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,7 @@ NDARRAY_HOST_DEVICE auto transpose_impl(const Shape& shape, index_sequence<Extra
12401240
shape.template dim<DimIndices>()..., shape.template dim<sizeof...(DimIndices) + Extras>()...);
12411241
}
12421242

1243-
} // namespace internal
1243+
} // namespace internal
12441244

12451245
/** Create a new shape using a list of `DimIndices...` to use as the
12461246
* dimensions of the shape. The new shape's i'th dimension will be the
@@ -1252,7 +1252,8 @@ NDARRAY_HOST_DEVICE auto transpose_impl(const Shape& shape, index_sequence<Extra
12521252
*
12531253
* Examples:
12541254
* - `transpose<2, 0, 1>(s) == make_shape(s.dim<2>(), s.dim<0>(), s.dim<1>())`
1255-
* - `transpose<1, 0>(s) == make_shape(s.dim<1>(), s.dim<0>(), ...)` where ... is all dimensions after dimension 1. */
1255+
* - `transpose<1, 0>(s) == make_shape(s.dim<1>(), s.dim<0>(), ...)` where ... is all dimensions
1256+
* after dimension 1. */
12561257
template <size_t... DimIndices, class... Dims,
12571258
class = internal::enable_if_permutation<sizeof...(DimIndices), DimIndices...>>
12581259
NDARRAY_HOST_DEVICE auto transpose(const shape<Dims...>& shape) {
@@ -1391,8 +1392,7 @@ NDARRAY_INLINE NDARRAY_HOST_DEVICE void for_each_value_in_order(
13911392

13921393
// Scalar buffers are a special case. The enable_if here (and above) are a workaround for a bug in
13931394
// old versions of GCC that causes this overload to be ambiguous.
1394-
template <size_t D, class Fn, class... Ptrs,
1395-
std::enable_if_t<(D == -1), int> = 0>
1395+
template <size_t D, class Fn, class... Ptrs, std::enable_if_t<(D == -1), int> = 0>
13961396
NDARRAY_INLINE NDARRAY_HOST_DEVICE void for_each_value_in_order(
13971397
const std::tuple<>& extent, Fn&& fn, Ptrs... ptrs) {
13981398
fn(*std::get<0>(ptrs)...);
@@ -1553,7 +1553,8 @@ NDARRAY_HOST_DEVICE auto make_compact(const Shape& s) {
15531553

15541554
/** A `shape` where all extents (and automatically computed compact strides) are constant. */
15551555
template <index_t... Extents>
1556-
using fixed_dense_shape = decltype(make_shape_from_tuple(internal::make_compact_dims<1>(dim<0, Extents>()...)));
1556+
using fixed_dense_shape =
1557+
decltype(make_shape_from_tuple(internal::make_compact_dims<1>(dim<0, Extents>()...)));
15571558

15581559
/** Returns `true` if a shape `src` can be assigned to a shape of type
15591560
* `ShapeDst` without error. */
@@ -2870,6 +2871,13 @@ const_array_ref<U, Shape> reinterpret(const array<T, Shape, Alloc>& a) {
28702871
return reinterpret<const U>(a.cref());
28712872
}
28722873

2874+
/** Reinterpret the const_array_ref `a` of type `T` (aka array_ref<const T>) to have a different
2875+
* type `U` using `const_cast`. */
2876+
template <class U, class T, class Shape>
2877+
array_ref<U, Shape> reinterpret_const(const const_array_ref<T, Shape>& a) {
2878+
return array_ref<U, Shape>(const_cast<U*>(a.base()), a.shape());
2879+
}
2880+
28732881
/** Reinterpret the shape of the array or array_ref `a` to be a new shape
28742882
* `new_shape`, with a base pointer offset `offset`. */
28752883
template <class NewShape, class T, class OldShape>

test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ cc_test(
44
srcs = [
55
"algorithm.cpp",
66
"array.cpp",
7+
"array_ref.cpp",
78
"ein_reduce.cpp",
89
"image.cpp",
910
"lifetime.cpp",

test/array_ref.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ TEST(reinterpret) {
6060
ASSERT_EQ(scalar(), eight);
6161
}
6262

63+
TEST(reinterpret_const) {
64+
array_of_rank<int, 1> a1({3});
65+
a1(0) = 5;
66+
a1(1) = 2;
67+
a1(2) = 6;
68+
69+
const_array_ref_of_rank<int, 1> a1_cref = a1.cref();
70+
array_ref_of_rank<int, 1> a1_ref = reinterpret_const<int>(a1_cref);
71+
a1_ref(0) = 10;
72+
a1_ref(1) = 20;
73+
a1_ref(2) = 30;
74+
75+
ASSERT_EQ(a1(0), 10);
76+
ASSERT_EQ(a1(1), 20);
77+
ASSERT_EQ(a1(2), 30);
78+
}
79+
6380
TEST(array_ref_copy) {
6481
int data[100];
6582
for (int i = 0; i < 100; i++) {
@@ -87,8 +104,8 @@ TEST(array_ref_incompatible_shape) {
87104
}
88105

89106
{
90-
array_ref<int, nda::shape<nda::fixed_dim<10, nda::dynamic>,
91-
nda::fixed_dim<2, nda::dynamic>>> dst;
107+
array_ref<int, nda::shape<nda::fixed_dim<10, nda::dynamic>, nda::fixed_dim<2, nda::dynamic>>>
108+
dst;
92109
array_ref<int, nda::shape_of_rank<2>> src(nullptr, {{0, 10}, {0, 3}});
93110
// Error converting dim 1: expected static extent 2, got 3
94111
// dst = src;
@@ -194,11 +211,11 @@ TEST(array_ref_static_convertibilty) {
194211
using AR3 = array_ref_of_rank<int, 3>;
195212

196213
static_assert(std::is_convertible<AR0&, int&>::value,
197-
"rank-0 array_ref should be convertible to scalar element");
214+
"rank-0 array_ref should be convertible to scalar element");
198215
static_assert(std::is_convertible<const AR0&, int&>::value,
199-
"rank-0 array_ref should be convertible to scalar element");
216+
"rank-0 array_ref should be convertible to scalar element");
200217
static_assert(!std::is_convertible<AR3&, int&>::value,
201-
"rank-3 array_ref should not be convertible to element");
218+
"rank-3 array_ref should not be convertible to element");
202219
}
203220

204221
TEST(array_ref_crop_slice) {

test/errors.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ void shape_operator_eq_different_rank() {
8686
s == s2;
8787
}
8888

89-
void shape_modify_runtime_dim() {
90-
s.dim(1).set_min(0);
91-
}
89+
void shape_modify_runtime_dim() { s.dim(1).set_min(0); }
9290

9391
void is_compatible_different_dims() { is_compatible<shape_of_rank<3>>(s); }
9492

@@ -148,6 +146,15 @@ void make_move_different_rank() { auto a2 = make_move(a, shape_of_rank<2>()); }
148146

149147
void make_move_ref_different_rank() { auto a2 = make_move(ref, shape_of_rank<3>()); }
150148

149+
// When uncommented, this fails to compile as expected, but since the error is due to `const_cast`
150+
// itself, it will say something like:
151+
//
152+
// include/array/array.h:2878:30: error: const_cast from 'nda::array_ref<const int,
153+
// nda::shape<nda::dim<-9, -9, 1>, nda::dim<-9, -9, -9>, nda::dim<-9, -9, -9>>>::pointer' (aka
154+
// 'const int *') to 'float *' is not allowed
155+
//
156+
// void reinterpret_const_wrong_type() { reinterpret_const<float>(a.cref()); }
157+
151158
void ein_wrong_rank() { ein<0, 1>(a); }
152159

153160
void ein_scalar_arithmetic() { ein<0, 1, 2>(a) + 2; }

0 commit comments

Comments
 (0)