|
| 1 | +import pytest # noqa: E902 |
| 2 | + |
| 3 | +from tokenizers import SentencePieceBPETokenizer |
| 4 | + |
| 5 | +from dataset import WMT14Dataset |
| 6 | + |
| 7 | + |
| 8 | +@pytest.mark.parametrize("langpair", ["en-de"]) |
| 9 | +@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]]) |
| 10 | +@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]] |
| 11 | +def test_len(langpair, source_lines, target_lines): |
| 12 | + ds = WMT14dataset(langpair, source_lines, target_lines) |
| 13 | + assert len(ds.source_lines) == len(ds) |
| 14 | + |
| 15 | + |
| 16 | +@pytest.mark.parametrize("langpair", ["en-de"]) |
| 17 | +@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]]) |
| 18 | +@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]] |
| 19 | +def test_getitem(langpair, source_lines, target_lines): |
| 20 | + ds = WMT14dataset(langpair, source_lines, target_lines) |
| 21 | + source_encode_pad_test, target_encode_pad_test = ds[0] |
| 22 | + assert source_encode_pad_test.size() == target_encode_pad_test.size() |
| 23 | + assert source_encode_pad_test.size()[0] == ds.model_config.max_len |
| 24 | + |
| 25 | + |
| 26 | +@pytest.mark.parametrize("langpair", ["en-de"]) |
| 27 | +@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]]) |
| 28 | +@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]] |
| 29 | +def test_encode(langpair, source_lines, target_lines): |
| 30 | + ds = WMT14dataset(langpair, source_lines, target_lines) |
| 31 | + source_encode_test, target_encode_test = ds._encode(source_lines[0], target_lines[0]) |
| 32 | + bos = ds.tokenizer.token_to_id('<bos>') |
| 33 | + eos = ds.tokenizer.token_to_id('<eos>') |
| 34 | + assert target_encode_test[0] == bos |
| 35 | + assert target_encode_test[-1] == eos |
| 36 | + assert isinstance(source_encode_test, list) |
| 37 | + assert isinstance(source_encode_test[0], int) |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.parametrize("langpair", ["en-de"]) |
| 41 | +@pytest.mark.parametrize("source_lines", [["Der Bau und die Reparatur der Autostraßen...", "die Mitteilungen sollen den geschäftlichen kommerziellen Charakter tragen."]]) |
| 42 | +@pytest.mark.parametrize("target_lines", [["Construction and repair of highways and...", "An announcement must be commercial character."]] |
| 43 | +def test_collate(langpair, source_lines, target_lines): |
| 44 | + ds = WMT14dataset(langpair, source_lines, target_lines) |
| 45 | + source_encode_pad_test, target_encode_pad_test = ds.collate(source_lines[0], target_lines[0]) |
| 46 | + assert source_encode_pad_test.size() == target_encode_pad_test.size() |
| 47 | + assert source_encode_pad_test.size()[0] == ds.model_config.max_len |
0 commit comments