From 2a7cf1b378838026d4cbf150ea997f8967ec7f36 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 14 Mar 2025 19:23:02 +0000 Subject: [PATCH 01/10] added mutual information and entropy for routing probs --- fast_llm/layers/transformer/config.py | 2 + .../layers/transformer/mixture_of_experts.py | 39 +++++ fast_llm/models/gpt/model.py | 16 ++ tests/test_routing_metrics.py | 161 ++++++++++++++++++ 4 files changed, 218 insertions(+) create mode 100644 tests/test_routing_metrics.py diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf985392..9c3caa07 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -74,6 +74,8 @@ class TransformerKwargs: class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" + router_entropy = "router_entropy" + router_mutual_info = "router_mutual_info" class RotaryEmbeddingType(str, enum.Enum): diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..3ad9dfb8 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -25,6 +25,30 @@ logger = logging.getLogger(__name__) +def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: + ''' + Calculates routing entropy for each token, then averages over all tokens. + If low, means a lot of mass is put on a single expert, which can indicate collapse. + ''' + n_experts = probs.size(-1) + entropy_values = entropy(probs) + average_entropy = entropy_values.mean() # Average over batch and tokens + return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) + +def entropy(probs: torch.Tensor) -> torch.Tensor: + probs = torch.clamp(probs, min=1e-9) # Avoid log(0) + return -torch.sum(probs * torch.log(probs), dim=-1) + +def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: + ''' + Calculates the difference between the entropy of the average routing and + the average routing entropy. If low, means that routing is not informative. + ''' + average_routing = torch.mean(probs, dim=1) # Average over tokens + entropy_avg_routing = entropy(average_routing).mean() # H[E[X]], mean over batch + entropy_routing = entropy(probs).mean() # E[H[X]] + + return (entropy_avg_routing - entropy_routing) / torch.log(torch.tensor(probs.size(-1), dtype=probs.dtype)) # Normalize class MixtureOfExpertMLP(MLPBase): """ @@ -111,6 +135,7 @@ def forward( else: raise NotImplementedError(self._routing_type) + if self._debug_mode: # To log all ranks set `global_=False` self._debug_log(scores, "Router scores", TransformerDimNames.top_experts, kwargs) @@ -174,6 +199,20 @@ def _topk_routing( scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + + # Calculate and log entropy and mutual information + entropy = calculate_normalized_average_entropy(probs) + mutual_info = calculate_mutual_information(probs) + + # Store these metrics + if "router_entropy" not in losses: + losses["router_entropy"] = [] + if "router_mutual_info" not in losses: + losses["router_mutual_info"] = [] + + losses["router_entropy"].append(entropy.detach()) + losses["router_mutual_info"].append(mutual_info.detach()) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8aa68333..166347bc 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -308,6 +308,22 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) + # Add new metrics + loss_defs.append( + LossDef( + name="router_entropy", + formatted_name="router entropy", + count=self._config.transformer.num_layers, + ) + ) + loss_defs.append( + LossDef( + name="router_mutual_info", + formatted_name="router mutual info", + count=self._config.transformer.num_layers, + ) + ) + if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) return loss_defs diff --git a/tests/test_routing_metrics.py b/tests/test_routing_metrics.py new file mode 100644 index 00000000..b343e351 --- /dev/null +++ b/tests/test_routing_metrics.py @@ -0,0 +1,161 @@ +import torch +import pytest +from fast_llm.layers.transformer.mixture_of_experts import ( + calculate_normalized_average_entropy, + calculate_mutual_information, + entropy +) + +def test_diversity_entropy(): + ''' + collapse routing would have low entropy and low mutual information + ''' + + batch_size = 2 + seq_len = 3 + n_experts = 4 + collapased_probs = torch.tensor([ + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ], + # Batch 2 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ] + ]) + norm_entropy = calculate_normalized_average_entropy(collapased_probs) + mutual_info = calculate_mutual_information(collapased_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {mutual_info}" + + + # diverse but no collapse + # should give low entropy and high mutual information + diverse_probs = torch.tensor([ + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.99, 0.0, 0.0], + [0.01, 0.01, 0.99, 0.0], + ], + # Batch 2 + [ + [0.01, 0.01, 0.99, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.01, 0.01, 0.99], + ] + ]) + norm_entropy = calculate_normalized_average_entropy(diverse_probs) + mutual_info = calculate_mutual_information(diverse_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.75), atol=1e-1), f"Expected 1.0, got {mutual_info}" + + +def test_calculate_normalized_average_entropy(): + # AI generated test case + # Create a batch of routing probabilities + batch_size = 2 + seq_len = 3 + n_experts = 4 + + # Test 1: Uniform distribution (should give normalized entropy of 1.0) + uniform_probs = torch.ones(batch_size, seq_len, n_experts) / n_experts + norm_entropy = calculate_normalized_average_entropy(uniform_probs) + assert torch.isclose(norm_entropy, torch.tensor(1.0), atol=1e-5), f"Expected 1.0, got {norm_entropy}" + + # Test 2: One-hot distribution (should give normalized entropy of 0.0) + one_hot = torch.zeros(batch_size, seq_len, n_experts) + for b in range(batch_size): + for s in range(seq_len): + one_hot[b, s, b % n_experts] = 1.0 + norm_entropy = calculate_normalized_average_entropy(one_hot) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {norm_entropy}" + + # Test 3: Mixed distribution + mixed_probs = torch.tensor([ + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.25, 0.25, 0.25, 0.25], # Token 3: uniform + ], + # Batch 2 + [ + [0.4, 0.4, 0.1, 0.1], # Token 1: split between experts 0 and 1 + [0.1, 0.1, 0.4, 0.4], # Token 2: split between experts 2 and 3 + [0.1, 0.1, 0.1, 0.7], # Token 3: mostly expert 3 + ] + ]) + norm_entropy = calculate_normalized_average_entropy(mixed_probs) + # The expected value is between 0 and 1 + assert 0.0 < norm_entropy < 1.0, f"Expected value between 0 and 1, got {norm_entropy}" + +def test_calculate_mutual_information(): + # AI generated test cases + # Create a batch of routing probabilities + batch_size = 2 + seq_len = 3 + n_experts = 4 + + # Test 1: All tokens route to the same expert (low mutual information) + same_expert = torch.zeros(batch_size, seq_len, n_experts) + same_expert[:, :, 0] = 1.0 # All tokens route to expert 0 + mutual_info = calculate_mutual_information(same_expert) + assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" + + # Test 2: Each token routes to a different expert (high mutual information) + different_experts = torch.zeros(batch_size, seq_len, n_experts) + for b in range(batch_size): + for s in range(seq_len): + different_experts[b, s, s % n_experts] = 1.0 + mutual_info = calculate_mutual_information(different_experts) + # The value should be positive and closer to 1 + assert mutual_info > 0.0, f"Expected positive value, got {mutual_info}" + + # Test 3: Mixed routing pattern + mixed_probs = torch.tensor([ + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.1, 0.1, 0.7, 0.1], # Token 3: mostly expert 2 + ], + # Batch 2 + [ + [0.1, 0.1, 0.1, 0.7], # Token 1: mostly expert 3 + [0.7, 0.1, 0.1, 0.1], # Token 2: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 3: mostly expert 1 + ] + ]) + mutual_info = calculate_mutual_information(mixed_probs) + # The expected value is between 0 and 1 + assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" + +def test_edge_cases(): + # AI generated test cases + # Test with very small batch and sequence length + tiny_probs = torch.tensor([[[0.25, 0.25, 0.25, 0.25]]]) # batch=1, seq_len=1, n_experts=4 + norm_entropy = calculate_normalized_average_entropy(tiny_probs) + mutual_info = calculate_mutual_information(tiny_probs) + assert torch.isclose(norm_entropy, torch.tensor(1.0)), f"Expected 1.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" + + # Test with very small probabilities + small_probs = torch.ones(2, 3, 4) * 1e-8 + small_probs[:, :, 0] = 1.0 - 3e-8 # Make sure they sum to 1 + norm_entropy = calculate_normalized_average_entropy(small_probs) + mutual_info = calculate_mutual_information(small_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {mutual_info}" + + + + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file From dd85e846483626015dfb510e43d58146012b552c Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 14 Mar 2025 19:27:22 +0000 Subject: [PATCH 02/10] format --- .../layers/transformer/mixture_of_experts.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 3ad9dfb8..6e35a942 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -25,30 +25,36 @@ logger = logging.getLogger(__name__) + def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: - ''' + """ Calculates routing entropy for each token, then averages over all tokens. If low, means a lot of mass is put on a single expert, which can indicate collapse. - ''' + """ n_experts = probs.size(-1) entropy_values = entropy(probs) average_entropy = entropy_values.mean() # Average over batch and tokens return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) + def entropy(probs: torch.Tensor) -> torch.Tensor: probs = torch.clamp(probs, min=1e-9) # Avoid log(0) return -torch.sum(probs * torch.log(probs), dim=-1) + def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: - ''' - Calculates the difference between the entropy of the average routing and + """ + Calculates the difference between the entropy of the average routing and the average routing entropy. If low, means that routing is not informative. - ''' + """ average_routing = torch.mean(probs, dim=1) # Average over tokens entropy_avg_routing = entropy(average_routing).mean() # H[E[X]], mean over batch entropy_routing = entropy(probs).mean() # E[H[X]] - - return (entropy_avg_routing - entropy_routing) / torch.log(torch.tensor(probs.size(-1), dtype=probs.dtype)) # Normalize + + return (entropy_avg_routing - entropy_routing) / torch.log( + torch.tensor(probs.size(-1), dtype=probs.dtype) + ) # Normalize + class MixtureOfExpertMLP(MLPBase): """ @@ -135,7 +141,6 @@ def forward( else: raise NotImplementedError(self._routing_type) - if self._debug_mode: # To log all ranks set `global_=False` self._debug_log(scores, "Router scores", TransformerDimNames.top_experts, kwargs) @@ -203,13 +208,13 @@ def _topk_routing( # Calculate and log entropy and mutual information entropy = calculate_normalized_average_entropy(probs) mutual_info = calculate_mutual_information(probs) - + # Store these metrics if "router_entropy" not in losses: losses["router_entropy"] = [] if "router_mutual_info" not in losses: losses["router_mutual_info"] = [] - + losses["router_entropy"].append(entropy.detach()) losses["router_mutual_info"].append(mutual_info.detach()) From aef18e73c1f6fad10c590b584a7e3775a56948e8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 14 Mar 2025 19:45:44 +0000 Subject: [PATCH 03/10] pre-commits --- tests/test_routing_metrics.py | 140 +++++++++++++++++----------------- 1 file changed, 72 insertions(+), 68 deletions(-) diff --git a/tests/test_routing_metrics.py b/tests/test_routing_metrics.py index b343e351..000cbc1c 100644 --- a/tests/test_routing_metrics.py +++ b/tests/test_routing_metrics.py @@ -1,55 +1,56 @@ -import torch import pytest +import torch + from fast_llm.layers.transformer.mixture_of_experts import ( - calculate_normalized_average_entropy, calculate_mutual_information, - entropy + calculate_normalized_average_entropy, ) + def test_diversity_entropy(): - ''' + """ collapse routing would have low entropy and low mutual information - ''' + """ - batch_size = 2 - seq_len = 3 - n_experts = 4 - collapased_probs = torch.tensor([ - # Batch 1 - [ - [0.99, 0.01, 0.0, 0.0], - [0.99, 0.01, 0.0, 0.0], - [0.99, 0.01, 0.0, 0.0], - ], - # Batch 2 + collapased_probs = torch.tensor( [ - [0.99, 0.01, 0.0, 0.0], - [0.99, 0.01, 0.0, 0.0], - [0.99, 0.01, 0.0, 0.0], + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ], + # Batch 2 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ], ] - ]) + ) norm_entropy = calculate_normalized_average_entropy(collapased_probs) mutual_info = calculate_mutual_information(collapased_probs) assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {mutual_info}" - # diverse but no collapse # should give low entropy and high mutual information - diverse_probs = torch.tensor([ - # Batch 1 - [ - [0.99, 0.01, 0.0, 0.0], - [0.01, 0.99, 0.0, 0.0], - [0.01, 0.01, 0.99, 0.0], - ], - # Batch 2 + diverse_probs = torch.tensor( [ - [0.01, 0.01, 0.99, 0.0], - [0.99, 0.01, 0.0, 0.0], - [0.01, 0.01, 0.01, 0.99], + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.99, 0.0, 0.0], + [0.01, 0.01, 0.99, 0.0], + ], + # Batch 2 + [ + [0.01, 0.01, 0.99, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.01, 0.01, 0.99], + ], ] - ]) + ) norm_entropy = calculate_normalized_average_entropy(diverse_probs) mutual_info = calculate_mutual_information(diverse_probs) assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" @@ -62,12 +63,12 @@ def test_calculate_normalized_average_entropy(): batch_size = 2 seq_len = 3 n_experts = 4 - + # Test 1: Uniform distribution (should give normalized entropy of 1.0) uniform_probs = torch.ones(batch_size, seq_len, n_experts) / n_experts norm_entropy = calculate_normalized_average_entropy(uniform_probs) assert torch.isclose(norm_entropy, torch.tensor(1.0), atol=1e-5), f"Expected 1.0, got {norm_entropy}" - + # Test 2: One-hot distribution (should give normalized entropy of 0.0) one_hot = torch.zeros(batch_size, seq_len, n_experts) for b in range(batch_size): @@ -75,39 +76,42 @@ def test_calculate_normalized_average_entropy(): one_hot[b, s, b % n_experts] = 1.0 norm_entropy = calculate_normalized_average_entropy(one_hot) assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {norm_entropy}" - + # Test 3: Mixed distribution - mixed_probs = torch.tensor([ - # Batch 1 + mixed_probs = torch.tensor( [ - [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 - [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 - [0.25, 0.25, 0.25, 0.25], # Token 3: uniform - ], - # Batch 2 - [ - [0.4, 0.4, 0.1, 0.1], # Token 1: split between experts 0 and 1 - [0.1, 0.1, 0.4, 0.4], # Token 2: split between experts 2 and 3 - [0.1, 0.1, 0.1, 0.7], # Token 3: mostly expert 3 + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.25, 0.25, 0.25, 0.25], # Token 3: uniform + ], + # Batch 2 + [ + [0.4, 0.4, 0.1, 0.1], # Token 1: split between experts 0 and 1 + [0.1, 0.1, 0.4, 0.4], # Token 2: split between experts 2 and 3 + [0.1, 0.1, 0.1, 0.7], # Token 3: mostly expert 3 + ], ] - ]) + ) norm_entropy = calculate_normalized_average_entropy(mixed_probs) # The expected value is between 0 and 1 assert 0.0 < norm_entropy < 1.0, f"Expected value between 0 and 1, got {norm_entropy}" + def test_calculate_mutual_information(): # AI generated test cases # Create a batch of routing probabilities batch_size = 2 seq_len = 3 n_experts = 4 - + # Test 1: All tokens route to the same expert (low mutual information) same_expert = torch.zeros(batch_size, seq_len, n_experts) same_expert[:, :, 0] = 1.0 # All tokens route to expert 0 mutual_info = calculate_mutual_information(same_expert) assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" - + # Test 2: Each token routes to a different expert (high mutual information) different_experts = torch.zeros(batch_size, seq_len, n_experts) for b in range(batch_size): @@ -116,26 +120,29 @@ def test_calculate_mutual_information(): mutual_info = calculate_mutual_information(different_experts) # The value should be positive and closer to 1 assert mutual_info > 0.0, f"Expected positive value, got {mutual_info}" - + # Test 3: Mixed routing pattern - mixed_probs = torch.tensor([ - # Batch 1 + mixed_probs = torch.tensor( [ - [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 - [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 - [0.1, 0.1, 0.7, 0.1], # Token 3: mostly expert 2 - ], - # Batch 2 - [ - [0.1, 0.1, 0.1, 0.7], # Token 1: mostly expert 3 - [0.7, 0.1, 0.1, 0.1], # Token 2: mostly expert 0 - [0.1, 0.7, 0.1, 0.1], # Token 3: mostly expert 1 + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.1, 0.1, 0.7, 0.1], # Token 3: mostly expert 2 + ], + # Batch 2 + [ + [0.1, 0.1, 0.1, 0.7], # Token 1: mostly expert 3 + [0.7, 0.1, 0.1, 0.1], # Token 2: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 3: mostly expert 1 + ], ] - ]) + ) mutual_info = calculate_mutual_information(mixed_probs) # The expected value is between 0 and 1 assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" + def test_edge_cases(): # AI generated test cases # Test with very small batch and sequence length @@ -144,7 +151,7 @@ def test_edge_cases(): mutual_info = calculate_mutual_information(tiny_probs) assert torch.isclose(norm_entropy, torch.tensor(1.0)), f"Expected 1.0, got {norm_entropy}" assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" - + # Test with very small probabilities small_probs = torch.ones(2, 3, 4) * 1e-8 small_probs[:, :, 0] = 1.0 - 3e-8 # Make sure they sum to 1 @@ -154,8 +161,5 @@ def test_edge_cases(): assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {mutual_info}" - - - if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) From bef39d84cfc0f142d240d30fd1e5ff175162b52c Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 14 Mar 2025 20:34:29 +0000 Subject: [PATCH 04/10] improved --- .../layers/transformer/mixture_of_experts.py | 16 ++++++++-------- tests/test_routing_metrics.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 6e35a942..c23a544e 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -29,7 +29,7 @@ def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: """ Calculates routing entropy for each token, then averages over all tokens. - If low, means a lot of mass is put on a single expert, which can indicate collapse. + If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization. """ n_experts = probs.size(-1) entropy_values = entropy(probs) @@ -45,15 +45,15 @@ def entropy(probs: torch.Tensor) -> torch.Tensor: def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: """ Calculates the difference between the entropy of the average routing and - the average routing entropy. If low, means that routing is not informative. + the average routing entropy, we average across all tokens of all examples in the batch. + If low, means that routing is not informative. """ - average_routing = torch.mean(probs, dim=1) # Average over tokens - entropy_avg_routing = entropy(average_routing).mean() # H[E[X]], mean over batch - entropy_routing = entropy(probs).mean() # E[H[X]] + n_experts = probs.size(-1) + average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens + entropy_avg_routing = entropy(average_routing) / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) # H[E[X]] + entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]] - return (entropy_avg_routing - entropy_routing) / torch.log( - torch.tensor(probs.size(-1), dtype=probs.dtype) - ) # Normalize + return entropy_avg_routing - entropy_routing class MixtureOfExpertMLP(MLPBase): diff --git a/tests/test_routing_metrics.py b/tests/test_routing_metrics.py index 000cbc1c..7cd6d005 100644 --- a/tests/test_routing_metrics.py +++ b/tests/test_routing_metrics.py @@ -54,7 +54,7 @@ def test_diversity_entropy(): norm_entropy = calculate_normalized_average_entropy(diverse_probs) mutual_info = calculate_mutual_information(diverse_probs) assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" - assert torch.isclose(mutual_info, torch.tensor(0.75), atol=1e-1), f"Expected 1.0, got {mutual_info}" + assert torch.isclose(mutual_info, torch.tensor(0.9), atol=1e-1), f"Expected 1.0, got {mutual_info}" def test_calculate_normalized_average_entropy(): From 620ec76dfd4fe07c9ff820aa2548198a5034bad4 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sat, 15 Mar 2025 00:37:31 +0000 Subject: [PATCH 05/10] using metrics dict instead of losses --- fast_llm/layers/transformer/config.py | 5 +++-- .../layers/transformer/mixture_of_experts.py | 19 +++++++++++-------- fast_llm/models/gpt/model.py | 16 +--------------- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 9c3caa07..ecad1407 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -74,9 +74,10 @@ class TransformerKwargs: class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" - router_entropy = "router_entropy" - router_mutual_info = "router_mutual_info" +class TransformerRoutingMetrics: + normalized_average_entropy = "normalized_average_entropy" + mutual_info = "mutual_info" class RotaryEmbeddingType(str, enum.Enum): none = "none" diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index c23a544e..ba391619 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -17,6 +17,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + TransformerRoutingMetrics ) from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -133,7 +134,7 @@ def forward( # Routing if self._routing_type == RoutingType.topk: - scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) + scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics) if self._num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) elif self._routing_type == RoutingType.sinkhorn: @@ -199,6 +200,7 @@ def _topk_routing( logits: torch.Tensor, grad_scale: float | None = None, losses: dict | None = None, + metrics: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) @@ -209,14 +211,15 @@ def _topk_routing( entropy = calculate_normalized_average_entropy(probs) mutual_info = calculate_mutual_information(probs) - # Store these metrics - if "router_entropy" not in losses: - losses["router_entropy"] = [] - if "router_mutual_info" not in losses: - losses["router_mutual_info"] = [] + # Store these metrics + if metrics is not None: + if TransformerRoutingMetrics.normalized_average_entropy not in metrics: + metrics[TransformerRoutingMetrics.normalized_average_entropy] = [] + if TransformerRoutingMetrics.mutual_info not in metrics: + metrics[TransformerRoutingMetrics.mutual_info] = [] - losses["router_entropy"].append(entropy.detach()) - losses["router_mutual_info"].append(mutual_info.detach()) + metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach()) + metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach()) mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 166347bc..640de1f2 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -19,6 +19,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + TransformerRoutingMetrics ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, RotaryEmbeddingPreprocessor from fast_llm.layers.transformer.transformer import TransformerLayer @@ -308,21 +309,6 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) - # Add new metrics - loss_defs.append( - LossDef( - name="router_entropy", - formatted_name="router entropy", - count=self._config.transformer.num_layers, - ) - ) - loss_defs.append( - LossDef( - name="router_mutual_info", - formatted_name="router mutual info", - count=self._config.transformer.num_layers, - ) - ) if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) From 7a93aee2014fc7065867ae4357d651b8bce2b695 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 16 Mar 2025 19:22:27 +0000 Subject: [PATCH 06/10] reduce metrics --- fast_llm/engine/base_model/base_model.py | 4 + fast_llm/engine/schedule/runner.py | 37 +++++- fast_llm/layers/transformer/config.py | 8 +- fast_llm/models/gpt/model.py | 14 +++ ...routing_metrics.py => test_moe_metrics.py} | 117 ++++++++++++++++++ 5 files changed, 175 insertions(+), 5 deletions(-) rename tests/{test_routing_metrics.py => test_moe_metrics.py} (59%) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 7233c183..c94e77e2 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: @abc.abstractmethod def loss_defs(self) -> list[LossDef]: pass + + @property + def metric_defs(self) -> list[LossDef]: + return [] diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index d4b1da10..256ccfd5 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -20,6 +20,7 @@ from fast_llm.engine.schedule.schedule import Schedule, Step from fast_llm.logging import log_memory_usage from fast_llm.utils import Assert +from typing import Callable logger = logging.getLogger(__name__) @@ -94,6 +95,7 @@ def __init__( self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs} def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup @@ -266,19 +268,34 @@ def run_step( lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) ) - return self._reduce_losses(context), update_successful, metrics + return self._reduce_losses(context), update_successful, self._reduce_metrics(context) def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: + return self._reduce_metric_or_loss(context, lambda name: self._loss_defs[name].count, "losses") + + def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]: + return self._reduce_metric_or_loss(context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric) + + def _reduce_metric_or_loss( + self, + context: BatchContext, + check_count: Callable[[str], int], + reduce_attr: str = "losses", + check_reduce: Callable[[str], bool] = lambda _: True, + ) -> dict[str, float | int]: reduced_losses = {} num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs - for name, losses in context.losses.items(): + for name, losses in context.__getattribute__(reduce_attr).items(): + if not check_reduce(name): + reduced_losses[name] = losses + continue if losses or self._distributed.pipeline_group: if losses: - reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count + reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name) if self._distributed.data_group: all_reduce(reduced_loss, group=self._distributed.data_group) else: - reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device) + reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device) if self._distributed.pipeline_group: all_reduce(reduced_loss, group=self._distributed.pipeline_group) else: @@ -289,6 +306,18 @@ def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: for name, reduced_loss in reduced_losses.items() } + def _is_reduced_metric(self, metric_name: str) -> bool: + """Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" + from fast_llm.layers.transformer.config import TransformerReducedMetrics + if metric_name not in self._metric_defs: + return False + if not hasattr(self, "_reduced_metrics"): + self._reduced_metrics = set() + for cls in TransformerReducedMetrics.__subclasses__(): + for attr_name in dir(cls): + self._reduced_metrics.add(attr_name) + return metric_name in self._reduced_metrics + def _train_step(self, context: BatchContext, step: Step) -> None: if step.throttle_event is not None: step.throttle_event.record() diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index ecad1407..bf3a0a14 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -75,7 +75,13 @@ class TransformerLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" -class TransformerRoutingMetrics: +class TransformerReducedMetrics: + """ + Metrics that are reduced in the same way as loss before logging. + """ + pass + +class TransformerRoutingMetrics(TransformerReducedMetrics): normalized_average_entropy = "normalized_average_entropy" mutual_info = "mutual_info" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 640de1f2..6ff82c74 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -314,6 +314,20 @@ def loss_defs(self) -> list[LossDef]: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) return loss_defs + @property + def metric_defs(self) -> list[LossDef]: + metric_defs = [] + if ( + self._config.transformer.num_experts > 1 + and self._config.transformer.expert_routing_type == RoutingType.topk + ): + metric_defs.append( + LossDef(name=TransformerRoutingMetrics.normalized_average_entropy, formatted_name="Normalized Entropy", count=1) + ) + metric_defs.append( + LossDef(name=TransformerRoutingMetrics.mutual_info, formatted_name="Mutual Information", count=1) + ) + return metric_defs class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig diff --git a/tests/test_routing_metrics.py b/tests/test_moe_metrics.py similarity index 59% rename from tests/test_routing_metrics.py rename to tests/test_moe_metrics.py index 7cd6d005..7ca1fb15 100644 --- a/tests/test_routing_metrics.py +++ b/tests/test_moe_metrics.py @@ -5,7 +5,18 @@ calculate_mutual_information, calculate_normalized_average_entropy, ) +import torch +from unittest import mock +from fast_llm.engine.schedule.runner import ScheduleRunner, BatchContext +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.multi_stage.multi_stage import MultiStageModel +from fast_llm.engine.base_model.base_model import LossDef +from fast_llm.layers.transformer.config import TransformerRoutingMetrics def test_diversity_entropy(): """ @@ -161,5 +172,111 @@ def test_edge_cases(): assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {mutual_info}" + +@pytest.fixture +def setup_runner(): + """Fixture to set up the test environment.""" + # Mock objects needed for testing + distributed_config = DistributedConfig() + + # Mock MultiStageModel with loss_defs + multi_stage = mock.MagicMock(spec=MultiStageModel) + multi_stage.base_model.loss_defs = [ + LossDef(name="test_loss", formatted_name="Test Loss", count=1) + ] + multi_stage.base_model.metric_defs = [ + LossDef(name=TransformerRoutingMetrics.normalized_average_entropy, formatted_name="Normalized Entropy", count=1), + LossDef(name=TransformerRoutingMetrics.mutual_info, formatted_name="Mutual Information", count=1) + ] + + # Create a schedule runner + schedule_config = ScheduleConfig() + runner = ScheduleRunner( + config=schedule_config, + multi_stage=multi_stage, + distributed_config=distributed_config + ) + + # Mock distributed object + distributed = mock.MagicMock(spec=Distributed) + distributed.config = distributed_config + distributed.device = torch.device("cpu") + distributed.data_group = None + distributed.pipeline_group = None + + # Setup the runner + runner._distributed = distributed + runner.is_initialized = True + + # Create a mock schedule + schedule = mock.MagicMock(spec=Schedule) + schedule.phase = PhaseType.training + schedule.batch_config.num_inputs = 3 + schedule._schedule_config = schedule_config + + # Create a batch context with metrics and losses + context = BatchContext( + iteration=1, + schedule=schedule, + ) + + # Add test metrics + context.metrics = { + # Metrics that should be reduced (in TransformerReducedMetrics) + TransformerRoutingMetrics.normalized_average_entropy: [ + torch.tensor(0.5), torch.tensor(0.6), torch.tensor(0.7) + ], + TransformerRoutingMetrics.mutual_info: [ + torch.tensor(0.2), torch.tensor(0.3), torch.tensor(0.4) + ], + # Metric that should not be reduced + "non_reduced_metric": [torch.tensor(1.0), torch.tensor(1.0), torch.tensor(1.0)] + } + + # Add test losses + context.losses = { + "test_loss": [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)] + } + + return runner, context, schedule + + +def test_reduce_metrics(setup_runner): + """Test that _reduce_metrics correctly reduces only the appropriate metrics""" + runner, context, _ = setup_runner + + assert runner._is_reduced_metric(TransformerRoutingMetrics.normalized_average_entropy) is True + assert runner._is_reduced_metric(TransformerRoutingMetrics.mutual_info) is True + + assert runner._is_reduced_metric("non_reduced_metric") is False + assert runner._is_reduced_metric("random_metric") is False + + reduced_metrics = runner._reduce_metrics(context) + + # Check that metrics in TransformerReducedMetrics were reduced + assert TransformerRoutingMetrics.normalized_average_entropy in reduced_metrics + assert TransformerRoutingMetrics.mutual_info in reduced_metrics + + # Check that the values were correctly averaged + assert pytest.approx(reduced_metrics[TransformerRoutingMetrics.normalized_average_entropy], 0.001) == 0.6 + assert pytest.approx(reduced_metrics[TransformerRoutingMetrics.mutual_info], 0.001) == 0.3 + + # Check that non-reduced metrics are not in the result + assert "non_reduced_metric" in reduced_metrics + assert sum(reduced_metrics["non_reduced_metric"]) == 3.0 + + +def test_reduce_losses(setup_runner): + """Test that _reduce_losses correctly reduces losses""" + runner, context, _ = setup_runner + + reduced_losses = runner._reduce_losses(context) + + assert "test_loss" in reduced_losses + assert pytest.approx(reduced_losses["test_loss"], 0.001) == 2.0 + + + + if __name__ == "__main__": pytest.main([__file__]) From eb617e81d9a448e619eb8684041fd12eff1fc089 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 16 Mar 2025 19:42:21 +0000 Subject: [PATCH 07/10] check return_metrics before reducing metrics --- fast_llm/engine/schedule/runner.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 256ccfd5..4184833f 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -267,14 +267,20 @@ def run_step( log_pipeline_parallel_main_rank( lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) ) - - return self._reduce_losses(context), update_successful, self._reduce_metrics(context) + metrics = self._reduce_metrics(context) if return_metrics else None + return ( + self._reduce_losses(context), + update_successful, + metrics, + ) def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: return self._reduce_metric_or_loss(context, lambda name: self._loss_defs[name].count, "losses") def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]: - return self._reduce_metric_or_loss(context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric) + return self._reduce_metric_or_loss( + context, lambda name: self._metric_defs[name].count, "metrics", self._is_reduced_metric + ) def _reduce_metric_or_loss( self, @@ -309,6 +315,7 @@ def _reduce_metric_or_loss( def _is_reduced_metric(self, metric_name: str) -> bool: """Check if a metric should be reduced (is defined in a TransformerReducedMetrics subclass).""" from fast_llm.layers.transformer.config import TransformerReducedMetrics + if metric_name not in self._metric_defs: return False if not hasattr(self, "_reduced_metrics"): From 440738a56d435639858e36fc8eef6ca70e92dd9a Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 16 Mar 2025 19:43:01 +0000 Subject: [PATCH 08/10] check return metrics before reducing --- fast_llm/engine/schedule/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 4184833f..bcf6c674 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -267,7 +267,7 @@ def run_step( log_pipeline_parallel_main_rank( lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) ) - metrics = self._reduce_metrics(context) if return_metrics else None + metrics = self._reduce_metrics(context) if return_metrics else metrics return ( self._reduce_losses(context), update_successful, From e5f3c4b6e04fb4501db1617d9e6cbe793920eebc Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 16 Mar 2025 20:25:24 +0000 Subject: [PATCH 09/10] corrwect averaging with number of layers --- fast_llm/models/gpt/model.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 6ff82c74..d1bc9e13 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -19,7 +19,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, - TransformerRoutingMetrics + TransformerRoutingMetrics, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, RotaryEmbeddingPreprocessor from fast_llm.layers.transformer.transformer import TransformerLayer @@ -309,7 +309,7 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) - + if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) return loss_defs @@ -322,13 +322,22 @@ def metric_defs(self) -> list[LossDef]: and self._config.transformer.expert_routing_type == RoutingType.topk ): metric_defs.append( - LossDef(name=TransformerRoutingMetrics.normalized_average_entropy, formatted_name="Normalized Entropy", count=1) + LossDef( + name=TransformerRoutingMetrics.normalized_average_entropy, + formatted_name="Normalized Entropy", + count=self._config.transformer.num_layers, + ) ) metric_defs.append( - LossDef(name=TransformerRoutingMetrics.mutual_info, formatted_name="Mutual Information", count=1) + LossDef( + name=TransformerRoutingMetrics.mutual_info, + formatted_name="Mutual Information", + count=self._config.transformer.num_layers, + ) ) return metric_defs + class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel From 27e2a5c00263993fe29c5666d02b412a4547bd99 Mon Sep 17 00:00:00 2001 From: oleksost Date: Sun, 16 Mar 2025 22:30:16 +0000 Subject: [PATCH 10/10] device --- fast_llm/layers/transformer/mixture_of_experts.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index ba391619..9e194d2c 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -35,8 +35,7 @@ def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: n_experts = probs.size(-1) entropy_values = entropy(probs) average_entropy = entropy_values.mean() # Average over batch and tokens - return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype)) - + return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) def entropy(probs: torch.Tensor) -> torch.Tensor: probs = torch.clamp(probs, min=1e-9) # Avoid log(0) @@ -207,12 +206,12 @@ def _topk_routing( if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - # Calculate and log entropy and mutual information - entropy = calculate_normalized_average_entropy(probs) - mutual_info = calculate_mutual_information(probs) # Store these metrics if metrics is not None: + # Calculate and log entropy and mutual information + entropy = calculate_normalized_average_entropy(probs) + mutual_info = calculate_mutual_information(probs) if TransformerRoutingMetrics.normalized_average_entropy not in metrics: metrics[TransformerRoutingMetrics.normalized_average_entropy] = [] if TransformerRoutingMetrics.mutual_info not in metrics: