Skip to content

Commit be5fc6e

Browse files
committed
fix: added heapsort implementation
1 parent cf79136 commit be5fc6e

File tree

5 files changed

+63
-47
lines changed

5 files changed

+63
-47
lines changed

experiment.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import random
3+
4+
from rtree_query_manager import RTreeQueryManager
5+
from sequential_query_manager import *
6+
from highd_query_manager import *
7+
import pickle
8+
9+
10+
def print_result(result):
11+
for i in result:
12+
for j in i:
13+
print(j)
14+
print("\n\n")
15+
16+
17+
def main() -> None:
18+
with open("out.embeds", mode="rb") as collection_file:
19+
collection = pickle.load(collection_file)
20+
k = 8
21+
for n in [100*(2**p) for p in range(0, 1)]:
22+
print("=" * 60)
23+
print(f"Extracting random sample of size {n}")
24+
sample = random.sample(collection, n)
25+
print("Building managers")
26+
sequential_query_manager = SequentialQueryManager(collection=sample)
27+
rtree_query_manager = RTreeQueryManager(collection=sample, m=3)
28+
high_d_query_manager = HighDQueryManager(collection=sample, num_bits=2000)
29+
print()
30+
query = os.path.join(os.getcwd(), "lfw/Adam_Sandler/Adam_Sandler_0002.jpg")
31+
print("Sequential query")
32+
print_result(sequential_query_manager.knn_query(q=query, k=k))
33+
print("RTree query")
34+
print_result(rtree_query_manager.knn_query(q=query, k=k))
35+
print("LSH query")
36+
print_result(high_d_query_manager.knn_query(q=query, k=k))
37+
print()
38+
39+
40+
if __name__ == "__main__":
41+
main()

heap.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def __lt__(self, other) -> bool:
1616
def __eq__(self, other) -> bool:
1717
return self.val == other.val
1818

19+
def __repr__(self) -> str:
20+
return str(self.val)
21+
1922

2023
class MaxHeap(Generic[T]):
2124

@@ -38,5 +41,11 @@ def empty(self):
3841
def size(self) -> int:
3942
return len(self.heap)
4043

41-
def top(self):
44+
def top(self) -> T:
4245
return self.heap[0].val
46+
47+
def heapsort(self) -> List[T]:
48+
49+
result = [self.pop() for i in range(len(self.heap))]
50+
result.reverse()
51+
return result

highd_query_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
class HighDQueryManager:
9+
@measure_execution_time
910
def __init__(self, num_bits: int, collection: List[Tuple[str, np.ndarray]]) -> None:
1011
self.collection = collection
1112

sequential_query_manager.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class SequentialQueryManager:
1010

11+
@measure_execution_time
1112
def __init__(self, collection: List[Tuple[str, np.ndarray]]) -> None:
1213
self.collection = collection
1314

@@ -47,27 +48,23 @@ def __gt__(self, other):
4748
def __eq__(self, other):
4849
return self.dist == other.dist
4950

51+
def __repr__(self) -> str:
52+
return str((self.file_name, self.dist))
53+
5054
# process face of query path file
5155
image = face_recognition.load_image_file(q)
5256
query = face_recognition.face_encodings(image)
5357

5458
result: List[List[Tuple[str, float]]] = []
5559

5660
for face_embed in query:
57-
result_tmp = MaxHeap[DistWrapper]()
61+
result_heap = MaxHeap[DistWrapper]()
5862
for c in self.collection:
5963
dist: float = np.linalg.norm(c[1] - face_embed)
60-
if result_tmp.size() < k:
61-
result_tmp.push(DistWrapper(d=dist, embed=c[1], file_name=c[0]))
62-
elif result_tmp.top().dist > dist:
63-
result_tmp.pop()
64-
result_tmp.push(DistWrapper(d=dist, embed=c[1], file_name=c[0]))
65-
result_tmp2: List[Tuple[str, float]] = []
66-
67-
while result_tmp.size() != 0:
68-
result_tmp2.append((result_tmp.top().file_name, result_tmp.top().dist))
69-
result_tmp.pop()
70-
result_tmp2 = sorted(result_tmp2, key=lambda x: x[1], reverse=False)
71-
result.append(result_tmp2)
72-
64+
if result_heap.size() < k:
65+
result_heap.push(DistWrapper(d=dist, embed=c[1], file_name=c[0]))
66+
elif result_heap.top().dist > dist:
67+
result_heap.pop()
68+
result_heap.push(DistWrapper(d=dist, embed=c[1], file_name=c[0]))
69+
result.append([(wrapper.file_name, wrapper.dist) for wrapper in result_heap.heapsort()])
7370
return result

test.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)