Skip to content

Commit e68465c

Browse files
committed
Add basic test for classification
1 parent 60ddef0 commit e68465c

File tree

7 files changed

+63
-0
lines changed

7 files changed

+63
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*.py[cod]
22
.*cache
3+
.valohai

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@ Linting/formatting happens via `pre-commit`. Install it with `pip install pre-co
1414

1515
The linters run by `pre-commit` are `ruff`, `black`, and `prettier`;
1616
you can (should) set up your IDE to run them automatically too.
17+
18+
### Tests
19+
20+
You can run tests with `py.test`:
21+
22+
```
23+
py.test -v .
24+
```

conftest.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import pytest
2+
from valohai.internals import global_state
3+
4+
5+
@pytest.fixture()
6+
def valohai_utils_global_state():
7+
global_state.flush_global_state()
8+
return global_state
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import csv
2+
3+
import valohai
4+
5+
# Snippet from https://valohai-ecosystem-datasets.s3.eu-west-1.amazonaws.com/yelp_reviews_batch_inference.txt
6+
EXAMPLE_DATA = """
7+
Old school.....traditional "mom 'n pop" quality and perfection.
8+
A great out of the way, non-corporate, vestige of Americana. You will love it.
9+
Good fish sandwich.
10+
I always feel like I am constantly bashing breweries for their food, but in my opinion, I feel the bar is raised for places like this.
11+
I called to complain, and the "manager" didn't even apologize!!! So frustrated. Never going back. They seem overpriced, too.
12+
""".strip()
13+
14+
15+
def test_inference(valohai_utils_global_state, monkeypatch, tmp_path):
16+
monkeypatch.setenv("VH_OUTPUTS_DIR", str(tmp_path))
17+
input_path = tmp_path / "input.txt"
18+
input_path.write_text(EXAMPLE_DATA)
19+
valohai.prepare(
20+
step="huggingface-classification-inference",
21+
default_parameters={
22+
"log_frequency": 1,
23+
# This is an untrained model, so the results won't be very interesting.
24+
"huggingface_repository": "distilbert-base-uncased",
25+
"output_path": "test.csv",
26+
},
27+
default_inputs={
28+
"data": str(input_path),
29+
},
30+
)
31+
from models.nlp.classification.huggingface.inference import main
32+
33+
main()
34+
with (tmp_path / "test.csv").open() as f:
35+
results = list(csv.DictReader(f))
36+
assert len(results) == 5

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ ignore_missing_imports = True
1111

1212
[mypy-datasets.*]
1313
ignore_missing_imports = True
14+
15+
[mypy-pytest.*]
16+
ignore_missing_imports = True

pytest.ini

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[pytest]
2+
# this is needed because models.nlp has its own utils which would conflict
3+
# with the utils in the root directory with the default `prepend` mode
4+
addopts = --import-mode=append

ruff.toml

+3
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ ignore = [
2020
"T2",
2121
"TRY003",
2222
]
23+
24+
[per-file-ignores]
25+
"**/test*.py" = ["S101"] # tests can use assertions

0 commit comments

Comments
 (0)