Skip to content

Commit 801ba9c

Browse files
author
Sam Partee
authored
OpenAI and Huggingface Providers (#16)
Providers for creating embeddings with OpenAI and Hugginface. Includes - base class - simple integration testing - async methods for clients that support it.
1 parent c14093b commit 801ba9c

File tree

7 files changed

+193
-2
lines changed

7 files changed

+193
-2
lines changed

.gitignore

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
__pycache__/
22
redisvl.egg-info/
3-
.coverage
4-
scratch
3+
.coverage*
4+
scratch

conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@ def df():
3737
}
3838
)
3939
return data
40+
41+
42+
@pytest.fixture
43+
def openai_key():
44+
return os.getenv("OPENAI_KEY")

redisvl/providers/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
from redisvl.providers.openai import OpenAIProvider
3+
from redisvl.providers.huggingface import HuggingfaceProvider
4+
5+
6+
__all__ = [
7+
"OpenAIProvider",
8+
"HuggingfaceProvider"
9+
]

redisvl/providers/base.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Callable, Dict, List, Optional
2+
3+
4+
class BaseProvider:
5+
def __init__(self, model: str, dims: int, api_config: Optional[Dict] = None):
6+
self._dims = dims
7+
self._model = model
8+
9+
@property
10+
def model(self) -> str:
11+
return self._model
12+
13+
@property
14+
def dims(self) -> int:
15+
return self._dims
16+
17+
def set_model(self, model, dims: int = None):
18+
self._model = model
19+
if dims:
20+
self._dims = dims
21+
22+
def embed_many(
23+
self, inputs: List[str], preprocess: callable = None, chunk_size: int = 1000
24+
) -> List[float]:
25+
raise NotImplementedError
26+
27+
def embed(self, emb_input: str, preprocess: callable = None) -> List[float]:
28+
raise NotImplementedError
29+
30+
async def aembed_many(
31+
self, inputs: List[str], preprocess: callable = None, chunk_size: int = 1000
32+
) -> List[float]:
33+
raise NotImplementedError
34+
35+
async def aembed(self, emb_input: str, preprocess: callable = None) -> List[float]:
36+
raise NotImplementedError
37+
38+
def batchify(self, seq: list, size: int, preprocess: Callable = None):
39+
for pos in range(0, len(seq), size):
40+
if preprocess:
41+
yield [preprocess(chunk) for chunk in seq[pos : pos + size]]
42+
else:
43+
yield [chunk for chunk in seq[pos : pos + size]]

redisvl/providers/huggingface.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Dict, List, Optional
2+
3+
from sentence_transformers import SentenceTransformer
4+
5+
from redisvl.providers.base import BaseProvider
6+
7+
8+
class HuggingfaceProvider(BaseProvider):
9+
def __init__(self, model: str, api_config: Optional[Dict] = None):
10+
# TODO set dims based on model
11+
dims = 768
12+
super().__init__(model, dims, api_config)
13+
self._model_client = SentenceTransformer(model)
14+
15+
def embed(self, emb_input: str, preprocess: callable = None) -> List[float]:
16+
if preprocess:
17+
emb_input = preprocess(emb_input)
18+
embedding = self._model_client.encode([emb_input])[0]
19+
return embedding.tolist()
20+
21+
def embed_many(
22+
self, inputs: List[str], preprocess: callable = None, chunk_size: int = 1000
23+
) -> List[List[float]]:
24+
25+
embeddings = []
26+
for batch in self.batchify(inputs, chunk_size, preprocess):
27+
batch_embeddings = self._model_client.encode(batch)
28+
embeddings.extend([embedding.tolist() for embedding in batch_embeddings])
29+
return embeddings

