Skip to content

Commit 1ceb4fc

Browse files
github-actions[bot]miguelgrinbergpquentin
authored
Mock sentence-transformers and nltk in tests (#3059) (#3063)
* Mock sentence-transformers and nltk in tests * Update test_elasticsearch/test_dsl/conftest.py * switch to a local mock that only affects the one test --------- (cherry picked from commit e05d7f1) Co-authored-by: Miguel Grinberg <[email protected]> Co-authored-by: Quentin Pradet <[email protected]>
1 parent 2aa0459 commit 1ceb4fc

File tree

3 files changed

+44
-20
lines changed

3 files changed

+44
-20
lines changed

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ dev = [
7777
"pandas",
7878
"mapbox-vector-tile",
7979
"jinja2",
80-
"nltk",
81-
"sentence_transformers",
8280
"tqdm",
8381
"mypy",
8482
"pyright",

test_elasticsearch/test_dsl/test_integration/test_examples/_async/test_vectors.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,27 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import sys
1819
from hashlib import md5
1920
from typing import Any, List, Tuple
2021
from unittest import SkipTest
22+
from unittest.mock import Mock, patch
2123

2224
import pytest
2325

2426
from elasticsearch import AsyncElasticsearch
2527

26-
from ..async_examples import vectors
27-
2828

2929
@pytest.mark.asyncio
3030
async def test_vector_search(
31-
async_write_client: AsyncElasticsearch, es_version: Tuple[int, ...], mocker: Any
31+
async_write_client: AsyncElasticsearch, es_version: Tuple[int, ...]
3232
) -> None:
3333
# this test only runs on Elasticsearch >= 8.11 because the example uses
3434
# a dense vector without specifying an explicit size
3535
if es_version < (8, 11):
3636
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3737

38-
class MockModel:
38+
class MockSentenceTransformer:
3939
def __init__(self, model: Any):
4040
pass
4141

@@ -44,9 +44,22 @@ def encode(self, text: str) -> List[float]:
4444
total = sum(vector)
4545
return [float(v) / total for v in vector]
4646

47-
mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
47+
def mock_nltk_tokenize(content: str):
48+
return content.split("\n")
49+
50+
# mock sentence_transformers and nltk, because they are quite big and
51+
# irrelevant for testing the example logic
52+
with patch.dict(
53+
sys.modules,
54+
{
55+
"sentence_transformers": Mock(SentenceTransformer=MockSentenceTransformer),
56+
"nltk": Mock(sent_tokenize=mock_nltk_tokenize),
57+
},
58+
):
59+
# import the example after the dependencies are mocked
60+
from ..async_examples import vectors
4861

49-
await vectors.create()
50-
await vectors.WorkplaceDoc._index.refresh()
51-
results = await (await vectors.search("Welcome to our team!")).execute()
52-
assert results[0].name == "New Employee Onboarding Guide"
62+
await vectors.create()
63+
await vectors.WorkplaceDoc._index.refresh()
64+
results = await (await vectors.search("Welcome to our team!")).execute()
65+
assert results[0].name == "Intellectual Property Policy"

test_elasticsearch/test_dsl/test_integration/test_examples/_sync/test_vectors.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,27 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import sys
1819
from hashlib import md5
1920
from typing import Any, List, Tuple
2021
from unittest import SkipTest
22+
from unittest.mock import Mock, patch
2123

2224
import pytest
2325

2426
from elasticsearch import Elasticsearch
2527

26-
from ..examples import vectors
27-
2828

2929
@pytest.mark.sync
3030
def test_vector_search(
31-
write_client: Elasticsearch, es_version: Tuple[int, ...], mocker: Any
31+
write_client: Elasticsearch, es_version: Tuple[int, ...]
3232
) -> None:
3333
# this test only runs on Elasticsearch >= 8.11 because the example uses
3434
# a dense vector without specifying an explicit size
3535
if es_version < (8, 11):
3636
raise SkipTest("This test requires Elasticsearch 8.11 or newer")
3737

38-
class MockModel:
38+
class MockSentenceTransformer:
3939
def __init__(self, model: Any):
4040
pass
4141

@@ -44,9 +44,22 @@ def encode(self, text: str) -> List[float]:
4444
total = sum(vector)
4545
return [float(v) / total for v in vector]
4646

47-
mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
47+
def mock_nltk_tokenize(content: str):
48+
return content.split("\n")
49+
50+
# mock sentence_transformers and nltk, because they are quite big and
51+
# irrelevant for testing the example logic
52+
with patch.dict(
53+
sys.modules,
54+
{
55+
"sentence_transformers": Mock(SentenceTransformer=MockSentenceTransformer),
56+
"nltk": Mock(sent_tokenize=mock_nltk_tokenize),
57+
},
58+
):
59+
# import the example after the dependencies are mocked
60+
from ..examples import vectors
4861

49-
vectors.create()
50-
vectors.WorkplaceDoc._index.refresh()
51-
results = (vectors.search("Welcome to our team!")).execute()
52-
assert results[0].name == "New Employee Onboarding Guide"
62+
vectors.create()
63+
vectors.WorkplaceDoc._index.refresh()
64+
results = (vectors.search("Welcome to our team!")).execute()
65+
assert results[0].name == "Intellectual Property Policy"

0 commit comments

Comments
 (0)