Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Log evaluation images to `save_dir` and create folder if needed. [\#2](https://github.com/dmidk/neural-lam/pull/2)

- Log learning rate and mean RMSE of all features during training, validation and testing. [\#5](https://github.com/dmidk/neural-lam/pull/5) @mafdmi


*Below follows changelog from the main neural-lam repository:*

Expand Down
162 changes: 114 additions & 48 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B,
pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps,
num_grid_nodes, d_f)

Returns
-------
torch.Tensor(B, pred_steps, num_grid_nodes, d_f):
The prediction
torch.Tensor(B, pred_steps, num_grid_nodes, d_f) or torch.Tensor(d_f,)
The prediction standard deviation or per-variable standard deviation
"""
prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
Expand Down Expand Up @@ -267,40 +274,48 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

return prediction, pred_std

def common_step(self, batch):
"""
Predict on single batch batch consists of: init_states: (B, 2,
num_grid_nodes, d_features) target_states: (B, pred_steps,
num_grid_nodes, d_features) forcing_features: (B, pred_steps,
num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
"""
(init_states, target_states, forcing_features, batch_times) = batch
def training_step(self, batch):
"""Train on single batch"""
init_states, target_states, forcing_features, _ = batch

prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
) # (B, pred_steps, num_grid_nodes, d_f)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
)

return prediction, target_states, pred_std, batch_times
entry_mses = metrics.mse(
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)

def training_step(self, batch):
"""
Train on single batch
"""
prediction, target, pred_std, _ = self.common_step(batch)
# Compute mean RMSE for first prediction step
mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0)

# Compute loss
batch_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
)
) # mean over unrolled times and batch

log_dict = {"train_loss": batch_loss}
# Logging
train_log_dict = {
"train_loss": batch_loss,
**{
f"train_rmse_{v}": mean_rmse_ar_step_1[i]
for (i, v) in enumerate(
self._datastore.get_vars_names(category="state")
)
},
"train_lr": self.trainer.optimizers[0].param_groups[0]["lr"],
}
self.log_dict(
log_dict,
train_log_dict,
prog_bar=True,
on_step=True,
on_epoch=True,
Expand All @@ -326,23 +341,51 @@ def validation_step(self, batch, batch_idx):
"""
Run validation on single batch
"""
prediction, target, pred_std, _ = self.common_step(batch)
init_states, target_states, forcing_features, _ = batch

prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
)

entry_mses = metrics.mse(
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)

# Compute mean RMSE for first prediction step
mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0)

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
),
dim=0,
) # (time_steps-1)
mean_loss = torch.mean(time_step_loss)

# Log loss per time step forward and mean
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_to_log
if step <= len(time_step_loss)
# Log loss per time step forward and mean
**{
f"val_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_to_log
if step <= len(time_step_loss)
},
"val_mean_loss": mean_loss,
# Log mean RMSE for first prediction step and learning rate
**{
f"val_rmse_{v}": mean_rmse_ar_step_1[i]
for (i, v) in enumerate(
self._datastore.get_vars_names(category="state")
)
},
"val_lr": self.trainer.optimizers[0].param_groups[0]["lr"],
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
val_log_dict,
on_step=False,
Expand All @@ -351,14 +394,6 @@ def validation_step(self, batch, batch_idx):
batch_size=batch[0].shape[0],
)

# Store MSEs
entry_mses = metrics.mse(
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.val_metrics["mse"].append(entry_mses)

def on_validation_epoch_end(self):
Expand All @@ -378,24 +413,50 @@ def test_step(self, batch, batch_idx):
Run test on single batch
"""
# TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
init_states, target_states, forcing_features, _ = batch

prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
)

entry_mses = metrics.mse(
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)

# Compute mean RMSE for first prediction step
mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0)

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
),
dim=0,
) # (time_steps-1,)
mean_loss = torch.mean(time_step_loss)

# Log loss per time step forward and mean
test_log_dict = {
f"test_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_to_log
# Log loss per time step forward and mean
**{
f"test_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_to_log
},
"test_mean_loss": mean_loss,
# Log mean RMSE for first prediction step and learning rate
**{
f"test_rmse_{v}": mean_rmse_ar_step_1[i]
for (i, v) in enumerate(
self._datastore.get_vars_names(category="state")
)
},
"test_lr": self.trainer.optimizers[0].param_groups[0]["lr"],
}
test_log_dict["test_mean_loss"] = mean_loss

self.log_dict(
test_log_dict,
Expand All @@ -405,14 +466,15 @@ def test_step(self, batch, batch_idx):
batch_size=batch[0].shape[0],
)

self.test_metrics["mse"].append(entry_mses)
# Compute all evaluation metrics for error maps Note: explicitly list
# metrics here, as test_metrics can contain additional ones, computed
# differently, but that should be aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
for metric_name in ("mae",):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
prediction,
target,
target_states,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
Expand All @@ -428,7 +490,7 @@ def test_step(self, batch, batch_idx):

# Save per-sample spatial loss for specific times
spatial_loss = self.loss(
prediction, target, pred_std, average_grid=False
prediction, target_states, pred_std, average_grid=False
) # (B, pred_steps, num_grid_nodes)
log_spatial_losses = spatial_loss[
:, [step - 1 for step in self.args.val_steps_to_log]
Expand Down Expand Up @@ -464,7 +526,11 @@ def plot_examples(self, batch, n_examples, split, prediction=None):
Generate if None.
"""
if prediction is None:
prediction, target, _, _ = self.common_step(batch)
(init_states, target, forcing_features, _) = batch

prediction, _ = self.unroll_prediction(
init_states, forcing_features, target
)

target = batch[1]
time = batch[3]
Expand Down
5 changes: 4 additions & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def main(input_args=None):
"--logger_run_name",
type=str,
default=None,
help="Logger run name, for e.g. MLFlow (with default value `None` neural-lam default format string is used)",
help=(
"Logger run name, for e.g. MLFlow (with default value `None`"
" neural-lam default format string is used)"
),
)
parser.add_argument(
"--val_steps_to_log",
Expand Down
5 changes: 5 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
from pathlib import Path
from unittest.mock import MagicMock

# Third-party
import numpy as np
Expand Down Expand Up @@ -213,6 +214,10 @@ def _create_graph():
dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2)

model = GraphLAM(args=args, datastore=datastore, config=config) # noqa
model.trainer = MagicMock()
optimizer_mock = MagicMock()
optimizer_mock.param_groups = [{"lr": 0.001}]
model.trainer.optimizers = [optimizer_mock]

model_device = model.to(device_name)
data_loader = DataLoader(dataset, batch_size=2)
Expand Down