@@ -70,18 +70,60 @@ def mse_ens(target, ens, mu, stddev):
7070 return torch .stack ([mse_loss (target , mem ) for mem in ens ], 0 ).mean ()
7171
7272
73- def kernel_crps (target , ens , mu , stddev , fair = True ):
74- ens_size = ens .shape [0 ]
75- mae = torch .stack ([(target - mem ).abs ().mean () for mem in ens ], 0 ).mean ()
73+ def kernel_crps (
74+ targets ,
75+ preds ,
76+ weights_channels : torch .Tensor | None ,
77+ weights_points : torch .Tensor | None ,
78+ fair = True ,
79+ ):
80+ """
81+ Compute kernel CRPS
82+
83+ Params:
84+ target : shape ( num_data_points , num_channels )
85+ pred : shape ( ens_dim , num_data_points , num_channels)
86+ weights_channels : shape = (num_channels,)
87+ weights_points : shape = (num_data_points)
88+
89+ Returns:
90+ loss: scalar - overall weighted CRPS
91+ loss_chs: [C] - per-channel CRPS (location-weighted, not channel-weighted)
92+ """
7693
77- if ens_size == 1 :
78- return mae
94+ ens_size = preds .shape [0 ]
95+ assert ens_size > 1 , "Ensemble size has to be greater than 1 for kernel CRPS."
96+ assert len (preds .shape ) == 3 , "if data has batch dimension, remove unsqueeze() below"
7997
80- coef = - 1.0 / (2.0 * ens_size * (ens_size - 1 )) if fair else - 1.0 / (2.0 * ens_size ** 2 )
81- ens_var = coef * torch .tensor ([(p1 - p2 ).abs ().sum () for p1 in ens for p2 in ens ]).sum ()
82- ens_var /= ens .shape [1 ]
98+ # replace NaN by 0
99+ mask_nan = ~ torch .isnan (targets )
100+ targets = torch .where (mask_nan , targets , 0 )
101+ preds = torch .where (mask_nan , preds , 0 )
102+
103+ # permute to enable/simply broadcasting and contractions below
104+ preds = preds .permute ([2 , 1 , 0 ]).unsqueeze (0 ).to (torch .float32 )
105+ targets = targets .permute ([1 , 0 ]).unsqueeze (0 ).to (torch .float32 )
106+
107+ mae = torch .mean (torch .abs (targets [..., None ] - preds ), dim = - 1 )
108+
109+ ens_n = - 1.0 / (ens_size * (ens_size - 1 )) if fair else - 1.0 / (ens_size ** 2 )
110+ abs = torch .abs
111+ ens_var = torch .zeros (size = preds .shape [:- 1 ], device = preds .device )
112+ # loop to reduce memory usage
113+ for i in range (ens_size ):
114+ ens_var += torch .sum (ens_n * abs (preds [..., i ].unsqueeze (- 1 ) - preds [..., i + 1 :]), dim = - 1 )
115+
116+ kcrps_locs_chs = mae + ens_var
117+
118+ # apply point weighting
119+ if weights_points is not None :
120+ kcrps_locs_chs = kcrps_locs_chs * weights_points
121+ # apply channel weighting
122+ kcrps_chs = torch .mean (torch .mean (kcrps_locs_chs , 0 ), - 1 )
123+ if weights_channels is not None :
124+ kcrps_chs = kcrps_chs * weights_channels
83125
84- return mae + ens_var
126+ return torch . mean ( kcrps_chs ), kcrps_chs
85127
86128
87129def mse_channel_location_weighted (
0 commit comments