Skip to content

Commit aac635e

Browse files
committed
tests: add HNSW first unit test objectbox#24
!!! :)
1 parent 948c43e commit aac635e

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

tests/test_hnsw.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import math
2+
import numpy as np
3+
import random
4+
from common import *
5+
from objectbox.query_builder import QueryBuilder
6+
7+
8+
def _find_expected_nn(points: np.ndarray, query: np.ndarray, n: int):
9+
""" Given a set of points of shape (N, P) and a query of shape (P), finds the n points nearest to query. """
10+
11+
assert points.ndim == 2 and query.ndim == 1
12+
assert points.shape[1] == query.shape[0]
13+
14+
d = np.linalg.norm(points - query, axis=1) # Euclidean distance
15+
return np.argsort(d)[:n]
16+
17+
18+
def _test_random_points(num_points: int, num_query_points: int, seed: Optional[int] = None):
19+
""" Generates random points in a 2d plane; checks the queried NN against the expected. """
20+
21+
print(f"Test random points; Points: {num_points}, Query points: {num_query_points}, Seed: {seed}")
22+
23+
k = 10
24+
25+
if seed is not None:
26+
np.random.seed(seed)
27+
28+
points = np.random.rand(num_points, 2).astype(np.float32)
29+
30+
db = create_test_objectbox()
31+
32+
# Init and seed DB
33+
box = objectbox.Box(db, VectorEntity)
34+
35+
print(f"Seeding DB with {num_points} points...")
36+
objects = []
37+
for i in range(points.shape[0]):
38+
object_ = VectorEntity()
39+
object_.name = f"point_{i}"
40+
object_.vector = points[i]
41+
objects.append(object_)
42+
box.put(*objects)
43+
print(f"DB seeded with {box.count()} random points!")
44+
45+
assert box.count() == num_points
46+
47+
# Generate a random list of query points
48+
query_points = np.random.rand(num_query_points, 2).astype(np.float32)
49+
50+
# Iterate query points, and compare expected result with OBX result
51+
print(f"Running {num_query_points} searches...")
52+
for i in range(query_points.shape[0]):
53+
query_point = query_points[i]
54+
55+
# Find the ground truth (brute force)
56+
expected_result = _find_expected_nn(points, query_point, k) + 1 # + 1 because OBX IDs start from 1
57+
assert len(expected_result) == k
58+
59+
# Run ANN with OBX
60+
query_builder = QueryBuilder(db, box)
61+
query_builder.nearest_neighbors_f32(VectorEntity.get_property("vector")._id, query_point, k)
62+
query = query_builder.build()
63+
obx_result = [id_ for id_, score in query.find_ids_with_scores()] # Ignore score
64+
assert len(obx_result) == k
65+
66+
# We would like at least half of the expected results, to be returned by the search (in any order)
67+
# Remember: it's an approximate search!
68+
search_score = len(np.intersect1d(expected_result, obx_result)) / k
69+
assert search_score >= 0.5 # TODO likely could be increased
70+
71+
print(f"Done!")
72+
73+
74+
def test_random_points():
75+
_test_random_points(num_points=100, num_query_points=10, seed=10)
76+
_test_random_points(num_points=100, num_query_points=10, seed=11)
77+
_test_random_points(num_points=100, num_query_points=10, seed=12)
78+
_test_random_points(num_points=100, num_query_points=10, seed=13)
79+
_test_random_points(num_points=100, num_query_points=10, seed=14)
80+
_test_random_points(num_points=100, num_query_points=10, seed=15)

0 commit comments

Comments
 (0)