Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MariaDB database #548

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ jobs:
- kdtree
- kgn
- luceneknn
- mariadb
- milvus
- mrpt
- nndescent
Expand Down
14 changes: 14 additions & 0 deletions ann_benchmarks/algorithms/mariadb/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM ann-benchmarks

RUN apt-get update

# Install Python mariadb module
RUN apt-get install -y libmariadb-dev
RUN pip3 install mariadb

# Install server
RUN apt-get install -y curl sudo
RUN curl -LsS https://r.mariadb.com/downloads/mariadb_repo_setup | sudo bash -s -- --mariadb-server-version="mariadb-11.8"
RUN apt-get install -y mariadb-server

WORKDIR /home/app
17 changes: 17 additions & 0 deletions ann_benchmarks/algorithms/mariadb/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
float:
any:
- base_args: ['@metric']
constructor: MariaDB
disabled: false
docker_tag: ann-benchmarks-mariadb
module: ann_benchmarks.algorithms.mariadb
name: mariadb
run_groups:
myisam:
arg_groups: [{M: [6, 8, 16, 32, 48], engine: 'MyISAM'}]
args: {}
query_args: [[5, 10, 20, 40, 60, 100]]
innodb:
arg_groups: [{M: [6, 8, 16, 32, 48], engine: 'InnoDB'}]
args: {}
query_args: [[5, 10, 20, 40, 60, 100]]
124 changes: 124 additions & 0 deletions ann_benchmarks/algorithms/mariadb/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import glob
import os
import subprocess
import sys
import time

import numpy as np
import mariadb
import psutil

from ..base.module import BaseANN

class MariaDB(BaseANN):

def __init__(self, metric, method_param):
self._m = method_param['M']
self._engine = method_param['engine']
self._cur = None

self._metric_type = {"angular": "cosine", "euclidean": "euclidean"}.get(metric, None)
if self._metric_type is None:
raise Exception(f"[MariaDB] Not support metric type: {metric}!!!")

self._sql_create_table = f"CREATE TABLE ann.ann (id INT PRIMARY KEY, v VECTOR(%d) NOT NULL) ENGINE={self._engine}"
self._sql_insert = f"INSERT INTO ann.ann (id, v) VALUES (%s, %s)"
self._sql_create_index = f"ALTER TABLE ann.ann ADD VECTOR KEY v(v) DISTANCE={self._metric_type} M={self._m}"
self._sql_search = f"SELECT id FROM ann.ann ORDER by vec_distance_{self._metric_type}(v, %s) LIMIT %d"

self.start_db()

# Connect to MariaDB using Unix socket
conn = mariadb.connect()
self._cur = conn.cursor()

def start_db(self):
# Get free memory in MB
free_memory = psutil.virtual_memory().available

# Set buffer/cache size
innodb_buffer_size = int(free_memory * 0.4)
key_buffer_size = int(free_memory * 0.3)
mhnsw_cache_size = int(free_memory * 0.4)

subprocess.run(
f"service mariadb start --skip-networking "
f"--innodb-buffer-pool-size={innodb_buffer_size} "
f"--key-buffer-size={key_buffer_size} "
# f"--general_log=1 --general_log_file=/var/lib/mysql/general.log "
f"--mhnsw-max-cache-size={mhnsw_cache_size}",
shell=True,
check=True,
stdout=sys.stdout,
stderr=sys.stderr
)

@staticmethod
def vector_to_hex(v):
"""Convert vector to bytes for database storage"""
return np.array(v, 'float32').tobytes()

def fit(self, X, batch_size=1000):
"""
Fit the database with vectors

Parameters:
X: numpy array of vectors to insert
batch_size: number of records to insert in each batch
"""
# Prepare database and table
self._cur.execute("SET GLOBAL max_allowed_packet = 1073741824")
self._cur.execute("DROP DATABASE IF EXISTS ann")
self._cur.execute("CREATE DATABASE ann")
self._cur.execute(self._sql_create_table, (len(X[0]),))

# Insert data in batches
print("Inserting data in batches...")
start_time = time.time()

batch = []
for i, embedding in enumerate(X):
batch.append((int(i), self.vector_to_hex(embedding)))

# Insert when batch is full
if len(batch) >= batch_size:
self._cur.executemany(self._sql_insert, batch)
batch.clear()

# Insert remaining records in final batch
if batch:
self._cur.executemany(self._sql_insert, batch)

insert_time = time.time() - start_time
print(f"Insert time for {len(X)} records: {insert_time:.2f}s")

self._cur.execute("COMMIT")
self._cur.execute("FLUSH TABLES")

# Create index
print("Creating index...")
start_time = time.time()
self._cur.execute(self._sql_create_index)

index_time = time.time() - start_time
print(f"Index creation time: {index_time:.2f}s")

self._cur.execute("COMMIT")
self._cur.execute("FLUSH TABLES")

def set_query_arguments(self, ef_search):
# Set ef_search
self._ef_search = ef_search
self._cur.execute(f"SET mhnsw_ef_search = {ef_search}")
self._cur.execute("COMMIT")

def query(self, v, n):
self._cur.execute(self._sql_search, (self.vector_to_hex(v), n))
return [id for id, in self._cur.fetchall()]

def get_memory_usage(self):
stem = '/var/lib/mysql/ann/ann#i#01.'
return sum(os.stat(f).st_size for f in glob.glob(stem + 'ibd') + glob.glob(stem + 'MY[ID]')) / 1024

def __str__(self):
return f"MariaDB(m={self._m}, ef_search={self._ef_search}, engine={self._engine})"