Skip to content

Commit 702594e

Browse files
switch to a local mock that only affects the one test
1 parent fe196e3 commit 702594e

File tree

3 files changed

+67
-68
lines changed

3 files changed

+67
-68
lines changed

test_elasticsearch/test_dsl/conftest.py

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,12 @@
1919
import asyncio
2020
import os
2121
import re
22-
import sys
2322
import time
2423
from datetime import datetime
25-
from hashlib import md5
26-
from typing import Any, AsyncGenerator, Dict, Generator, List, Tuple, cast
24+
from typing import Any, AsyncGenerator, Dict, Generator, Tuple, cast
2725
from unittest import SkipTest
2826
from unittest.mock import AsyncMock, Mock
2927

30-
import pytest
3128
import pytest_asyncio
3229
from elastic_transport import ObjectApiResponse
3330
from pytest import fixture, skip
@@ -52,54 +49,6 @@
5249
)
5350

5451

55-
def pytest_configure(config: "pytest.Config"):
56-
# mock sentence-transformers
57-
class MockSentenceTransformer:
58-
def __init__(self, model: Any):
59-
pass
60-
61-
def encode(self, text: str) -> List[float]:
62-
vector = [int(ch) for ch in md5(text.encode()).digest()]
63-
total = sum(vector)
64-
return [float(v) / total for v in vector]
65-
66-
mock_sentence_transformers_mod = type(sys)("sentence_transformers")
67-
setattr(
68-
mock_sentence_transformers_mod,
69-
"__original_mod",
70-
sys.modules.get("sentence_transformers"),
71-
)
72-
setattr(
73-
mock_sentence_transformers_mod, "SentenceTransformer", MockSentenceTransformer
74-
)
75-
sys.modules[mock_sentence_transformers_mod.__name__] = (
76-
mock_sentence_transformers_mod
77-
)
78-
79-
# mock nltk
80-
def mock_tokenize(content: str):
81-
return content.split("\n")
82-
83-
mock_nltk_mod = type(sys)("nlkt")
84-
setattr(mock_nltk_mod, "__original_mod", sys.modules.get("nltk"))
85-
setattr(mock_nltk_mod, "download", Mock())
86-
setattr(mock_nltk_mod, "sent_tokenize", mock_tokenize)
87-
sys.modules["nltk"] = mock_nltk_mod
88-
89-
90-
def pytest_unconfigure(config: "pytest.Config"):
91-
original_sentence_transformers = sys.modules["sentence_transformers"].__original_mod
92-
if original_sentence_transformers:
93-
sys.modules["sentence_transformers"] = original_sentence_transformers
94-
else:
95-
del sys.modules["sentence_transformers"]
96-
original_nltk = sys.modules["nltk"].__original_mod
97-
if original_nltk:
98-
sys.modules["nltk"] = original_nltk
99-
else:
100-
del sys.modules["nltk"]
101-
102-
10352
def get_test_client(
10453
elasticsearch_url, wait: bool = True, **kwargs: Any
10554
) -> Elasticsearch:

test_elasticsearch/test_dsl/test_integration/test_examples/_async/test_vectors.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,51 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Any, Tuple
18+
import sys
19+
from hashlib import md5
20+
from typing import Any, List, Tuple
1921
from unittest import SkipTest
22+
from unittest.mock import Mock, patch
2023

2124
import pytest
2225

2326
from elasticsearch import AsyncElasticsearch
2427

25-
from ..async_examples import vectors
26-
2728

2829
@pytest.mark.asyncio
2930
async def test_vector_search(
30-
async_write_client: AsyncElasticsearch, es_version: Tuple[int, ...], mocker: Any
31+
async_write_client: AsyncElasticsearch, es_version: Tuple[int, ...]
3132
) -> None:
3233
# this test only runs on Elasticsearch >= 8.11 because the example uses
3334
# a dense vector without specifying an explicit size
3435
if es_version < (8, 11):
3536
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3637

37-
await vectors.create()
38-
await vectors.WorkplaceDoc._index.refresh()
39-
results = await (await vectors.search("Welcome to our team!")).execute()
40-
assert results[0].name == "Intellectual Property Policy"
38+
class MockSentenceTransformer:
39+
def __init__(self, model: Any):
40+
pass
41+
42+
def encode(self, text: str) -> List[float]:
43+
vector = [int(ch) for ch in md5(text.encode()).digest()]
44+
total = sum(vector)
45+
return [float(v) / total for v in vector]
46+
47+
def mock_nltk_tokenize(content: str):
48+
return content.split("\n")
49+
50+
# mock sentence_transformare and nltk, because they are quite big and
51+
# irrelevant for testing the example logic
52+
with patch.dict(
53+
sys.modules,
54+
{
55+
"sentence_transformers": Mock(SentenceTransformer=MockSentenceTransformer),
56+
"nltk": Mock(sent_tokenize=mock_nltk_tokenize),
57+
},
58+
):
59+
# import the example after the dependencies are mocked
60+
from ..async_examples import vectors
61+
62+
await vectors.create()
63+
await vectors.WorkplaceDoc._index.refresh()
64+
results = await (await vectors.search("Welcome to our team!")).execute()
65+
assert results[0].name == "Intellectual Property Policy"

test_elasticsearch/test_dsl/test_integration/test_examples/_sync/test_vectors.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,51 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from typing import Any, Tuple
18+
import sys
19+
from hashlib import md5
20+
from typing import Any, List, Tuple
1921
from unittest import SkipTest
22+
from unittest.mock import Mock, patch
2023

2124
import pytest
2225

2326
from elasticsearch import Elasticsearch
2427

25-
from ..examples import vectors
26-
2728

2829
@pytest.mark.sync
2930
def test_vector_search(
30-
write_client: Elasticsearch, es_version: Tuple[int, ...], mocker: Any
31+
write_client: Elasticsearch, es_version: Tuple[int, ...]
3132
) -> None:
3233
# this test only runs on Elasticsearch >= 8.11 because the example uses
3334
# a dense vector without specifying an explicit size
3435
if es_version < (8, 11):
3536
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3637

37-
vectors.create()
38-
vectors.WorkplaceDoc._index.refresh()
39-
results = (vectors.search("Welcome to our team!")).execute()
40-
assert results[0].name == "Intellectual Property Policy"
38+
class MockSentenceTransformer:
39+
def __init__(self, model: Any):
40+
pass
41+
42+
def encode(self, text: str) -> List[float]:
43+
vector = [int(ch) for ch in md5(text.encode()).digest()]
44+
total = sum(vector)
45+
return [float(v) / total for v in vector]
46+
47+
def mock_nltk_tokenize(content: str):
48+
return content.split("\n")
49+
50+
# mock sentence_transformare and nltk, because they are quite big and
51+
# irrelevant for testing the example logic
52+
with patch.dict(
53+
sys.modules,
54+
{
55+
"sentence_transformers": Mock(SentenceTransformer=MockSentenceTransformer),
56+
"nltk": Mock(sent_tokenize=mock_nltk_tokenize),
57+
},
58+
):
59+
# import the example after the dependencies are mocked
60+
from ..examples import vectors
61+
62+
vectors.create()
63+
vectors.WorkplaceDoc._index.refresh()
64+
results = (vectors.search("Welcome to our team!")).execute()
65+
assert results[0].name == "Intellectual Property Policy"

0 commit comments

Comments
 (0)