|
| 1 | +import hashlib |
| 2 | +from functools import wraps |
| 3 | +from time import sleep |
| 4 | +from typing import Callable, List, Tuple |
| 5 | + |
| 6 | +from meilisearch.client import Client, TaskInfo |
| 7 | + |
| 8 | + |
| 9 | +# References: |
| 10 | +# https://www.meilisearch.com/docs/learn/experimental/vector_search |
| 11 | +# https://github.com/meilisearch/meilisearch-python/blob/d5a0babe50b4ce5789892845db98b30d4db72203/tests/index/test_index_search_meilisearch.py#L491-L493 |
| 12 | +# https://github.com/meilisearch/meilisearch-python/blob/d5a0babe50b4ce5789892845db98b30d4db72203/tests/conftest.py#L132-L146 |
| 13 | + |
| 14 | +VECOR_NAME = "docs-embed" |
| 15 | +VECOR_DIM = 768 |
| 16 | + |
| 17 | +MeilisearchFunc = Callable[..., Tuple[Client, TaskInfo]] |
| 18 | + |
| 19 | + |
| 20 | +def wait_for_task_completion(func: MeilisearchFunc) -> MeilisearchFunc: |
| 21 | + """ |
| 22 | + Decorator to wait for MeiliSearch task completion |
| 23 | + A function that is being decorated should return (Client, TaskInfo) |
| 24 | + """ |
| 25 | + |
| 26 | + @wraps(func) |
| 27 | + def wrapped_meilisearch_function(*args, **kwargs): |
| 28 | + # Extract the Client and Task info from the function's return value |
| 29 | + client, task = func(*args, **kwargs) |
| 30 | + index_id = args[1] # Adjust this index based on where it actually appears in your arguments |
| 31 | + task_id = task.task_uid |
| 32 | + |
| 33 | + while True: |
| 34 | + # task failed |
| 35 | + if task.status == "failed": |
| 36 | + # Optionally, retrieve more detailed error information if available |
| 37 | + error_message = task.error.get("message") if task.error else "Unknown error" |
| 38 | + error_type = task.error.get("type") if task.error else "Unknown" |
| 39 | + error_link = task.error.get("link") if task.error else "No additional information" |
| 40 | + |
| 41 | + # Raise an exception with the error details |
| 42 | + raise Exception( |
| 43 | + f"Task {task_id} failed with error type '{error_type}': {error_message}. More info: {error_link}" |
| 44 | + ) |
| 45 | + task = client.index(index_id).get_task(task_id) # Use the Index object's uid |
| 46 | + # task succeeded |
| 47 | + if task.status == "succeeded": |
| 48 | + return task |
| 49 | + # task processing |
| 50 | + sleep(1) |
| 51 | + |
| 52 | + return wrapped_meilisearch_function |
| 53 | + |
| 54 | + |
| 55 | +@wait_for_task_completion |
| 56 | +def create_embedding_db(client: Client, index_name: str): |
| 57 | + index = client.index(index_name) |
| 58 | + task_info = index.update_embedders({VECOR_NAME: {"source": "userProvided", "dimensions": VECOR_DIM}}) |
| 59 | + return client, task_info |
| 60 | + |
| 61 | + |
| 62 | +@wait_for_task_completion |
| 63 | +def delete_embedding_db(client: Client, index_name: str): |
| 64 | + index = client.index(index_name) |
| 65 | + task_info = index.delete() |
| 66 | + return client, task_info |
| 67 | + |
| 68 | + |
| 69 | +def hash_text_sha1(text): |
| 70 | + hash_object = hashlib.sha1() |
| 71 | + # Encode the text to bytes and update the hash object |
| 72 | + hash_object.update(text.encode("utf-8")) |
| 73 | + # Get the hexadecimal digest of the hash |
| 74 | + hex_dig = hash_object.hexdigest() |
| 75 | + return hex_dig |
| 76 | + |
| 77 | + |
| 78 | +@wait_for_task_completion |
| 79 | +def add_embeddings_to_db(client: Client, index_name: str, embeddings): |
| 80 | + index = client.index(index_name) |
| 81 | + payload_data = [ |
| 82 | + { |
| 83 | + "id": hash_text_sha1(e.text), |
| 84 | + "text": e.text, |
| 85 | + "source": e.source, |
| 86 | + "library": e.package_name, |
| 87 | + "_vectors": {VECOR_NAME: e.embedding}, |
| 88 | + } |
| 89 | + for e in embeddings |
| 90 | + ] |
| 91 | + task_info = index.add_documents(payload_data) |
| 92 | + return client, task_info |
0 commit comments