Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from config import LlamaConfig
from llama import load_pretrained
from tokenizer import Tokenizer
from typing import List

class LlamaZeroShotClassifier(torch.nn.Module):
def __init__(self, config: LlamaConfig, tokenizer: Tokenizer, label_names: list[str]):
def __init__(self, config: LlamaConfig, tokenizer: Tokenizer, label_names: List[str]):
super(LlamaZeroShotClassifier, self).__init__()
self.num_labels = config.num_labels
self.llama = load_pretrained(config.pretrained_model_path)
Expand Down
6 changes: 3 additions & 3 deletions llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def forward(self, x):
1) layer normalization of the input (via Root Mean Square layer normalization)
2) self-attention on the layer-normalized input
3) a residual connection (i.e., add the input to the output of the self-attention)
3) layer normalization on the output of the self-attention
4) a feed-forward network on the layer-normalized output of the self-attention
5) add a residual connection from the unnormalized self-attention output to the
4) layer normalization on the output of the self-attention
5) a feed-forward network on the layer-normalized output of the self-attention
6) add a residual connection from the unnormalized self-attention output to the
output of the feed-forward network
'''
# todo
Expand Down
3 changes: 2 additions & 1 deletion rope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np

from rope import apply_rotary_emb
from typing import Tuple

seed = 0

Expand All @@ -17,7 +18,7 @@ def construct_key() -> torch.Tensor:
'''
return 3 * torch.ones([1, 2, 2, 4])

def test_apply_rotary_emb() -> tuple[torch.Tensor, torch.Tensor]:
def test_apply_rotary_emb() -> Tuple[torch.Tensor, torch.Tensor]:
rng = np.random.default_rng(seed)
torch.manual_seed(seed)
model = torch.nn.Linear(3, 2, bias=False)
Expand Down
4 changes: 2 additions & 2 deletions structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ The desired outputs are

### To be implemented
Components that require your implementations are comment with ```#todo```. The detailed instructions can be found in their corresponding code blocks
* ```llama.Attention.forward```
* ```llama.Attention.compute_query_key_value_scores```
* ```llama.RMSNorm.norm```
* ```llama.Llama.forward```
* ```llama.LlamaLayer.forward```
* ```llama.Llama.generate```
* ```rope.apply_rotary_emb``` (this one may be tricky! you can use `rope_test.py` to test your implementation)
* ```optimizer.AdamW.step```
Expand Down