Skip to content

Commit dd6fe99

Browse files
committed
style: update
1 parent 88e9900 commit dd6fe99

File tree

5 files changed

+10
-7
lines changed

5 files changed

+10
-7
lines changed

src/dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@ def __init__(
4040
self.model_config = OmegaConf.load(model_config_dir / "transformers.yaml")
4141
self.tokenizer_config = OmegaConf.load(self.tokenizer_config_path)
4242
self.tokenizer = SentencePieceBPETokenizer(
43-
vocab_file=str(tokenizer_dir / (self.tokenizer_config.tokenizer_name + "-vocab.json")),
44-
merges_file=str(tokenizer_dir / (self.tokenizer_config.tokenizer_name + "-merges.txt")),
43+
vocab_file=str(
44+
tokenizer_dir / (self.tokenizer_config.tokenizer_name + "-vocab.json")
45+
),
46+
merges_file=str(
47+
tokenizer_dir / (self.tokenizer_config.tokenizer_name + "-merges.txt")
48+
),
4549
)
4650
self.source_lines = source_lines
4751
self.target_lines = target_lines
@@ -50,7 +54,9 @@ def __len__(self) -> int:
5054
return len(self.source_lines)
5155

5256
def __getitem__(self, index: int) -> Tuple[List[int], List[int]]:
53-
source_encoded, target_encoded = self.collate(self.source_lines[index], self.target_lines[index])
57+
source_encoded, target_encoded = self.collate(
58+
self.source_lines[index], self.target_lines[index]
59+
)
5460
return source_encoded, target_encoded
5561

5662
def _encode(

tests/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
22
import sys
33

4-
54
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

tests/test_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest # noqa: E902
2-
32
from tokenizers import SentencePieceBPETokenizer
43

54
from dataset import WMT14Dataset

tests/test_load_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@
66
@pytest.mark.parametrize("langpair", ["de-en"])
77
def test_setup(langpair):
88
dm = WMT14DataModule(langpair)
9-

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from src.utils import read_lines
44

5-
test_filepath = 'data/example.de'
5+
test_filepath = "data/example.de"
66

77

88
@pytest.mark.parametrize("filepath", test_filepath)

0 commit comments

Comments
 (0)