Skip to content

Commit d1edd85

Browse files
committed
handling for uints, floats
1 parent 5fa7e71 commit d1edd85

File tree

2 files changed

+215
-7
lines changed

2 files changed

+215
-7
lines changed

src/_arraykit.c

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6091,7 +6091,7 @@ TriMap_dst_no_fill(TriMapObject *self, PyObject *Py_UNUSED(unused)) {
60916091

60926092
// #define TO_TYPE_PAIR(e1, e2) ((e1 << 8) | e2)
60936093

6094-
// Based on `tm` state, transfer from src or from dst (depending on `from_src`) to a `array_to`, a newly created contiguous array that is compatible with the values in `array_from`. Returns -1 on error.
6094+
// Based on `tm` state, transfer from src or from dst (depending on `from_src`) to a `array_to`, a newly created contiguous array that is compatible with the values in `array_from`. Returns -1 on error. This only needs to match to / from type combinations that are possible from `resolve_dtype`, i.e., bool never goes to integer.
60956095
static inline int
60966096
AK_TM_transfer(TriMapObject* tm,
60976097
bool from_src,
@@ -6118,6 +6118,15 @@ AK_TM_transfer(TriMapObject* tm,
61186118
case NPY_INT8:
61196119
TRANSFER_SCALARS(npy_int64, npy_int8); // to, from
61206120
break;
6121+
case NPY_UINT32:
6122+
TRANSFER_SCALARS(npy_int64, npy_uint32); // to, from
6123+
break;
6124+
case NPY_UINT16:
6125+
TRANSFER_SCALARS(npy_int64, npy_uint16); // to, from
6126+
break;
6127+
case NPY_UINT8:
6128+
TRANSFER_SCALARS(npy_int64, npy_uint8); // to, from
6129+
break;
61216130
}
61226131
break;
61236132
case NPY_INT32:
@@ -6198,6 +6207,45 @@ AK_TM_transfer(TriMapObject* tm,
61986207
TRANSFER_SCALARS(npy_uint8, npy_uint8); // to, from
61996208
break;
62006209

6210+
case NPY_FLOAT64:
6211+
switch (PyArray_TYPE(array_from)) {
6212+
case NPY_FLOAT64:
6213+
TRANSFER_SCALARS(npy_float64, npy_float64); // to, from
6214+
break;
6215+
case NPY_FLOAT32:
6216+
TRANSFER_SCALARS(npy_float64, npy_float32); // to, from
6217+
break;
6218+
case NPY_FLOAT16:
6219+
TRANSFER_SCALARS(npy_float64, npy_float16); // to, from
6220+
break;
6221+
case NPY_INT64:
6222+
TRANSFER_SCALARS(npy_float64, npy_int64); // to, from
6223+
break;
6224+
case NPY_INT32:
6225+
TRANSFER_SCALARS(npy_float64, npy_int32); // to, from
6226+
break;
6227+
case NPY_INT16:
6228+
TRANSFER_SCALARS(npy_float64, npy_int16); // to, from
6229+
break;
6230+
case NPY_INT8:
6231+
TRANSFER_SCALARS(npy_float64, npy_int8); // to, from
6232+
break;
6233+
case NPY_UINT64:
6234+
TRANSFER_SCALARS(npy_float64, npy_uint64); // to, from
6235+
break;
6236+
case NPY_UINT32:
6237+
TRANSFER_SCALARS(npy_float64, npy_uint32); // to, from
6238+
break;
6239+
case NPY_UINT16:
6240+
TRANSFER_SCALARS(npy_float64, npy_uint16); // to, from
6241+
break;
6242+
case NPY_UINT8:
6243+
TRANSFER_SCALARS(npy_float64, npy_uint8); // to, from
6244+
break;
6245+
}
6246+
break;
6247+
6248+
62016249
// unicode
62026250
case NPY_UNICODE: {
62036251
if (PyArray_TYPE(array_from) != NPY_UNICODE) {

test/test_tri_map.py

Lines changed: 166 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,6 @@ def test_tri_map_map_uint_b(self) -> None:
521521
tm.register_one(3, 3)
522522
tm.register_unmatched_dst()
523523

524-
# import ipdb; ipdb.set_trace()
525524
post_src = tm.map_src_fill(src, 17, np.dtype(np.uint32))
526525
del src
527526
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
@@ -545,7 +544,6 @@ def test_tri_map_map_uint_c(self) -> None:
545544
tm.register_one(3, 3)
546545
tm.register_unmatched_dst()
547546

548-
# import ipdb; ipdb.set_trace()
549547
post_src = tm.map_src_fill(src, 17, np.dtype(np.uint64))
550548
del src
551549
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
@@ -569,7 +567,6 @@ def test_tri_map_map_uint_d(self) -> None:
569567
tm.register_one(3, 3)
570568
tm.register_unmatched_dst()
571569

572-
# import ipdb; ipdb.set_trace()
573570
post_src = tm.map_src_fill(src, 17, np.dtype(np.int16))
574571
del src
575572
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
@@ -593,13 +590,176 @@ def test_tri_map_map_uint_e(self) -> None:
593590
tm.register_one(3, 3)
594591
tm.register_unmatched_dst()
595592

596-
# import ipdb; ipdb.set_trace()
597593
post_src = tm.map_src_fill(src, 17, np.dtype(np.int32))
598-
del src
599594
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
600595
self.assertEqual(post_src.dtype, np.dtype(np.int32))
601596

602597
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int32))
603-
del dst
604598
self.assertEqual(post_dst.dtype, np.dtype(np.int32))
605599
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
600+
601+
602+
def test_tri_map_map_uint_f(self) -> None:
603+
src = np.array([0, 20, 8, 8], dtype=np.uint16)
604+
dst = np.array([7, 20, 20, 8], dtype=np.uint16)
605+
606+
# full outer
607+
tm = TriMap(len(src), len(dst))
608+
tm.register_one(0, -1)
609+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
610+
tm.register_one(2, 3)
611+
tm.register_one(3, 3)
612+
tm.register_unmatched_dst()
613+
614+
post_src = tm.map_src_fill(src, 17, np.dtype(np.int32))
615+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
616+
self.assertEqual(post_src.dtype, np.dtype(np.int32))
617+
618+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int32))
619+
self.assertEqual(post_dst.dtype, np.dtype(np.int32))
620+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
621+
622+
def test_tri_map_map_uint_g(self) -> None:
623+
src = np.array([0, 20, 8, 8], dtype=np.uint32)
624+
dst = np.array([7, 20, 20, 8], dtype=np.uint32)
625+
626+
# full outer
627+
tm = TriMap(len(src), len(dst))
628+
tm.register_one(0, -1)
629+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
630+
tm.register_one(2, 3)
631+
tm.register_one(3, 3)
632+
tm.register_unmatched_dst()
633+
634+
post_src = tm.map_src_fill(src, 17, np.dtype(np.int64))
635+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
636+
self.assertEqual(post_src.dtype, np.dtype(np.int64))
637+
638+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int64))
639+
self.assertEqual(post_dst.dtype, np.dtype(np.int64))
640+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
641+
642+
def test_tri_map_map_uint_h(self) -> None:
643+
src = np.array([0, 20, 8, 8], dtype=np.uint8)
644+
dst = np.array([7, 20, 20, 8], dtype=np.uint8)
645+
646+
# full outer
647+
tm = TriMap(len(src), len(dst))
648+
tm.register_one(0, -1)
649+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
650+
tm.register_one(2, 3)
651+
tm.register_one(3, 3)
652+
tm.register_unmatched_dst()
653+
654+
post_src = tm.map_src_fill(src, 17, np.dtype(np.int64))
655+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
656+
self.assertEqual(post_src.dtype, np.dtype(np.int64))
657+
658+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int64))
659+
self.assertEqual(post_dst.dtype, np.dtype(np.int64))
660+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
661+
662+
663+
def test_tri_map_map_uint_i(self) -> None:
664+
src = np.array([0, 20, 8, 8], dtype=np.uint64)
665+
dst = np.array([7, 20, 20, 8], dtype=np.uint64)
666+
667+
# full outer
668+
tm = TriMap(len(src), len(dst))
669+
tm.register_one(0, -1)
670+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
671+
tm.register_one(2, 3)
672+
tm.register_one(3, 3)
673+
tm.register_unmatched_dst()
674+
675+
post_src = tm.map_src_fill(src, 17, np.dtype(np.int64))
676+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
677+
self.assertEqual(post_src.dtype, np.dtype(np.float64))
678+
679+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int64))
680+
self.assertEqual(post_dst.dtype, np.dtype(np.float64))
681+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
682+
683+
684+
685+
def test_tri_map_map_float_a(self) -> None:
686+
src = np.array([0, 20, 8, 8], dtype=np.uint8)
687+
dst = np.array([7, 20, 20, 8], dtype=np.uint8)
688+
689+
# full outer
690+
tm = TriMap(len(src), len(dst))
691+
tm.register_one(0, -1)
692+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
693+
tm.register_one(2, 3)
694+
tm.register_one(3, 3)
695+
tm.register_unmatched_dst()
696+
697+
post_src = tm.map_src_fill(src, 17, np.dtype(np.float64))
698+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
699+
self.assertEqual(post_src.dtype, np.dtype(np.float64))
700+
701+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.float64))
702+
self.assertEqual(post_dst.dtype, np.dtype(np.float64))
703+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
704+
705+
706+
def test_tri_map_map_float_b(self) -> None:
707+
src = np.array([0, 20, 8, 8], dtype=np.int8)
708+
dst = np.array([7, 20, 20, 8], dtype=np.int8)
709+
710+
# full outer
711+
tm = TriMap(len(src), len(dst))
712+
tm.register_one(0, -1)
713+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
714+
tm.register_one(2, 3)
715+
tm.register_one(3, 3)
716+
tm.register_unmatched_dst()
717+
718+
post_src = tm.map_src_fill(src, 17, np.dtype(np.float64))
719+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
720+
self.assertEqual(post_src.dtype, np.dtype(np.float64))
721+
722+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.float64))
723+
self.assertEqual(post_dst.dtype, np.dtype(np.float64))
724+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
725+
726+
727+
def test_tri_map_map_float_c(self) -> None:
728+
src = np.array([0, 20, 8, 8], dtype=np.float32)
729+
dst = np.array([7, 20, 20, 8], dtype=np.float32)
730+
731+
# full outer
732+
tm = TriMap(len(src), len(dst))
733+
tm.register_one(0, -1)
734+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
735+
tm.register_one(2, 3)
736+
tm.register_one(3, 3)
737+
tm.register_unmatched_dst()
738+
739+
post_src = tm.map_src_fill(src, 17, np.dtype(np.float64))
740+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
741+
self.assertEqual(post_src.dtype, np.dtype(np.float64))
742+
743+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.float64))
744+
self.assertEqual(post_dst.dtype, np.dtype(np.float64))
745+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
746+
747+
def test_tri_map_map_float_d(self) -> None:
748+
src = np.array([0, 20, 8, 8], dtype=np.float64)
749+
dst = np.array([7, 20, 20, 8], dtype=np.float64)
750+
751+
# full outer
752+
tm = TriMap(len(src), len(dst))
753+
tm.register_one(0, -1)
754+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
755+
tm.register_one(2, 3)
756+
tm.register_one(3, 3)
757+
tm.register_unmatched_dst()
758+
759+
post_src = tm.map_src_fill(src, 17, np.dtype(np.float64))
760+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
761+
self.assertEqual(post_src.dtype, np.dtype(np.float64))
762+
763+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.float64))
764+
self.assertEqual(post_dst.dtype, np.dtype(np.float64))
765+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])

0 commit comments

Comments
 (0)