Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Log evaluation images to `save_dir` and create folder if needed. [\#2](https://github.com/dmidk/neural-lam/pull/2)

- Log learning rate and mean RMSE of all features during training, validation and testing. [\#5](https://github.com/dmidk/neural-lam/pull/5) @mafdmi


*Below follows changelog from the main neural-lam repository:*

Expand Down
69 changes: 54 additions & 15 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def unroll_prediction(self, init_states, forcing_features, true_states):

def common_step(self, batch):
"""
Predict on single batch batch consists of: init_states: (B, 2,
Predict on single batch consists of: init_states: (B, 2,
num_grid_nodes, d_features) target_states: (B, pred_steps,
num_grid_nodes, d_features) forcing_features: (B, pred_steps,
num_grid_nodes, d_forcing),
Expand All @@ -283,13 +283,24 @@ def common_step(self, batch):
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)

return prediction, target_states, pred_std, batch_times
entry_mses = metrics.mse(
prediction,
target_states,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)

return prediction, target_states, pred_std, batch_times, entry_mses

def training_step(self, batch):
"""
Train on single batch
"""
prediction, target, pred_std, _ = self.common_step(batch)
prediction, target, pred_std, _, entry_mses = self.common_step(batch)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the common step should have a more descriptive name - it should reflect it's function rather than the fact that it is being shared.
Also it has two responsibilities which is prediction with the model and processsing of the prediction - could maybe be factored into separate steps?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, maybe somebody else can contribute here, since I don't know enough to properly phrase it.

Copy link
Copy Markdown

@matschreiner matschreiner Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that this function is actually unpacking the batch and performing the unroll prediction step. Then it returns some elements of the unpacked batch and the calculation, then calculate metrics, which is an unclear responsibility that is hard to name and write a good docstring for and that is what happened. :D A cleaner implementation I would say is to just unpack the batch in the steps and perform the prediction.

So basically remove the common_step function entirely and replace it by

(init_states, target_states, forcing_features, batch_times) = batch
prediction, pred_std = self.unroll_prediction(init_states, forcing_features, target_states)  
entry_mses = .... 


# Compute mean RMSE for first prediction step
mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0)

# Compute loss
batch_loss = torch.mean(
Expand All @@ -298,9 +309,18 @@ def training_step(self, batch):
)
) # mean over unrolled times and batch

log_dict = {"train_loss": batch_loss}
# Logging
train_log_dict = {"train_loss": batch_loss}
state_var_names = self._datastore.get_vars_names(category="state")
train_log_dict |= {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is hard to read. Maybe wrap it in a function with a descriptive name? I would imagine that the train_log_dict should be defined in one line, something like

train_log_dict = {"train_loss": batch_loss, "lr": ..., **rmse_dict}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it is more readable now, but I made it into one dict. I haven't put it into its own function, since I thought it is part of the training_step or validation_step to log. But we can do that if you think it'd be better.

f"train_rmse_{v}": mean_rmse_ar_step_1[i]
for (i, v) in enumerate(state_var_names)
}
train_log_dict["train_lr"] = self.trainer.optimizers[0].param_groups[0][
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't learning rate only related to training since it's called "train_lr"?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean.. actually I'm unsure if we have learning rate for test and validation steps? If not, yes, then let's simply call it "lr"

"lr"
]
self.log_dict(
log_dict,
train_log_dict,
prog_bar=True,
on_step=True,
on_epoch=True,
Expand All @@ -326,7 +346,10 @@ def validation_step(self, batch, batch_idx):
"""
Run validation on single batch
"""
prediction, target, pred_std, _ = self.common_step(batch)
prediction, target, pred_std, _, entry_mses = self.common_step(batch)

# Compute mean RMSE for first prediction step
mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0)

time_step_loss = torch.mean(
self.loss(
Expand All @@ -343,6 +366,15 @@ def validation_step(self, batch, batch_idx):
if step <= len(time_step_loss)
}
val_log_dict["val_mean_loss"] = mean_loss
# Log mean RMSE for first prediction step and learning rate
state_var_names = self._datastore.get_vars_names(category="state")
val_log_dict |= {
f"val_rmse_{v}": mean_rmse_ar_step_1[i]
for (i, v) in enumerate(state_var_names)
}
val_log_dict["val_lr"] = self.trainer.optimizers[0].param_groups[0][
"lr"
]
self.log_dict(
val_log_dict,
on_step=False,
Expand All @@ -352,13 +384,6 @@ def validation_step(self, batch, batch_idx):
)

# Store MSEs
entry_mses = metrics.mse(
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
sum_vars=False,
) # (B, pred_steps, d_f)
self.val_metrics["mse"].append(entry_mses)

def on_validation_epoch_end(self):
Expand All @@ -378,10 +403,13 @@ def test_step(self, batch, batch_idx):
Run test on single batch
"""
# TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
prediction, target, pred_std, _, entry_mses = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)

# Compute mean RMSE for first prediction step
mean_rmse_ar_step_1 = torch.mean(torch.sqrt(entry_mses[:, 0, :]), dim=0)

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
Expand All @@ -396,6 +424,15 @@ def test_step(self, batch, batch_idx):
for step in self.args.val_steps_to_log
}
test_log_dict["test_mean_loss"] = mean_loss
# Log mean RMSE for first prediction step and learning rate
state_var_names = self._datastore.get_vars_names(category="state")
test_log_dict |= {
f"test_rmse_{v}": mean_rmse_ar_step_1[i]
for (i, v) in enumerate(state_var_names)
}
test_log_dict["test_lr"] = self.trainer.optimizers[0].param_groups[0][
"lr"
]

self.log_dict(
test_log_dict,
Expand All @@ -405,10 +442,12 @@ def test_step(self, batch, batch_idx):
batch_size=batch[0].shape[0],
)

# Store already computed MSEs
self.test_metrics["mse"].append(entry_mses)
# Compute all evaluation metrics for error maps Note: explicitly list
# metrics here, as test_metrics can contain additional ones, computed
# differently, but that should be aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
for metric_name in ("mae",):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
prediction,
Expand Down
5 changes: 4 additions & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def main(input_args=None):
"--logger_run_name",
type=str,
default=None,
help="Logger run name, for e.g. MLFlow (with default value `None` neural-lam default format string is used)",
help=(
"Logger run name, for e.g. MLFlow (with default value `None`"
" neural-lam default format string is used)"
),
)
parser.add_argument(
"--val_steps_to_log",
Expand Down
5 changes: 5 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
from pathlib import Path
from unittest.mock import MagicMock

# Third-party
import numpy as np
Expand Down Expand Up @@ -213,6 +214,10 @@ def _create_graph():
dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2)

model = GraphLAM(args=args, datastore=datastore, config=config) # noqa
model.trainer = MagicMock()
optimizer_mock = MagicMock()
optimizer_mock.param_groups = [{"lr": 0.001}]
model.trainer.optimizers = [optimizer_mock]

model_device = model.to(device_name)
data_loader = DataLoader(dataset, batch_size=2)
Expand Down