-
Notifications
You must be signed in to change notification settings - Fork 794
/
Copy pathmodule.py
124 lines (97 loc) · 4.22 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import glob
import os
import subprocess
import sys
import time
import numpy as np
import mariadb
import psutil
from ..base.module import BaseANN
class MariaDB(BaseANN):
def __init__(self, metric, method_param):
self._m = method_param['M']
self._engine = method_param['engine']
self._cur = None
self._metric_type = {"angular": "cosine", "euclidean": "euclidean"}.get(metric, None)
if self._metric_type is None:
raise Exception(f"[MariaDB] Not support metric type: {metric}!!!")
self._sql_create_table = f"CREATE TABLE ann.ann (id INT PRIMARY KEY, v VECTOR(%d) NOT NULL) ENGINE={self._engine}"
self._sql_insert = f"INSERT INTO ann.ann (id, v) VALUES (%s, %s)"
self._sql_create_index = f"ALTER TABLE ann.ann ADD VECTOR KEY v(v) DISTANCE={self._metric_type} M={self._m}"
self._sql_search = f"SELECT id FROM ann.ann ORDER by vec_distance_{self._metric_type}(v, %s) LIMIT %d"
self.start_db()
# Connect to MariaDB using Unix socket
conn = mariadb.connect()
self._cur = conn.cursor()
def start_db(self):
# Get free memory in MB
free_memory = psutil.virtual_memory().available
# Set buffer/cache size
innodb_buffer_size = int(free_memory * 0.4)
key_buffer_size = int(free_memory * 0.3)
mhnsw_cache_size = int(free_memory * 0.4)
subprocess.run(
f"service mariadb start --skip-networking "
f"--innodb-buffer-pool-size={innodb_buffer_size} "
f"--key-buffer-size={key_buffer_size} "
# f"--general_log=1 --general_log_file=/var/lib/mysql/general.log "
f"--mhnsw-max-cache-size={mhnsw_cache_size}",
shell=True,
check=True,
stdout=sys.stdout,
stderr=sys.stderr
)
@staticmethod
def vector_to_hex(v):
"""Convert vector to bytes for database storage"""
return np.array(v, 'float32').tobytes()
def fit(self, X, batch_size=1000):
"""
Fit the database with vectors
Parameters:
X: numpy array of vectors to insert
batch_size: number of records to insert in each batch
"""
# Prepare database and table
self._cur.execute("SET GLOBAL max_allowed_packet = 1073741824")
self._cur.execute("DROP DATABASE IF EXISTS ann")
self._cur.execute("CREATE DATABASE ann")
self._cur.execute(self._sql_create_table, (len(X[0]),))
# Insert data in batches
print("Inserting data in batches...")
start_time = time.time()
batch = []
for i, embedding in enumerate(X):
batch.append((int(i), self.vector_to_hex(embedding)))
# Insert when batch is full
if len(batch) >= batch_size:
self._cur.executemany(self._sql_insert, batch)
batch.clear()
# Insert remaining records in final batch
if batch:
self._cur.executemany(self._sql_insert, batch)
insert_time = time.time() - start_time
print(f"Insert time for {len(X)} records: {insert_time:.2f}s")
self._cur.execute("COMMIT")
self._cur.execute("FLUSH TABLES")
# Create index
print("Creating index...")
start_time = time.time()
self._cur.execute(self._sql_create_index)
index_time = time.time() - start_time
print(f"Index creation time: {index_time:.2f}s")
self._cur.execute("COMMIT")
self._cur.execute("FLUSH TABLES")
def set_query_arguments(self, ef_search):
# Set ef_search
self._ef_search = ef_search
self._cur.execute(f"SET mhnsw_ef_search = {ef_search}")
self._cur.execute("COMMIT")
def query(self, v, n):
self._cur.execute(self._sql_search, (self.vector_to_hex(v), n))
return [id for id, in self._cur.fetchall()]
def get_memory_usage(self):
stem = '/var/lib/mysql/ann/ann#i#01.'
return sum(os.stat(f).st_size for f in glob.glob(stem + 'ibd') + glob.glob(stem + 'MY[ID]')) / 1024
def __str__(self):
return f"MariaDB(m={self._m}, ef_search={self._ef_search}, engine={self._engine})"