Skip to content

Commit

Permalink
Remove unneeded changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tianmu-li committed Dec 19, 2024
1 parent a234ff8 commit 8d036db
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 10 deletions.
10 changes: 4 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,14 +686,9 @@ def __init__(self,
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
self.q_size = self.num_heads * self.head_size * tp_size
self.kv_size = self.num_kv_heads * self.head_size * tp_size
input_size = self.hidden_size
self.output_sizes = [
self.q_size, # q_proj
]
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
Expand Down Expand Up @@ -927,6 +922,8 @@ def weight_loader(self,
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0)
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
Expand Down Expand Up @@ -964,6 +961,7 @@ def weight_loader(self,
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)


class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,6 @@ def load_weights(self, weights: Iterable[Tuple[str,

param = params_dict[name]
weight_loader = param.weight_loader
if self.split_qk_v and (shard_id == "v" or shard_id == "k" or shard_id == "q") :
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)

break
else:
Expand Down

0 comments on commit 8d036db

Please sign in to comment.