Skip to content

Commit 5e6ddfd

Browse files
fix(server): fix llamav2 config (#635)
1 parent cf83f9b commit 5e6ddfd

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from torch import nn
2525
from transformers.activations import ACT2FN
26+
from transformers.configuration_utils import PretrainedConfig
2627
from typing import Optional, List, Tuple
2728

2829
# Flash attention imports
@@ -43,6 +44,56 @@
4344
)
4445

4546

47+
class LlamaConfig(PretrainedConfig):
48+
def __init__(
49+
self,
50+
vocab_size=32000,
51+
hidden_size=4096,
52+
intermediate_size=11008,
53+
num_hidden_layers=32,
54+
num_attention_heads=32,
55+
num_key_value_heads=None,
56+
hidden_act="silu",
57+
max_position_embeddings=2048,
58+
initializer_range=0.02,
59+
rms_norm_eps=1e-6,
60+
use_cache=True,
61+
pad_token_id=0,
62+
bos_token_id=1,
63+
eos_token_id=2,
64+
pretraining_tp=1,
65+
tie_word_embeddings=False,
66+
rope_scaling=None,
67+
**kwargs,
68+
):
69+
self.vocab_size = vocab_size
70+
self.max_position_embeddings = max_position_embeddings
71+
self.hidden_size = hidden_size
72+
self.intermediate_size = intermediate_size
73+
self.num_hidden_layers = num_hidden_layers
74+
self.num_attention_heads = num_attention_heads
75+
76+
# for backward compatibility
77+
if num_key_value_heads is None:
78+
num_key_value_heads = num_attention_heads
79+
80+
self.num_key_value_heads = num_key_value_heads
81+
self.hidden_act = hidden_act
82+
self.initializer_range = initializer_range
83+
self.rms_norm_eps = rms_norm_eps
84+
self.pretraining_tp = pretraining_tp
85+
self.use_cache = use_cache
86+
self.rope_scaling = rope_scaling
87+
88+
super().__init__(
89+
pad_token_id=pad_token_id,
90+
bos_token_id=bos_token_id,
91+
eos_token_id=eos_token_id,
92+
tie_word_embeddings=tie_word_embeddings,
93+
**kwargs,
94+
)
95+
96+
4697
class LlamaRMSNorm(nn.Module):
4798
def __init__(self, prefix, weights, eps=1e-6):
4899
"""

server/text_generation_server/models/flash_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import torch.distributed
33

44
from opentelemetry import trace
5-
from transformers import AutoConfig
65
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
76
from typing import Optional
87

98
from text_generation_server.models import FlashCausalLM
109
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
1110
FlashLlamaForCausalLM,
11+
LlamaConfig,
1212
)
1313
from text_generation_server.utils import (
1414
initialize_torch_distributed,
@@ -52,7 +52,7 @@ def __init__(
5252
trust_remote_code=trust_remote_code,
5353
)
5454

55-
config = AutoConfig.from_pretrained(
55+
config = LlamaConfig.from_pretrained(
5656
model_id, revision=revision, trust_remote_code=trust_remote_code
5757
)
5858

0 commit comments

Comments
 (0)