Skip to content
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

The `milvus-model` library provides the integration with common embedding and reranker models for Milvus, a high performance open-source vector database built for AI applications. `milvus-model` lib is included as a dependency in `pymilvus`, the Python SDK of Milvus.

`milvus-model` supports embedding and reranker models from service providers like OpenAI, Voyage AI, Cohere, and open-source models through SentenceTransformers.
`milvus-model` supports embedding and reranker models from service providers like OpenAI, Voyage AI, Cohere, and open-source models through SentenceTransformers or [text-embeddings-inference](https://github.com/huggingface/text-embeddings-inference) .

`milvus-model` supports Python 3.8 and above.

Expand Down
2 changes: 2 additions & 0 deletions src/pymilvus/model/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pymilvus.model.dense.sentence_transformer import SentenceTransformerEmbeddingFunction
from pymilvus.model.dense.voyageai import VoyageEmbeddingFunction
from pymilvus.model.dense.jinaai import JinaEmbeddingFunction
from pymilvus.model.dense.opensource import OpenSourceEmbeddingFunction
from pymilvus.model.dense.onnx import OnnxEmbeddingFunction
from pymilvus.model.dense.cohere import CohereEmbeddingFunction
from pymilvus.model.dense.mistralai import MistralAIEmbeddingFunction
Expand All @@ -13,6 +14,7 @@
"SentenceTransformerEmbeddingFunction",
"VoyageEmbeddingFunction",
"JinaEmbeddingFunction",
"OpenSourceEmbeddingFunction",
"OnnxEmbeddingFunction",
"CohereEmbeddingFunction",
"MistralAIEmbeddingFunction",
Expand Down
47 changes: 47 additions & 0 deletions src/pymilvus/model/dense/opensource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import List, Optional

import numpy as np
import requests

from pymilvus.model.base import BaseEmbeddingFunction


class OpenSourceEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
api_url: str,
dimensions: Optional[int] = None,
):
self.api_url = api_url + "/v1/embeddings"
self._session = requests.Session()
self._dim = dimensions

@property
def dim(self):
if self._dim is None:
self._dim = self._call_api(["get dim"])[0].shape[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this really work? i.e. self._call_api(["get dim"]) aka self._session.post(self.api_url,
json= {"input": ["get dim"]},) will return the vector shape? That sounds magical

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works by sending a dummy message to the API to retrieve the vector dimension, as the original API does not directly provide this information. I'll add a comment here for clarification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM.

return self._dim

def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._call_api(queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._call_api(documents)

def __call__(self, texts: List[str]) -> List[np.array]:
return self._call_api(texts)

def _call_api(self, texts: List[str]):
data = {"input": texts}
resp = self._session.post( # type: ignore[assignment]
self.api_url,
json=data,
).json()
if "data" not in resp:
raise RuntimeError(resp["message"])

embeddings = resp["data"]

# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return]
return [np.array(result["embedding"]) for result in sorted_embeddings]
2 changes: 2 additions & 0 deletions src/pymilvus/model/reranker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from pymilvus.model.reranker.voyageai import VoyageRerankFunction
from pymilvus.model.reranker.cross_encoder import CrossEncoderRerankFunction
from pymilvus.model.reranker.jinaai import JinaRerankFunction
from pymilvus.model.reranker.opensource import OpenSourceRerankFunction

__all__ = [
"CohereRerankFunction",
"BGERerankFunction",
"VoyageRerankFunction",
"CrossEncoderRerankFunction",
"JinaRerankFunction",
"OpenSourceRerankFunction",
]
30 changes: 30 additions & 0 deletions src/pymilvus/model/reranker/opensource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List

import requests

from pymilvus.model.base import BaseRerankFunction, RerankResult


class OpenSourceRerankFunction(BaseRerankFunction):
def __init__(self, api_url: str):
self.api_url = api_url + "/rerank"
self._session = requests.Session()

def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
resp = self._session.post( # type: ignore[assignment]
self.api_url,
json={
"query": query,
"raw_scores": False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall these params be configurable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and what does raw_scores mean? say will it not return scores?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When raw_scores is set to false, the returned scores are normalized to a range of 0-1. When set to true, the scores are the raw, unnormalized values. I believe it should default to false to align with mdoel like JinaAI Rerank. Perhaps I should consider removing this configuration entirely.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

"return_text": True,
"texts": documents,
"truncate": False,
},
).json()
if "error" in resp:
raise RuntimeError(resp["error"])

results = []
for res in resp:
results.append(RerankResult(text=res["text"], score=res["score"], index=res["index"]))
return results