diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 7233c183..c94e77e2 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: @abc.abstractmethod def loss_defs(self) -> list[LossDef]: pass + + @property + def metric_defs(self) -> list[LossDef]: + return [] diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index d4b1da10..bcf6c674 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -20,6 +20,7 @@ from fast_llm.engine.schedule.schedule import Schedule, Step from fast_llm.logging import log_memory_usage from fast_llm.utils import Assert +from typing import Callable logger = logging.getLogger(__name__) @@ -94,6 +95,7 @@ def __init__( self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs} def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup @@ -265,20 +267,41 @@ def run_step( log_pipeline_parallel_main_rank( lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) ) - - return self._reduce_losses(context), update_successful, metrics + metrics = self._reduce_metrics(context) if return_metrics else metrics + return ( + self._reduce_losses(context), + update_successful, + metrics, + ) def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: + return self._reduce_metric_or_loss(context, lambda name: self._loss_defs[name].count, "losses") + + def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]: + return self._reduce_metric_or_loss( + context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric + ) + + def _reduce_metric_or_loss( + self, + context: BatchContext, + check_count: Callable[[str], int], + reduce_attr: str = "losses", + check_reduce: Callable[[str], bool] = lambda _: True, + ) -> dict[str, float | int]: reduced_losses = {} num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs - for name, losses in context.losses.items(): + for name, losses in context.__getattribute__(reduce_attr).items(): + if not check_reduce(name): + reduced_losses[name] = losses + continue if losses or self._distributed.pipeline_group: if losses: - reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count + reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name) if self._distributed.data_group: all_reduce(reduced_loss, group=self._distributed.data_group) else: - reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device) + reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device) if self._distributed.pipeline_group: all_reduce(reduced_loss, group=self._distributed.pipeline_group) else: @@ -289,6 +312,19 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: for name, reduced_loss in reduced_losses.items() } + def _is_reduced_metric(self, metric_name: str) -> bool: + """Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" + from fast_llm.layers.transformer.config import TransformerReducedMetrics + + if metric_name not in self._metric_defs: + return False + if not hasattr(self, "_reduced_metrics"): + self._reduced_metrics = set() + for cls in TransformerReducedMetrics.__subclasses__(): + for attr_name in dir(cls): + self._reduced_metrics.add(attr_name) + return metric_name in self._reduced_metrics + def _train_step(self, context: BatchContext, step: Step) -> None: if step.throttle_event is not None: step.throttle_event.record() diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf985392..bf3a0a14 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -75,6 +75,15 @@ class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" +class TransformerReducedMetrics: + """ + Metrics that are reduced in the same way as loss before logging. + """ + pass + +class TransformerRoutingMetrics(TransformerReducedMetrics): + normalized_average_entropy = "normalized_average_entropy" + mutual_info = "mutual_info" class RotaryEmbeddingType(str, enum.Enum): none = "none" diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..9e194d2c 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -17,6 +17,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + TransformerRoutingMetrics ) from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -26,6 +27,35 @@ logger = logging.getLogger(__name__) +def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: + """ + Calculates routing entropy for each token, then averages over all tokens. + If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization. + """ + n_experts = probs.size(-1) + entropy_values = entropy(probs) + average_entropy = entropy_values.mean() # Average over batch and tokens + return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) + +def entropy(probs: torch.Tensor) -> torch.Tensor: + probs = torch.clamp(probs, min=1e-9) # Avoid log(0) + return -torch.sum(probs * torch.log(probs), dim=-1) + + +def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: + """ + Calculates the difference between the entropy of the average routing and + the average routing entropy, we average across all tokens of all examples in the batch. + If low, means that routing is not informative. + """ + n_experts = probs.size(-1) + average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens + entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]] + entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]] + + return entropy_avg_routing - entropy_routing + + class MixtureOfExpertMLP(MLPBase): """ MoeLayer following implementation from @@ -103,7 +133,7 @@ def forward( # Routing if self._routing_type == RoutingType.topk: - scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) + scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics) if self._num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) elif self._routing_type == RoutingType.sinkhorn: @@ -169,11 +199,27 @@ def _topk_routing( logits: torch.Tensor, grad_scale: float | None = None, losses: dict | None = None, + metrics: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + + + # Store these metrics + if metrics is not None: + # Calculate and log entropy and mutual information + entropy = calculate_normalized_average_entropy(probs) + mutual_info = calculate_mutual_information(probs) + if TransformerRoutingMetrics.normalized_average_entropy not in metrics: + metrics[TransformerRoutingMetrics.normalized_average_entropy] = [] + if TransformerRoutingMetrics.mutual_info not in metrics: + metrics[TransformerRoutingMetrics.mutual_info] = [] + + metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach()) + metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach()) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8aa68333..d1bc9e13 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -19,6 +19,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + TransformerRoutingMetrics, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, RotaryEmbeddingPreprocessor from fast_llm.layers.transformer.transformer import TransformerLayer @@ -308,10 +309,34 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) + if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) return loss_defs + @property + def metric_defs(self) -> list[LossDef]: + metric_defs = [] + if ( + self._config.transformer.num_experts > 1 + and self._config.transformer.expert_routing_type == RoutingType.topk + ): + metric_defs.append( + LossDef( + name=TransformerRoutingMetrics.normalized_average_entropy, + formatted_name="Normalized Entropy", + count=self._config.transformer.num_layers, + ) + ) + metric_defs.append( + LossDef( + name=TransformerRoutingMetrics.mutual_info, + formatted_name="Mutual Information", + count=self._config.transformer.num_layers, + ) + ) + return metric_defs + class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig diff --git a/tests/test_moe_metrics.py b/tests/test_moe_metrics.py new file mode 100644 index 00000000..7ca1fb15 --- /dev/null +++ b/tests/test_moe_metrics.py @@ -0,0 +1,282 @@ +import pytest +import torch + +from fast_llm.layers.transformer.mixture_of_experts import ( + calculate_mutual_information, + calculate_normalized_average_entropy, +) +import torch +from unittest import mock + +from fast_llm.engine.schedule.runner import ScheduleRunner, BatchContext +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.multi_stage.multi_stage import MultiStageModel +from fast_llm.engine.base_model.base_model import LossDef +from fast_llm.layers.transformer.config import TransformerRoutingMetrics + +def test_diversity_entropy(): + """ + collapse routing would have low entropy and low mutual information + """ + + collapased_probs = torch.tensor( + [ + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ], + # Batch 2 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ], + ] + ) + norm_entropy = calculate_normalized_average_entropy(collapased_probs) + mutual_info = calculate_mutual_information(collapased_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {mutual_info}" + + # diverse but no collapse + # should give low entropy and high mutual information + diverse_probs = torch.tensor( + [ + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.99, 0.0, 0.0], + [0.01, 0.01, 0.99, 0.0], + ], + # Batch 2 + [ + [0.01, 0.01, 0.99, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.01, 0.01, 0.99], + ], + ] + ) + norm_entropy = calculate_normalized_average_entropy(diverse_probs) + mutual_info = calculate_mutual_information(diverse_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.9), atol=1e-1), f"Expected 1.0, got {mutual_info}" + + +def test_calculate_normalized_average_entropy(): + # AI generated test case + # Create a batch of routing probabilities + batch_size = 2 + seq_len = 3 + n_experts = 4 + + # Test 1: Uniform distribution (should give normalized entropy of 1.0) + uniform_probs = torch.ones(batch_size, seq_len, n_experts) / n_experts + norm_entropy = calculate_normalized_average_entropy(uniform_probs) + assert torch.isclose(norm_entropy, torch.tensor(1.0), atol=1e-5), f"Expected 1.0, got {norm_entropy}" + + # Test 2: One-hot distribution (should give normalized entropy of 0.0) + one_hot = torch.zeros(batch_size, seq_len, n_experts) + for b in range(batch_size): + for s in range(seq_len): + one_hot[b, s, b % n_experts] = 1.0 + norm_entropy = calculate_normalized_average_entropy(one_hot) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {norm_entropy}" + + # Test 3: Mixed distribution + mixed_probs = torch.tensor( + [ + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.25, 0.25, 0.25, 0.25], # Token 3: uniform + ], + # Batch 2 + [ + [0.4, 0.4, 0.1, 0.1], # Token 1: split between experts 0 and 1 + [0.1, 0.1, 0.4, 0.4], # Token 2: split between experts 2 and 3 + [0.1, 0.1, 0.1, 0.7], # Token 3: mostly expert 3 + ], + ] + ) + norm_entropy = calculate_normalized_average_entropy(mixed_probs) + # The expected value is between 0 and 1 + assert 0.0 < norm_entropy < 1.0, f"Expected value between 0 and 1, got {norm_entropy}" + + +def test_calculate_mutual_information(): + # AI generated test cases + # Create a batch of routing probabilities + batch_size = 2 + seq_len = 3 + n_experts = 4 + + # Test 1: All tokens route to the same expert (low mutual information) + same_expert = torch.zeros(batch_size, seq_len, n_experts) + same_expert[:, :, 0] = 1.0 # All tokens route to expert 0 + mutual_info = calculate_mutual_information(same_expert) + assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" + + # Test 2: Each token routes to a different expert (high mutual information) + different_experts = torch.zeros(batch_size, seq_len, n_experts) + for b in range(batch_size): + for s in range(seq_len): + different_experts[b, s, s % n_experts] = 1.0 + mutual_info = calculate_mutual_information(different_experts) + # The value should be positive and closer to 1 + assert mutual_info > 0.0, f"Expected positive value, got {mutual_info}" + + # Test 3: Mixed routing pattern + mixed_probs = torch.tensor( + [ + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.1, 0.1, 0.7, 0.1], # Token 3: mostly expert 2 + ], + # Batch 2 + [ + [0.1, 0.1, 0.1, 0.7], # Token 1: mostly expert 3 + [0.7, 0.1, 0.1, 0.1], # Token 2: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 3: mostly expert 1 + ], + ] + ) + mutual_info = calculate_mutual_information(mixed_probs) + # The expected value is between 0 and 1 + assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" + + +def test_edge_cases(): + # AI generated test cases + # Test with very small batch and sequence length + tiny_probs = torch.tensor([[[0.25, 0.25, 0.25, 0.25]]]) # batch=1, seq_len=1, n_experts=4 + norm_entropy = calculate_normalized_average_entropy(tiny_probs) + mutual_info = calculate_mutual_information(tiny_probs) + assert torch.isclose(norm_entropy, torch.tensor(1.0)), f"Expected 1.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" + + # Test with very small probabilities + small_probs = torch.ones(2, 3, 4) * 1e-8 + small_probs[:, :, 0] = 1.0 - 3e-8 # Make sure they sum to 1 + norm_entropy = calculate_normalized_average_entropy(small_probs) + mutual_info = calculate_mutual_information(small_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {mutual_info}" + + + +@pytest.fixture +def setup_runner(): + """Fixture to set up the test environment.""" + # Mock objects needed for testing + distributed_config = DistributedConfig() + + # Mock MultiStageModel with loss_defs + multi_stage = mock.MagicMock(spec=MultiStageModel) + multi_stage.base_model.loss_defs = [ + LossDef(name="test_loss", formatted_name="Test Loss", count=1) + ] + multi_stage.base_model.metric_defs = [ + LossDef(name=TransformerRoutingMetrics.normalized_average_entropy, formatted_name="Normalized Entropy", count=1), + LossDef(name=TransformerRoutingMetrics.mutual_info, formatted_name="Mutual Information", count=1) + ] + + # Create a schedule runner + schedule_config = ScheduleConfig() + runner = ScheduleRunner( + config=schedule_config, + multi_stage=multi_stage, + distributed_config=distributed_config + ) + + # Mock distributed object + distributed = mock.MagicMock(spec=Distributed) + distributed.config = distributed_config + distributed.device = torch.device("cpu") + distributed.data_group = None + distributed.pipeline_group = None + + # Setup the runner + runner._distributed = distributed + runner.is_initialized = True + + # Create a mock schedule + schedule = mock.MagicMock(spec=Schedule) + schedule.phase = PhaseType.training + schedule.batch_config.num_inputs = 3 + schedule._schedule_config = schedule_config + + # Create a batch context with metrics and losses + context = BatchContext( + iteration=1, + schedule=schedule, + ) + + # Add test metrics + context.metrics = { + # Metrics that should be reduced (in TransformerReducedMetrics) + TransformerRoutingMetrics.normalized_average_entropy: [ + torch.tensor(0.5), torch.tensor(0.6), torch.tensor(0.7) + ], + TransformerRoutingMetrics.mutual_info: [ + torch.tensor(0.2), torch.tensor(0.3), torch.tensor(0.4) + ], + # Metric that should not be reduced + "non_reduced_metric": [torch.tensor(1.0), torch.tensor(1.0), torch.tensor(1.0)] + } + + # Add test losses + context.losses = { + "test_loss": [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)] + } + + return runner, context, schedule + + +def test_reduce_metrics(setup_runner): + """Test that _reduce_metrics correctly reduces only the appropriate metrics""" + runner, context, _ = setup_runner + + assert runner._is_reduced_metric(TransformerRoutingMetrics.normalized_average_entropy) is True + assert runner._is_reduced_metric(TransformerRoutingMetrics.mutual_info) is True + + assert runner._is_reduced_metric("non_reduced_metric") is False + assert runner._is_reduced_metric("random_metric") is False + + reduced_metrics = runner._reduce_metrics(context) + + # Check that metrics in TransformerReducedMetrics were reduced + assert TransformerRoutingMetrics.normalized_average_entropy in reduced_metrics + assert TransformerRoutingMetrics.mutual_info in reduced_metrics + + # Check that the values were correctly averaged + assert pytest.approx(reduced_metrics[TransformerRoutingMetrics.normalized_average_entropy], 0.001) == 0.6 + assert pytest.approx(reduced_metrics[TransformerRoutingMetrics.mutual_info], 0.001) == 0.3 + + # Check that non-reduced metrics are not in the result + assert "non_reduced_metric" in reduced_metrics + assert sum(reduced_metrics["non_reduced_metric"]) == 3.0 + + +def test_reduce_losses(setup_runner): + """Test that _reduce_losses correctly reduces losses""" + runner, context, _ = setup_runner + + reduced_losses = runner._reduce_losses(context) + + assert "test_loss" in reduced_losses + assert pytest.approx(reduced_losses["test_loss"], 0.001) == 2.0 + + + + +if __name__ == "__main__": + pytest.main([__file__])