diff --git a/README.md b/README.md index 7b5c69e..f591aec 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ The `milvus-model` library provides the integration with common embedding and reranker models for Milvus, a high performance open-source vector database built for AI applications. `milvus-model` lib is included as a dependency in `pymilvus`, the Python SDK of Milvus. -`milvus-model` supports embedding and reranker models from service providers like OpenAI, Voyage AI, Cohere, and open-source models through SentenceTransformers. +`milvus-model` supports embedding and reranker models from service providers like OpenAI, Voyage AI, Cohere, and open-source models through SentenceTransformers or Hugging Face [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference) . `milvus-model` supports Python 3.8 and above. diff --git a/src/pymilvus/model/dense/__init__.py b/src/pymilvus/model/dense/__init__.py index a8e91d5..f5c26bd 100644 --- a/src/pymilvus/model/dense/__init__.py +++ b/src/pymilvus/model/dense/__init__.py @@ -2,6 +2,7 @@ from pymilvus.model.dense.sentence_transformer import SentenceTransformerEmbeddingFunction from pymilvus.model.dense.voyageai import VoyageEmbeddingFunction from pymilvus.model.dense.jinaai import JinaEmbeddingFunction +from pymilvus.model.dense.tei import TEIEmbeddingFunction from pymilvus.model.dense.onnx import OnnxEmbeddingFunction from pymilvus.model.dense.cohere import CohereEmbeddingFunction from pymilvus.model.dense.mistralai import MistralAIEmbeddingFunction @@ -13,9 +14,10 @@ "SentenceTransformerEmbeddingFunction", "VoyageEmbeddingFunction", "JinaEmbeddingFunction", + "TEIEmbeddingFunction", "OnnxEmbeddingFunction", "CohereEmbeddingFunction", "MistralAIEmbeddingFunction", "NomicEmbeddingFunction", - "InstructorEmbeddingFunction" + "InstructorEmbeddingFunction", ] diff --git a/src/pymilvus/model/dense/tei.py b/src/pymilvus/model/dense/tei.py new file mode 100644 index 0000000..e69e1d6 --- /dev/null +++ b/src/pymilvus/model/dense/tei.py @@ -0,0 +1,49 @@ +from typing import List, Optional + +import numpy as np +import requests + +from pymilvus.model.base import BaseEmbeddingFunction + + +class TEIEmbeddingFunction(BaseEmbeddingFunction): + def __init__( + self, + api_url: str, + dimensions: Optional[int] = None, + ): + self.api_url = api_url + "/v1/embeddings" + self._session = requests.Session() + self._dim = dimensions + + @property + def dim(self): + if self._dim is None: + # This works by sending a dummy message to the API to retrieve the vector dimension, + # as the original API does not directly provide this information + self._dim = self._call_api(["get dim"])[0].shape[0] + return self._dim + + def encode_queries(self, queries: List[str]) -> List[np.array]: + return self._call_api(queries) + + def encode_documents(self, documents: List[str]) -> List[np.array]: + return self._call_api(documents) + + def __call__(self, texts: List[str]) -> List[np.array]: + return self._call_api(texts) + + def _call_api(self, texts: List[str]): + data = {"input": texts} + resp = self._session.post( # type: ignore[assignment] + self.api_url, + json=data, + ).json() + if "data" not in resp: + raise RuntimeError(resp["message"]) + + embeddings = resp["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return] + return [np.array(result["embedding"]) for result in sorted_embeddings] diff --git a/src/pymilvus/model/reranker/__init__.py b/src/pymilvus/model/reranker/__init__.py index bd28640..7ac4f17 100644 --- a/src/pymilvus/model/reranker/__init__.py +++ b/src/pymilvus/model/reranker/__init__.py @@ -3,6 +3,7 @@ from pymilvus.model.reranker.voyageai import VoyageRerankFunction from pymilvus.model.reranker.cross_encoder import CrossEncoderRerankFunction from pymilvus.model.reranker.jinaai import JinaRerankFunction +from pymilvus.model.reranker.tei import TEIRerankFunction __all__ = [ "CohereRerankFunction", @@ -10,4 +11,5 @@ "VoyageRerankFunction", "CrossEncoderRerankFunction", "JinaRerankFunction", + "TEIRerankFunction", ] diff --git a/src/pymilvus/model/reranker/tei.py b/src/pymilvus/model/reranker/tei.py new file mode 100644 index 0000000..1b81953 --- /dev/null +++ b/src/pymilvus/model/reranker/tei.py @@ -0,0 +1,28 @@ +from typing import List + +import requests + +from pymilvus.model.base import BaseRerankFunction, RerankResult + + +class TEIRerankFunction(BaseRerankFunction): + def __init__(self, api_url: str): + self.api_url = api_url + "/rerank" + self._session = requests.Session() + + def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: + resp = self._session.post( # type: ignore[assignment] + self.api_url, + json={ + "query": query, + "return_text": True, + "texts": documents, + }, + ).json() + if "error" in resp: + raise RuntimeError(resp["error"]) + + results = [] + for res in resp[:5]: + results.append(RerankResult(text=res["text"], score=res["score"], index=res["index"])) + return results