Skip to content

Commit 308ff3d

Browse files
committed
added annotations throughout and bug fixes
1 parent b36bbaa commit 308ff3d

File tree

4 files changed

+289
-156
lines changed

4 files changed

+289
-156
lines changed

config.py

+118-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
####################
88

99
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10+
C = 512
1011
T = 30 # max context length; informed via quick analysis
12+
N_LAYERS = 6
13+
NUM_HEADS = 8
14+
HEAD_SIZE = 64 # C // NUM_HEADS = 512 // 8
15+
16+
DROPOUT = 0.1
1117
BATCH_SIZE = 8
1218
BATCH_SIZE_VAL = 50
1319

@@ -24,11 +30,116 @@
2430
}
2531
VOCAB_SIZE = 30000
2632
IGNORE = [
27-
# chars that appear very infrequently (1-5 times) in the dataset. Ignoring these sentence
33+
# chars that appear very infrequently (~1-5 times) in the dataset. Ignoring these sentence
2834
# all together given negligible impact on training and project focus is educational
29-
'°', '²', '½', 'Á', 'Ç', 'É', '×', 'ß', 'à', 'á', 'â', 'ã', 'ä', 'å', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï',
30-
'ð', 'ñ', 'ó', 'ô', 'ö', 'ú', 'û', 'ü', 'ā', 'ă', 'Ĉ', 'ĉ', 'Č', 'ĝ', 'ĥ', 'ī', 'ı', 'ĵ', 'ł', 'ō', 'ŝ', 'ş', 'š',
31-
'ŭ', 'ș', 'ə', 'ʻ', 'π', 'ḥ', 'ṛ', '/', '…', '√', '🌡', '😷', '🤒', '🤧', '🤮', '🦠', '🧼', '«', '»', 'Í',
32-
'Ö', 'ć', 'ń', 'ŏ', 'ū', 'М', 'Ч', 'а', 'з', 'и', 'к', 'л', 'о', 'р', 'с', 'т', 'ы', 'э', 'ׁ', '‐', '–', '—', '‘',
33-
'’', '“', '”', '₂', '€', '→', 'あ', '@', '^', '+', '"', '&', '_', '{', '}', '(', ')', '[', ']', '#', "...",
34-
]
35+
"°",
36+
"²",
37+
"½",
38+
"Á",
39+
"Ç",
40+
"É",
41+
"×",
42+
"ß",
43+
"à",
44+
"á",
45+
"â",
46+
"ã",
47+
"ä",
48+
"å",
49+
"ç",
50+
"è",
51+
"é",
52+
"ê",
53+
"ë",
54+
"ì",
55+
"í",
56+
"î",
57+
"ï",
58+
"ð",
59+
"ñ",
60+
"ó",
61+
"ô",
62+
"ö",
63+
"ú",
64+
"û",
65+
"ü",
66+
"ā",
67+
"ă",
68+
"Ĉ",
69+
"ĉ",
70+
"Č",
71+
"ĝ",
72+
"ĥ",
73+
"ī",
74+
"ı",
75+
"ĵ",
76+
"ł",
77+
"ō",
78+
"ŝ",
79+
"ş",
80+
"š",
81+
"ŭ",
82+
"ș",
83+
"ə",
84+
"ʻ",
85+
"π",
86+
"ḥ",
87+
"ṛ",
88+
"/",
89+
"…",
90+
"√",
91+
"🌡",
92+
"😷",
93+
"🤒",
94+
"🤧",
95+
"🤮",
96+
"🦠",
97+
"🧼",
98+
"«",
99+
"»",
100+
"Í",
101+
"Ö",
102+
"ć",
103+
"ń",
104+
"ŏ",
105+
"ū",
106+
"М",
107+
"Ч",
108+
"а",
109+
"з",
110+
"и",
111+
"к",
112+
"л",
113+
"о",
114+
"р",
115+
"с",
116+
"т",
117+
"ы",
118+
"э",
119+
"ׁ",
120+
"‐",
121+
"–",
122+
"—",
123+
"‘",
124+
"’",
125+
"“",
126+
"”",
127+
"₂",
128+
"€",
129+
"→",
130+
"あ",
131+
"@",
132+
"^",
133+
"+",
134+
'"',
135+
"&",
136+
"_",
137+
"{",
138+
"}",
139+
"(",
140+
")",
141+
"[",
142+
"]",
143+
"#",
144+
"...",
145+
]

dataset.py

+5-37
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,17 @@
77
class LanguageDataset(Dataset):
88
def __init__(self, X1, X2, y, pad_token_id=None):
99
super(LanguageDataset).__init__()
10-
self.X1 = torch.tensor(X1, dtype=torch.int32, device=config.DEVICE)
11-
self.X2 = torch.tensor(X2, dtype=torch.int32, device=config.DEVICE)
12-
self.y = torch.tensor(y, dtype=torch.float32, device=config.DEVICE) # float to compare w/model output
10+
self.X1 = torch.tensor(X1, device=config.DEVICE) # NL sequences
11+
self.X2 = torch.tensor(X2, device=config.DEVICE) # EN sequences
12+
self.y = torch.tensor(y, dtype=torch.float32, device=config.DEVICE) # EN shifted
1313
self.pad_token_id = pad_token_id
1414

1515
def __getitem__(self, index):
16+
# each sample returns: NL seq (X1), EN seq (X2), EN+1 seq (y) and pad masking for NL seq (x1pad)
1617
if self.pad_token_id is not None:
1718
x1pad = self.X1[index]==self.pad_token_id
1819
return self.X1[index], self.X2[index], self.y[index], x1pad[None,:]
1920
return self.X1[index], self.X2[index], self.y[index]
2021

2122
def __len__(self):
22-
return len(self.X1)
23-
24-
25-
X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split(
26-
X1,
27-
X2,
28-
y,
29-
test_size=0.15,
30-
shuffle=False, # already shuffled and grouped
31-
)
32-
train_data = LanguageDataset(X1_train, X2_train, y_train, pad_token_id=pad_token_id)
33-
test_data = LanguageDataset(X1_test, X2_test, y_test, pad_token_id=pad_token_id)
34-
35-
training_dl = DataLoader(
36-
train_data,
37-
batch_size=config.BATCH_SIZE,
38-
shuffle=False, # keep sequences of the same length together
39-
drop_last=False,
40-
)
41-
42-
# for loss performance over train/test datasets (vs. batch being trained on)
43-
train_dl = DataLoader(
44-
train_data,
45-
batch_size=config.BATCH_SIZE_VAL,
46-
shuffle=True, # sample across the dataset, regardless of sequence len
47-
drop_last=False,
48-
)
49-
50-
test_dl = DataLoader(
51-
test_data,
52-
batch_size=config.BATCH_SIZE_VAL,
53-
shuffle=True,
54-
drop_last=False,
55-
)
23+
return len(self.X1)

0 commit comments

Comments
 (0)