Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Oct 3, 2024
1 parent 9675fe6 commit 2628268
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,10 @@ def _load_gqa(config, prefix: str, weights):
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert (
list(weight.weight.shape)
== [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
]
), f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

if config.attention_bias:
w = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,10 @@ def _load_gqa(config, prefix: str, weights):
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert (
list(weight.weight.shape)
== [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
]
), f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,10 @@ def _load_gqa(config, prefix: str, weights):
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert (
list(weight.weight.shape)
== [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
]
), f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,10 @@ def _load_gqa(config, prefix: str, weights):
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert (
list(weight.weight.shape)
== [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
]
), f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,10 @@ def _load_gqa(config, prefix: str, weights):
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert (
list(weight.shape)
== [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
]
), f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

# this is the same as llama except for Phi uses bias=True
return TensorParallelColumnLinear(get_linear(weight, bias=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,10 @@ def _load_gqa(config, prefix: str, weights):
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert (
list(weight.weight.shape)
== [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
]
), f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
assert list(weight.weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

if config.use_bias:
w = [
Expand Down

0 comments on commit 2628268

Please sign in to comment.