Skip to content

Commit

Permalink
slight cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
BerenMillidge committed Nov 19, 2024
1 parent 6866df1 commit 3f016c1
Showing 1 changed file with 0 additions and 8 deletions.
8 changes: 0 additions & 8 deletions mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def from_pretrained(cls, model_name, **kwargs):
ffn_hidden_size = json_config["ffn_hidden_size"],
vocab_size = json_config["vocab_size"],
num_mem_blocks = num_mem_blocks,
#num_key_value_heads = json_config["num_key_value_heads"],
num_query_groups = json_config["num_query_groups"],
use_shared_attention_lora = use_shared_attention_lora,
use_mem_rope = json_config["use_mem_rope"],
Expand All @@ -163,13 +162,10 @@ def from_pretrained(cls, model_name, **kwargs):
if el == "g":
g_indices.append(i)
i = 0
#die
for k in list(state_dict.keys()):
new_k = k.replace("model","decoder").replace("mamba_layers","layers").replace("mamba","mixer").replace("input_layernorm","norm").replace("linear_layers","block_map").replace("feed_forward","mlp.mixer").replace("self_attn.o_proj","sa.mixer.linear_proj").replace("pre_ff_layernorm","mlp.norm").replace("in_proj","in_proj.0").replace("self_attn.linear_q","sa.mixer.linear_q").replace("self_attn.linear_k","sa.mixer.linear_k").replace("self_attn.linear_v","sa.mixer.linear_v")
#print("NUM MEM BLOCKS: ", num_mem_blocks)
for i in range(num_mem_blocks):
new_k = new_k.replace("decoder.blocks." + str(i) + ".norm.weight","decoder.blocks." + str(i) + ".sa.norm.weight")
#new_k = new_k.replace("decoder.blocks.1.norm.weight","decoder.blocks.1.sa.norm.weight")
if "block_map" in new_k:
block_idx = int(new_k.split("block_map.")[1].split(".")[0])
i +=1
Expand All @@ -185,7 +181,6 @@ def from_pretrained(cls, model_name, **kwargs):
q = state_dict["decoder.blocks." + str(i) + ".self_attn.q_proj.weight"]
k = state_dict["decoder.blocks." + str(i) + ".self_attn.k_proj.weight"]
v = state_dict["decoder.blocks." + str(i) + ".self_attn.v_proj.weight"]
print("NUM HEADS: ", num_heads)
qkv = HF_QKV_Inverse_Transform(q,k,v,num_heads)
state_dict["decoder.blocks."+str(i)+".sa.mixer.linear_qkv.weight"] = qkv
del state_dict["decoder.blocks." + str(i) + ".self_attn.q_proj.weight"]
Expand All @@ -196,15 +191,12 @@ def from_pretrained(cls, model_name, **kwargs):
return model

def save_pretrained(self, save_directory):
# Ensure save_directory exists
if not os.path.exists(save_directory):
os.makedirs(save_directory)

# Save the model's state_dict
model_path = os.path.join(save_directory, 'pytorch_model.bin')
torch.save(self.state_dict(), model_path)

# Save the configuration of the model
config_path = os.path.join(save_directory, 'config.json')
with open(config_path, 'w') as f:
json.dump(self.config.__dict__, f)

0 comments on commit 3f016c1

Please sign in to comment.