|
23 | 23 |
|
24 | 24 | from torch import nn |
25 | 25 | from transformers.activations import ACT2FN |
| 26 | +from transformers.configuration_utils import PretrainedConfig |
26 | 27 | from typing import Optional, List, Tuple |
27 | 28 |
|
28 | 29 | # Flash attention imports |
|
43 | 44 | ) |
44 | 45 |
|
45 | 46 |
|
| 47 | +class LlamaConfig(PretrainedConfig): |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + vocab_size=32000, |
| 51 | + hidden_size=4096, |
| 52 | + intermediate_size=11008, |
| 53 | + num_hidden_layers=32, |
| 54 | + num_attention_heads=32, |
| 55 | + num_key_value_heads=None, |
| 56 | + hidden_act="silu", |
| 57 | + max_position_embeddings=2048, |
| 58 | + initializer_range=0.02, |
| 59 | + rms_norm_eps=1e-6, |
| 60 | + use_cache=True, |
| 61 | + pad_token_id=0, |
| 62 | + bos_token_id=1, |
| 63 | + eos_token_id=2, |
| 64 | + pretraining_tp=1, |
| 65 | + tie_word_embeddings=False, |
| 66 | + rope_scaling=None, |
| 67 | + **kwargs, |
| 68 | + ): |
| 69 | + self.vocab_size = vocab_size |
| 70 | + self.max_position_embeddings = max_position_embeddings |
| 71 | + self.hidden_size = hidden_size |
| 72 | + self.intermediate_size = intermediate_size |
| 73 | + self.num_hidden_layers = num_hidden_layers |
| 74 | + self.num_attention_heads = num_attention_heads |
| 75 | + |
| 76 | + # for backward compatibility |
| 77 | + if num_key_value_heads is None: |
| 78 | + num_key_value_heads = num_attention_heads |
| 79 | + |
| 80 | + self.num_key_value_heads = num_key_value_heads |
| 81 | + self.hidden_act = hidden_act |
| 82 | + self.initializer_range = initializer_range |
| 83 | + self.rms_norm_eps = rms_norm_eps |
| 84 | + self.pretraining_tp = pretraining_tp |
| 85 | + self.use_cache = use_cache |
| 86 | + self.rope_scaling = rope_scaling |
| 87 | + |
| 88 | + super().__init__( |
| 89 | + pad_token_id=pad_token_id, |
| 90 | + bos_token_id=bos_token_id, |
| 91 | + eos_token_id=eos_token_id, |
| 92 | + tie_word_embeddings=tie_word_embeddings, |
| 93 | + **kwargs, |
| 94 | + ) |
| 95 | + |
| 96 | + |
46 | 97 | class LlamaRMSNorm(nn.Module): |
47 | 98 | def __init__(self, prefix, weights, eps=1e-6): |
48 | 99 | """ |
|
0 commit comments