Add per variable loss to Stepper and Log using TrainAggs#981
Add per variable loss to Stepper and Log using TrainAggs#981Arcomano1234 wants to merge 11 commits intomainfrom
Conversation
| from fme.core.typing_ import TensorMapping | ||
|
|
||
| # Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]). | ||
| PER_CHANNEL_LOSS_PREFIX = "loss/" |
There was a problem hiding this comment.
not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.
There was a problem hiding this comment.
Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.
There was a problem hiding this comment.
The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".
There was a problem hiding this comment.
Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.
There was a problem hiding this comment.
I'd suggest using a new attribute on the TrainOutput instead of a string label.
fme/ace/stepper/single_module.py
Outdated
| ) | ||
| step_loss = self._loss_obj(gen_step, target_step, step=step) | ||
| metrics[f"loss_step_{step}"] = step_loss.detach() | ||
| per_channel = self._loss_obj.forward_per_channel( |
There was a problem hiding this comment.
Is there a way to avoid computing this when it is not needed? I.e. it seems we only should do this calculation when computing metrics for the train/val aggregator, but not when actually training.
There was a problem hiding this comment.
Agreed its a little hacky but the newest commit has a way around this.
fme/core/generics/trainer.py
Outdated
| stepped = self.stepper.train_on_batch( | ||
| batch, | ||
| self._no_optimization, | ||
| compute_per_channel_metrics=compute_per_channel, |
There was a problem hiding this comment.
I think it would be fine to just hard-code this to True instead of coupling to the aggregator. Important thing is that it's only done for the train_evaluation_batches not the full training dataset.
There was a problem hiding this comment.
Yeah I agree should simply the code.
| target_data: BatchData, | ||
| optimization: OptimizationABC, | ||
| metrics: dict[str, float], | ||
| *, |
There was a problem hiding this comment.
Claude has been trying to add this back in every time I've made changes happen, for some reason it likes this style.
There was a problem hiding this comment.
This forces kwargs to be passed as kwargs and not as positional arguments, which IMO is nice though not something we enforce.
There was a problem hiding this comment.
Yeah at least for me claude really wants us to use this style. There is some benefit IMO to this way of passing kwargs but thats outside the scope of this PR
| from fme.core.typing_ import TensorMapping | ||
|
|
||
| # Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]). | ||
| PER_CHANNEL_LOSS_PREFIX = "loss/" |
There was a problem hiding this comment.
Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.
…/ace into feature/per-channel-loss-train-agg
There was a problem hiding this comment.
I think we can iterate on this design to reduce the amount of code needed to handle this concern, and avoid re-computing the loss N times (with N GPU kernel dispatches). I'd suggest making a substantial change to the API of the loss object, having them return 1D vectors across channel, to avoid most of the code here as well as to avoid recomputing the loss (which can be expensive, I expect the GPUs to have very low occupancy during these calls). If we need to allow support for cross-channel losses that don't nicely fit into per-channel losses, we can iterate later by having loss return a data type allowing both per-channel and scalar losses (each optional).
Also, it feels a little odd to have TrainAggregator responsible so much for these logs. At the least, the validation aggregator should also log them, and it should be easy in principle (even if we don't want to run it that way as a matter of course) to include these in the per-batch metrics we get during training. The suggested change(s) would make this easier to implement if we choose, as it avoids the low-level prefix arithmetic in the aggregator.
| if compute_per_channel_metrics: | ||
| per_channel = self._loss_obj.forward_per_channel( | ||
| gen_step, target_step, step=step | ||
| ) | ||
| if per_channel_sum is None: | ||
| per_channel_sum = { | ||
| k: v.detach().clone() for k, v in per_channel.items() | ||
| } | ||
| else: | ||
| for k in per_channel_sum: | ||
| per_channel_sum[k] = ( | ||
| per_channel_sum[k] + per_channel[k].detach() | ||
| ) |
There was a problem hiding this comment.
Rather than re-computing the per-channel metrics, I suggest updating the code so _loss_obj returns a 1-D vector over the channel dimension. Then at this level you can use the unpacker to turn that into a dict of per-channel losses. while still using step_loss.sum() as the optimized/accumulated loss. This should make it cheap enough that you can just always "compute" these per-channel metrics, instead of needing new boolean flags.
Additionally, right now you have low-level accumulation code in the middle of the optimize function. I suggest instead using a simple Aggregator implementation that takes in a dict of values and keeps a running mean of those values (maybe the existing batch metrics aggregator already handles this) to accumulate these channel losses. You could either attach them as an additional attribute on TrainOutput that then gets passed to this aggregator above this scope (probably better? at least more consistent with what we currently do), or you could take the aggregator as an Optional input argument and record them in this scope.
There was a problem hiding this comment.
Make sure that when refactoring this, you maintain the sum-ness of loss across steps (perhaps make sure this is tested).
|
|
||
| return self.loss(predict_tensors, target_tensors) | ||
|
|
||
| def call_per_channel( |
There was a problem hiding this comment.
I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.
| step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5) | ||
| return self.loss(predict_dict, target_dict) * step_weight | ||
|
|
||
| def forward_per_channel( |
There was a problem hiding this comment.
What's the meaning of "forward" here? In any case, I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.
| from fme.core.typing_ import TensorMapping | ||
|
|
||
| # Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]). | ||
| PER_CHANNEL_LOSS_PREFIX = "loss/" |
There was a problem hiding this comment.
I'd suggest using a new attribute on the TrainOutput instead of a string label.
| target_data: BatchData, | ||
| optimization: OptimizationABC, | ||
| metrics: dict[str, float], | ||
| *, |
There was a problem hiding this comment.
This forces kwargs to be passed as kwargs and not as positional arguments, which IMO is nice though not something we enforce.
Currently we do not log the individual components of the loss (e.g., variable contributions to the overall loss), although we do log weighted RMSE for training, val, and inference (which is very rarely the actual loss). This can make diagnosing overfitting difficult. This PR adds a per_channel (per variable) loss method to
fme/core/loss.py. This gets called in the Stepper inside_accumulate_lossduring the TrainAggregator (not during regular training).Example of run on wandb.
Changes:
Add per channel loss to
fme/core/loss.pyLog per channel loss to metrics inside
_accumulate_lossof the Stepper when called byTrainAggregatorAdd per channel metrics to
TrainAggregatorTests added
Resolves #485