diff --git a/CHANGELOG.md b/CHANGELOG.md index fe66d2136..c3a20e98c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add generalized Log-Spectral Distance (LSD) metric for spatial frequency analysis. + Supports both regular grids (via FFT) and irregular grids (via Graph Signal Processing). [\#508](https://github.com/mllam/neural-lam/pull/508) @sohampatil01-svg + - Add `AGENTS.md` file to the repo to give agents more information about the codebase and the contribution culture.[\#416](https://github.com/mllam/neural-lam/pull/416) @sadamov - Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py index 7db2cca6d..d70eb355d 100644 --- a/neural_lam/metrics.py +++ b/neural_lam/metrics.py @@ -1,5 +1,6 @@ # Third-party import torch +import torch.fft def get_metric(metric_name): @@ -53,7 +54,9 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars): return metric_entry_vals -def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): +def wmse( + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs +): """ Weighted Mean Squared Error @@ -84,7 +87,9 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) -def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): +def mse( + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs +): """ (Unweighted) Mean Squared Error @@ -108,7 +113,9 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) -def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): +def wmae( + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs +): """ Weighted Mean Absolute Error @@ -139,7 +146,9 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) -def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): +def mae( + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs +): """ (Unweighted) Mean Absolute Error @@ -163,7 +172,9 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): ) -def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): +def nll( + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs +): """ Negative Log Likelihood loss, for isotropic Gaussian likelihood @@ -191,7 +202,7 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): def crps_gauss( - pred, target, pred_std, mask=None, average_grid=True, sum_vars=True + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs ): """ (Negative) Continuous Ranked Probability Score (CRPS) @@ -227,6 +238,166 @@ def crps_gauss( ) +def log_spectral_distance( + pred, + target, + pred_std, + mask=None, + average_grid=True, + sum_vars=True, + grid_shape=None, + edge_index=None, + num_moments=10, + eps=1e-8, +): + """ + Log-Spectral Distance (LSD) + + (...,) is any number of batch dimensions, potentially different + but broadcastable + pred: (..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. (unused) + mask: (N,), boolean mask describing which grid nodes to use (unused) + average_grid: boolean, if result should be averaged over grid + sum_vars: boolean, if variable dimension -1 should be reduced + grid_shape: tuple (ny, nx), shape of the 2D grid (for regular grids) + edge_index: (2, M), edges in the graph (for unstructured grids) + num_moments: int, number of Laplacian moments to use for unstructured LSD + eps: float, small value to avoid log(0) + + Returns: + metric_val: One of (...,), (..., d_state), depends on reduction arguments. + """ + # Regular grid LSD using FFT + if grid_shape is not None or (edge_index is None and _is_square_grid(pred)): + if grid_shape is None: + num_nodes = pred.shape[-2] + side = int(num_nodes**0.5) + grid_shape = (side, side) + + # Reshape to (..., d_state, ny, nx) for FFT + ny, nx = grid_shape + # Move d_state to before grid dimensions + # pred is (..., N, d_state) -> (..., d_state, N) -> (..., d_state, ny, nx) + p = pred.transpose(-1, -2).reshape( + *pred.shape[:-2], pred.shape[-1], ny, nx + ) + t = target.transpose(-1, -2).reshape( + *target.shape[:-2], target.shape[-1], ny, nx + ) + + # Compute 2D RFFT + f_pred = torch.fft.rfft2(p, norm="ortho") + f_target = torch.fft.rfft2(t, norm="ortho") + + # Power Spectrum: |F(u,v)|^2 + ps_pred = torch.abs(f_pred) ** 2 + ps_target = torch.abs(f_target) ** 2 + + # Average over frequency dimensions + # entry_lsd is (..., d_state, freq_y, freq_x) + # We compute mean( (10 * log10(P_target/P_pred))^2 ) then sqrt + diff_lsd = (10 * torch.log10((ps_target + eps) / (ps_pred + eps))) ** 2 + metric_val = torch.mean(diff_lsd, dim=(-2, -1)) # (..., d_state) + metric_val = torch.sqrt(metric_val) + + # Unstructured grid LSD using Graph Signal Processing + elif edge_index is not None: + # Compute spectral moments using Normalized Laplacian + # moments: (..., d_state, num_moments) + m_pred = _compute_laplacian_moments(pred, edge_index, num_moments) + m_target = _compute_laplacian_moments(target, edge_index, num_moments) + + # Log-Spectral Distance over moments: + # RMS of 10 * log10(m_target / m_pred) + # diff_lsd is (..., d_state, num_moments) + diff_lsd = (10 * torch.log10((m_target + eps) / (m_pred + eps))) ** 2 + metric_val = torch.mean(diff_lsd, dim=-1) # (..., d_state) + metric_val = torch.sqrt(metric_val) + + else: + raise ValueError( + "log_spectral_distance requires grid_shape, edge_index, " + "or a square grid" + ) + + if sum_vars: + metric_val = torch.sum(metric_val, dim=-1) # (...,) + + return metric_val + + +def _is_square_grid(pred): + """Check if the grid dimension is a perfect square""" + num_nodes = pred.shape[-2] + side = int(num_nodes**0.5) + return side**2 == num_nodes + + +def _compute_laplacian_moments(x, edge_index, num_moments): + """ + Compute moments of the spectral distribution: m_k = x^T L^k x + where L is the Normalized Laplacian. + """ + # x: (..., N, d_state) + # edge_index: (2, M) + # returns: (..., d_state, num_moments) + N = x.shape[-2] + device = x.device + + # 1. Compute Normalized Laplacian as a sparse matrix + # L = I - D^-1/2 A D^-1/2 + row, col = edge_index + deg = torch.zeros(N, device=device) + # Assume unweighted adjacency for now, or use edge_weight if provided + # For neural-lam, m2m_edge_index is usually unweighted or has features + deg.scatter_add_(0, col, torch.ones_like(row, dtype=torch.float32)) + + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0 + + # Normalized weights: -1 / sqrt(di * dj) + val = -deg_inv_sqrt[row] * deg_inv_sqrt[col] + + # Sparse Laplacian L + # Off-diagonal: -D^-1/2 A D^-1/2 + indices = torch.cat( + [edge_index, torch.stack([torch.arange(N, device=device)] * 2)], dim=1 + ) + values = torch.cat([val, torch.ones(N, device=device)]) + L = torch.sparse_coo_tensor(indices, values, (N, N)).coalesce() + + # 2. Iteratively compute x_k = L^k x and m_k = x^T x_k + # x is (..., N, d_state) + # Reshape x to (N, -1) for sparse mm + orig_shape = x.shape + x_flat = x.transpose(-2, -1).reshape(-1, N).t() # (N, B * d_state) + + moments = [] + curr_x = x_flat + for _ in range(num_moments): + # m_k = x^T (L^k x) + # dot product per column + m_k = torch.sum(x_flat * curr_x, dim=0) # (B * d_state,) + moments.append(m_k) + # curr_x = L * curr_x + curr_x = torch.sparse.mm(L, curr_x) + + # moments: list of (B * d_state,) + moments = torch.stack(moments, dim=-1) # (B * d_state, num_moments) + + # Reshape back to (..., d_state, num_moments) + res = moments.view(*orig_shape[:-2], orig_shape[-1], num_moments) + + # Normalize moments by k=0 (total energy) to get relative distribution? + # Actually, standard LSD compares absolute power spectra. + # But if we want it to be scale-invariant, we could. + # The proposal didn't specify, so we use absolute moments for now. + # We take absolute value to ensure positivity before log + return torch.abs(res) + + DEFINED_METRICS = { "mse": mse, "mae": mae, @@ -234,4 +405,5 @@ def crps_gauss( "wmae": wmae, "nll": nll, "crps_gauss": crps_gauss, + "lsd": log_spectral_distance, } diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a411a3afc..84b4ae4ec 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -124,6 +124,14 @@ def __init__( # Instantiate loss function self.loss = metrics.get_metric(args.loss) + # Store grid shape for spectral metrics if datastore is regular grid + from ..datastore.base import BaseRegularGridDatastore + if isinstance(datastore, BaseRegularGridDatastore): + grid_shape = datastore.grid_shape_state + self.grid_shape = (grid_shape.y, grid_shape.x) + else: + self.grid_shape = None + boundary_mask = torch.tensor( da_boundary_mask.values, dtype=torch.float32 ).unsqueeze( @@ -160,6 +168,9 @@ def __init__( self._datastore.step_length ) + # Graph information for metrics (to be set by subclasses) + self.edge_index = None + def _create_dataarray_from_tensor( self, tensor: torch.Tensor, @@ -303,7 +314,12 @@ def training_step(self, batch): # Compute loss batch_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + grid_shape=self.grid_shape, + edge_index=self.edge_index, ) ) # mean over unrolled times and batch @@ -346,7 +362,12 @@ def validation_step(self, batch, batch_idx): time_step_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + grid_shape=self.grid_shape, + edge_index=self.edge_index, ), dim=0, ) # (time_steps-1) @@ -400,7 +421,12 @@ def test_step(self, batch, batch_idx): time_step_loss = torch.mean( self.loss( - prediction, target, pred_std, mask=self.interior_mask_bool + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + grid_shape=self.grid_shape, + edge_index=self.edge_index, ), dim=0, ) # (time_steps-1,) @@ -444,7 +470,12 @@ def test_step(self, batch, batch_idx): # Save per-sample spatial loss for specific times spatial_loss = self.loss( - prediction, target, pred_std, average_grid=False + prediction, + target, + pred_std, + average_grid=False, + grid_shape=self.grid_shape, + edge_index=self.edge_index, ) # (B, pred_steps, num_grid_nodes) log_spatial_losses = spatial_loss[ :, [step - 1 for step in self.args.val_steps_to_log] diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index fd38a2e67..518711f66 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -32,6 +32,12 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): else: setattr(self, name, attr_value) + # Set edge_index for metrics (from finest level mesh graph) + if self.hierarchical: + self.edge_index = self.m2m_edge_index[0] + else: + self.edge_index = self.m2m_edge_index + # Specify dimensions of data self.num_mesh_nodes, _ = self.get_num_mesh() utils.log_on_rank_zero( diff --git a/tests/test_lsd_training.py b/tests/test_lsd_training.py new file mode 100644 index 000000000..5cceee1e5 --- /dev/null +++ b/tests/test_lsd_training.py @@ -0,0 +1,94 @@ + +import pytest +import pytorch_lightning as pl +import torch +import wandb +from pathlib import Path +from neural_lam import config as nlconfig +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.models.graph_lam import GraphLAM +from neural_lam.weather_dataset import WeatherDataModule +from tests.conftest import init_datastore_example + +def run_lsd_training(datastore): + """ + Run one epoch of training with LSD loss. + """ + if torch.cuda.is_available(): + device_name = "cuda" + torch.set_float32_matmul_precision("high") + else: + device_name = "cpu" + + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + num_devices = 2 + else: + num_devices = 1 + + trainer = pl.Trainer( + max_epochs=1, + deterministic=True, + accelerator=device_name, + devices=num_devices, + log_every_n_steps=1, + detect_anomaly=True, + ) + + graph_name = "1level_lsd" + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + data_module = WeatherDataModule( + datastore=datastore, + ar_steps_train=1, + ar_steps_eval=1, + standardize=True, + batch_size=2, + num_workers=1, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + ) + + class ModelArgs: + output_std = False + loss = "lsd" + restore_opt = False + n_example_pred = 0 + graph = graph_name + hidden_dim = 4 + hidden_layers = 1 + processor_layers = 1 + mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1] + metrics_watch = [] + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + + model_args = ModelArgs() + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ) + ) + + model = GraphLAM( + args=model_args, + datastore=datastore, + config=config, + ) + + # Mock wandb to avoid network calls + with torch.no_grad(): + trainer.fit(model=model, datamodule=data_module) + +def test_training_lsd(): + """Test training with LSD loss on dummy data""" + datastore = init_datastore_example("dummydata") + run_lsd_training(datastore) diff --git a/tests/test_training.py b/tests/test_training.py index 972740695..78e2a8f28 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -38,14 +38,20 @@ def run_simple_training(datastore, set_output_std): else: device_name = "cpu" + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + num_devices = 2 + else: + num_devices = 1 + trainer = pl.Trainer( max_epochs=1, deterministic=True, accelerator=device_name, # XXX: `devices` has to be set to 2 otherwise # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails - # because it expects to aggregate over multiple devices - devices=2, + # because it expects to aggregate over multiple devices. + # Now fixed in all_gather_cat to handle single-device cases. + devices=num_devices, log_every_n_steps=1, # use `detect_anomaly` to ensure that we don't have NaNs popping up # during training