From cc4cbecd0fab562f174c55f9700c386916bbbd01 Mon Sep 17 00:00:00 2001 From: Artem Bolgar Date: Mon, 18 Mar 2024 19:03:19 -0700 Subject: [PATCH 1/2] Fixing block size for Mistral-7B. --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index dbf24e5d..4900f2b2 100644 --- a/model.py +++ b/model.py @@ -62,7 +62,7 @@ def from_name(cls, name: str): "30B": dict(n_layer=60, n_head=52, dim=6656), "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016, rope_base=1000000), # CodeLlama-34B-Python-hf "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672), - "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), + "Mistral-7B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000), "stories15M": dict(n_layer=6, n_head=6, dim=288), "stories110M": dict(n_layer=12, n_head=12, dim=768), } From 8ca548c07d61315dbef4645a0f23b83015d582d5 Mon Sep 17 00:00:00 2001 From: Artem Bolgar Date: Tue, 19 Mar 2024 16:39:56 -0700 Subject: [PATCH 2/2] Saving some memory on freq_cis tensor with big block_size but small max_seq_length --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 4900f2b2..67158fa2 100644 --- a/model.py +++ b/model.py @@ -110,7 +110,7 @@ def setup_caches(self, max_batch_size, max_seq_length): for b in self.layers: b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) + self.freqs_cis = precompute_freqs_cis(self.max_seq_length, self.config.dim // self.config.n_head, self.config.rope_base) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: