-
Notifications
You must be signed in to change notification settings - Fork 0
Moved calculation of mse to common_step and added logging of mean rmse #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev/gefion
Are you sure you want to change the base?
Changes from 5 commits
ed4b729
7e50f49
49b02cb
09ba710
0b20c02
61d04e7
be469ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -269,7 +269,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states): | |
|
|
||
| def common_step(self, batch): | ||
| """ | ||
| Predict on single batch batch consists of: init_states: (B, 2, | ||
| Predict on single batch consists of: init_states: (B, 2, | ||
| num_grid_nodes, d_features) target_states: (B, pred_steps, | ||
| num_grid_nodes, d_features) forcing_features: (B, pred_steps, | ||
| num_grid_nodes, d_forcing), | ||
|
|
@@ -283,13 +283,24 @@ def common_step(self, batch): | |
| # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, | ||
| # pred_steps, num_grid_nodes, d_f) or (d_f,) | ||
|
|
||
| return prediction, target_states, pred_std, batch_times | ||
| entry_mses = metrics.mse( | ||
| prediction, | ||
| target_states, | ||
| pred_std, | ||
| mask=self.interior_mask_bool, | ||
| sum_vars=False, | ||
| ) # (B, pred_steps, d_f) | ||
|
|
||
| return prediction, target_states, pred_std, batch_times, entry_mses | ||
|
|
||
| def training_step(self, batch): | ||
| """ | ||
| Train on single batch | ||
| """ | ||
| prediction, target, pred_std, _ = self.common_step(batch) | ||
| prediction, target, pred_std, _, entry_mses = self.common_step(batch) | ||
|
||
|
|
||
| # Compute mean RMSE for first prediction step | ||
| mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0) | ||
|
|
||
| # Compute loss | ||
| batch_loss = torch.mean( | ||
|
|
@@ -298,9 +309,18 @@ def training_step(self, batch): | |
| ) | ||
| ) # mean over unrolled times and batch | ||
|
|
||
| log_dict = {"train_loss": batch_loss} | ||
| # Logging | ||
| train_log_dict = {"train_loss": batch_loss} | ||
| state_var_names = self._datastore.get_vars_names(category="state") | ||
| train_log_dict |= { | ||
|
||
| f"train_rmse_{v}": mean_rmse_ar_step_1[i] | ||
| for (i, v) in enumerate(state_var_names) | ||
| } | ||
| train_log_dict["train_lr"] = self.trainer.optimizers[0].param_groups[0][ | ||
|
||
| "lr" | ||
| ] | ||
| self.log_dict( | ||
| log_dict, | ||
| train_log_dict, | ||
| prog_bar=True, | ||
| on_step=True, | ||
| on_epoch=True, | ||
|
|
@@ -326,7 +346,10 @@ def validation_step(self, batch, batch_idx): | |
| """ | ||
| Run validation on single batch | ||
| """ | ||
| prediction, target, pred_std, _ = self.common_step(batch) | ||
| prediction, target, pred_std, _, entry_mses = self.common_step(batch) | ||
|
|
||
| # Compute mean RMSE for first prediction step | ||
| mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0) | ||
|
|
||
| time_step_loss = torch.mean( | ||
| self.loss( | ||
|
|
@@ -343,6 +366,15 @@ def validation_step(self, batch, batch_idx): | |
| if step <= len(time_step_loss) | ||
| } | ||
| val_log_dict["val_mean_loss"] = mean_loss | ||
mfroelund marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Log mean RMSE for first prediction step and learning rate | ||
| state_var_names = self._datastore.get_vars_names(category="state") | ||
| val_log_dict |= { | ||
| f"val_rmse_{v}": mean_rmse_ar_step_1[i] | ||
| for (i, v) in enumerate(state_var_names) | ||
| } | ||
| val_log_dict["val_lr"] = self.trainer.optimizers[0].param_groups[0][ | ||
| "lr" | ||
| ] | ||
| self.log_dict( | ||
| val_log_dict, | ||
| on_step=False, | ||
|
|
@@ -352,13 +384,6 @@ def validation_step(self, batch, batch_idx): | |
| ) | ||
|
|
||
| # Store MSEs | ||
mfroelund marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| entry_mses = metrics.mse( | ||
| prediction, | ||
| target, | ||
| pred_std, | ||
| mask=self.interior_mask_bool, | ||
| sum_vars=False, | ||
| ) # (B, pred_steps, d_f) | ||
| self.val_metrics["mse"].append(entry_mses) | ||
|
|
||
| def on_validation_epoch_end(self): | ||
|
|
@@ -378,10 +403,13 @@ def test_step(self, batch, batch_idx): | |
| Run test on single batch | ||
| """ | ||
| # TODO Here batch_times can be used for plotting routines | ||
| prediction, target, pred_std, batch_times = self.common_step(batch) | ||
| prediction, target, pred_std, _, entry_mses = self.common_step(batch) | ||
| # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, | ||
| # pred_steps, num_grid_nodes, d_f) or (d_f,) | ||
|
|
||
| # Compute mean RMSE for first prediction step | ||
| mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0) | ||
|
|
||
| time_step_loss = torch.mean( | ||
| self.loss( | ||
| prediction, target, pred_std, mask=self.interior_mask_bool | ||
|
|
@@ -396,6 +424,15 @@ def test_step(self, batch, batch_idx): | |
| for step in self.args.val_steps_to_log | ||
| } | ||
| test_log_dict["test_mean_loss"] = mean_loss | ||
| # Log mean RMSE for first prediction step and learning rate | ||
| state_var_names = self._datastore.get_vars_names(category="state") | ||
| test_log_dict |= { | ||
| f"test_rmse_{v}": mean_rmse_ar_step_1[i] | ||
| for (i, v) in enumerate(state_var_names) | ||
| } | ||
| test_log_dict["test_lr"] = self.trainer.optimizers[0].param_groups[0][ | ||
| "lr" | ||
| ] | ||
|
|
||
| self.log_dict( | ||
| test_log_dict, | ||
|
|
@@ -405,10 +442,12 @@ def test_step(self, batch, batch_idx): | |
| batch_size=batch[0].shape[0], | ||
| ) | ||
|
|
||
| # Store already computed MSEs | ||
mfroelund marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.test_metrics["mse"].append(entry_mses) | ||
| # Compute all evaluation metrics for error maps Note: explicitly list | ||
| # metrics here, as test_metrics can contain additional ones, computed | ||
| # differently, but that should be aggregated on_test_epoch_end | ||
| for metric_name in ("mse", "mae"): | ||
| for metric_name in ("mae",): | ||
| metric_func = metrics.get_metric(metric_name) | ||
| batch_metric_vals = metric_func( | ||
| prediction, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.