redisvl/providers/openai.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Dict, List, Optional
2+
3+
import openai
4+
5+
from redisvl.providers.base import BaseProvider
6+
7+
8+
class OpenAIProvider(BaseProvider):
9+
def __init__(self, model: str, api_config: Optional[Dict] = None):
10+
dims = 1536
11+
super().__init__(model, dims, api_config)
12+
if not api_config:
13+
raise ValueError("OpenAI API key is required in api_config")
14+
15+
openai.api_key = api_config.get("api_key", None)
16+
self._model_client = openai.Embedding
17+
18+
def embed_many(
19+
self, inputs: List[str], preprocess: callable = None, chunk_size: int = 1000
20+
) -> List[List[float]]:
21+
22+
results = []
23+
for batch in self.batchify(inputs, chunk_size, preprocess):
24+
response = self._model_client.create(input=batch, engine=self._model)
25+
results += [r["embedding"] for r in response["data"]]
26+
return results
27+
28+
def embed(self, emb_input: str, preprocess: callable = None) -> List[float]:
29+
if preprocess:
30+
emb_input = preprocess(emb_input)
31+
result = self._model_client.create(input=[emb_input], engine=self._model)
32+
return result["data"][0]["embedding"]
33+
34+
async def aembed_many(
35+
self, inputs: List[str], preprocess: callable = None, chunk_size: int = 1000
36+
) -> List[List[float]]:
37+
38+
results = []
39+
for batch in self.batchify(inputs, chunk_size, preprocess):
40+
response = await self._model_client.acreate(input=batch, engine=self._model)
41+
results += [r["embedding"] for r in response["data"]]
42+
return results
43+
44+
async def aembed(self, emb_input: str, preprocess: callable = None) -> List[float]:
45+
if preprocess:
46+
emb_input = preprocess(emb_input)
47+
result = await self._model_client.acreate(input=[emb_input], engine=self._model)
48+
return result["data"][0]["embedding"]

tests/integration/test_providers.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
from redisvl.providers import (
3+
HuggingfaceProvider,
4+
OpenAIProvider
5+
)
6+
7+
8+
@pytest.fixture(params=[HuggingfaceProvider, OpenAIProvider])
9+
def provider(request, openai_key):
10+
# Here we use actual models for integration test
11+
if request.param == HuggingfaceProvider:
12+
return request.param(model="sentence-transformers/all-mpnet-base-v2")
13+
elif request.param == OpenAIProvider:
14+
return request.param(model="text-embedding-ada-002", api_config={
15+
"api_key": openai_key
16+
})
17+
18+
def test_provider_embed(provider):
19+
text = 'This is a test sentence.'
20+
embedding = provider.embed(text)
21+
22+
assert isinstance(embedding, list)
23+
assert len(embedding) == provider.dims
24+
25+
def test_provider_embed_many(provider):
26+
texts = ['This is the first test sentence.', 'This is the second test sentence.']
27+
embeddings = provider.embed_many(texts)
28+
29+
assert isinstance(embeddings, list)
30+
assert len(embeddings) == len(texts)
31+
assert all(isinstance(emb, list) and len(emb) == provider.dims for emb in embeddings)
32+
33+
34+
@pytest.fixture(params=[OpenAIProvider])
35+
def aprovider(request, openai_key):
36+
# Here we use actual models for integration test
37+
if request.param == OpenAIProvider:
38+
return request.param(model="text-embedding-ada-002", api_config={
39+
"api_key": openai_key
40+
})
41+
42+
@pytest.mark.asyncio
43+
async def test_provider_aembed(aprovider):
44+
text = 'This is a test sentence.'
45+
embedding = await aprovider.aembed(text)
46+
47+
assert isinstance(embedding, list)
48+
assert len(embedding) == aprovider.dims
49+
50+
@pytest.mark.asyncio
51+
async def test_provider_aembed_many(aprovider):
52+
texts = ['This is the first test sentence.', 'This is the second test sentence.']
53+
embeddings = await aprovider.aembed_many(texts)
54+
55+
assert isinstance(embeddings, list)
56+
assert len(embeddings) == len(texts)
57+
assert all(isinstance(emb, list) and len(emb) == aprovider.dims for emb in embeddings)

0 commit comments

Comments
 (0)