Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

def common_step(self, batch):
"""
Predict on single batch batch consists of: init_states: (B, 2,
Predict on single batch consists of: init_states: (B, 2,
num_grid_nodes, d_features) target_states: (B, pred_steps,
num_grid_nodes, d_features) forcing_features: (B, pred_steps,
num_grid_nodes, d_forcing),
Expand All @@ -283,13 +283,31 @@ def common_step(self, batch):
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)

return prediction, target_states, pred_std, batch_times
# Calculate MSEs
entry_mses = metrics.mse(
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)

# Log mean RMSE for first prediction step
mean_rmse_ar_step_1 = np.mean(
np.sqrt(entry_mses[:, 0, :]), axis=0
) # take mean across all samples in batch but only for first pred step
state_var_names = self._datastore.get_vars_names(category="state")
self.log_dict(
{v: mean_rmse_ar_step_1[i] for (i, v) in enumerate(state_var_names)}
)

return prediction, target_states, pred_std, batch_times, entry_mses

def training_step(self, batch):
"""
Train on single batch
"""
prediction, target, pred_std, _ = self.common_step(batch)
prediction, target, pred_std, _, _ = self.common_step(batch)

# Compute loss
batch_loss = torch.mean(
Expand Down Expand Up @@ -326,7 +344,7 @@ def validation_step(self, batch, batch_idx):
"""
Run validation on single batch
"""
prediction, target, pred_std, _ = self.common_step(batch)
prediction, target, pred_std, _, entry_mses = self.common_step(batch)

time_step_loss = torch.mean(
self.loss(
Expand All @@ -352,13 +370,6 @@ def validation_step(self, batch, batch_idx):
)

# Store MSEs
entry_mses = metrics.mse(
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.val_metrics["mse"].append(entry_mses)

def on_validation_epoch_end(self):
Expand All @@ -378,7 +389,7 @@ def test_step(self, batch, batch_idx):
Run test on single batch
"""
# TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
prediction, target, pred_std, _, entry_mses = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)

Expand All @@ -405,10 +416,12 @@ def test_step(self, batch, batch_idx):
batch_size=batch[0].shape[0],
)

# Store already computed MSEs
self.test_metrics["mse"].append(entry_mses)
# Compute all evaluation metrics for error maps Note: explicitly list
# metrics here, as test_metrics can contain additional ones, computed
# differently, but that should be aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
for metric_name in ("mae",):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
prediction,
Expand Down