Skip to content

Commit 3c7a0a1

Browse files
committed
finish exercise.
1 parent 765f17b commit 3c7a0a1

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

src/attention_model.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,10 @@ def dot_product_attention(
2626
Returns:
2727
torch.Tensor: The attention values of shape [batch, heads, out_length, d_v]
2828
"""
29-
t, dk = q.shape[-2:]
30-
device = q.device
31-
32-
scaled = q @ k.transpose(-2, -1) / torch.sqrt(torch.tensor(dk))
33-
if is_causal:
34-
scaled = scaled + (-1) * torch.exp(
35-
(torch.tril(torch.ones(t, t).to(device)) - 0.5) * -2.0 * torch.inf
36-
)
37-
soft_scaled = f.softmax(scaled, dim=-1)
38-
attention_out = soft_scaled @ v
29+
# TODO implement multi head attention.
30+
# Use i.e. torch.transpose, torch.sqrt, torch.tril, torch.exp, torch.inf
31+
# as well as torch.nn.functional.softmax .
32+
attention_out = None
3933
return attention_out
4034

4135

src/util.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,5 @@ def convert(sequences: torch.Tensor, inv_vocab: dict) -> list:
9898
list: A list of characters.
9999
"""
100100
res = []
101-
for int_seq in sequences:
102-
char_seq = []
103-
for int_char in int_seq:
104-
char_seq.append(inv_vocab[int(int_char)])
105-
res.append(char_seq)
101+
# TODO: Return a nested list of characters.
106102
return res

0 commit comments

Comments
 (0)