-
Notifications
You must be signed in to change notification settings - Fork 38
Add per variable loss to Stepper and Log using TrainAggs #981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
62077b9
1d1aded
7079529
ef1b9f6
87af429
6223e0e
35f2f1d
84211b1
7def3ca
a67c993
377fdc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,9 @@ | |
| from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim | ||
| from fme.core.typing_ import TensorMapping | ||
|
|
||
| # Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]). | ||
| PER_CHANNEL_LOSS_PREFIX = "loss/" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest using a new attribute on the TrainOutput instead of a string label. |
||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class TrainAggregatorConfig: | ||
|
|
@@ -22,10 +25,13 @@ class TrainAggregatorConfig: | |
| Attributes: | ||
| spherical_power_spectrum: Whether to compute the spherical power spectrum. | ||
| weighted_rmse: Whether to compute the weighted RMSE. | ||
| per_channel_loss: Whether to accumulate and report per-variable (per-channel) | ||
| loss in get_logs (e.g. train/mean/loss/<var_name>). | ||
| """ | ||
|
|
||
| spherical_power_spectrum: bool = True | ||
| weighted_rmse: bool = True | ||
| per_channel_loss: bool = True | ||
|
|
||
|
|
||
| class Aggregator(Protocol): | ||
|
|
@@ -47,6 +53,8 @@ class TrainAggregator(AggregatorABC[TrainOutput]): | |
| def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations): | ||
| self._n_loss_batches = 0 | ||
| self._loss = torch.tensor(0.0, device=get_device()) | ||
| self._per_channel_loss: dict[str, torch.Tensor] = {} | ||
| self._per_channel_loss_enabled = config.per_channel_loss | ||
| self._paired_aggregators: dict[str, Aggregator] = {} | ||
| if config.spherical_power_spectrum: | ||
| self._paired_aggregators["power_spectrum"] = ( | ||
|
|
@@ -62,10 +70,27 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations) | |
| include_grad_mag_percent_diff=False, | ||
| ) | ||
|
|
||
| @property | ||
| def per_channel_loss_enabled(self) -> bool: | ||
| """ | ||
| Whether this aggregator accumulates per-variable loss from batch metrics. | ||
| """ | ||
| return self._per_channel_loss_enabled | ||
|
|
||
| @torch.no_grad() | ||
| def record_batch(self, batch: TrainOutput): | ||
| self._loss += batch.metrics["loss"] | ||
| self._n_loss_batches += 1 | ||
| if self._per_channel_loss_enabled: | ||
| for key, value in batch.metrics.items(): | ||
| if not key.startswith(PER_CHANNEL_LOSS_PREFIX): | ||
| continue | ||
| var_name = key.removeprefix(PER_CHANNEL_LOSS_PREFIX) | ||
| acc = self._per_channel_loss.get( | ||
| var_name, | ||
| torch.tensor(0.0, device=get_device(), dtype=value.dtype), | ||
| ) | ||
| self._per_channel_loss[var_name] = acc + value | ||
|
|
||
| folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data) | ||
| folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble) | ||
|
|
@@ -93,6 +118,11 @@ def get_logs(self, label: str) -> dict[str, torch.Tensor]: | |
| logs[f"{label}/mean/loss"] = float( | ||
| dist.reduce_mean(self._loss / self._n_loss_batches).cpu().numpy() | ||
| ) | ||
| if self._n_loss_batches > 0 and self._per_channel_loss_enabled: | ||
| for var_name, acc in self._per_channel_loss.items(): | ||
| logs[f"{label}/mean/loss/{var_name}"] = float( | ||
| dist.reduce_mean(acc / self._n_loss_batches).cpu().numpy() | ||
| ) | ||
| return logs | ||
|
|
||
| @torch.no_grad() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1503,6 +1503,7 @@ def train_on_batch( | |
| data: BatchData, | ||
| optimization: OptimizationABC, | ||
| compute_derived_variables: bool = False, | ||
| compute_per_channel_metrics: bool = False, | ||
| ) -> TrainOutput: | ||
| """ | ||
| Train the model on a batch of data with one or more forward steps. | ||
|
|
@@ -1519,6 +1520,9 @@ def train_on_batch( | |
| Use `NullOptimization` to disable training. | ||
| compute_derived_variables: Whether to compute derived variables for the | ||
| prediction and target data. | ||
| compute_per_channel_metrics: Whether to compute per-variable loss and add | ||
| to metrics (for TrainAggregator). Only set True when evaluating for | ||
| logging, not during training steps. | ||
|
|
||
| Returns: | ||
| The loss metrics, the generated data, the normalized generated data, | ||
|
|
@@ -1539,6 +1543,7 @@ def train_on_batch( | |
| target_data, | ||
| optimization, | ||
| metrics, | ||
| compute_per_channel_metrics=compute_per_channel_metrics, | ||
| ) | ||
|
|
||
| regularizer_loss = self._stepper.get_regularizer_loss() | ||
|
|
@@ -1573,6 +1578,8 @@ def _accumulate_loss( | |
| target_data: BatchData, | ||
| optimization: OptimizationABC, | ||
| metrics: dict[str, float], | ||
| *, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why *?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude has been trying to add this back in every time I've made changes happen, for some reason it likes this style.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This forces kwargs to be passed as kwargs and not as positional arguments, which IMO is nice though not something we enforce.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah at least for me claude really wants us to use this style. There is some benefit IMO to this way of passing kwargs but thats outside the scope of this PR |
||
| compute_per_channel_metrics: bool = False, | ||
| ) -> list[EnsembleTensorDict]: | ||
| input_data = data.get_start(self._prognostic_names, self.n_ic_timesteps) | ||
| # output from self.predict_paired does not include initial condition | ||
|
|
@@ -1606,6 +1613,7 @@ def _accumulate_loss( | |
| "data requirements are retrieved, so this is a bug." | ||
| ) | ||
| n_forward_steps = stochastic_n_forward_steps | ||
| per_channel_sum: dict[str, torch.Tensor] | None = None | ||
| for step in range(n_forward_steps): | ||
| optimize_step = ( | ||
| step == n_forward_steps - 1 or not self._config.optimize_last_step_only | ||
|
|
@@ -1628,8 +1636,24 @@ def _accumulate_loss( | |
| ) | ||
| step_loss = self._loss_obj(gen_step, target_step, step=step) | ||
| metrics[f"loss_step_{step}"] = step_loss.detach() | ||
| if compute_per_channel_metrics: | ||
| per_channel = self._loss_obj.forward_per_channel( | ||
| gen_step, target_step, step=step | ||
| ) | ||
| if per_channel_sum is None: | ||
| per_channel_sum = { | ||
| k: v.detach().clone() for k, v in per_channel.items() | ||
| } | ||
| else: | ||
| for k in per_channel_sum: | ||
| per_channel_sum[k] = ( | ||
| per_channel_sum[k] + per_channel[k].detach() | ||
| ) | ||
|
Comment on lines
+1639
to
+1651
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than re-computing the per-channel metrics, I suggest updating the code so Additionally, right now you have low-level accumulation code in the middle of the optimize function. I suggest instead using a simple
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure that when refactoring this, you maintain the sum-ness of loss across steps (perhaps make sure this is tested). |
||
| if optimize_step: | ||
| optimization.accumulate_loss(step_loss) | ||
| if compute_per_channel_metrics and per_channel_sum is not None: | ||
| for k, v in per_channel_sum.items(): | ||
| metrics[f"loss/{k}"] = v.detach() | ||
| return output_list | ||
|
|
||
| def update_training_history(self, training_job: TrainingJob) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -490,12 +490,17 @@ def train_one_epoch(self): | |
| self.train_data.alternate_shuffle() | ||
| aggregator = self._aggregator_builder.get_train_aggregator() | ||
| self.stepper.set_eval() | ||
| compute_per_channel = getattr(aggregator, "per_channel_loss_enabled", False) | ||
| with torch.no_grad(), self.validation_context(): | ||
| for batch in self.train_data.subset_loader( | ||
| stop_batch=self.config.train_evaluation_batches | ||
| ): | ||
| with GlobalTimer(): | ||
| stepped = self.stepper.train_on_batch(batch, self._no_optimization) | ||
| stepped = self.stepper.train_on_batch( | ||
| batch, | ||
| self._no_optimization, | ||
| compute_per_channel_metrics=compute_per_channel, | ||
|
||
| ) | ||
| aggregator.record_batch(stepped) | ||
| if ( | ||
| self._should_save_checkpoints() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,6 +73,50 @@ def __call__( | |
|
|
||
| return self.loss(predict_tensors, target_tensors) | ||
|
|
||
| def call_per_channel( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses. |
||
| self, | ||
| predict_dict: TensorMapping, | ||
| target_dict: TensorMapping, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Compute loss per variable (per channel). | ||
|
|
||
| Returns: | ||
| Dict mapping each variable name to its scalar loss tensor. | ||
| """ | ||
| predict_tensors = self.packer.pack( | ||
| self.normalizer.normalize(predict_dict), axis=self.channel_dim | ||
| ) | ||
| target_tensors = self.packer.pack( | ||
| self.normalizer.normalize(target_dict), axis=self.channel_dim | ||
| ) | ||
| nan_mask = target_tensors.isnan() | ||
| if nan_mask.any(): | ||
| predict_tensors = torch.where(nan_mask, 0.0, predict_tensors) | ||
| target_tensors = torch.where(nan_mask, 0.0, target_tensors) | ||
|
|
||
| result: dict[str, torch.Tensor] = {} | ||
| channel_dim = ( | ||
| self.channel_dim | ||
| if self.channel_dim >= 0 | ||
| else predict_tensors.dim() + self.channel_dim | ||
| ) | ||
| # _weight_tensor is always 4D (1, n_channels, 1, 1) from | ||
| # _construct_weight_tensor | ||
| weight_channel_dim = 1 | ||
| for i, name in enumerate(self.packer.names): | ||
| pred_slice = predict_tensors.select(channel_dim, i).unsqueeze(channel_dim) | ||
| target_slice = target_tensors.select(channel_dim, i).unsqueeze(channel_dim) | ||
| weight_slice = self._weight_tensor.select(weight_channel_dim, i) | ||
| # Broadcast weight to match pred_slice ndim | ||
| # (e.g. 5D when ensemble dim present) | ||
| while weight_slice.dim() < pred_slice.dim(): | ||
| weight_slice = weight_slice.unsqueeze(-1) | ||
| result[name] = self.loss.loss( | ||
| weight_slice * pred_slice, weight_slice * target_slice | ||
| ) | ||
| return result | ||
|
|
||
| def get_normalizer_state(self) -> dict[str, float]: | ||
| return self.normalizer.get_state() | ||
|
|
||
|
|
@@ -470,6 +514,22 @@ def forward( | |
| step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5) | ||
| return self.loss(predict_dict, target_dict) * step_weight | ||
|
|
||
| def forward_per_channel( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the meaning of "forward" here? In any case, I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses. |
||
| self, | ||
| predict_dict: TensorMapping, | ||
| target_dict: TensorMapping, | ||
| step: int, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Compute per-variable (per-channel) loss with step weighting. | ||
|
|
||
| Returns: | ||
| Dict mapping each variable name to its scalar loss tensor. | ||
| """ | ||
| step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5) | ||
| per_channel = self.loss.call_per_channel(predict_dict, target_dict) | ||
| return {k: v * step_weight for k, v in per_channel.items()} | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class StepLossConfig: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.