Skip to content
Open
4 changes: 3 additions & 1 deletion configs/baselines/era5/ace-train-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
experiment_dir: /results
save_checkpoint: true
validate_using_ema: true
max_epochs: 120
max_epochs: 2
n_forward_steps: 2
train_aggregator:
per_channel_loss: True
ema:
decay: 0.999
inference:
Expand Down
2 changes: 1 addition & 1 deletion configs/baselines/era5/run-ace-train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ run_training() {
-- torchrun --nproc_per_node $N_GPUS -m fme.ace.train $CONFIG_PATH
}

run_training "ace-train-config.yaml" "ace2-era5-train" "ace2-era5"
run_training "ace-train-config.yaml" "ace2-era5-train-test-per-var-loss" "ace2-era5"
90 changes: 90 additions & 0 deletions fme/ace/aggregator/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,93 @@ def test_aggregator_gets_logs_with_no_batches(config: TrainAggregatorConfig):
logs = agg.get_logs(label="test")
assert np.isnan(logs.pop("test/mean/loss"))
assert logs == {}


def test_aggregator_logs_per_channel_loss():
"""
Per-channel (per-variable) loss is accumulated and reported in get_logs.
"""
batch_size = 4
n_ensemble = 1
n_time = 2
nx, ny = 2, 2
device = get_device()
gridded_operations = LatLonOperations(
area_weights=torch.ones(nx, ny, device=device)
)
config = TrainAggregatorConfig(
spherical_power_spectrum=False, weighted_rmse=False, per_channel_loss=True
)
agg = TrainAggregator(config=config, operations=gridded_operations)
target_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)},
)
gen_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)},
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(1.0, device=device),
"loss/a": torch.tensor(0.5, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(2.0, device=device),
"loss/a": torch.tensor(1.0, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
logs = agg.get_logs(label="train")
assert logs["train/mean/loss"] == 1.5
assert logs["train/mean/loss/a"] == 0.75


def test_aggregator_per_channel_loss_disabled():
"""When per_channel_loss=False, get_logs does not include per-variable loss."""
batch_size = 4
n_ensemble = 1
n_time = 2
nx, ny = 2, 2
device = get_device()
gridded_operations = LatLonOperations(
area_weights=torch.ones(nx, ny, device=device)
)
config = TrainAggregatorConfig(
spherical_power_spectrum=False,
weighted_rmse=False,
per_channel_loss=False,
)
agg = TrainAggregator(config=config, operations=gridded_operations)
target_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, 1, n_time, nx, ny, device=device)},
)
gen_data = EnsembleTensorDict(
{"a": torch.randn(batch_size, n_ensemble, n_time, nx, ny, device=device)},
)
agg.record_batch(
batch=TrainOutput(
metrics={
"loss": torch.tensor(1.0, device=device),
"loss/a": torch.tensor(0.5, device=device),
},
target_data=target_data,
gen_data=gen_data,
time=xr.DataArray(np.zeros((batch_size, n_time)), dims=["sample", "time"]),
normalize=lambda x: x,
),
)
logs = agg.get_logs(label="train")
assert logs["train/mean/loss"] == 1.0
assert "train/mean/loss/a" not in logs
26 changes: 26 additions & 0 deletions fme/ace/aggregator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim
from fme.core.typing_ import TensorMapping

# Metric key prefix for per-variable loss (must match stepper's metrics["loss/<var>"]).
PER_CHANNEL_LOSS_PREFIX = "loss/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not ideal to have this coupling with the naming in the stepper metrics, but this already exists for the the other loss terms so I think it's okay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was not a fan of this at all either but Claude and I couldn't think of a good way. I guess the one thing I can do is make this an aggregator it self and decouple anything from the stepper. This would also help reduce the need to record it when we aren't using it during training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem becomes then getting the loss function to the aggregator which has it's own complications. I defer to you or Jeremy on whether its worth decoupling this from the stepper and just pass a loss_fn to an "PerChannelLossAggregator".

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, like I said it's a pre-existing issue so I don't think we should worry about decoupling in this PR. But open to other thoughts from @mcgibbon on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using a new attribute on the TrainOutput instead of a string label.



