-
Notifications
You must be signed in to change notification settings - Fork 32
Add support for open source models based on text-embeddings-inference #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
8bc54c1
68ca669
7530119
1238eb7
2fb7f46
30d0d96
6bd6a78
8cfca21
1d42ad0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from typing import List, Optional | ||
|
|
||
| import numpy as np | ||
| import requests | ||
|
|
||
| from pymilvus.model.base import BaseEmbeddingFunction | ||
|
|
||
|
|
||
| class OpenSourceEmbeddingFunction(BaseEmbeddingFunction): | ||
| def __init__( | ||
| self, | ||
| api_url: str, | ||
| dimensions: Optional[int] = None, | ||
| ): | ||
| self.api_url = api_url + "/v1/embeddings" | ||
| self._session = requests.Session() | ||
| self._dim = dimensions | ||
|
|
||
| @property | ||
| def dim(self): | ||
| if self._dim is None: | ||
| self._dim = self._call_api(["get dim"])[0].shape[0] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this really work? i.e. self._call_api(["get dim"]) aka self._session.post(self.api_url, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This works by sending a dummy message to the API to retrieve the vector dimension, as the original API does not directly provide this information. I'll add a comment here for clarification. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM. |
||
| return self._dim | ||
|
|
||
| def encode_queries(self, queries: List[str]) -> List[np.array]: | ||
| return self._call_api(queries) | ||
|
|
||
| def encode_documents(self, documents: List[str]) -> List[np.array]: | ||
| return self._call_api(documents) | ||
|
|
||
| def __call__(self, texts: List[str]) -> List[np.array]: | ||
| return self._call_api(texts) | ||
|
|
||
| def _call_api(self, texts: List[str]): | ||
| data = {"input": texts} | ||
| resp = self._session.post( # type: ignore[assignment] | ||
| self.api_url, | ||
| json=data, | ||
| ).json() | ||
| if "data" not in resp: | ||
| raise RuntimeError(resp["message"]) | ||
|
|
||
| embeddings = resp["data"] | ||
|
|
||
| # Sort resulting embeddings by index | ||
| sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore[no-any-return] | ||
| return [np.array(result["embedding"]) for result in sorted_embeddings] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| from typing import List | ||
|
|
||
| import requests | ||
|
|
||
| from pymilvus.model.base import BaseRerankFunction, RerankResult | ||
|
|
||
|
|
||
| class OpenSourceRerankFunction(BaseRerankFunction): | ||
| def __init__(self, api_url: str): | ||
| self.api_url = api_url + "/rerank" | ||
| self._session = requests.Session() | ||
|
|
||
| def __call__(self, query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: | ||
| resp = self._session.post( # type: ignore[assignment] | ||
| self.api_url, | ||
| json={ | ||
| "query": query, | ||
| "raw_scores": False, | ||
|
||
| "return_text": True, | ||
| "texts": documents, | ||
| "truncate": False, | ||
| }, | ||
| ).json() | ||
| if "error" in resp: | ||
| raise RuntimeError(resp["error"]) | ||
|
|
||
| results = [] | ||
| for res in resp: | ||
| results.append(RerankResult(text=res["text"], score=res["score"], index=res["index"])) | ||
| return results | ||
Uh oh!
There was an error while loading. Please reload this page.