Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Track entropy and MI of routing distribution for topk MoE #188

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
@abc.abstractmethod
def loss_defs(self) -> list[LossDef]:
pass

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loss/metric split is way more complicated than needed. How about having a single entry, and using a is_metric flag in LossDef (or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning from run_step

def metric_defs(self) -> list[LossDef]:
return []
46 changes: 41 additions & 5 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fast_llm.engine.schedule.schedule import Schedule, Step
from fast_llm.logging import log_memory_usage
from fast_llm.utils import Assert
from typing import Callable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self._tied_parameters = self._multi_stage.tied_parameters
self._num_stages = len(self._stages)
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs}
self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs}

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
Expand Down Expand Up @@ -265,20 +267,41 @@ def run_step(
log_pipeline_parallel_main_rank(
lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str)
)

return self._reduce_losses(context), update_successful, metrics
metrics = self._reduce_metrics(context) if return_metrics else metrics
return (
self._reduce_losses(context),
update_successful,
metrics,
)

def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
return self._reduce_metric_or_loss(context, lambda name: self._loss_defs[name].count, "losses")

def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]:
return self._reduce_metric_or_loss(
context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric
)

def _reduce_metric_or_loss(
self,
context: BatchContext,
check_count: Callable[[str], int],
reduce_attr: str = "losses",
check_reduce: Callable[[str], bool] = lambda _: True,
) -> dict[str, float | int]:
reduced_losses = {}
num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs
for name, losses in context.losses.items():
for name, losses in context.__getattribute__(reduce_attr).items():
if not check_reduce(name):
reduced_losses[name] = losses
continue
if losses or self._distributed.pipeline_group:
if losses:
reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count
reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name)
if self._distributed.data_group:
all_reduce(reduced_loss, group=self._distributed.data_group)
else:
reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device)
reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device)
if self._distributed.pipeline_group:
all_reduce(reduced_loss, group=self._distributed.pipeline_group)
else:
Expand All @@ -289,6 +312,19 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
for name, reduced_loss in reduced_losses.items()
}

def _is_reduced_metric(self, metric_name: str) -> bool:
"""Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass)."""
from fast_llm.layers.transformer.config import TransformerReducedMetrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't use hard-coded values here. Suggestion above would fix it, or there are a few other ways to get this dynamically.


if metric_name not in self._metric_defs:
return False
if not hasattr(self, "_reduced_metrics"):
self._reduced_metrics = set()
for cls in TransformerReducedMetrics.__subclasses__():
for attr_name in dir(cls):
self._reduced_metrics.add(attr_name)
return metric_name in self._reduced_metrics

def _train_step(self, context: BatchContext, step: Step) -> None:
if step.throttle_event is not None:
step.throttle_event.record()
Expand Down
9 changes: 9 additions & 0 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ class TransformerLossNames:
load_balancing_loss = "load_balancing_loss"
router_z_loss = "router_z_loss"

class TransformerReducedMetrics:
"""
Metrics that are reduced in the same way as loss before logging.
"""
pass

class TransformerRoutingMetrics(TransformerReducedMetrics):
normalized_average_entropy = "normalized_average_entropy"
mutual_info = "mutual_info"

class RotaryEmbeddingType(str, enum.Enum):
none = "none"
Expand Down
48 changes: 47 additions & 1 deletion fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TransformerDimNames,
TransformerKwargs,
TransformerLossNames,
TransformerRoutingMetrics
)
from fast_llm.layers.transformer.mlp import MLPBase
from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage
Expand All @@ -26,6 +27,35 @@
logger = logging.getLogger(__name__)


def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could try @torch.compile on these for a free performance boost.

"""
Calculates routing entropy for each token, then averages over all tokens.
If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization.
"""
n_experts = probs.size(-1)
entropy_values = entropy(probs)
average_entropy = entropy_values.mean() # Average over batch and tokens
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device))

def entropy(probs: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calculate_entropy

probs = torch.clamp(probs, min=1e-9) # Avoid log(0)
return -torch.sum(probs * torch.log(probs), dim=-1)


def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor:
"""
Calculates the difference between the entropy of the average routing and
the average routing entropy, we average across all tokens of all examples in the batch.
If low, means that routing is not informative.
"""
n_experts = probs.size(-1)
average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens
entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]]
entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]]

return entropy_avg_routing - entropy_routing


class MixtureOfExpertMLP(MLPBase):
"""
MoeLayer following implementation from
Expand Down Expand Up @@ -103,7 +133,7 @@ def forward(

# Routing
if self._routing_type == RoutingType.topk:
scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses)
scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics)
if self._num_shared_experts > 0:
scores, top_experts = self._add_shared_experts(top_experts, scores)
elif self._routing_type == RoutingType.sinkhorn:
Expand Down Expand Up @@ -169,11 +199,27 @@ def _topk_routing(
logits: torch.Tensor,
grad_scale: float | None = None,
losses: dict | None = None,
metrics: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
if losses is not None or (self.training and grad_scale is not None):
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)


# Store these metrics
if metrics is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the extra computation involved, this should be enabled through a config parameter

# Calculate and log entropy and mutual information
entropy = calculate_normalized_average_entropy(probs)
mutual_info = calculate_mutual_information(probs)
if TransformerRoutingMetrics.normalized_average_entropy not in metrics:
metrics[TransformerRoutingMetrics.normalized_average_entropy] = []
if TransformerRoutingMetrics.mutual_info not in metrics:
metrics[TransformerRoutingMetrics.mutual_info] = []

metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach())
metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach())

mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1)
# Auxiliary loss, corresponding to the sum of probabilities for the top experts.
# In the optimal case (uniform distribution), loss = experts_per_token / num_experts.
Expand Down
25 changes: 25 additions & 0 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TransformerDimNames,
TransformerKwargs,
TransformerLossNames,
TransformerRoutingMetrics,
)
from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, RotaryEmbeddingPreprocessor
from fast_llm.layers.transformer.transformer import TransformerLayer
Expand Down Expand Up @@ -308,10 +309,34 @@ def loss_defs(self) -> list[LossDef]:
count=self._config.transformer.num_layers,
)
)

if self._config.logit_z_loss:
LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1)
return loss_defs

@property
def metric_defs(self) -> list[LossDef]:
metric_defs = []
if (
self._config.transformer.num_experts > 1
and self._config.transformer.expert_routing_type == RoutingType.topk
):
metric_defs.append(
LossDef(
name=TransformerRoutingMetrics.normalized_average_entropy,
formatted_name="Normalized Entropy",
count=self._config.transformer.num_layers,
)
)
metric_defs.append(
LossDef(
name=TransformerRoutingMetrics.mutual_info,
formatted_name="Mutual Information",
count=self._config.transformer.num_layers,
)
)
return metric_defs


class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]):
config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig
Expand Down
Loading