-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Track entropy and MI of routing distribution for topk MoE #188
base: main
Are you sure you want to change the base?
Changes from all commits
2a7cf1b
dd85e84
aef18e7
bef39d8
620ec76
7a93aee
eb617e8
440738a
e5f3c4b
27e2a5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
from fast_llm.engine.schedule.schedule import Schedule, Step | ||
from fast_llm.logging import log_memory_usage | ||
from fast_llm.utils import Assert | ||
from typing import Callable | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -94,6 +95,7 @@ def __init__( | |
self._tied_parameters = self._multi_stage.tied_parameters | ||
self._num_stages = len(self._stages) | ||
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} | ||
self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs} | ||
|
||
def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: | ||
assert not self._is_setup | ||
|
@@ -265,20 +267,41 @@ def run_step( | |
log_pipeline_parallel_main_rank( | ||
lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) | ||
) | ||
|
||
return self._reduce_losses(context), update_successful, metrics | ||
metrics = self._reduce_metrics(context) if return_metrics else metrics | ||
return ( | ||
self._reduce_losses(context), | ||
update_successful, | ||
metrics, | ||
) | ||
|
||
def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: | ||
return self._reduce_metric_or_loss(context, lambda name: self._loss_defs[name].count, "losses") | ||
|
||
def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]: | ||
return self._reduce_metric_or_loss( | ||
context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric | ||
) | ||
|
||
def _reduce_metric_or_loss( | ||
self, | ||
context: BatchContext, | ||
check_count: Callable[[str], int], | ||
reduce_attr: str = "losses", | ||
check_reduce: Callable[[str], bool] = lambda _: True, | ||
) -> dict[str, float | int]: | ||
reduced_losses = {} | ||
num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs | ||
for name, losses in context.losses.items(): | ||
for name, losses in context.__getattribute__(reduce_attr).items(): | ||
if not check_reduce(name): | ||
reduced_losses[name] = losses | ||
continue | ||
if losses or self._distributed.pipeline_group: | ||
if losses: | ||
reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count | ||
reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name) | ||
if self._distributed.data_group: | ||
all_reduce(reduced_loss, group=self._distributed.data_group) | ||
else: | ||
reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device) | ||
reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device) | ||
if self._distributed.pipeline_group: | ||
all_reduce(reduced_loss, group=self._distributed.pipeline_group) | ||
else: | ||
|
@@ -289,6 +312,19 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: | |
for name, reduced_loss in reduced_losses.items() | ||
} | ||
|
||
def _is_reduced_metric(self, metric_name: str) -> bool: | ||
"""Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" | ||
from fast_llm.layers.transformer.config import TransformerReducedMetrics | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't use hard-coded values here. Suggestion above would fix it, or there are a few other ways to get this dynamically. |
||
|
||
if metric_name not in self._metric_defs: | ||
return False | ||
if not hasattr(self, "_reduced_metrics"): | ||
self._reduced_metrics = set() | ||
for cls in TransformerReducedMetrics.__subclasses__(): | ||
for attr_name in dir(cls): | ||
self._reduced_metrics.add(attr_name) | ||
return metric_name in self._reduced_metrics | ||
|
||
def _train_step(self, context: BatchContext, step: Step) -> None: | ||
if step.throttle_event is not None: | ||
step.throttle_event.record() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
TransformerDimNames, | ||
TransformerKwargs, | ||
TransformerLossNames, | ||
TransformerRoutingMetrics | ||
) | ||
from fast_llm.layers.transformer.mlp import MLPBase | ||
from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage | ||
|
@@ -26,6 +27,35 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could try |
||
""" | ||
Calculates routing entropy for each token, then averages over all tokens. | ||
If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization. | ||
""" | ||
n_experts = probs.size(-1) | ||
entropy_values = entropy(probs) | ||
average_entropy = entropy_values.mean() # Average over batch and tokens | ||
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) | ||
|
||
def entropy(probs: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
probs = torch.clamp(probs, min=1e-9) # Avoid log(0) | ||
return -torch.sum(probs * torch.log(probs), dim=-1) | ||
|
||
|
||
def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Calculates the difference between the entropy of the average routing and | ||
the average routing entropy, we average across all tokens of all examples in the batch. | ||
If low, means that routing is not informative. | ||
""" | ||
n_experts = probs.size(-1) | ||
average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens | ||
entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]] | ||
entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]] | ||
|
||
return entropy_avg_routing - entropy_routing | ||
|
||
|
||
class MixtureOfExpertMLP(MLPBase): | ||
""" | ||
MoeLayer following implementation from | ||
|
@@ -103,7 +133,7 @@ def forward( | |
|
||
# Routing | ||
if self._routing_type == RoutingType.topk: | ||
scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) | ||
scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics) | ||
if self._num_shared_experts > 0: | ||
scores, top_experts = self._add_shared_experts(top_experts, scores) | ||
elif self._routing_type == RoutingType.sinkhorn: | ||
|
@@ -169,11 +199,27 @@ def _topk_routing( | |
logits: torch.Tensor, | ||
grad_scale: float | None = None, | ||
losses: dict | None = None, | ||
metrics: dict | None = None, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) | ||
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) | ||
if losses is not None or (self.training and grad_scale is not None): | ||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32) | ||
|
||
|
||
# Store these metrics | ||
if metrics is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the extra computation involved, this should be enabled through a config parameter |
||
# Calculate and log entropy and mutual information | ||
entropy = calculate_normalized_average_entropy(probs) | ||
mutual_info = calculate_mutual_information(probs) | ||
if TransformerRoutingMetrics.normalized_average_entropy not in metrics: | ||
metrics[TransformerRoutingMetrics.normalized_average_entropy] = [] | ||
if TransformerRoutingMetrics.mutual_info not in metrics: | ||
metrics[TransformerRoutingMetrics.mutual_info] = [] | ||
|
||
metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach()) | ||
metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach()) | ||
|
||
mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) | ||
# Auxiliary loss, corresponding to the sum of probabilities for the top experts. | ||
# In the optimal case (uniform distribution), loss = experts_per_token / num_experts. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loss/metric split is way more complicated than needed. How about having a single entry, and using a
is_metric
flag inLossDef
(or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning fromrun_step