Skip to content

Commit

Permalink
Styling
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jan 15, 2025
1 parent 9c2165c commit 01d4a18
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def should_parallelize_layer_predicate_func(layer):
"kv_size_multiplier": None,
"fuse_qkv": None,
"q_output_size_per_partition": None,
"kv_output_size_per_partition": None ,
"kv_output_size_per_partition": None,
}
for mod in model.modules():
if isinstance(mod, OptimumGQAQKVColumnParallelLinear):
Expand Down
12 changes: 10 additions & 2 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,17 @@ def consolidate_tensor_parallel_checkpoints(
if weight_name == "weight_q":
s = slice(0, gqa_qkv_metadata["q_output_size_per_partition"])
elif weight_name == "weight_k":
s = slice(gqa_qkv_metadata["q_output_size_per_partition"], gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"])
s = slice(
gqa_qkv_metadata["q_output_size_per_partition"],
gqa_qkv_metadata["q_output_size_per_partition"]
+ gqa_qkv_metadata["kv_output_size_per_partition"],
)
elif weight_name == "weight_v":
s = slice(gqa_qkv_metadata["q_output_size_per_partition"] + gqa_qkv_metadata["kv_output_size_per_partition"], None)
s = slice(
gqa_qkv_metadata["q_output_size_per_partition"]
+ gqa_qkv_metadata["kv_output_size_per_partition"],
None,
)
else:
s = slice(None, None)
else:
Expand Down
9 changes: 6 additions & 3 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def __init__(
keep_master_weight: bool = False,
kv_size_multiplier: int = 1,
):
from neuronx_distributed.parallel_layers.utils import set_tensor_model_parallel_attributes
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size
from neuronx_distributed.parallel_layers.utils import set_tensor_model_parallel_attributes

super().__init__(
input_size,
Expand Down Expand Up @@ -784,7 +784,7 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
# proj_name = weight_name[-1]
if layer.fuse_qkv:
weight = getattr(layer, "weight_qkv")
bias = getattr(layer, f"bias_qkv")
bias = getattr(layer, "bias_qkv")
else:
weight = getattr(layer, weight_name)
bias = getattr(layer, f"bias_{proj_name}")
Expand Down Expand Up @@ -814,7 +814,10 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
if proj_name == "q":
s = slice(0, layer.q_output_size_per_partition)
elif proj_name == "k":
s = slice(layer.q_output_size_per_partition, layer.q_output_size_per_partition + layer.kv_output_size_per_partition)
s = slice(
layer.q_output_size_per_partition,
layer.q_output_size_per_partition + layer.kv_output_size_per_partition,
)
else:
s = slice(layer.q_output_size_per_partition + layer.kv_output_size_per_partition, None)
weight[s, :] = weight_data
Expand Down

0 comments on commit 01d4a18

Please sign in to comment.