Skip to content

Commit

Permalink
add attention lora flag
Browse files Browse the repository at this point in the history
  • Loading branch information
BerenMillidge committed Aug 16, 2024
1 parent cafc1bd commit 38ceb85
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 4 deletions.
Binary file added __pycache__/mamba_model.cpython-312.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, config, layer_number, attn_mask_type=AttnMaskType.padding, **
attention_dropout=0.0, layer_number=layer_number,
attn_mask_type="no_mask")

if self.config.use_shared_block_lora and 0 == 1:
if self.config.use_shared_attention_lora:
self.linear_q_lora_A_list = nn.ParameterList([])
self.linear_q_lora_B_list = nn.ParameterList([])
self.linear_k_lora_A_list = nn.ParameterList([])
Expand Down Expand Up @@ -153,7 +153,7 @@ def forward(self, hidden_states, attention_mask, key_value_states=None, inferenc
)


if self.config.use_shared_block_lora and 0 == 1:
if self.config.use_shared_attention_lora:
new_lora_tensor_shape = new_tensor_shape[:-1] + (-1,)
linear_q_lora_A = self.linear_q_lora_A_list[forward_layer_idx]
linear_q_lora_B = self.linear_q_lora_B_list[forward_layer_idx]
Expand Down
2 changes: 2 additions & 0 deletions mamba_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class MambaConfig():
use_low_rank_mamba_proj: bool = False
num_shared_mamba_proj: int = 1
mamba_lora_rank: int = 1
use_shared_attention_lora: bool = False
rope_theta: int = 10000



Expand Down
6 changes: 5 additions & 1 deletion mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def from_pretrained(cls, model_name, **kwargs):
NUM_MEM_BLOCKS =2
json_config = load_config_hf(model_name)
state_dict = load_state_dict_hf(model_name)
if "num_mem_blocks" in json_config.keys():
num_mem_blocks = json_config["num_mem_blocks"]
else:
num_mem_blocks = NUM_MEM_BLOCKS
config = MambaConfig(
num_layers = json_config["num_hidden_layers"],
hidden_size = json_config["hidden_size"],
Expand All @@ -141,7 +145,7 @@ def from_pretrained(cls, model_name, **kwargs):
kv_channels = json_config["kv_channels"],
ffn_hidden_size = json_config["ffn_hidden_size"],
vocab_size = json_config["vocab_size"],
num_mem_blocks = NUM_MEM_BLOCKS,
num_mem_blocks = num_mem_blocks,
)
model = MambaModel(config = config, max_sequence_length = 4096)
model.load_state_dict(state_dict)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ requests==2.31.0
safetensors==0.4.2
sympy==1.12
tokenizers==0.15.2
torch==2.2.2
#torch==2.2.2
tqdm==4.66.2
transformers==4.39.3
triton==2.2.0
Expand Down

0 comments on commit 38ceb85

Please sign in to comment.