Skip to content

Commit 274b37a

Browse files
committed
added support for bytes, some complex types
1 parent 0f3f7ba commit 274b37a

File tree

2 files changed

+128
-74
lines changed

2 files changed

+128
-74
lines changed

src/_arraykit.c

Lines changed: 110 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6057,7 +6057,7 @@ TriMap_dst_no_fill(TriMapObject *self, PyObject *Py_UNUSED(unused)) {
60576057
# define TRANSFER_SCALARS(npy_type_to, npy_type_from) { \
60586058
npy_type_to* array_to_data = (npy_type_to*)PyArray_DATA(array_to); \
60596059
for (Py_ssize_t i = 0; i < one_count; i++) { \
6060-
array_to_data[one_pairs[i].to] = (npy_type_to)*(npy_type_from*)PyArray_GETPTR1(\
6060+
array_to_data[one_pairs[i].to] = (npy_type_to)*(npy_type_from*)PyArray_GETPTR1( \
60616061
array_from, one_pairs[i].from); \
60626062
} \
60636063
npy_type_to* t; \
@@ -6088,8 +6088,43 @@ TriMap_dst_no_fill(TriMapObject *self, PyObject *Py_UNUSED(unused)) {
60886088
} \
60896089
} \
60906090

6091+
#define TRANSFER_FLEXIBLE(npy_type) { \
6092+
npy_intp element_size = PyArray_DESCR(array_to)->elsize; \
6093+
npy_intp element_cp = element_size / sizeof(npy_type); \
6094+
npy_type* array_to_data = (npy_type*)PyArray_DATA(array_to); \
6095+
npy_type* f; \
6096+
npy_type* t; \
6097+
npy_type* t_end; \
6098+
npy_intp dst_pos; \
6099+
npy_int64 f_pos; \
6100+
PyArrayObject* dst; \
6101+
for (Py_ssize_t i = 0; i < one_count; i++) { \
6102+
f = (npy_type*)PyArray_GETPTR1(array_from, one_pairs[i].from); \
6103+
t = array_to_data + element_cp * one_pairs[i].to; \
6104+
memcpy(t, f, element_size); \
6105+
} \
6106+
for (Py_ssize_t i = 0; i < tm->many_count; i++) { \
6107+
t = array_to_data + element_cp * tm->many_to[i].start; \
6108+
t_end = array_to_data + element_cp * tm->many_to[i].stop; \
6109+
if (from_src) { \
6110+
f = (npy_type*)PyArray_GETPTR1(array_from, tm->many_from[i].src); \
6111+
for (; t < t_end; t += element_cp) { \
6112+
memcpy(t, f, element_size); \
6113+
} \
6114+
} \
6115+
else { \
6116+
dst_pos = 0; \
6117+
dst = tm->many_from[i].dst; \
6118+
for (; t < t_end; t += element_cp) { \
6119+
f_pos = *(npy_int64*)PyArray_GETPTR1(dst, dst_pos); \
6120+
f = (npy_type*)PyArray_GETPTR1(array_from, f_pos); \
6121+
memcpy(t, f, element_size); \
6122+
dst_pos++; \
6123+
} \
6124+
} \
6125+
} \
6126+
} \
60916127

6092-
// #define TO_TYPE_PAIR(e1, e2) ((e1 << 8) | e2)
60936128

60946129
// Based on `tm` state, transfer from src or from dst (depending on `from_src`) to a `array_to`, a newly created contiguous array that is compatible with the values in `array_from`. Returns -1 on error. This only needs to match to / from type combinations that are possible from `resolve_dtype`, i.e., bool never goes to integer.
60956130
static inline int
@@ -6285,88 +6320,90 @@ AK_TM_transfer(TriMapObject* tm,
62856320
}
62866321
break;
62876322

