diff --git a/fme/ace/aggregator/test_train.py b/fme/ace/aggregator/test_train.py index d2ce71589..459f5916e 100644 --- a/fme/ace/aggregator/test_train.py +++ b/fme/ace/aggregator/test_train.py @@ -78,3 +78,93 @@ def test_aggregator_gets_logs_with_no_batches(config: TrainAggregatorConfig): logs = agg.get_logs(label="test") assert np.isnan(logs.pop("test/mean/loss")) assert logs == {} + + +def test_aggregator_logs_per_channel_loss(): + """ + Per-channel (per-variable) loss is accumulated from batch.metrics and reported. + """ + batch_size = 4 + n_ensemble = 1 + n_time = 2 + nx, ny = 2, 2 + device = get_device() + gridded_operations = LatLonOperations( + area_weights=torch.ones(nx, ny, device=device) + ) + config = TrainAggregatorConfig( + spherical_power_spectrum=False, weighted_rmse=False, per_channel_loss=True + ) + agg = TrainAggregator(config=config, operations=gridded_operations) + target_data = EnsembleTensorDict( + {"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)}, + ) + gen_data = EnsembleTensorDict( + {"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)}, + ) + agg.record_batch( + batch=TrainOutput( + metrics={ + "loss": torch.tensor(1.0, device=device), + "loss/a": torch.tensor(0.5, device=device), + }, + target_data=target_data, + gen_data=gen_data, + time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]), + normalize=lambda x: x, + ), + ) + agg.record_batch( + batch=TrainOutput( + metrics={ + "loss": torch.tensor(2.0, device=device), + "loss/a": torch.tensor(1.0, device=device), + }, + target_data=target_data, + gen_data=gen_data, + time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]), + normalize=lambda x: x, + ), + ) + logs = agg.get_logs(label="train") + assert logs["train/mean/loss"] == 1.5 + assert logs["train/mean/loss/a"] == 0.75 + + +def test_aggregator_per_channel_loss_disabled(): + """When per_channel_loss=False, get_logs does not include per-variable loss.""" + batch_size = 4 + n_ensemble = 1 + n_time = 2 + nx, ny = 2, 2 + device = get_device() + gridded_operations = LatLonOperations( + area_weights=torch.ones(nx, ny, device=device) + ) + config = TrainAggregatorConfig( + spherical_power_spectrum=False, + weighted_rmse=False, + per_channel_loss=False, + ) + agg = TrainAggregator(config=config, operations=gridded_operations) + target_data = EnsembleTensorDict( + {"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)}, + ) + gen_data = EnsembleTensorDict( + {"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)}, + ) + agg.record_batch( + batch=TrainOutput( + metrics={ + "loss": torch.tensor(1.0, device=device), + "loss/a": torch.tensor(0.5, device=device), + }, + target_data=target_data, + gen_data=gen_data, + time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]), + normalize=lambda x: x, + ), + ) + logs = agg.get_logs(label="train") + assert logs["train/mean/loss"] == 1.0 + assert "train/mean/loss/a" not in logs diff --git a/fme/ace/aggregator/train.py b/fme/ace/aggregator/train.py index 0e1071e0b..d75bfeabb 100644 --- a/fme/ace/aggregator/train.py +++ b/fme/ace/aggregator/train.py @@ -14,6 +14,9 @@ from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim from fme.core.typing_ import TensorMapping +# Metric key prefix for per-variable loss (must match stepper's metrics["loss/"]). +PER_CHANNEL_LOSS_PREFIX = "loss/" + @dataclasses.dataclass class TrainAggregatorConfig: @@ -23,10 +26,13 @@ class TrainAggregatorConfig: Attributes: spherical_power_spectrum: Whether to compute the spherical power spectrum. weighted_rmse: Whether to compute the weighted RMSE. + per_channel_loss: Whether to accumulate and report per-variable (per-channel) + loss in get_logs (e.g. train/mean/loss/). """ spherical_power_spectrum: bool = True weighted_rmse: bool = True + per_channel_loss: bool = True class Aggregator(Protocol): @@ -48,6 +54,8 @@ class TrainAggregator(AggregatorABC[TrainOutput]): def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations): self._n_loss_batches = 0 self._loss = torch.tensor(0.0, device=get_device()) + self._per_channel_loss: dict[str, torch.Tensor] = {} + self._per_channel_loss_enabled = config.per_channel_loss self._paired_aggregators: dict[str, Aggregator] = {} if config.spherical_power_spectrum: try: @@ -73,6 +81,16 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations) def record_batch(self, batch: TrainOutput): self._loss += batch.metrics["loss"] self._n_loss_batches += 1 + if self._per_channel_loss_enabled: + for key, value in batch.metrics.items(): + if not key.startswith(PER_CHANNEL_LOSS_PREFIX): + continue + var_name = key.removeprefix(PER_CHANNEL_LOSS_PREFIX) + acc = self._per_channel_loss.get( + var_name, + torch.tensor(0.0, device=get_device(), dtype=value.dtype), + ) + self._per_channel_loss[var_name] = acc + value folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data) folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble) @@ -100,6 +118,11 @@ def get_logs(self, label: str) -> dict[str, torch.Tensor]: logs[f"{label}/mean/loss"] = float( dist.reduce_mean(self._loss / self._n_loss_batches).cpu().numpy() ) + if self._n_loss_batches > 0 and self._per_channel_loss_enabled: + for var_name, acc in self._per_channel_loss.items(): + logs[f"{label}/mean/loss/{var_name}"] = float( + dist.reduce_mean(acc / self._n_loss_batches).cpu().numpy() + ) return logs @torch.no_grad() diff --git a/fme/ace/stepper/single_module.py b/fme/ace/stepper/single_module.py index afb2ea5fd..c3e1fc0ed 100644 --- a/fme/ace/stepper/single_module.py +++ b/fme/ace/stepper/single_module.py @@ -1503,6 +1503,7 @@ def train_on_batch( data: BatchData, optimization: OptimizationABC, compute_derived_variables: bool = False, + compute_per_channel_metrics: bool = False, ) -> TrainOutput: """ Train the model on a batch of data with one or more forward steps. @@ -1519,6 +1520,9 @@ def train_on_batch( Use `NullOptimization` to disable training. compute_derived_variables: Whether to compute derived variables for the prediction and target data. + compute_per_channel_metrics: Whether to compute per-variable loss and add + to metrics (for TrainAggregator). Only set True when evaluating for + logging, not during training steps. Returns: The loss metrics, the generated data, the normalized generated data, @@ -1539,6 +1543,7 @@ def train_on_batch( target_data, optimization, metrics, + compute_per_channel_metrics=compute_per_channel_metrics, ) regularizer_loss = self._stepper.get_regularizer_loss() @@ -1573,6 +1578,8 @@ def _accumulate_loss( target_data: BatchData, optimization: OptimizationABC, metrics: dict[str, float], + *, + compute_per_channel_metrics: bool = False, ) -> list[EnsembleTensorDict]: input_data = data.get_start(self._prognostic_names, self.n_ic_timesteps) # output from self.predict_paired does not include initial condition @@ -1606,6 +1613,7 @@ def _accumulate_loss( "data requirements are retrieved, so this is a bug." ) n_forward_steps = stochastic_n_forward_steps + per_channel_sum: dict[str, torch.Tensor] | None = None for step in range(n_forward_steps): optimize_step = ( step == n_forward_steps - 1 or not self._config.optimize_last_step_only @@ -1628,8 +1636,24 @@ def _accumulate_loss( ) step_loss = self._loss_obj(gen_step, target_step, step=step) metrics[f"loss_step_{step}"] = step_loss.detach() + if compute_per_channel_metrics: + per_channel = self._loss_obj.forward_per_channel( + gen_step, target_step, step=step + ) + if per_channel_sum is None: + per_channel_sum = { + k: v.detach().clone() for k, v in per_channel.items() + } + else: + for k in per_channel_sum: + per_channel_sum[k] = ( + per_channel_sum[k] + per_channel[k].detach() + ) if optimize_step: optimization.accumulate_loss(step_loss) + if compute_per_channel_metrics and per_channel_sum is not None: + for k, v in per_channel_sum.items(): + metrics[f"loss/{k}"] = v.detach() return output_list def update_training_history(self, training_job: TrainingJob) -> None: diff --git a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False-crps.pt b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False-crps.pt index 6378eb1d6..9f6f82508 100644 Binary files a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False-crps.pt and b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False-crps.pt differ diff --git a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False.pt b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False.pt index 42834eedc..080e94523 100644 Binary files a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False.pt and b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-False.pt differ diff --git a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True-crps.pt b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True-crps.pt index c501a17a2..4e676e3b0 100644 Binary files a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True-crps.pt and b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True-crps.pt differ diff --git a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True.pt b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True.pt index 05dc5079d..aade0fc39 100644 Binary files a/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True.pt and b/fme/ace/stepper/testdata/stepper_train_on_batch_regression-True.pt differ diff --git a/fme/core/generics/test_trainer.py b/fme/core/generics/test_trainer.py index 3539f43c5..1eb3f01e0 100644 --- a/fme/core/generics/test_trainer.py +++ b/fme/core/generics/test_trainer.py @@ -198,6 +198,7 @@ def train_on_batch( batch: BDType, optimization: OptimizationABC, compute_derived_variables: bool = False, + compute_per_channel_metrics: bool = False, ) -> TrainOutput: optimization.accumulate_loss(torch.tensor(float("inf"))) optimization.step_weights() diff --git a/fme/core/generics/train_stepper.py b/fme/core/generics/train_stepper.py index 254daebe3..8d72eec45 100644 --- a/fme/core/generics/train_stepper.py +++ b/fme/core/generics/train_stepper.py @@ -31,6 +31,7 @@ def train_on_batch( data: BD, optimization: OptimizationABC, compute_derived_variables: bool = False, + compute_per_channel_metrics: bool = False, ) -> TO: pass diff --git a/fme/core/generics/trainer.py b/fme/core/generics/trainer.py index 7bedc1fe9..24ad3549c 100644 --- a/fme/core/generics/trainer.py +++ b/fme/core/generics/trainer.py @@ -495,7 +495,11 @@ def train_one_epoch(self): stop_batch=self.config.train_evaluation_batches ): with GlobalTimer(): - stepped = self.stepper.train_on_batch(batch, self._no_optimization) + stepped = self.stepper.train_on_batch( + batch, + self._no_optimization, + compute_per_channel_metrics=True, + ) aggregator.record_batch(stepped) if ( self._should_save_checkpoints() diff --git a/fme/core/loss.py b/fme/core/loss.py index 13c6c0460..fc8683d93 100644 --- a/fme/core/loss.py +++ b/fme/core/loss.py @@ -73,6 +73,50 @@ def __call__( return self.loss(predict_tensors, target_tensors) + def call_per_channel( + self, + predict_dict: TensorMapping, + target_dict: TensorMapping, + ) -> dict[str, torch.Tensor]: + """ + Compute loss per variable (per channel). + + Returns: + Dict mapping each variable name to its scalar loss tensor. + """ + predict_tensors = self.packer.pack( + self.normalizer.normalize(predict_dict), axis=self.channel_dim + ) + target_tensors = self.packer.pack( + self.normalizer.normalize(target_dict), axis=self.channel_dim + ) + nan_mask = target_tensors.isnan() + if nan_mask.any(): + predict_tensors = torch.where(nan_mask, 0.0, predict_tensors) + target_tensors = torch.where(nan_mask, 0.0, target_tensors) + + result: dict[str, torch.Tensor] = {} + channel_dim = ( + self.channel_dim + if self.channel_dim >= 0 + else predict_tensors.dim() + self.channel_dim + ) + # _weight_tensor is always 4D (1, n_channels, 1, 1) from + # _construct_weight_tensor + weight_channel_dim = 1 + for i, name in enumerate(self.packer.names): + pred_slice = predict_tensors.select(channel_dim, i).unsqueeze(channel_dim) + target_slice = target_tensors.select(channel_dim, i).unsqueeze(channel_dim) + weight_slice = self._weight_tensor.select(weight_channel_dim, i) + # Broadcast weight to match pred_slice ndim + # (e.g. 5D when ensemble dim present) + while weight_slice.dim() < pred_slice.dim(): + weight_slice = weight_slice.unsqueeze(-1) + result[name] = self.loss.loss( + weight_slice * pred_slice, weight_slice * target_slice + ) + return result + def get_normalizer_state(self) -> dict[str, float]: return self.normalizer.get_state() @@ -470,6 +514,22 @@ def forward( step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5) return self.loss(predict_dict, target_dict) * step_weight + def forward_per_channel( + self, + predict_dict: TensorMapping, + target_dict: TensorMapping, + step: int, + ) -> dict[str, torch.Tensor]: + """ + Compute per-variable (per-channel) loss with step weighting. + + Returns: + Dict mapping each variable name to its scalar loss tensor. + """ + step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5) + per_channel = self.loss.call_per_channel(predict_dict, target_dict) + return {k: v * step_weight for k, v in per_channel.items()} + @dataclasses.dataclass class StepLossConfig: diff --git a/fme/core/test_loss.py b/fme/core/test_loss.py index 5c955dc0a..e665ec645 100644 --- a/fme/core/test_loss.py +++ b/fme/core/test_loss.py @@ -402,3 +402,30 @@ def test_WeightedMappingLoss_with_target_nans(): x[:, 0, :, 0] = 0.0 y[:, 0, :, 0] = 0.0 assert torch.allclose(mapping_loss(x_mapping, y_mapping), loss(x, y)) + + +def test_WeightedMappingLoss_call_per_channel(): + """Per-channel loss sums to total loss and has one entry per variable.""" + loss_fn = torch.nn.MSELoss() + n_channels = 3 + out_names = [f"var_{i}" for i in range(n_channels)] + normalizer = StandardNormalizer( + means={name: torch.as_tensor(0.0) for name in out_names}, + stds={name: torch.as_tensor(1.0) for name in out_names}, + ) + mapping_loss = WeightedMappingLoss( + loss_fn, + weights={}, + out_names=out_names, + normalizer=normalizer, + ) + x = torch.randn(4, n_channels, 5, 5, device=get_device(), dtype=torch.float) + y = torch.randn(4, n_channels, 5, 5, device=get_device(), dtype=torch.float) + x_mapping = {name: x[:, i, :, :] for i, name in enumerate(out_names)} + y_mapping = {name: y[:, i, :, :] for i, name in enumerate(out_names)} + total = mapping_loss(x_mapping, y_mapping) + per_channel = mapping_loss.call_per_channel(x_mapping, y_mapping) + assert set(per_channel.keys()) == set(out_names) + # Mean of per-channel losses equals total (total is mean over all channels) + mean_per_channel = torch.stack(list(per_channel.values())).mean() + assert torch.allclose(total, mean_per_channel) diff --git a/fme/coupled/stepper.py b/fme/coupled/stepper.py index 1fa9f3efc..bbac23a79 100644 --- a/fme/coupled/stepper.py +++ b/fme/coupled/stepper.py @@ -1499,6 +1499,7 @@ def train_on_batch( data: CoupledBatchData, optimization: OptimizationABC, compute_derived_variables: bool = False, + compute_per_channel_metrics: bool = False, ) -> CoupledTrainOutput: """ Args: @@ -1508,7 +1509,9 @@ def train_on_batch( Use `NullOptimization` to disable training. compute_derived_variables: Whether to compute derived variables for the prediction and target atmosphere data. - + compute_per_channel_metrics: Whether to compute per-variable loss and add + to metrics (for TrainAggregator). Only set True when evaluating for + logging, not during training steps. """ # get initial condition prognostic variables input_data = CoupledPrognosticState( diff --git a/fme/diffusion/stepper.py b/fme/diffusion/stepper.py index 67935bca6..1a6bec997 100644 --- a/fme/diffusion/stepper.py +++ b/fme/diffusion/stepper.py @@ -1006,6 +1006,7 @@ def train_on_batch( data: BatchData, optimization: OptimizationABC, compute_derived_variables: bool = False, + compute_per_channel_metrics: bool = False, ) -> TrainOutput: """ Step the model forward multiple steps on a batch of data. @@ -1017,6 +1018,9 @@ def train_on_batch( Use `NullOptimization` to disable training. compute_derived_variables: Whether to compute derived variables for the prediction and target data. + compute_per_channel_metrics: Whether to compute per-variable loss and add + to metrics (for TrainAggregator). Only set True when evaluating for + logging, not during training steps. Returns: The loss metrics, the generated data, the normalized generated data,