Skip to content

Commit ba7e1c1

Browse files
Mock sentence-transformers and nltk in tests
1 parent a57e556 commit ba7e1c1

File tree

4 files changed

+54
-27
lines changed

4 files changed

+54
-27
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ dev = [
7777
"pandas",
7878
"mapbox-vector-tile",
7979
"jinja2",
80-
"nltk",
81-
"sentence_transformers",
8280
"tqdm",
8381
"mypy",
8482
"pyright",

test_elasticsearch/test_dsl/conftest.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
import asyncio
2020
import os
2121
import re
22+
import sys
2223
import time
2324
from datetime import datetime
24-
from typing import Any, AsyncGenerator, Dict, Generator, Tuple, cast
25+
from hashlib import md5
26+
from typing import Any, AsyncGenerator, Dict, Generator, List, Tuple, cast
2527
from unittest import SkipTest
2628
from unittest.mock import AsyncMock, Mock
2729

30+
import pytest
2831
import pytest_asyncio
2932
from elastic_transport import ObjectApiResponse
3033
from pytest import fixture, skip
@@ -49,6 +52,54 @@
4952
)
5053

5154

55+
def pytest_configure(config: "pytest.Config"):
56+
# mock setence-transformers
57+
class MockSentenceTransformer:
58+
def __init__(self, model: Any):
59+
pass
60+
61+
def encode(self, text: str) -> List[float]:
62+
vector = [int(ch) for ch in md5(text.encode()).digest()]
63+
total = sum(vector)
64+
return [float(v) / total for v in vector]
65+
66+
mock_sentence_transformers_mod = type(sys)("sentence_transformers")
67+
setattr(
68+
mock_sentence_transformers_mod,
69+
"__original_mod",
70+
sys.modules.get("sentence_transformers"),
71+
)
72+
setattr(
73+
mock_sentence_transformers_mod, "SentenceTransformer", MockSentenceTransformer
74+
)
75+
sys.modules[mock_sentence_transformers_mod.__name__] = (
76+
mock_sentence_transformers_mod
77+
)
78+
79+
# mock nltk
80+
def mock_tokenize(content: str):
81+
return content.split("\n")
82+
83+
mock_nltk_mod = type(sys)("nlkt")
84+
setattr(mock_nltk_mod, "__original_mod", sys.modules.get("nltk"))
85+
setattr(mock_nltk_mod, "download", Mock())
86+
setattr(mock_nltk_mod, "sent_tokenize", mock_tokenize)
87+
sys.modules["nltk"] = mock_nltk_mod
88+
89+
90+
def pytest_unconfigure(config: "pytest.Config"):
91+
original_sentence_transformers = sys.modules["sentence_transformers"].__original_mod
92+
if original_sentence_transformers:
93+
sys.modules["sentence_transformers"] = original_sentence_transformers
94+
else:
95+
del sys.modules["sentence_transformers"]
96+
original_nltk = sys.modules["nltk"].__original_mod
97+
if original_nltk:
98+
sys.modules["nltk"] = original_nltk
99+
else:
100+
del sys.modules["nltk"]
101+
102+
52103
def get_test_client(
53104
elasticsearch_url, wait: bool = True, **kwargs: Any
54105
) -> Elasticsearch:

test_elasticsearch/test_dsl/test_integration/test_examples/_async/test_vectors.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,7 @@ async def test_vector_search(
3535
if es_version < (8, 11):
3636
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3737

38-
class MockModel:
39-
def __init__(self, model: Any):
40-
pass
41-
42-
def encode(self, text: str) -> List[float]:
43-
vector = [int(ch) for ch in md5(text.encode()).digest()]
44-
total = sum(vector)
45-
return [float(v) / total for v in vector]
46-
47-
mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
48-
4938
await vectors.create()
5039
await vectors.WorkplaceDoc._index.refresh()
5140
results = await (await vectors.search("Welcome to our team!")).execute()
52-
assert results[0].name == "New Employee Onboarding Guide"
41+
assert results[0].name == "Intellectual Property Policy"

test_elasticsearch/test_dsl/test_integration/test_examples/_sync/test_vectors.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,7 @@ def test_vector_search(
3535
if es_version < (8, 11):
3636
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3737

38-
class MockModel:
39-
def __init__(self, model: Any):
40-
pass
41-
42-
def encode(self, text: str) -> List[float]:
43-
vector = [int(ch) for ch in md5(text.encode()).digest()]
44-
total = sum(vector)
45-
return [float(v) / total for v in vector]
46-
47-
mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
48-
4938
vectors.create()
5039
vectors.WorkplaceDoc._index.refresh()
5140
results = (vectors.search("Welcome to our team!")).execute()
52-
assert results[0].name == "New Employee Onboarding Guide"
41+
assert results[0].name == "Intellectual Property Policy"

0 commit comments

Comments
 (0)