Skip to content

Commit 5586d79

Browse files
authored
Clessig/develop/fix kcrps 1077 (#1078)
* Improved robustness for loss fcts where ch loss does not make sense * Re-enabled kernel CRPS and added weighting options * Fixes * Improved tensor reordering
1 parent aae0b8a commit 5586d79

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

src/weathergen/train/loss.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,60 @@ def mse_ens(target, ens, mu, stddev):
7070
return torch.stack([mse_loss(target, mem) for mem in ens], 0).mean()
7171

7272

73-
def kernel_crps(target, ens, mu, stddev, fair=True):
74-
ens_size = ens.shape[0]
75-
mae = torch.stack([(target - mem).abs().mean() for mem in ens], 0).mean()
73+
def kernel_crps(
74+
targets,
75+
preds,
76+
weights_channels: torch.Tensor | None,
77+
weights_points: torch.Tensor | None,
78+
fair=True,
79+
):
80+
"""
81+
Compute kernel CRPS
82+
83+
Params:
84+
target : shape ( num_data_points , num_channels )
85+
pred : shape ( ens_dim , num_data_points , num_channels)
86+
weights_channels : shape = (num_channels,)
87+
weights_points : shape = (num_data_points)
88+
89+
Returns:
90+
loss: scalar - overall weighted CRPS
91+
loss_chs: [C] - per-channel CRPS (location-weighted, not channel-weighted)
92+
"""
7693

77-
if ens_size == 1:
78-
return mae
94+
ens_size = preds.shape[0]
95+
assert ens_size > 1, "Ensemble size has to be greater than 1 for kernel CRPS."
96+
assert len(preds.shape) == 3, "if data has batch dimension, remove unsqueeze() below"
7997

80-
coef = -1.0 / (2.0 * ens_size * (ens_size - 1)) if fair else -1.0 / (2.0 * ens_size**2)
81-
ens_var = coef * torch.tensor([(p1 - p2).abs().sum() for p1 in ens for p2 in ens]).sum()
82-
ens_var /= ens.shape[1]
98+
# replace NaN by 0
99+
mask_nan = ~torch.isnan(targets)
100+
targets = torch.where(mask_nan, targets, 0)
101+
preds = torch.where(mask_nan, preds, 0)
102+
103+
# permute to enable/simply broadcasting and contractions below
104+
preds = preds.permute([2, 1, 0]).unsqueeze(0).to(torch.float32)
105+
targets = targets.permute([1, 0]).unsqueeze(0).to(torch.float32)
106+
107+
mae = torch.mean(torch.abs(targets[..., None] - preds), dim=-1)
108+
109+
ens_n = -1.0 / (ens_size * (ens_size - 1)) if fair else -1.0 / (ens_size**2)
110+
abs = torch.abs
111+
ens_var = torch.zeros(size=preds.shape[:-1], device=preds.device)
112+
# loop to reduce memory usage
113+
for i in range(ens_size):
114+
ens_var += torch.sum(ens_n * abs(preds[..., i].unsqueeze(-1) - preds[..., i + 1 :]), dim=-1)
115+
116+
kcrps_locs_chs = mae + ens_var
117+
118+
# apply point weighting
119+
if weights_points is not None:
120+
kcrps_locs_chs = kcrps_locs_chs * weights_points
121+
# apply channel weighting
122+
kcrps_chs = torch.mean(torch.mean(kcrps_locs_chs, 0), -1)
123+
if weights_channels is not None:
124+
kcrps_chs = kcrps_chs * weights_channels
83125

84-
return mae + ens_var
126+
return torch.mean(kcrps_chs), kcrps_chs
85127

86128

87129
def mse_channel_location_weighted(

src/weathergen/train/loss_calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _loss_per_loss_function(
167167

168168
# accumulate loss
169169
loss_lfct = loss_lfct + loss
170-
losses_chs += loss_chs.detach()
170+
losses_chs = losses_chs + loss_chs.detach() if len(loss_chs) > 0 else losses_chs
171171
ctr_substeps += 1 if loss > 0.0 else 0
172172

173173
# normalize over forecast steps in window

0 commit comments

Comments
 (0)