Skip to content

Commit 0291a06

Browse files
committed
Create module for MariaDB
MariaDB supports Vector now. Add new module for benchmark against MariaDB 11.8 database server.
1 parent 33ecd5f commit 0291a06

File tree

4 files changed

+180
-0
lines changed

4 files changed

+180
-0
lines changed

.github/workflows/benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
- kdtree
4747
- kgn
4848
- luceneknn
49+
- mariadb
4950
- milvus
5051
- mrpt
5152
- nndescent
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
FROM ann-benchmarks
2+
3+
RUN apt-get update
4+
5+
# Install Python mariadb module
6+
RUN apt-get install -y libmariadb-dev
7+
RUN pip3 install mariadb
8+
9+
# Install server
10+
RUN apt-get install -y curl sudo
11+
RUN curl -LsS https://r.mariadb.com/downloads/mariadb_repo_setup | sudo bash -s -- --mariadb-server-version="mariadb-11.8"
12+
RUN apt-get install -y mariadb-server
13+
14+
WORKDIR /home/app
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
float:
2+
any:
3+
- base_args: ['@metric']
4+
constructor: MariaDB
5+
disabled: false
6+
docker_tag: ann-benchmarks-mariadb
7+
module: ann_benchmarks.algorithms.mariadb
8+
name: mariadb
9+
run_groups:
10+
myisam:
11+
arg_groups: [{M: [6, 8, 12, 16, 32, 48], engine: 'MyISAM'}]
12+
args: {}
13+
query_args: [[10, 20, 30, 40]]
14+
innodb:
15+
arg_groups: [{M: [6, 8, 12, 16, 32, 48], engine: 'InnoDB'}]
16+
args: {}
17+
query_args: [[10, 20, 30, 40]]
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import glob
2+
import os
3+
import subprocess
4+
import sys
5+
import time
6+
7+
from itertools import chain
8+
from multiprocessing.pool import Pool
9+
10+
import mariadb
11+
import numpy as np
12+
import psutil
13+
14+
from ..base.module import BaseANN
15+
16+
def vector_to_hex(v):
17+
"""Convert vector to bytes for database storage"""
18+
return np.array(v, 'float32').tobytes()
19+
20+
def many_queries(arg):
21+
conn = mariadb.connect()
22+
cur = conn.cursor()
23+
24+
res = []
25+
for v in arg[2]:
26+
cur.execute(arg[0], (vector_to_hex(v), arg[1]))
27+
res.append([id for id, in cur.fetchall()])
28+
29+
return res
30+
31+
class MariaDB(BaseANN):
32+
33+
def __init__(self, metric, method_param):
34+
self._m = method_param['M']
35+
self._engine = method_param['engine']
36+
self._cur = None
37+
38+
self._metric_type = {"angular": "cosine", "euclidean": "euclidean"}.get(metric, None)
39+
if self._metric_type is None:
40+
raise Exception(f"[MariaDB] Not support metric type: {metric}!!!")
41+
42+
self._sql_create_table = f"CREATE TABLE ann.ann (id INT PRIMARY KEY, v VECTOR(%d) NOT NULL) ENGINE={self._engine}"
43+
self._sql_insert = f"INSERT INTO ann.ann (id, v) VALUES (%s, %s)"
44+
self._sql_create_index = f"ALTER TABLE ann.ann ADD VECTOR KEY v(v) DISTANCE={self._metric_type} M={self._m}"
45+
self._sql_search = f"SELECT id FROM ann.ann ORDER by vec_distance_{self._metric_type}(v, %s) LIMIT %d"
46+
47+
self.start_db()
48+
49+
# Connect to MariaDB using Unix socket
50+
conn = mariadb.connect()
51+
self._cur = conn.cursor()
52+
53+
def start_db(self):
54+
# Get free memory in MB
55+
free_memory = psutil.virtual_memory().available
56+
57+
# Set buffer/cache size
58+
innodb_buffer_size = int(free_memory * 0.4)
59+
key_buffer_size = int(free_memory * 0.3)
60+
mhnsw_cache_size = int(free_memory * 0.4)
61+
62+
subprocess.run(
63+
f"service mariadb start --skip-networking "
64+
f"--innodb-buffer-pool-size={innodb_buffer_size} "
65+
f"--key-buffer-size={key_buffer_size} "
66+
# f"--general_log=1 --general_log_file=/var/lib/mysql/general.log "
67+
f"--mhnsw-max-cache-size={mhnsw_cache_size}",
68+
shell=True,
69+
check=True,
70+
stdout=sys.stdout,
71+
stderr=sys.stderr
72+
)
73+
74+
def fit(self, X, batch_size=1000):
75+
"""
76+
Fit the database with vectors
77+
78+
Parameters:
79+
X: numpy array of vectors to insert
80+
batch_size: number of records to insert in each batch
81+
"""
82+
# Prepare database and table
83+
self._cur.execute("SET GLOBAL max_allowed_packet = 1073741824")
84+
self._cur.execute("DROP DATABASE IF EXISTS ann")
85+
self._cur.execute("CREATE DATABASE ann")
86+
self._cur.execute(self._sql_create_table, (len(X[0]),))
87+
88+
# Insert data in batches
89+
print("Inserting data in batches...")
90+
start_time = time.time()
91+
92+
batch = []
93+
for i, embedding in enumerate(X):
94+
batch.append((int(i), vector_to_hex(embedding)))
95+
96+
# Insert when batch is full
97+
if len(batch) >= batch_size:
98+
self._cur.executemany(self._sql_insert, batch)
99+
batch.clear()
100+
101+
# Insert remaining records in final batch
102+
if batch:
103+
self._cur.executemany(self._sql_insert, batch)
104+
105+
insert_time = time.time() - start_time
106+
print(f"Insert time for {len(X)} records: {insert_time:.2f}s")
107+
108+
self._cur.execute("COMMIT")
109+
self._cur.execute("FLUSH TABLES")
110+
111+
# Create index
112+
print("Creating index...")
113+
start_time = time.time()
114+
self._cur.execute(self._sql_create_index)
115+
116+
index_time = time.time() - start_time
117+
print(f"Index creation time: {index_time:.2f}s")
118+
119+
self._cur.execute("COMMIT")
120+
self._cur.execute("FLUSH TABLES")
121+
122+
def set_query_arguments(self, ef_search):
123+
# Set ef_search
124+
self._ef_search = ef_search
125+
self._cur.execute(f"SET GLOBAL mhnsw_ef_search = {ef_search}")
126+
self._cur.execute("COMMIT")
127+
128+
def query(self, v, n):
129+
self._cur.execute(self._sql_search, (vector_to_hex(v), n))
130+
131+
return [id for id, in self._cur.fetchall()]
132+
133+
def batch_query(self, X, n):
134+
XX=[]
135+
for i in range(os.cpu_count()):
136+
XX.append((self._sql_search, n, X[int(len(X)/os.cpu_count()*i):int(len(X)/os.cpu_count()*(i+1))]))
137+
pool = Pool()
138+
self._res = pool.map(many_queries, XX)
139+
140+
def get_batch_results(self):
141+
return np.array(list(chain(*self._res)))
142+
143+
def get_memory_usage(self):
144+
stem = '/var/lib/mysql/ann/ann#i#01.'
145+
return sum(os.stat(f).st_size for f in glob.glob(stem + 'ibd') + glob.glob(stem + 'MY[ID]')) / 1024
146+
147+
def __str__(self):
148+
return f"MariaDB(m={self._m}, ef_search={self._ef_search}, engine={self._engine})"

0 commit comments

Comments
 (0)