Skip to content

Commit 5fa7e71

Browse files
committed
uint handling
1 parent 96ea600 commit 5fa7e71

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/_arraykit.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6131,6 +6131,12 @@ AK_TM_transfer(TriMapObject* tm,
61316131
case NPY_INT8:
61326132
TRANSFER_SCALARS(npy_int32, npy_int8); // to, from
61336133
break;
6134+
case NPY_UINT16:
6135+
TRANSFER_SCALARS(npy_int32, npy_uint16); // to, from
6136+
break;
6137+
case NPY_UINT8:
6138+
TRANSFER_SCALARS(npy_int32, npy_uint8); // to, from
6139+
break;
61346140
}
61356141
break;
61366142
case NPY_INT16:

test/test_tri_map.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,3 +579,27 @@ def test_tri_map_map_uint_d(self) -> None:
579579
del dst
580580
self.assertEqual(post_dst.dtype, np.dtype(np.int16))
581581
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
582+
583+
584+
def test_tri_map_map_uint_e(self) -> None:
585+
src = np.array([0, 20, 8, 8], dtype=np.uint8)
586+
dst = np.array([7, 20, 20, 8], dtype=np.uint8)
587+
588+
# full outer
589+
tm = TriMap(len(src), len(dst))
590+
tm.register_one(0, -1)
591+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
592+
tm.register_one(2, 3)
593+
tm.register_one(3, 3)
594+
tm.register_unmatched_dst()
595+
596+
# import ipdb; ipdb.set_trace()
597+
post_src = tm.map_src_fill(src, 17, np.dtype(np.int32))
598+
del src
599+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
600+
self.assertEqual(post_src.dtype, np.dtype(np.int32))
601+
602+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int32))
603+
del dst
604+
self.assertEqual(post_dst.dtype, np.dtype(np.int32))
605+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])

0 commit comments

Comments
 (0)