Skip to content

Commit eb7e26c

Browse files
committed
support for float16
1 parent d1edd85 commit eb7e26c

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

src/_arraykit.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6245,6 +6245,42 @@ AK_TM_transfer(TriMapObject* tm,
62456245
}
62466246
break;
62476247

6248+
case NPY_FLOAT32:
6249+
switch (PyArray_TYPE(array_from)) {
6250+
case NPY_FLOAT32:
6251+
TRANSFER_SCALARS(npy_float32, npy_float32); // to, from
6252+
break;
6253+
case NPY_FLOAT16:
6254+
TRANSFER_SCALARS(npy_float32, npy_float16); // to, from
6255+
break;
6256+
case NPY_INT16:
6257+
TRANSFER_SCALARS(npy_float32, npy_int16); // to, from
6258+
break;
6259+
case NPY_INT8:
6260+
TRANSFER_SCALARS(npy_float32, npy_int8); // to, from
6261+
break;
6262+
case NPY_UINT16:
6263+
TRANSFER_SCALARS(npy_float32, npy_uint16); // to, from
6264+
break;
6265+
case NPY_UINT8:
6266+
TRANSFER_SCALARS(npy_float32, npy_uint8); // to, from
6267+
break;
6268+
}
6269+
break;
6270+
6271+
case NPY_FLOAT16:
6272+
switch (PyArray_TYPE(array_from)) {
6273+
case NPY_FLOAT16:
6274+
TRANSFER_SCALARS(npy_float16, npy_float16); // to, from
6275+
break;
6276+
case NPY_INT8:
6277+
TRANSFER_SCALARS(npy_float16, npy_int8); // to, from
6278+
break;
6279+
case NPY_UINT8:
6280+
TRANSFER_SCALARS(npy_float16, npy_uint8); // to, from
6281+
break;
6282+
}
6283+
break;
62486284

62496285
// unicode
62506286
case NPY_UNICODE: {

test/test_tri_map.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,124 @@ def test_tri_map_map_float_d(self) -> None:
763763
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.float64))
764764
self.assertEqual(post_dst.dtype, np.dtype(np.float64))
765765
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
766+
767+
def test_tri_map_map_float_e(self) -> None:
768+
src = np.array([0, 20, 8, 8], dtype=np.float32)
769+
dst = np.array([7, 20, 20, 8], dtype=np.float32)
770+
771+
# full outer
772+
tm = TriMap(len(src), len(dst))
773+
tm.register_one(0, -1)
774+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
775+
tm.register_one(2, 3)
776+
tm.register_one(3, 3)
777+
tm.register_unmatched_dst()
778+
779+
post_src = tm.map_src_fill(src, 17, np.dtype(np.int8))
780+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
781+
self.assertEqual(post_src.dtype, np.dtype(np.float32))
782+
783+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int8))
784+
self.assertEqual(post_dst.dtype, np.dtype(np.float32))
785+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
786+
787+
def test_tri_map_map_float_f(self) -> None:
788+
src = np.array([0, 20, 8, 8], dtype=np.float32)
789+
dst = np.array([7, 20, 20, 8], dtype=np.float32)
790+
791+
# full outer
792+
tm = TriMap(len(src), len(dst))
793+
tm.register_one(0, -1)
794+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
795+
tm.register_one(2, 3)
796+
tm.register_one(3, 3)
797+
tm.register_unmatched_dst()
798+
799+
post_src = tm.map_src_fill(src, 17, np.dtype(np.uint16))
800+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
801+
self.assertEqual(post_src.dtype, np.dtype(np.float32))
802+
803+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.uint16))
804+
self.assertEqual(post_dst.dtype, np.dtype(np.float32))
805+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
806+
807+
def test_tri_map_map_float_g(self) -> None:
808+
src = np.array([0, 20, 8, 8], dtype=np.float32)
809+
dst = np.array([7, 20, 20, 8], dtype=np.float32)
810+
811+
# full outer
812+
tm = TriMap(len(src), len(dst))
813+
tm.register_one(0, -1)
814+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
815+
tm.register_one(2, 3)
816+
tm.register_one(3, 3)
817+
tm.register_unmatched_dst()
818+
819+
post_src = tm.map_src_fill(src, 17, np.dtype(np.float32))
820+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
821+
self.assertEqual(post_src.dtype, np.dtype(np.float32))
822+
823+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.float32))
824+
self.assertEqual(post_dst.dtype, np.dtype(np.float32))
825+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
826+
827+
def test_tri_map_map_float_h(self) -> None:
828+
src = np.array([0, 20, 8, 8], dtype=np.float16)
829+
dst = np.array([7, 20, 20, 8], dtype=np.float16)
830+
831+
# full outer
832+
tm = TriMap(len(src), len(dst))
833+
tm.register_one(0, -1)
834+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
835+
tm.register_one(2, 3)
836+
tm.register_one(3, 3)
837+
tm.register_unmatched_dst()
838+
839+
post_src = tm.map_src_fill(src, 17, np.dtype(np.float16))
840+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
841+
self.assertEqual(post_src.dtype, np.dtype(np.float16))
842+
843+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.float16))
844+
self.assertEqual(post_dst.dtype, np.dtype(np.float16))
845+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
846+
847+
848+
def test_tri_map_map_float_i(self) -> None:
849+
src = np.array([0, 20, 8, 8], dtype=np.float16)
850+
dst = np.array([7, 20, 20, 8], dtype=np.float16)
851+
852+
# full outer
853+
tm = TriMap(len(src), len(dst))
854+
tm.register_one(0, -1)
855+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
856+
tm.register_one(2, 3)
857+
tm.register_one(3, 3)
858+
tm.register_unmatched_dst()
859+
860+
post_src = tm.map_src_fill(src, 17, np.dtype(np.int8))
861+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
862+
self.assertEqual(post_src.dtype, np.dtype(np.float16))
863+
864+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.int8))
865+
self.assertEqual(post_dst.dtype, np.dtype(np.float16))
866+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])
867+
868+
def test_tri_map_map_float_i(self) -> None:
869+
src = np.array([0, 20, 8, 8], dtype=np.float16)
870+
dst = np.array([7, 20, 20, 8], dtype=np.float16)
871+
872+
# full outer
873+
tm = TriMap(len(src), len(dst))
874+
tm.register_one(0, -1)
875+
tm.register_many(1, np.array([1, 2], dtype=np.dtype(np.int64)))
876+
tm.register_one(2, 3)
877+
tm.register_one(3, 3)
878+
tm.register_unmatched_dst()
879+
880+
post_src = tm.map_src_fill(src, 17, np.dtype(np.uint8))
881+
self.assertEqual(post_src.tolist(), [0, 20, 20, 8, 8, 17])
882+
self.assertEqual(post_src.dtype, np.dtype(np.float16))
883+
884+
post_dst = tm.map_dst_fill(dst, 17, np.dtype(np.uint8))
885+
self.assertEqual(post_dst.dtype, np.dtype(np.float16))
886+
self.assertEqual(post_dst.tolist(), [17, 20, 20, 8, 8, 7])

0 commit comments

Comments
 (0)