Skip to content

Commit 54a5892

Browse files
committed
feat: add test for dataset
1 parent 1b56f39 commit 54a5892

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

tests/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
22
import sys
33

4-
sys.path.insert(
5-
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../tokenizer"))
6-
)
4+
75
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

tests/test_dataset.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest # noqa: E902
2+
3+
from tokenizers import SentencePieceBPETokenizer
4+
5+
from dataset import WMT14Dataset
6+
7+
8+
@pytest.mark.parametrize("langpair", ["en-de"])
9+
@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]])
10+
@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]]
11+
def test_len(langpair, source_lines, target_lines):
12+
ds = WMT14dataset(langpair, source_lines, target_lines)
13+
assert len(ds.source_lines) == len(ds)
14+
15+
16+
@pytest.mark.parametrize("langpair", ["en-de"])
17+
@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]])
18+
@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]]
19+
def test_getitem(langpair, source_lines, target_lines):
20+
ds = WMT14dataset(langpair, source_lines, target_lines)
21+
source_encode_pad_test, target_encode_pad_test = ds[0]
22+
assert source_encode_pad_test.size() == target_encode_pad_test.size()
23+
assert source_encode_pad_test.size()[0] == ds.model_config.max_len
24+
25+
26+
@pytest.mark.parametrize("langpair", ["en-de"])
27+
@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]])
28+
@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]]
29+
def test_encode(langpair, source_lines, target_lines):
30+
ds = WMT14dataset(langpair, source_lines, target_lines)
31+
source_encode_test, target_encode_test = ds._encode(source_lines[0], target_lines[0])
32+
bos = ds.tokenizer.token_to_id('<bos>')
33+
eos = ds.tokenizer.token_to_id('<eos>')
34+
assert target_encode_test[0] == bos
35+
assert target_encode_test[-1] == eos
36+
assert isinstance(source_encode_test, list)
37+
assert isinstance(source_encode_test[0], int)
38+
39+
40+
@pytest.mark.parametrize("langpair", ["en-de"])
41+
@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]])
42+
@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]]
43+
def test_collate(langpair, source_lines, target_lines):
44+
ds = WMT14dataset(langpair, source_lines, target_lines)
45+
source_encode_pad_test, target_encode_pad_test = ds.collate(source_lines[0], target_lines[0])
46+
assert source_encode_pad_test.size() == target_encode_pad_test.size()
47+
assert source_encode_pad_test.size()[0] == ds.model_config.max_len

tests/test_load_dataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import pytest
2+
3+
from src.load_dataset import WMT14DataModule
4+
5+
6+
@pytest.mark.parametrize("langpair", ["de-en"])
7+
def test_setup(langpair):
8+
dm = WMT14DataModule(langpair)
9+

tests/test_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest # noqa: F401
2+
3+
from src.utils import read_lines
4+
5+
test_filepath = 'data/example.de'
6+
7+
8+
@pytest.mark.parametrize("filepath", test_filepath)
9+
def test_read_lines(filepath):
10+
de = read_lines(filepath)
11+
assert isinstance(de, list)
12+
assert (
13+
de[0]
14+
== "iron cement ist eine gebrauchs-fertige Paste, die mit einem Spachtel oder den Fingern als Hohlkehle in die Formecken (Winkel) der Stahlguss -Kokille aufgetragen wird."
15+
)

0 commit comments

Comments
 (0)