diff --git a/sqlite-vec.c b/sqlite-vec.c index 3cc802f0..b4412fa9 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -460,6 +460,58 @@ static double distance_l1_f32(const void *a, const void *b, const void *d) { return l1_f32(a, b, d); } +// https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34 +static u8 hamdist_table[256] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, + 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, + 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, + 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; + +static f32 distance_cosine_bit_u64(u64 *a, u64 *b, size_t n) { + f32 dot = 0; + f32 aMag = 0; + f32 bMag = 0; + + for (size_t i = 0; i < n; i++) { + dot += __builtin_popcountl(a[i] & b[i]); + aMag += __builtin_popcountl(a[i]); + bMag += __builtin_popcountl(b[i]); + } + + return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); +} + +static f32 distance_cosine_bit_u8(u8 *a, u8 *b, size_t n) { + f32 dot = 0; + f32 aMag = 0; + f32 bMag = 0; + + for (size_t i = 0; i < n; i++) { + dot += hamdist_table[a[i] & b[i]]; + aMag += hamdist_table[a[i]]; + bMag += hamdist_table[b[i]]; + } + + return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); +} + +static f32 distance_cosine_bit(const void *pA, const void *pB, + const void *pD) { + size_t dim = *((size_t *)pD); + + if ((dim % 64) == 0) { + return distance_cosine_bit_u64((u64 *)pA, (u64 *)pB, dim / 8 / CHAR_BIT); + } + return distance_cosine_bit_u8((u8 *)pA, (u8 *)pB, dim / CHAR_BIT); +} + static f32 distance_cosine_float(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { f32 *pVect1 = (f32 *)pVect1v; @@ -497,20 +549,6 @@ static f32 distance_cosine_int8(const void *pA, const void *pB, return 1 - (dot / (sqrt(aMag) * sqrt(bMag))); } -// https://github.com/facebookresearch/faiss/blob/77e2e79cd0a680adc343b9840dd865da724c579e/faiss/utils/hamming_distance/common.h#L34 -static u8 hamdist_table[256] = { - 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, - 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4, - 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, - 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5, - 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, - 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, - 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; - static f32 distance_hamming_u8(u8 *a, u8 *b, size_t n) { int same = 0; for (unsigned long i = 0; i < n; i++) { @@ -1167,9 +1205,8 @@ static void vec_distance_cosine(sqlite3_context *context, int argc, switch (elementType) { case SQLITE_VEC_ELEMENT_TYPE_BIT: { - sqlite3_result_error( - context, "Cannot calculate cosine distance between two bitvectors.", - -1); + f32 result = distance_cosine_bit(a, b, &dimensions); + sqlite3_result_double(context, result); goto finish; } case SQLITE_VEC_ELEMENT_TYPE_FLOAT32: { diff --git a/tests/test-loadable.py b/tests/test-loadable.py index a8058c9e..38618f08 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -423,6 +423,25 @@ def check(a, b, dtype=np.float32): check([1, 2, 3], [-9, -8, -7], dtype=np.int8) assert vec_distance_cosine("[1.1, 1.0]", "[1.2, 1.2]") == 0.001131898257881403 + vec_distance_cosine_bit = lambda *args: db.execute( + "select vec_distance_cosine(vec_bit(?), vec_bit(?))", args + ).fetchone()[0] + assert isclose( + vec_distance_cosine_bit(b"\xff", b"\x01"), + npy_cosine([1,1,1,1,1,1,1,1], [0,0,0,0,0,0,0,1]), + abs_tol=1e-6 + ) + assert isclose( + vec_distance_cosine_bit(b"\xab", b"\xab"), + npy_cosine([1,0,1,0,1,0,1,1], [1,0,1,0,1,0,1,1]), + abs_tol=1e-6 + ) + # test 64-bit + assert isclose( + vec_distance_cosine_bit(b"\xaa" * 8, b"\xff" * 8), + npy_cosine([1,0] * 32, [1] * 64), + abs_tol=1e-6 + ) def test_vec_distance_hamming(): vec_distance_hamming = lambda *args: db.execute(