Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,12 @@ def __init__(
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")

# TODO: add is_te_min_version("2.9.0") before merge
if config.qk_clip:
# TE 2.9.0 introduces return_max_logit for qk-clip getting the max attention logits
extra_kwargs["return_max_logit"] = True
self.current_max_attn_logits = None

super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels,
Expand Down Expand Up @@ -1058,6 +1064,20 @@ def forward(
**attention_bias_kwargs,
**packed_seq_kwargs,
)

# TODO: add is_te_min_version("2.9.0") before merge
if self.config.qk_clip:
# Update Q K outside of TE Attention API
core_attn_out, batch_max_attention_logits = core_attn_out

# Update QK_Clip balancing eta
if self.current_max_attn_logits is None:
self.current_max_attn_logits = batch_max_attention_logits
else:
self.current_max_attn_logits = torch.max(
self.current_max_attn_logits, batch_max_attention_logits
)

else:
core_attn_out = super().forward(
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
Expand Down
105 changes: 105 additions & 0 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,13 @@ def set_for_recompute_input_layernorm(self):
"""Set the attention layer for recompute input_layernorm. Only needed for fp8."""
raise NotImplementedError("set_for_recompute_input_layernorm is not implemented.")

def clip_qk(self):
"""
QK Clipping is a technique to clip the query and key attention logits to prevent the
attention logits from exploding.
"""
raise NotImplementedError("clip_qk is not implemented.")


class SelfAttention(Attention):
"""Self-attention layer class
Expand Down Expand Up @@ -1134,6 +1141,104 @@ def set_for_recompute_input_layernorm(self):

set_save_original_input(self.linear_qkv)

def clip_qk(self):
"""
QK Clipping is a technique to clip the query and key attention logits to prevent the
attention logits from exploding. This function is experimental on GQA.
"""
if not self.config.qk_clip:
raise ValueError("qk_clip option needs to be enabled")

if self.core_attention.current_max_attn_logits is None:
raise ValueError("current_max_attn_logits is None")

assert self.core_attention.current_max_attn_logits.shape == (
self.num_attention_heads_per_partition,
), f"current_max_attn_logits shape is not ({self.num_attention_heads_per_partition}, ) \
but {self.core_attention.current_max_attn_logits.shape}"

grouped_max_attn_logits = torch.max(
self.core_attention.current_max_attn_logits.view(
self.num_query_groups_per_partition, -1
),
dim=1,
).values

# only update the weight if any head has
# current_max_attn_logits > qk_clip_threshold
if torch.any(grouped_max_attn_logits > self.config.qk_clip_threshold):
# Use num_query_groups_per_partition for tensor parallel scenarios

# qk_clip_balancing_eta (g, 1, 1)
assert grouped_max_attn_logits.shape == (
self.num_query_groups_per_partition,
), f"current_max_attn_logits shape is not ({self.num_query_groups_per_partition},) \
but {grouped_max_attn_logits.shape}"
qk_clip_balancing_eta = torch.clamp(
self.config.qk_clip_threshold / grouped_max_attn_logits, max=1.0
).view(self.num_query_groups_per_partition, 1, 1)
assert torch.all(qk_clip_balancing_eta <= 1.0)

# Handle different weight access patterns (main_param vs direct access)
if hasattr(self.linear_qkv.weight, 'main_param'):
weight = self.linear_qkv.weight.main_param.data
weight_shape = weight.shape
else:
weight = self.linear_qkv.weight.data

# Reshape to (g, query_projection_size + 2 * kv_projection_size, -1)
weight_reshaped = weight.view(
self.num_query_groups_per_partition,
(self.query_projection_size + 2 * self.kv_projection_size)
// self.num_query_groups_per_partition,
-1,
)

# Split into query_projection_size and 2 * kv_projection_size parts:
# (n, a, -1) and (n, b, -1)
weight_q = weight_reshaped[
:, : self.query_projection_size // self.num_query_groups_per_partition, :
]
weight_k = weight_reshaped[
:,
self.query_projection_size
// self.num_query_groups_per_partition : (
self.query_projection_size + self.kv_projection_size
)
// self.num_query_groups_per_partition,
:,
]
weight_v = weight_reshaped[
:,
(self.query_projection_size + self.kv_projection_size)
// self.num_query_groups_per_partition :,
:,
]

# extend the qk_clip_balancing_eta to the same shape as weight_q and weight_k
qk_clip_balancing_eta_extended = qk_clip_balancing_eta.repeat(1, weight_q.size(1), 1)

# Clipping
weight_q = weight_q * torch.pow(
qk_clip_balancing_eta_extended, self.config.qk_clip_alpha
)
weight_k = weight_k * torch.pow(qk_clip_balancing_eta, 1 - self.config.qk_clip_alpha)

# Concatenate back and reshape to original shape
weight_updated = torch.cat([weight_q, weight_k, weight_v], dim=1)
weight_updated = weight_updated.view(
self.query_projection_size + 2 * self.kv_projection_size, -1
)

# Apply the updated weights
if hasattr(self.linear_qkv.weight, 'main_param'):
self.linear_qkv.weight.main_param.data.copy_(weight_updated)
else:
self.linear_qkv.weight.data.copy_(weight_updated)

# reset current_max_attn_logits
self.core_attention.current_max_attn_logits = None


class CrossAttention(Attention):
"""Cross-attention layer class
Expand Down
124 changes: 124 additions & 0 deletions megatron/core/transformer/multi_latent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,3 +917,127 @@ def set_for_recompute_input_layernorm(self):
if self.config.q_lora_rank is not None:
set_save_original_input(self.linear_q_down_proj)
set_save_original_input(self.linear_kv_down_proj)

def clip_qk(self):
"""
QK Clipping is a technique to clip the query and key attention logits to prevent the
attention logits from exploding. Per MuonClip usage, we update the weight by calling this
function after Muon optimizer step.
"""

if not self.config.qk_clip:
raise ValueError("qk_clip option needs to be enabled")

if self.core_attention.current_max_attn_logits is None:
raise ValueError("current_max_attn_logits is None")

# Check if we're in absorption mode
if self.cache_mla_latents and not hasattr(self, 'linear_kv_up_proj'):
raise ValueError(
"qk_clip is not supported when cache_mla_latents is enabled and absorption is "
"active. The linear_kv_up_proj layer has been deleted during absorption "
"preparation."
)

assert self.core_attention.current_max_attn_logits.shape == (
self.num_attention_heads_per_partition,
), f"current_max_attn_logits shape is not ({self.num_attention_heads_per_partition}, ) \
but {self.core_attention.current_max_attn_logits.shape}"

# only update the weight if any head has
# current_max_attn_logits > qk_clip_threshold
if torch.any(self.core_attention.current_max_attn_logits > self.config.qk_clip_threshold):
# Use num_attention_heads_per_partition for tensor parallel scenarios

# qk_clip_balancing_eta (n, 1, 1)
assert self.core_attention.current_max_attn_logits.shape == (
self.num_attention_heads_per_partition,
), f"current_max_attn_logits shape is not ({self.num_attention_heads_per_partition},) \
but {self.core_attention.current_max_attn_logits.shape}"
qk_clip_balancing_eta = torch.clamp(
self.config.qk_clip_threshold / self.core_attention.current_max_attn_logits, max=1.0
).view(self.num_attention_heads_per_partition, 1, 1)
assert torch.all(qk_clip_balancing_eta <= 1.0)

# Update q side weight, keep qk_pos_emb_head_dim side weight unchanged
if self.config.q_lora_rank is None:
q_proj_weight = self.linear_q_proj.weight
else:
q_proj_weight = self.linear_q_up_proj.weight

# Handle different weight access patterns (main_param vs direct access)
if hasattr(q_proj_weight, 'main_param'):
weight = q_proj_weight.main_param.data
else:
weight = q_proj_weight.data

# Reshape to (n, a + b, -1)
weight_reshaped = weight.view(
self.num_attention_heads_per_partition,
self.config.qk_head_dim + self.config.qk_pos_emb_head_dim,
-1,
)

# Split into qk_head_dim and qk_pos_emb_head_dim parts: (n, a, -1) and (n, b, -1)
weight_q_nope = weight_reshaped[:, : self.config.qk_head_dim, :]
weight_q_pe = weight_reshaped[:, self.config.qk_head_dim :, :]

# Clipping
weight_q_nope = weight_q_nope * torch.pow(
qk_clip_balancing_eta, self.config.qk_clip_alpha
)
weight_q_pe = weight_q_pe * qk_clip_balancing_eta

# Concatenate back and reshape to original shape
weight_q_updated = torch.cat([weight_q_nope, weight_q_pe], dim=1)
weight_q_updated = weight_q_updated.view(
self.num_attention_heads_per_partition
* (self.config.qk_head_dim + self.config.qk_pos_emb_head_dim),
-1,
)

# Apply the updated weights
if hasattr(q_proj_weight, 'main_param'):
q_proj_weight.main_param.data.copy_(weight_q_updated)
else:
q_proj_weight.data.copy_(weight_q_updated)

# Update k side weight, keep v side weight unchanged
kv_proj_weight = self.linear_kv_up_proj.weight

# Handle different weight access patterns
if hasattr(kv_proj_weight, 'main_param'):
weight_kv = kv_proj_weight.main_param.data
else:
weight_kv = kv_proj_weight.data

# shape: (n, qk_head_dim + v_head_dim, kv_lora_rank)
weight_reshaped = weight_kv.view(
self.num_attention_heads_per_partition,
self.config.qk_head_dim + self.config.v_head_dim,
-1,
)

# Split into qk_head_dim and v_head_dim parts: (n, a, -1) and (n, b, -1)
weight_k = weight_reshaped[:, : self.config.qk_head_dim, :]
weight_v = weight_reshaped[:, self.config.qk_head_dim :, :]

# Clipping
weight_k = weight_k * torch.pow(qk_clip_balancing_eta, 1 - self.config.qk_clip_alpha)

# Concatenate back and reshape to original shape
weight_kv_updated = torch.cat([weight_k, weight_v], dim=1)
weight_kv_updated = weight_kv_updated.view(
self.num_attention_heads_per_partition
* (self.config.qk_head_dim + self.config.v_head_dim),
-1,
)

# Apply the updated weights
if hasattr(kv_proj_weight, 'main_param'):
kv_proj_weight.main_param.data.copy_(weight_kv_updated)
else:
kv_proj_weight.data.copy_(weight_kv_updated)

# reset current_max_attn_logits
self.core_attention.current_max_attn_logits = None
10 changes: 10 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ class TransformerConfig(ModelParallelConfig):
qk_layernorm: bool = False
"""Whether to apply `normalization` type of normalization to the query and key embeddings."""

qk_clip: bool = False
"""Whether to clip the query and key weights. Introduced in TE 2.9.0. Needed for Muon LLM
training."""

qk_clip_alpha: float = 0.5
"""The balancing alpha for qk-clip. Q = Q * (eta ** alpha)"""

qk_clip_threshold: float = 100
"""The balancing threshold for qk-clip. eta = min(threshold / max_attention_logits, 1.0)"""

test_mode: bool = False
"""Whether to run real-time tests."""

Expand Down
22 changes: 21 additions & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,15 @@ def validate_args(args, defaults={}):
if args.add_bias_linear:
args.add_qkv_bias = True

if args.qk_clip:
# TODO: add is_te_min_version("2.9.0") before merge
# assert is_te_min_version("2.9.0"), \
# '--qk-clip is only supported with TE >= 2.9.0.'
assert 0.0 < args.qk_clip_alpha < 1.0, \
'--qk-clip-balancing-alpha must be between 0.0 and 1.0 when using --qk-clip.'
assert args.qk_clip_threshold > 0, \
'--qk-clip-balancing-threshold must be greater than 0 when using --qk-clip.'

# Retro checks.
if args.retro_add_retriever:

Expand Down Expand Up @@ -1205,6 +1214,9 @@ def validate_args(args, defaults={}):
assert (
args.recompute_granularity != 'full'
), 'recompute_granularity must not be full when CUDA Graphs are enabled.'

if args.multi_latent_attention:
assert not args.group_query_attention, "Group query attention is mutually exclusive with multi latent attention."

# Print arguments.
_print_args("arguments", args)
Expand Down Expand Up @@ -1860,6 +1872,8 @@ def _add_logging_args(parser):
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
group.add_argument('--log-max-attention-logit', action='store_true',
help='Enable max attention logit logging to tensorboard.')
group.add_argument('--wandb-project', type=str, default='',
help='The wandb project name. Ignore wandb by default.')
group.add_argument('--wandb-entity', type=str, default='',
Expand Down Expand Up @@ -2202,6 +2216,12 @@ def _add_training_args(parser):
group.add_argument('--add-qkv-bias', action='store_true',
help='Enable bias only in the QKV linear layers',
dest='add_qkv_bias')
group.add_argument('--qk-clip', action='store_true',
help='Whether to use qk-clip for training stabilization, strongly recommended for Muon.')
group.add_argument('--qk-clip-alpha', type=float, default=0.5,
help='The balancing alpha for qk-clip.')
group.add_argument('--qk-clip-threshold', type=float, default=100,
help='The balancing threshold for qk-clip.')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
Expand Down Expand Up @@ -3196,7 +3216,7 @@ def _add_mla_args(parser):
group = parser.add_argument_group(title="mla")
group.add_argument('--q-lora-rank', type=int, default=None,
help="Rank of Query tensor's low rank representation.")
group.add_argument('--kv-lora-rank', type=int, default=32,
group.add_argument('--kv-lora-rank', type=int, default=512,
help="Rank of Key and Value tensors' low rank representation.")
group.add_argument('--qk-head-dim', type=int, default=128,
help="Dimension of the head in the QK projection. q_head_dim = qk_head_dim + qk_pos_emb_head_dim")
Expand Down
Loading
Loading