6323+
// NOTE: full support for scalar to complex requires assigning within complex struct
62886324
case NPY_COMPLEX128:
62896325
switch (PyArray_TYPE(array_from)) {
62906326
case NPY_COMPLEX128:
62916327
TRANSFER_SCALARS(npy_complex128, npy_complex128); // to, from
62926328
break;
6293-
case NPY_FLOAT64:
6294-
TRANSFER_SCALARS(npy_complex128, npy_float64); // to, from
6295-
break;
6296-
case NPY_FLOAT32:
6297-
TRANSFER_SCALARS(npy_complex128, npy_float32); // to, from
6298-
break;
6299-
case NPY_FLOAT16:
6300-
TRANSFER_SCALARS(npy_complex128, npy_float16); // to, from
6301-
break;
6302-
case NPY_INT64:
6303-
TRANSFER_SCALARS(npy_complex128, npy_int64); // to, from
6304-
break;
6305-
case NPY_INT32:
6306-
TRANSFER_SCALARS(npy_complex128, npy_int32); // to, from
6307-
break;
6308-
case NPY_INT16:
6309-
TRANSFER_SCALARS(npy_complex128, npy_int16); // to, from
6310-
break;
6311-
case NPY_INT8:
6312-
TRANSFER_SCALARS(npy_complex128, npy_int8); // to, from
6313-
break;
6314-
case NPY_UINT64:
6315-
TRANSFER_SCALARS(npy_complex128, npy_uint64); // to, from
6316-
break;
6317-
case NPY_UINT32:
6318-
TRANSFER_SCALARS(npy_complex128, npy_uint32); // to, from
6319-
break;
6320-
case NPY_UINT16:
6321-
TRANSFER_SCALARS(npy_complex128, npy_uint16); // to, from
6322-
break;
6323-
case NPY_UINT8:
6324-
TRANSFER_SCALARS(npy_complex128, npy_uint8); // to, from
6329+
// case NPY_COMPLEX64:
6330+
// TRANSFER_SCALARS(npy_complex128, npy_complex64); // to, from
6331+
// break;
6332+
// case NPY_FLOAT64:
6333+
// TRANSFER_SCALARS(npy_complex128, npy_float64); // to, from
6334+
// break;
6335+
// case NPY_FLOAT32:
6336+
// TRANSFER_SCALARS(npy_complex128, npy_float32); // to, from
6337+
// break;
6338+
// case NPY_FLOAT16:
6339+
// TRANSFER_SCALARS(npy_complex128, npy_float16); // to, from
6340+
// break;
6341+
// case NPY_INT64:
6342+
// TRANSFER_SCALARS(npy_complex128, npy_int64); // to, from
6343+
// break;
6344+
// case NPY_INT32:
6345+
// TRANSFER_SCALARS(npy_complex128, npy_int32); // to, from
6346+
// break;
6347+
// case NPY_INT16:
6348+
// TRANSFER_SCALARS(npy_complex128, npy_int16); // to, from
6349+
// break;
6350+
// case NPY_INT8:
6351+
// TRANSFER_SCALARS(npy_complex128, npy_int8); // to, from
6352+
// break;
6353+
// case NPY_UINT64:
6354+
// TRANSFER_SCALARS(npy_complex128, npy_uint64); // to, from
6355+
// break;
6356+
// case NPY_UINT32:
6357+
// TRANSFER_SCALARS(npy_complex128, npy_uint32); // to, from
6358+
// break;
6359+
// case NPY_UINT16:
6360+
// TRANSFER_SCALARS(npy_complex128, npy_uint16); // to, from
6361+
// break;
6362+
// case NPY_UINT8:
6363+
// TRANSFER_SCALARS(npy_complex128, npy_uint8); // to, from
6364+
// break;
6365+
}
6366+
break;
6367+
6368+
case NPY_COMPLEX64:
6369+
switch (PyArray_TYPE(array_from)) {
6370+
case NPY_COMPLEX64:
6371+
TRANSFER_SCALARS(npy_complex64, npy_complex64); // to, from
63256372
break;
6373+
// case NPY_FLOAT32:
6374+
// TRANSFER_SCALARS(npy_complex64, npy_float32); // to, from
6375+
// break;
6376+
// case NPY_FLOAT16:
6377+
// TRANSFER_SCALARS(npy_complex64, npy_float16); // to, from
6378+
// break;
6379+
// case NPY_INT32:
6380+
// TRANSFER_SCALARS(npy_complex64, npy_int32); // to, from
6381+
// break;
6382+
// case NPY_INT16:
6383+
// TRANSFER_SCALARS(npy_complex64, npy_int16); // to, from
6384+
// break;
6385+
// case NPY_INT8:
6386+
// TRANSFER_SCALARS(npy_complex64, npy_int8); // to, from
6387+
// break;
6388+
// case NPY_UINT32:
6389+
// TRANSFER_SCALARS(npy_complex64, npy_uint32); // to, from
6390+
// break;
6391+
// case NPY_UINT16:
6392+
// TRANSFER_SCALARS(npy_complex64, npy_uint16); // to, from
6393+
// break;
6394+
// case NPY_UINT8:
6395+
// TRANSFER_SCALARS(npy_complex64, npy_uint8); // to, from
6396+
// break;
63266397
}
63276398
break;
63286399

