Skip to content

Commit 3024520

Browse files
committed
Update type annotation and expand hint in attention function
1 parent db82c61 commit 3024520

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/attention_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@ def dot_product_attention(
2222
q (torch.Tensor): The query tensor of shape [batch, heads, out_length, d_k].
2323
k (torch.Tensor): The key tensor of shape [batch, heads, out_length, d_k].
2424
v (torch.Tensor): The value-tensor of shape [batch, heads, out_length, d_v].
25+
is_causal (bool): Whether to apply a causal mask.
2526
2627
Returns:
2728
torch.Tensor: The attention values of shape [batch, heads, out_length, d_v]
2829
"""
2930
# TODO implement multi head attention.
30-
# Use i.e. torch.transpose, torch.sqrt, torch.tril, torch.exp, torch.inf
31-
# as well as torch.nn.functional.softmax .
31+
# Hint: You will likely need torch.transpose, torch.sqrt, torch.tril,
32+
# torch.inf, and torch.nn.functional.softmax.
33+
# For applying the causal mask, you can either try using torch.exp or torch.masked_fill.
34+
3235
attention_out = None
3336
return attention_out
3437

src/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def convert(sequences: torch.Tensor, inv_vocab: dict) -> list:
9191
"""Convert an array of character-integers to a list of letters.
9292
9393
Args:
94-
sequences (jnp.ndarray): An integer array, which represents characters.
94+
sequences (torch.Tensor): An integer array, which represents characters.
9595
inv_vocab (dict): The dictonary with the integer to char mapping.
9696
9797
Returns:

0 commit comments

Comments
 (0)