diff --git a/CHANGELOG.md b/CHANGELOG.md index fe66d2136..8de738230 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed unscaled `pred_std` when `output_std=True` in `BaseGraphModel` which caused exponential loss blowup during probabilistic model training [\#347](https://github.com/mllam/neural-lam/issues/347) - Initialize `da_forcing_mean` and `da_forcing_std` to `None` when forcing data is absent, fixing `AttributeError` in `WeatherDataset` with `standardize=True` [\#369](https://github.com/mllam/neural-lam/issues/369) @Sir-Sloth-The-Lazy - Ensure proper sorting of `analysis_time` in `NpyFilesDatastoreMEPS._get_analysis_times` independent of the order in which files are processed with glob [\#386](https://github.com/mllam/neural-lam/pull/386) @Gopisokk diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index fd38a2e67..6261d494f 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -348,10 +348,10 @@ def predict_step(self, prev_state, prev_prev_state, forcing): pred_delta_mean, pred_std_raw = net_output.chunk( 2, dim=-1 ) # both (B, num_grid_nodes, d_f) - # NOTE: The predicted std. is not scaled in any way here + # Scale predicted std. with one-step difference std. # linter for some reason does not think softplus is callable # pylint: disable-next=not-callable - pred_std = torch.nn.functional.softplus(pred_std_raw) + pred_std = torch.nn.functional.softplus(pred_std_raw) * self.diff_std else: pred_delta_mean = net_output pred_std = None