63296400
// unicode
63306401
case NPY_UNICODE: {
6331-
if (PyArray_TYPE(array_from) != NPY_UNICODE) {
6332-
return -1;
6333-
}
6334-
npy_intp element_size = PyArray_DESCR(array_to)->elsize;
6335-
// get number of UCS4 code points per element
6336-
npy_intp element_cp = element_size / UCS4_SIZE;
6337-
Py_UCS4* array_to_data = (Py_UCS4*)PyArray_DATA(array_to); // contiguous
6338-
Py_UCS4* f;
6339-
Py_UCS4* t;
6340-
Py_UCS4* t_end;
6341-
npy_intp dst_pos;
6342-
npy_int64 f_pos;
6343-
PyArrayObject* dst;
6344-
for (Py_ssize_t i = 0; i < one_count; i++) {
6345-
f = (Py_UCS4*)PyArray_GETPTR1(array_from, one_pairs[i].from);
6346-
t = array_to_data + element_cp * one_pairs[i].to;
6347-
memcpy(t, f, element_size);
6348-
}
6349-
for (Py_ssize_t i = 0; i < tm->many_count; i++) {
6350-
t = array_to_data + element_cp * tm->many_to[i].start;
6351-
t_end = array_to_data + element_cp * tm->many_to[i].stop;
6352-
if (from_src) {
6353-
// copy the same src into multiple final
6354-
f = (Py_UCS4*)PyArray_GETPTR1(array_from, tm->many_from[i].src);
6355-
for (; t < t_end; t += element_cp) {
6356-
memcpy(t, f, element_size);
6357-
}
6358-
}
6359-
else { // from_dst, dst is an array
6360-
dst_pos = 0;
6361-
dst = tm->many_from[i].dst;
6362-
for (; t < t_end; t += element_cp) {
6363-
f_pos = *(npy_int64*)PyArray_GETPTR1(dst, dst_pos); // DO NOT TEMPLATE
6364-
f = (Py_UCS4*)PyArray_GETPTR1(array_from, f_pos);
6365-
memcpy(t, f, element_size);
6366-
dst_pos++;
6367-
}
6368-
}
6369-
}
6402+
TRANSFER_FLEXIBLE(Py_UCS4);
6403+
break;
6404+
}
6405+
case NPY_STRING: {
6406+
TRANSFER_FLEXIBLE(char);
63706407
break;
63716408
}
63726409
// NOTE: could use PyArray_Scalar instead of PyArray_GETITEM if we wanted to store scalars instead of Python objects; however, that is pretty uncommon for object arrays to store PyArray_Scalars

test/test_tri_map.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,4 +883,21 @@ def test_tri_map_map_float_i(self) -> None:
883883

884884
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.uint8))
885885
self.assertEqual(post_dst.dtype, np.dtype(np.float16))
886-
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
886+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
887+
888+
def test_tri_map_map_bytes_a(self) -> None:
889+
src = np.array(['a', 'bbb', 'cc', 'dddd', 'a'], dtype=np.bytes_)
890+
dst = np.array(['cc', 'dddd', 'a', 'bbb', 'cc'], dtype=np.bytes_)
891+
892+
tm = TriMap(len(src), len(dst))
893+
tm.register_one(0, 2)
894+
tm.register_one(1, 3)
895+
tm.register_many(2, np.array([0, 4], dtype=np.dtype(np.int64)))
896+
tm.register_one(3, 1)
897+
tm.register_one(4, 2)
898+
899+
post_src = tm.map_src_no_fill(src)
900+
self.assertEqual(post_src.tolist(), [b'a', b'bbb', b'cc', b'cc', b'dddd', b'a'])
901+
902+
post_dst = tm.map_dst_no_fill(dst)
903+
self.assertEqual(post_dst.tolist(), [b'a', b'bbb', b'cc', b'cc', b'dddd', b'a'])

0 commit comments

Comments
 (0)