Skip to content
Open
4 changes: 3 additions & 1 deletion configs/baselines/era5/ace-train-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
experiment_dir: /results
save_checkpoint: true
validate_using_ema: true
max_epochs: 120
max_epochs: 2
n_forward_steps: 2
train_aggregator:
per_channel_loss: True
ema:
decay: 0.999
inference:
Expand Down
2 changes: 1 addition & 1 deletion configs/baselines/era5/run-ace-train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ run_training() {
-- torchrun --nproc_per_node $N_GPUS -m fme.ace.train $CONFIG_PATH
}

run_training "ace-train-config.yaml" "ace2-era5-train" "ace2-era5"
run_training "ace-train-config.yaml" "ace2-era5-train-test-per-var-loss" "ace2-era5"
90 changes: 90 additions & 0 deletions fme/ace/aggregator/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,93 @@ def test_aggregator_gets_logs_with_no_batches(config: TrainAggregatorConfig):
logs = agg.get_logs(label="test")
assert np.isnan(logs.pop("test/mean/loss"))
assert logs == {}


def test_aggregator_logs_per_channel_loss():
"""
Per-channel (per-variable) loss is accumulated from batch.metrics and reported.
"""
batch_size = 4
n_ensemble = 1
n_time = 2
nx, ny = 2, 2
device = get_device()
gridded_operations = LatLonOperations(
area_weights=torch.ones(nx, ny, device=device)
)
config = TrainAggregatorConfig(
spherical_power_spectrum=False, weighted_rmse=False, per_channel_loss=True
)
agg = TrainAggregator(config=config, operations=gridded_operations)
target_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)},
)
gen_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)},
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(1.0, device=device),
"loss/a": torch.tensor(0.5, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(2.0, device=device),
"loss/a": torch.tensor(1.0, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
logs = agg.get_logs(label="train")
assert logs["train/mean/loss"] == 1.5
assert logs["train/mean/loss/a"] == 0.75


def test_aggregator_per_channel_loss_disabled():
"""When per_channel_loss=False, get_logs does not include per-variable loss."""
batch_size = 4
n_ensemble = 1
n_time = 2
nx, ny = 2, 2
device = get_device()
gridded_operations = LatLonOperations(
area_weights=torch.ones(nx, ny, device=device)
)
config = TrainAggregatorConfig(
spherical_power_spectrum=False,
weighted_rmse=False,
per_channel_loss=False,
)
agg = TrainAggregator(config=config, operations=gridded_operations)
target_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)},
)
gen_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)},
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(1.0, device=device),
"loss/a": torch.tensor(0.5, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
logs = agg.get_logs(label="train")
assert logs["train/mean/loss"] == 1.0
assert "train/mean/loss/a" not in logs
23 changes: 23 additions & 0 deletions fme/ace/aggregator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim
from fme.core.typing_ import TensorMapping

# Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]).
PER_CHANNEL_LOSS_PREFIX = "loss/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using a new attribute on the TrainOutput instead of a string label.



@dataclasses.dataclass
class TrainAggregatorConfig:
Expand All @@ -23,10 +26,13 @@ class TrainAggregatorConfig:
Attributes:
spherical_power_spectrum: Whether to compute the spherical power spectrum.
weighted_rmse: Whether to compute the weighted RMSE.
per_channel_loss: Whether to accumulate and report per-variable (per-channel)
loss in get_logs (e.g. train/mean/loss/<var_name>).
"""

spherical_power_spectrum: bool = True
weighted_rmse: bool = True
per_channel_loss: bool = True


class Aggregator(Protocol):
Expand All @@ -48,6 +54,8 @@ class TrainAggregator(AggregatorABC[TrainOutput]):
def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations):
self._n_loss_batches = 0
self._loss = torch.tensor(0.0, device=get_device())
self._per_channel_loss: dict[str, torch.Tensor] = {}
self._per_channel_loss_enabled = config.per_channel_loss
self._paired_aggregators: dict[str, Aggregator] = {}
if config.spherical_power_spectrum:
try:
Expand All @@ -73,6 +81,16 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations)
def record_batch(self, batch: TrainOutput):
self._loss += batch.metrics["loss"]
self._n_loss_batches += 1
if self._per_channel_loss_enabled:
for key, value in batch.metrics.items():
if not key.startswith(PER_CHANNEL_LOSS_PREFIX):
continue
var_name = key.removeprefix(PER_CHANNEL_LOSS_PREFIX)
acc = self._per_channel_loss.get(
var_name,
torch.tensor(0.0, device=get_device(), dtype=value.dtype),
)
self._per_channel_loss[var_name] = acc + value

folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data)
folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble)
Expand Down Expand Up @@ -100,6 +118,11 @@ def get_logs(self, label: str) -> dict[str, torch.Tensor]:
logs[f"{label}/mean/loss"] = float(
dist.reduce_mean(self._loss / self._n_loss_batches).cpu().numpy()
)
if self._n_loss_batches > 0 and self._per_channel_loss_enabled:
for var_name, acc in self._per_channel_loss.items():
logs[f"{label}/mean/loss/{var_name}"] = float(
dist.reduce_mean(acc / self._n_loss_batches).cpu().numpy()
)
return logs

@torch.no_grad()
Expand Down
24 changes: 24 additions & 0 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,6 +1503,7 @@ def train_on_batch(
data: BatchData,
optimization: OptimizationABC,
compute_derived_variables: bool = False,
compute_per_channel_metrics: bool = False,
) -> TrainOutput:
"""
Train the model on a batch of data with one or more forward steps.
Expand All @@ -1519,6 +1520,9 @@ def train_on_batch(
Use `NullOptimization` to disable training.
compute_derived_variables: Whether to compute derived variables for the
prediction and target data.
compute_per_channel_metrics: Whether to compute per-variable loss and add
to metrics (for TrainAggregator). Only set True when evaluating for
logging, not during training steps.

Returns:
The loss metrics, the generated data, the normalized generated data,
Expand All @@ -1539,6 +1543,7 @@ def train_on_batch(
target_data,
optimization,
metrics,
compute_per_channel_metrics=compute_per_channel_metrics,
)

regularizer_loss = self._stepper.get_regularizer_loss()
Expand Down Expand Up @@ -1573,6 +1578,8 @@ def _accumulate_loss(
target_data: BatchData,
optimization: OptimizationABC,
metrics: dict[str, float],
*,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why *?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude has been trying to add this back in every time I've made changes happen, for some reason it likes this style.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forces kwargs to be passed as kwargs and not as positional arguments, which IMO is nice though not something we enforce.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah at least for me claude really wants us to use this style. There is some benefit IMO to this way of passing kwargs but thats outside the scope of this PR

compute_per_channel_metrics: bool = False,
) -> list[EnsembleTensorDict]:
input_data = data.get_start(self._prognostic_names, self.n_ic_timesteps)
# output from self.predict_paired does not include initial condition
Expand Down Expand Up @@ -1606,6 +1613,7 @@ def _accumulate_loss(
"data requirements are retrieved, so this is a bug."
)
n_forward_steps = stochastic_n_forward_steps
per_channel_sum: dict[str, torch.Tensor] | None = None
for step in range(n_forward_steps):
optimize_step = (
step == n_forward_steps - 1 or not self._config.optimize_last_step_only
Expand All @@ -1628,8 +1636,24 @@ def _accumulate_loss(
)
step_loss = self._loss_obj(gen_step, target_step, step=step)
metrics[f"loss_step_{step}"] = step_loss.detach()
if compute_per_channel_metrics:
per_channel = self._loss_obj.forward_per_channel(
gen_step, target_step, step=step
)
if per_channel_sum is None:
per_channel_sum = {
k: v.detach().clone() for k, v in per_channel.items()
}
else:
for k in per_channel_sum:
per_channel_sum[k] = (
per_channel_sum[k] + per_channel[k].detach()
)
Comment on lines +1639 to +1651
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than re-computing the per-channel metrics, I suggest updating the code so _loss_obj returns a 1-D vector over the channel dimension. Then at this level you can use the unpacker to turn that into a dict of per-channel losses. while still using step_loss.sum() as the optimized/accumulated loss. This should make it cheap enough that you can just always "compute" these per-channel metrics, instead of needing new boolean flags.

Additionally, right now you have low-level accumulation code in the middle of the optimize function. I suggest instead using a simple Aggregator implementation that takes in a dict of values and keeps a running mean of those values (maybe the existing batch metrics aggregator already handles this) to accumulate these channel losses. You could either attach them as an additional attribute on TrainOutput that then gets passed to this aggregator above this scope (probably better? at least more consistent with what we currently do), or you could take the aggregator as an Optional input argument and record them in this scope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure that when refactoring this, you maintain the sum-ness of loss across steps (perhaps make sure this is tested).

if optimize_step:
optimization.accumulate_loss(step_loss)
if compute_per_channel_metrics and per_channel_sum is not None:
for k, v in per_channel_sum.items():
metrics[f"loss/{k}"] = v.detach()
return output_list

def update_training_history(self, training_job: TrainingJob) -> None:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions fme/core/generics/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def train_on_batch(
batch: BDType,
optimization: OptimizationABC,
compute_derived_variables: bool = False,
compute_per_channel_metrics: bool = False,
) -> TrainOutput:
optimization.accumulate_loss(torch.tensor(float("inf")))
optimization.step_weights()
Expand Down
1 change: 1 addition & 0 deletions fme/core/generics/train_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def train_on_batch(
data: BD,
optimization: OptimizationABC,
compute_derived_variables: bool = False,
compute_per_channel_metrics: bool = False,
) -> TO:
pass

Expand Down
6 changes: 5 additions & 1 deletion fme/core/generics/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ def train_one_epoch(self):
stop_batch=self.config.train_evaluation_batches
):
with GlobalTimer():
stepped = self.stepper.train_on_batch(batch, self._no_optimization)
stepped = self.stepper.train_on_batch(
batch,
self._no_optimization,
compute_per_channel_metrics=True,
)
aggregator.record_batch(stepped)
if (
self._should_save_checkpoints()
Expand Down
60 changes: 60 additions & 0 deletions fme/core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,50 @@ def __call__(

return self.loss(predict_tensors, target_tensors)

def call_per_channel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.

self,
predict_dict: TensorMapping,
target_dict: TensorMapping,
) -> dict[str, torch.Tensor]:
"""
Compute loss per variable (per channel).

Returns:
Dict mapping each variable name to its scalar loss tensor.
"""
predict_tensors = self.packer.pack(
self.normalizer.normalize(predict_dict), axis=self.channel_dim
)
target_tensors = self.packer.pack(
self.normalizer.normalize(target_dict), axis=self.channel_dim
)
nan_mask = target_tensors.isnan()
if nan_mask.any():
predict_tensors = torch.where(nan_mask, 0.0, predict_tensors)
target_tensors = torch.where(nan_mask, 0.0, target_tensors)

result: dict[str, torch.Tensor] = {}
channel_dim = (
self.channel_dim
if self.channel_dim >= 0
else predict_tensors.dim() + self.channel_dim
)
# _weight_tensor is always 4D (1, n_channels, 1, 1) from
# _construct_weight_tensor
weight_channel_dim = 1
for i, name in enumerate(self.packer.names):
pred_slice = predict_tensors.select(channel_dim, i).unsqueeze(channel_dim)
target_slice = target_tensors.select(channel_dim, i).unsqueeze(channel_dim)
weight_slice = self._weight_tensor.select(weight_channel_dim, i)
# Broadcast weight to match pred_slice ndim
# (e.g. 5D when ensemble dim present)
while weight_slice.dim() < pred_slice.dim():
weight_slice = weight_slice.unsqueeze(-1)
result[name] = self.loss.loss(
weight_slice * pred_slice, weight_slice * target_slice
)
return result

def get_normalizer_state(self) -> dict[str, float]:
return self.normalizer.get_state()

Expand Down Expand Up @@ -470,6 +514,22 @@ def forward(
step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5)
return self.loss(predict_dict, target_dict) * step_weight

def forward_per_channel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of "forward" here? In any case, I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.

self,
predict_dict: TensorMapping,
target_dict: TensorMapping,
step: int,
) -> dict[str, torch.Tensor]:
"""
Compute per-variable (per-channel) loss with step weighting.

Returns:
Dict mapping each variable name to its scalar loss tensor.
"""
step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5)
per_channel = self.loss.call_per_channel(predict_dict, target_dict)
return {k: v * step_weight for k, v in per_channel.items()}


@dataclasses.dataclass
class StepLossConfig:
Expand Down
27 changes: 27 additions & 0 deletions fme/core/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,30 @@ def test_WeightedMappingLoss_with_target_nans():
x[:, 0, :, 0] = 0.0
y[:, 0, :, 0] = 0.0
assert torch.allclose(mapping_loss(x_mapping, y_mapping), loss(x, y))


def test_WeightedMappingLoss_call_per_channel():
"""Per-channel loss sums to total loss and has one entry per variable."""
loss_fn = torch.nn.MSELoss()
n_channels = 3
out_names = [f"var_{i}" for i in range(n_channels)]
normalizer = StandardNormalizer(
means={name: torch.as_tensor(0.0) for name in out_names},
stds={name: torch.as_tensor(1.0) for name in out_names},
)
mapping_loss = WeightedMappingLoss(
loss_fn,
weights={},
out_names=out_names,
normalizer=normalizer,
)
x = torch.randn(4, n_channels, 5, 5, device=get_device(), dtype=torch.float)
y = torch.randn(4, n_channels, 5, 5, device=get_device(), dtype=torch.float)
x_mapping = {name: x[:, i, :, :] for i, name in enumerate(out_names)}
y_mapping = {name: y[:, i, :, :] for i, name in enumerate(out_names)}
total = mapping_loss(x_mapping, y_mapping)
per_channel = mapping_loss.call_per_channel(x_mapping, y_mapping)
assert set(per_channel.keys()) == set(out_names)
# Mean of per-channel losses equals total (total is mean over all channels)
mean_per_channel = torch.stack(list(per_channel.values())).mean()
assert torch.allclose(total, mean_per_channel)
5 changes: 4 additions & 1 deletion fme/coupled/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,7 @@ def train_on_batch(
data: CoupledBatchData,
optimization: OptimizationABC,
compute_derived_variables: bool = False,
compute_per_channel_metrics: bool = False,
) -> CoupledTrainOutput:
"""
Args:
Expand All @@ -1508,7 +1509,9 @@ def train_on_batch(
Use `NullOptimization` to disable training.
compute_derived_variables: Whether to compute derived variables for the
prediction and target atmosphere data.

compute_per_channel_metrics: Whether to compute per-variable loss and add
to metrics (for TrainAggregator). Only set True when evaluating for
logging, not during training steps.
"""
# get initial condition prognostic variables
input_data = CoupledPrognosticState(
Expand Down
Loading
Loading