@dataclasses.dataclass
class TrainAggregatorConfig:
Expand All @@ -22,10 +25,13 @@ class TrainAggregatorConfig:
Attributes:
spherical_power_spectrum: Whether to compute the spherical power spectrum.
weighted_rmse: Whether to compute the weighted RMSE.
per_channel_loss: Whether to accumulate and report per-variable (per-channel)
loss in get_logs (e.g. train/mean/loss/<var_name>).
"""

spherical_power_spectrum: bool = True
weighted_rmse: bool = True
per_channel_loss: bool = False


class Aggregator(Protocol):
Expand All @@ -47,6 +53,11 @@ class TrainAggregator(AggregatorABC[TrainOutput]):
def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations):
self._n_loss_batches = 0
self._loss = torch.tensor(0.0, device=get_device())
if config.per_channel_loss:
self._per_channel_loss: dict[str, torch.Tensor] = {}
self._per_channel_loss_enabled = True
else:
self._per_channel_loss_enabled = False
self._paired_aggregators: dict[str, Aggregator] = {}
if config.spherical_power_spectrum:
self._paired_aggregators["power_spectrum"] = (
Expand All @@ -66,6 +77,16 @@ def __init__(self, config: TrainAggregatorConfig, operations: GriddedOperations)
def record_batch(self, batch: TrainOutput):
self._loss += batch.metrics["loss"]
self._n_loss_batches += 1
if self._per_channel_loss_enabled:
for key, value in batch.metrics.items():
if not key.startswith(PER_CHANNEL_LOSS_PREFIX):
continue
var_name = key.removeprefix(PER_CHANNEL_LOSS_PREFIX)
acc = self._per_channel_loss.get(
var_name,
torch.tensor(0.0, device=get_device(), dtype=value.dtype),
)
self._per_channel_loss[var_name] = acc + value

folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data)
folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble)
Expand Down Expand Up @@ -93,6 +114,11 @@ def get_logs(self, label: str) -> dict[str, torch.Tensor]:
logs[f"{label}/mean/loss"] = float(
dist.reduce_mean(self._loss / self._n_loss_batches).cpu().numpy()
)
if self._n_loss_batches > 0 and self._per_channel_loss_enabled:
for var_name, acc in self._per_channel_loss.items():
logs[f"{label}/mean/loss/{var_name}"] = float(
dist.reduce_mean(acc / self._n_loss_batches).cpu().numpy()
)
return logs

@torch.no_grad()
Expand Down
16 changes: 16 additions & 0 deletions fme/ace/stepper/single_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,7 @@ def _accumulate_loss(
"data requirements are retrieved, so this is a bug."
)
n_forward_steps = stochastic_n_forward_steps
per_channel_sum: dict[str, torch.Tensor] | None = None
for step in range(n_forward_steps):
optimize_step = (
step == n_forward_steps - 1 or not self._config.optimize_last_step_only
Expand All @@ -1628,8 +1629,23 @@ def _accumulate_loss(
)
step_loss = self._loss_obj(gen_step, target_step, step=step)
metrics[f"loss_step_{step}"] = step_loss.detach()
per_channel = self._loss_obj.forward_per_channel(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid computing this when it is not needed? I.e. it seems we only should do this calculation when computing metrics for the train/val aggregator, but not when actually training.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed its a little hacky but the newest commit has a way around this.

gen_step, target_step, step=step
)
if per_channel_sum is None:
per_channel_sum = {
k: v.detach().clone() for k, v in per_channel.items()
}
else:
for k in per_channel_sum:
per_channel_sum[k] = (
per_channel_sum[k] + per_channel[k].detach()
)
if optimize_step:
optimization.accumulate_loss(step_loss)
if per_channel_sum is not None and n_forward_steps > 0:
for k, v in per_channel_sum.items():
metrics[f"loss/{k}"] = (v / n_forward_steps).detach()
return output_list

def update_training_history(self, training_job: TrainingJob) -> None:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
60 changes: 60 additions & 0 deletions fme/core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,50 @@ def __call__(

return self.loss(predict_tensors, target_tensors)

def call_per_channel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.

self,
predict_dict: TensorMapping,
target_dict: TensorMapping,
) -> dict[str, torch.Tensor]:
"""
Compute loss per variable (per channel).

