Skip to content

Add per variable loss to Stepper and Log using TrainAggs#981

Open
Arcomano1234 wants to merge 11 commits intomainfrom
feature/per-channel-loss-train-agg
Open

Add per variable loss to Stepper and Log using TrainAggs#981
Arcomano1234 wants to merge 11 commits intomainfrom
feature/per-channel-loss-train-agg

Conversation

@Arcomano1234
Copy link
Contributor

@Arcomano1234 Arcomano1234 commented Mar 16, 2026

Currently we do not log the individual components of the loss (e.g., variable contributions to the overall loss), although we do log weighted RMSE for training, val, and inference (which is very rarely the actual loss). This can make diagnosing overfitting difficult. This PR adds a per_channel (per variable) loss method to fme/core/loss.py. This gets called in the Stepper inside _accumulate_loss during the TrainAggregator (not during regular training).

Example of run on wandb.

Changes:

  • Add per channel loss to fme/core/loss.py

  • Log per channel loss to metrics inside _accumulate_loss of the Stepper when called by TrainAggregator

  • Add per channel metrics to TrainAggregator

  • Tests added

Resolves #485

@Arcomano1234 Arcomano1234 changed the title claude first attempt to add per variable loss to TrainAgg Add per variable loss to TrainAggs Mar 16, 2026
from fme.core.typing_ import TensorMapping

# Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]).
PER_CHANNEL_LOSS_PREFIX = "loss/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using a new attribute on the TrainOutput instead of a string label.

)
step_loss = self._loss_obj(gen_step, target_step, step=step)
metrics[f"loss_step_{step}"] = step_loss.detach()
per_channel = self._loss_obj.forward_per_channel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid computing this when it is not needed? I.e. it seems we only should do this calculation when computing metrics for the train/val aggregator, but not when actually training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed its a little hacky but the newest commit has a way around this.

@Arcomano1234 Arcomano1234 changed the title Add per variable loss to TrainAggs Add per variable loss to Stepper and Log using TrainAggs Mar 18, 2026
stepped = self.stepper.train_on_batch(
batch,
self._no_optimization,
compute_per_channel_metrics=compute_per_channel,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be fine to just hard-code this to True instead of coupling to the aggregator. Important thing is that it's only done for the train_evaluation_batches not the full training dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree should simply the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

target_data: BatchData,
optimization: OptimizationABC,
metrics: dict[str, float],
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why *?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude has been trying to add this back in every time I've made changes happen, for some reason it likes this style.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forces kwargs to be passed as kwargs and not as positional arguments, which IMO is nice though not something we enforce.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah at least for me claude really wants us to use this style. There is some benefit IMO to this way of passing kwargs but thats outside the scope of this PR

from fme.core.typing_ import TensorMapping

# Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]).
PER_CHANNEL_LOSS_PREFIX = "loss/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.

@Arcomano1234 Arcomano1234 marked this pull request as ready for review March 18, 2026 22:50
Copy link
Contributor

@mcgibbon mcgibbon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can iterate on this design to reduce the amount of code needed to handle this concern, and avoid re-computing the loss N times (with N GPU kernel dispatches). I'd suggest making a substantial change to the API of the loss object, having them return 1D vectors across channel, to avoid most of the code here as well as to avoid recomputing the loss (which can be expensive, I expect the GPUs to have very low occupancy during these calls). If we need to allow support for cross-channel losses that don't nicely fit into per-channel losses, we can iterate later by having loss return a data type allowing both per-channel and scalar losses (each optional).

Also, it feels a little odd to have TrainAggregator responsible so much for these logs. At the least, the validation aggregator should also log them, and it should be easy in principle (even if we don't want to run it that way as a matter of course) to include these in the per-batch metrics we get during training. The suggested change(s) would make this easier to implement if we choose, as it avoids the low-level prefix arithmetic in the aggregator.

Comment on lines +1639 to +1651
if compute_per_channel_metrics:
per_channel = self._loss_obj.forward_per_channel(
gen_step, target_step, step=step
)
if per_channel_sum is None:
per_channel_sum = {
k: v.detach().clone() for k, v in per_channel.items()
}
else:
for k in per_channel_sum:
per_channel_sum[k] = (
per_channel_sum[k] + per_channel[k].detach()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than re-computing the per-channel metrics, I suggest updating the code so _loss_obj returns a 1-D vector over the channel dimension. Then at this level you can use the unpacker to turn that into a dict of per-channel losses. while still using step_loss.sum() as the optimized/accumulated loss. This should make it cheap enough that you can just always "compute" these per-channel metrics, instead of needing new boolean flags.

Additionally, right now you have low-level accumulation code in the middle of the optimize function. I suggest instead using a simple Aggregator implementation that takes in a dict of values and keeps a running mean of those values (maybe the existing batch metrics aggregator already handles this) to accumulate these channel losses. You could either attach them as an additional attribute on TrainOutput that then gets passed to this aggregator above this scope (probably better? at least more consistent with what we currently do), or you could take the aggregator as an Optional input argument and record them in this scope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure that when refactoring this, you maintain the sum-ness of loss across steps (perhaps make sure this is tested).


return self.loss(predict_tensors, target_tensors)

def call_per_channel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.

step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5)
return self.loss(predict_dict, target_dict) * step_weight

def forward_per_channel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of "forward" here? In any case, I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.

from fme.core.typing_ import TensorMapping

# Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]).
PER_CHANNEL_LOSS_PREFIX = "loss/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using a new attribute on the TrainOutput instead of a string label.

target_data: BatchData,
optimization: OptimizationABC,
metrics: dict[str, float],
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forces kwargs to be passed as kwargs and not as positional arguments, which IMO is nice though not something we enforce.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add metrics to track variable-specific contributions to the training loss

3 participants