Skip to content

Commit 48672d1

Browse files
committed
use meilisearch for vector db
1 parent 051a6ac commit 48672d1

File tree

3 files changed

+104
-4
lines changed

3 files changed

+104
-4
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from setuptools import find_packages, setup
55

6-
install_requires = ["black", "GitPython", "tqdm", "pyyaml", "packaging", "nbformat", "huggingface_hub", "pillow"]
6+
install_requires = ["black", "GitPython", "tqdm", "pyyaml", "packaging", "nbformat", "huggingface_hub", "pillow", "meilisearch"]
77

88
extras = {}
99

src/doc_builder/build_embeddings.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from pathlib import Path
2424
from typing import List
2525

26+
import meilisearch
2627
from huggingface_hub import get_inference_endpoint
2728
from tqdm import tqdm
2829

2930
from .autodoc import autodoc_markdown, resolve_links_in_text
3031
from .convert_md_to_mdx import process_md
3132
from .convert_rst_to_mdx import find_indent, is_empty_line
33+
from .meilisearch_helper import add_embeddings_to_db
3234
from .utils import read_doc_config
3335

3436

@@ -449,7 +451,13 @@ def build_embeddings(
449451

450452
# Step 2: create embeddings
451453
embeddings = call_embedding_inference(chunks)
452-
print(len(embeddings))
453454

454-
# Step 3: push embeddings to vector database
455-
# TODO
455+
# Step 3: push embeddings to vector database (meilisearch)
456+
client = meilisearch.Client("https://edge.meilisearch.com", os.environ["MEILISEARCH_KEY"])
457+
index_name = "docs-embed"
458+
459+
payload_docs_size = 50
460+
461+
for i in tqdm(range(32, len(embeddings), payload_docs_size)):
462+
chunk_embeddings = embeddings[i : i + payload_docs_size]
463+
add_embeddings_to_db(client, index_name, chunk_embeddings)

src/doc_builder/meilisearch_helper.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import hashlib
2+
from functools import wraps
3+
from time import sleep
4+
from typing import Callable, List, Tuple
5+
6+
from meilisearch.client import Client, TaskInfo
7+
8+
9+
# References:
10+
# https://www.meilisearch.com/docs/learn/experimental/vector_search
11+
# https://github.com/meilisearch/meilisearch-python/blob/d5a0babe50b4ce5789892845db98b30d4db72203/tests/index/test_index_search_meilisearch.py#L491-L493
12+
# https://github.com/meilisearch/meilisearch-python/blob/d5a0babe50b4ce5789892845db98b30d4db72203/tests/conftest.py#L132-L146
13+
14+
VECOR_NAME = "docs-embed"
15+
VECOR_DIM = 768
16+
17+
MeilisearchFunc = Callable[..., Tuple[Client, TaskInfo]]
18+
19+
20+
def wait_for_task_completion(func: MeilisearchFunc) -> MeilisearchFunc:
21+
"""
22+
Decorator to wait for MeiliSearch task completion
23+
A function that is being decorated should return (Client, TaskInfo)
24+
"""
25+
26+
@wraps(func)
27+
def wrapped_meilisearch_function(*args, **kwargs):
28+
# Extract the Client and Task info from the function's return value
29+
client, task = func(*args, **kwargs)
30+
index_id = args[1] # Adjust this index based on where it actually appears in your arguments
31+
task_id = task.task_uid
32+
33+
while True:
34+
# task failed
35+
if task.status == "failed":
36+
# Optionally, retrieve more detailed error information if available
37+
error_message = task.error.get("message") if task.error else "Unknown error"
38+
error_type = task.error.get("type") if task.error else "Unknown"
39+
error_link = task.error.get("link") if task.error else "No additional information"
40+
41+
# Raise an exception with the error details
42+
raise Exception(
43+
f"Task {task_id} failed with error type '{error_type}': {error_message}. More info: {error_link}"
44+
)
45+
task = client.index(index_id).get_task(task_id) # Use the Index object's uid
46+
# task succeeded
47+
if task.status == "succeeded":
48+
return task
49+
# task processing
50+
sleep(1)
51+
52+
return wrapped_meilisearch_function
53+
54+
55+
@wait_for_task_completion
56+
def create_embedding_db(client: Client, index_name: str):
57+
index = client.index(index_name)
58+
task_info = index.update_embedders({VECOR_NAME: {"source": "userProvided", "dimensions": VECOR_DIM}})
59+
return client, task_info
60+
61+
62+
@wait_for_task_completion
63+
def delete_embedding_db(client: Client, index_name: str):
64+
index = client.index(index_name)
65+
task_info = index.delete()
66+
return client, task_info
67+
68+
69+
def hash_text_sha1(text):
70+
hash_object = hashlib.sha1()
71+
# Encode the text to bytes and update the hash object
72+
hash_object.update(text.encode("utf-8"))
73+
# Get the hexadecimal digest of the hash
74+
hex_dig = hash_object.hexdigest()
75+
return hex_dig
76+
77+
78+
@wait_for_task_completion
79+
def add_embeddings_to_db(client: Client, index_name: str, embeddings):
80+
index = client.index(index_name)
81+
payload_data = [
82+
{
83+
"id": hash_text_sha1(e.text),
84+
"text": e.text,
85+
"source": e.source,
86+
"library": e.package_name,
87+
"_vectors": {VECOR_NAME: e.embedding},
88+
}
89+
for e in embeddings
90+
]
91+
task_info = index.add_documents(payload_data)
92+
return client, task_info

0 commit comments

Comments
 (0)