Returns:
Dict mapping each variable name to its scalar loss tensor.
"""
predict_tensors = self.packer.pack(
self.normalizer.normalize(predict_dict), axis=self.channel_dim
)
target_tensors = self.packer.pack(
self.normalizer.normalize(target_dict), axis=self.channel_dim
)
nan_mask = target_tensors.isnan()
if nan_mask.any():
predict_tensors = torch.where(nan_mask, 0.0, predict_tensors)
target_tensors = torch.where(nan_mask, 0.0, target_tensors)

result: dict[str, torch.Tensor] = {}
channel_dim = (
self.channel_dim
if self.channel_dim >= 0
else predict_tensors.dim() + self.channel_dim
)
# _weight_tensor is always 4D (1, n_channels, 1, 1) from
# _construct_weight_tensor
weight_channel_dim = 1
for i, name in enumerate(self.packer.names):
pred_slice = predict_tensors.select(channel_dim, i).unsqueeze(channel_dim)
target_slice = target_tensors.select(channel_dim, i).unsqueeze(channel_dim)
weight_slice = self._weight_tensor.select(weight_channel_dim, i)
# Broadcast weight to match pred_slice ndim
# (e.g. 5D when ensemble dim present)
while weight_slice.dim() < pred_slice.dim():
weight_slice = weight_slice.unsqueeze(-1)
result[name] = self.loss.loss(
weight_slice * pred_slice, weight_slice * target_slice
)
return result

def get_normalizer_state(self) -> dict[str, float]:
return self.normalizer.get_state()

Expand Down Expand Up @@ -470,6 +514,22 @@ def forward(
step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5)
return self.loss(predict_dict, target_dict) * step_weight

def forward_per_channel(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of "forward" here? In any case, I think this function can be avoided if loss_obj were to return a 1D vector of per-channel losses.

self,
predict_dict: TensorMapping,
target_dict: TensorMapping,
step: int,
) -> dict[str, torch.Tensor]:
"""
Compute per-variable (per-channel) loss with step weighting.

Returns:
Dict mapping each variable name to its scalar loss tensor.
"""
step_weight = (1.0 + self.sqrt_loss_decay_constant * step) ** (-0.5)
per_channel = self.loss.call_per_channel(predict_dict, target_dict)
return {k: v * step_weight for k, v in per_channel.items()}


@dataclasses.dataclass
class StepLossConfig:
Expand Down
27 changes: 27 additions & 0 deletions fme/core/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,30 @@ def test_WeightedMappingLoss_with_target_nans():
x[:, 0, :, 0] = 0.0
y[:, 0, :, 0] = 0.0
assert torch.allclose(mapping_loss(x_mapping, y_mapping), loss(x, y))


def test_WeightedMappingLoss_call_per_channel():
"""Per-channel loss sums to total loss and has one entry per variable."""
loss_fn = torch.nn.MSELoss()
n_channels = 3
out_names = [f"var_{i}" for i in range(n_channels)]
normalizer = StandardNormalizer(
means={name: torch.as_tensor(0.0) for name in out_names},
stds={name: torch.as_tensor(1.0) for name in out_names},
)
mapping_loss = WeightedMappingLoss(
loss_fn,
weights={},
out_names=out_names,
normalizer=normalizer,
)
x = torch.randn(4, n_channels, 5, 5, device=get_device(), dtype=torch.float)
y = torch.randn(4, n_channels, 5, 5, device=get_device(), dtype=torch.float)
x_mapping = {name: x[:, i, :, :] for i, name in enumerate(out_names)}
y_mapping = {name: y[:, i, :, :] for i, name in enumerate(out_names)}
total = mapping_loss(x_mapping, y_mapping)
per_channel = mapping_loss.call_per_channel(x_mapping, y_mapping)
assert set(per_channel.keys()) == set(out_names)
# Mean of per-channel losses equals total (total is mean over all channels)
mean_per_channel = torch.stack(list(per_channel.values())).mean()
assert torch.allclose(total, mean_per_channel)
Loading