-
Notifications
You must be signed in to change notification settings - Fork 38
Add per variable loss to Stepper and Log using TrainAggs #981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
62077b9
1d1aded
7079529
ef1b9f6
87af429
6223e0e
35f2f1d
84211b1
7def3ca
a67c993
377fdc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,9 @@ | |
| from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim | ||
| from fme.core.typing_ import TensorMapping | ||
|
|
||
| # Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]). | ||
| PER_CHANNEL_LOSS_PREFIX = "loss/" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest using a new attribute on the TrainOutput instead of a string label. |
||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class TrainAggregatorConfig: | ||
|
|
@@ -22,10 +25,13 @@ class TrainAggregatorConfig: | |
| Attributes: | ||
| spherical_power_spectrum: Whether to compute the spherical power spectrum. | ||
| weighted_rmse: Whether to compute the weighted RMSE. | ||
| per_channel_loss: Whether to accumulate and report per-variable (per-channel) | ||
| loss in get_logs (e.g. train/mean/loss/<var_name>). | ||
| """ | ||
|
|
||
| spherical_power_spectrum: bool = True | ||
| weighted_rmse: bool = True | ||
| per_channel_loss: bool = False | ||
Arcomano1234 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class Aggregator(Protocol): | ||
|
|
@@ -47,6 +53,11 @@ class TrainAggregator(AggregatorABC[TrainOutput]): | |
| def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations): | ||
| self._n_loss_batches = 0 | ||
| self._loss = torch.tensor(0.0, device=get_device()) | ||
| if config.per_channel_loss: | ||
| self._per_channel_loss: dict[str, torch.Tensor] = {} | ||
| self._per_channel_loss_enabled = True | ||
| else: | ||
| self._per_channel_loss_enabled = False | ||
| self._paired_aggregators: dict[str, Aggregator] = {} | ||
| if config.spherical_power_spectrum: | ||
| self._paired_aggregators["power_spectrum"] = ( | ||
|
|
@@ -66,6 +77,16 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations) | |
| def record_batch(self, batch: TrainOutput): | ||
| self._loss += batch.metrics["loss"] | ||
| self._n_loss_batches += 1 | ||
| if self._per_channel_loss_enabled: | ||
| for key, value in batch.metrics.items(): | ||
| if not key.startswith(PER_CHANNEL_LOSS_PREFIX): | ||
| continue | ||
| var_name = key.removeprefix(PER_CHANNEL_LOSS_PREFIX) | ||
| acc = self._per_channel_loss.get( | ||
| var_name, | ||
| torch.tensor(0.0, device=get_device(), dtype=value.dtype), | ||
| ) | ||
| self._per_channel_loss[var_name] = acc + value | ||
|
|
||
| folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data) | ||
| folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble) | ||
|
|
@@ -93,6 +114,11 @@ def get_logs(self, label: str) -> dict[str, torch.Tensor]: | |
| logs[f"{label}/mean/loss"] = float( | ||
| dist.reduce_mean(self._loss / self._n_loss_batches).cpu().numpy() | ||
| ) | ||
| if self._n_loss_batches > 0 and self._per_channel_loss_enabled: | ||
| for var_name, acc in self._per_channel_loss.items(): | ||
| logs[f"{label}/mean/loss/{var_name}"] = float( | ||
| dist.reduce_mean(acc / self._n_loss_batches).cpu().numpy() | ||
| ) | ||
| return logs | ||
|
|
||
| @torch.no_grad() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1606,6 +1606,7 @@ def _accumulate_loss( | |
| "data requirements are retrieved, so this is a bug." | ||
| ) | ||
| n_forward_steps = stochastic_n_forward_steps | ||
| per_channel_sum: dict[str, torch.Tensor] | None = None | ||
| for step in range(n_forward_steps): | ||
| optimize_step = ( | ||
| step == n_forward_steps - 1 or not self._config.optimize_last_step_only | ||
|
|
@@ -1628,8 +1629,23 @@ def _accumulate_loss( | |
| ) | ||
| step_loss = self._loss_obj(gen_step, target_step, step=step) | ||
| metrics[f"loss_step_{step}"] = step_loss.detach() | ||
| per_channel = self._loss_obj.forward_per_channel( | ||
|
||
| gen_step, target_step, step=step | ||
| ) | ||
| if per_channel_sum is None: | ||
| per_channel_sum = { | ||
| k: v.detach().clone() for k, v in per_channel.items() | ||
| } | ||
| else: | ||
| for k in per_channel_sum: | ||
| per_channel_sum[k] = ( | ||
| per_channel_sum[k] + per_channel[k].detach() | ||
| ) | ||
| if optimize_step: | ||
| optimization.accumulate_loss(step_loss) | ||
| if per_channel_sum is not None and n_forward_steps > 0: | ||
| for k, v in per_channel_sum.items(): | ||
| metrics[f"loss/{k}"] = (v / n_forward_steps).detach() | ||
Arcomano1234 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return output_list | ||
|
|
||
| def update_training_history(self, training_job: TrainingJob) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,6 +73,50 @@ def __call__( | |
|
|
||
| return self.loss(predict_tensors, target_tensors) | ||
|
|
||
| def call_per_channel( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses. |
||
| self, | ||
| predict_dict: TensorMapping, | ||
| target_dict: TensorMapping, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Compute loss per variable (per channel). | ||
|
|
||
| Returns: | ||
| Dict mapping each variable name to its scalar loss tensor. | ||
| """ | ||
| predict_tensors = self.packer.pack( | ||
| self.normalizer.normalize(predict_dict), axis=self.channel_dim | ||
| ) | ||
| target_tensors = self.packer.pack( | ||
| self.normalizer.normalize(target_dict), axis=self.channel_dim | ||
| ) | ||
| nan_mask = target_tensors.isnan() | ||
| if nan_mask.any(): | ||
| predict_tensors = torch.where(nan_mask, 0.0, predict_tensors) | ||
| target_tensors = torch.where(nan_mask, 0.0, target_tensors) | ||
|
|
||
| result: dict[str, torch.Tensor] = {} | ||
| channel_dim = ( | ||
| self.channel_dim | ||
| if self.channel_dim >= 0 | ||
| else predict_tensors.dim() + self.channel_dim | ||
| ) | ||
| # _weight_tensor is always 4D (1, n_channels, 1, 1) from | ||
| # _construct_weight_tensor | ||
| weight_channel_dim = 1 | ||
| for i, name in enumerate(self.packer.names): | ||
| pred_slice = predict_tensors.select(channel_dim, i).unsqueeze(channel_dim) | ||
| target_slice = target_tensors.select(channel_dim, i).unsqueeze(channel_dim) | ||
| weight_slice = self._weight_tensor.select(weight_channel_dim, i) | ||
| # Broadcast weight to match pred_slice ndim | ||
| # (e.g. 5D when ensemble dim present) | ||
| while weight_slice.dim() < pred_slice.dim(): | ||
| weight_slice = weight_slice.unsqueeze(-1) | ||
| result[name] = self.loss.loss( | ||
| weight_slice * pred_slice, weight_slice * target_slice | ||
| ) | ||
| return result | ||
|
|
||
| def get_normalizer_state(self) -> dict[str, float]: | ||
| return self.normalizer.get_state() | ||
|
|
||
|
|
@@ -470,6 +514,22 @@ def forward( | |
| step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5) | ||
| return self.loss(predict_dict, target_dict) * step_weight | ||
|
|
||
| def forward_per_channel( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the meaning of "forward" here? In any case, I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses. |
||
| self, | ||
| predict_dict: TensorMapping, | ||
| target_dict: TensorMapping, | ||
| step: int, | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Compute per-variable (per-channel) loss with step weighting. | ||
|
|
||
| Returns: | ||
| Dict mapping each variable name to its scalar loss tensor. | ||
| """ | ||
| step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5) | ||
| per_channel = self.loss.call_per_channel(predict_dict, target_dict) | ||
| return {k: v * step_weight for k, v in per_channel.items()} | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class StepLossConfig: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.