Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add generalized Log-Spectral Distance (LSD) metric for spatial frequency analysis.
Supports both regular grids (via FFT) and irregular grids (via Graph Signal Processing). [\#508](https://github.com/mllam/neural-lam/pull/508) @sohampatil01-svg

- Add `AGENTS.md` file to the repo to give agents more information about the codebase and the contribution culture.[\#416](https://github.com/mllam/neural-lam/pull/416) @sadamov

- Enable `pin_memory` in DataLoaders when GPU is available for faster async CPU-to-GPU data transfers [\#236](https://github.com/mllam/neural-lam/pull/236) @abhaygoudannavar
Expand Down
184 changes: 178 additions & 6 deletions neural_lam/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Third-party
import torch
import torch.fft


def get_metric(metric_name):
Expand Down Expand Up @@ -53,7 +54,9 @@ def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars):
return metric_entry_vals


def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmse(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs
):
"""
Weighted Mean Squared Error

Expand Down Expand Up @@ -84,7 +87,9 @@ def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mse(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs
):
"""
(Unweighted) Mean Squared Error

Expand All @@ -108,7 +113,9 @@ def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def wmae(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs
):
"""
Weighted Mean Absolute Error

Expand Down Expand Up @@ -139,7 +146,9 @@ def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def mae(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs
):
"""
(Unweighted) Mean Absolute Error

Expand All @@ -163,7 +172,9 @@ def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
)


def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):
def nll(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs
):
"""
Negative Log Likelihood loss, for isotropic Gaussian likelihood

Expand Down Expand Up @@ -191,7 +202,7 @@ def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True):


def crps_gauss(
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True
pred, target, pred_std, mask=None, average_grid=True, sum_vars=True, **kwargs
):
"""
(Negative) Continuous Ranked Probability Score (CRPS)
Expand Down Expand Up @@ -227,11 +238,172 @@ def crps_gauss(
)


def log_spectral_distance(
pred,
target,
pred_std,
mask=None,
average_grid=True,
sum_vars=True,
grid_shape=None,
edge_index=None,
num_moments=10,
eps=1e-8,
):
"""
Log-Spectral Distance (LSD)

(...,) is any number of batch dimensions, potentially different
but broadcastable
pred: (..., N, d_state), prediction
target: (..., N, d_state), target
pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. (unused)
mask: (N,), boolean mask describing which grid nodes to use (unused)
average_grid: boolean, if result should be averaged over grid
sum_vars: boolean, if variable dimension -1 should be reduced
grid_shape: tuple (ny, nx), shape of the 2D grid (for regular grids)
edge_index: (2, M), edges in the graph (for unstructured grids)
num_moments: int, number of Laplacian moments to use for unstructured LSD
eps: float, small value to avoid log(0)

Returns:
metric_val: One of (...,), (..., d_state), depends on reduction arguments.
"""
# Regular grid LSD using FFT
if grid_shape is not None or (edge_index is None and _is_square_grid(pred)):
if grid_shape is None:
num_nodes = pred.shape[-2]
side = int(num_nodes**0.5)
grid_shape = (side, side)

# Reshape to (..., d_state, ny, nx) for FFT
ny, nx = grid_shape
# Move d_state to before grid dimensions
# pred is (..., N, d_state) -> (..., d_state, N) -> (..., d_state, ny, nx)
p = pred.transpose(-1, -2).reshape(
*pred.shape[:-2], pred.shape[-1], ny, nx
)
t = target.transpose(-1, -2).reshape(
*target.shape[:-2], target.shape[-1], ny, nx
)

# Compute 2D RFFT
f_pred = torch.fft.rfft2(p, norm="ortho")
f_target = torch.fft.rfft2(t, norm="ortho")

# Power Spectrum: |F(u,v)|^2
ps_pred = torch.abs(f_pred) ** 2
ps_target = torch.abs(f_target) ** 2

# Average over frequency dimensions
# entry_lsd is (..., d_state, freq_y, freq_x)
# We compute mean( (10 * log10(P_target/P_pred))^2 ) then sqrt
diff_lsd = (10 * torch.log10((ps_target + eps) / (ps_pred + eps))) ** 2
metric_val = torch.mean(diff_lsd, dim=(-2, -1)) # (..., d_state)
metric_val = torch.sqrt(metric_val)

# Unstructured grid LSD using Graph Signal Processing
elif edge_index is not None:
# Compute spectral moments using Normalized Laplacian
# moments: (..., d_state, num_moments)
m_pred = _compute_laplacian_moments(pred, edge_index, num_moments)
m_target = _compute_laplacian_moments(target, edge_index, num_moments)

# Log-Spectral Distance over moments:
# RMS of 10 * log10(m_target / m_pred)
# diff_lsd is (..., d_state, num_moments)
diff_lsd = (10 * torch.log10((m_target + eps) / (m_pred + eps))) ** 2
metric_val = torch.mean(diff_lsd, dim=-1) # (..., d_state)
metric_val = torch.sqrt(metric_val)

else:
raise ValueError(
"log_spectral_distance requires grid_shape, edge_index, "
"or a square grid"
)

if sum_vars:
metric_val = torch.sum(metric_val, dim=-1) # (...,)

return metric_val


def _is_square_grid(pred):
"""Check if the grid dimension is a perfect square"""
num_nodes = pred.shape[-2]
side = int(num_nodes**0.5)
return side**2 == num_nodes


def _compute_laplacian_moments(x, edge_index, num_moments):
"""
Compute moments of the spectral distribution: m_k = x^T L^k x
where L is the Normalized Laplacian.
"""
# x: (..., N, d_state)
# edge_index: (2, M)
# returns: (..., d_state, num_moments)
N = x.shape[-2]
device = x.device

# 1. Compute Normalized Laplacian as a sparse matrix
# L = I - D^-1/2 A D^-1/2
row, col = edge_index
deg = torch.zeros(N, device=device)
# Assume unweighted adjacency for now, or use edge_weight if provided
# For neural-lam, m2m_edge_index is usually unweighted or has features
deg.scatter_add_(0, col, torch.ones_like(row, dtype=torch.float32))

deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0

# Normalized weights: -1 / sqrt(di * dj)
val = -deg_inv_sqrt[row] * deg_inv_sqrt[col]

# Sparse Laplacian L
# Off-diagonal: -D^-1/2 A D^-1/2
indices = torch.cat(
[edge_index, torch.stack([torch.arange(N, device=device)] * 2)], dim=1
)
values = torch.cat([val, torch.ones(N, device=device)])
L = torch.sparse_coo_tensor(indices, values, (N, N)).coalesce()

# 2. Iteratively compute x_k = L^k x and m_k = x^T x_k
# x is (..., N, d_state)
# Reshape x to (N, -1) for sparse mm
orig_shape = x.shape
x_flat = x.transpose(-2, -1).reshape(-1, N).t() # (N, B * d_state)

moments = []
curr_x = x_flat
for _ in range(num_moments):
# m_k = x^T (L^k x)
# dot product per column
m_k = torch.sum(x_flat * curr_x, dim=0) # (B * d_state,)
moments.append(m_k)
# curr_x = L * curr_x
curr_x = torch.sparse.mm(L, curr_x)

# moments: list of (B * d_state,)
moments = torch.stack(moments, dim=-1) # (B * d_state, num_moments)

# Reshape back to (..., d_state, num_moments)
res = moments.view(*orig_shape[:-2], orig_shape[-1], num_moments)

# Normalize moments by k=0 (total energy) to get relative distribution?
# Actually, standard LSD compares absolute power spectra.
# But if we want it to be scale-invariant, we could.
# The proposal didn't specify, so we use absolute moments for now.
# We take absolute value to ensure positivity before log
return torch.abs(res)


DEFINED_METRICS = {
"mse": mse,
"mae": mae,
"wmse": wmse,
"wmae": wmae,
"nll": nll,
"crps_gauss": crps_gauss,
"lsd": log_spectral_distance,
}
39 changes: 35 additions & 4 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def __init__(
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)

# Store grid shape for spectral metrics if datastore is regular grid
from ..datastore.base import BaseRegularGridDatastore
if isinstance(datastore, BaseRegularGridDatastore):
grid_shape = datastore.grid_shape_state
self.grid_shape = (grid_shape.y, grid_shape.x)
else:
self.grid_shape = None

boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
).unsqueeze(
Expand Down Expand Up @@ -160,6 +168,9 @@ def __init__(
self._datastore.step_length
)

# Graph information for metrics (to be set by subclasses)
self.edge_index = None

def _create_dataarray_from_tensor(
self,
tensor: torch.Tensor,
Expand Down Expand Up @@ -303,7 +314,12 @@ def training_step(self, batch):
# Compute loss
batch_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
grid_shape=self.grid_shape,
edge_index=self.edge_index,
)
) # mean over unrolled times and batch

Expand Down Expand Up @@ -346,7 +362,12 @@ def validation_step(self, batch, batch_idx):

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
grid_shape=self.grid_shape,
edge_index=self.edge_index,
),
dim=0,
) # (time_steps-1)
Expand Down Expand Up @@ -400,7 +421,12 @@ def test_step(self, batch, batch_idx):

time_step_loss = torch.mean(
self.loss(
prediction, target, pred_std, mask=self.interior_mask_bool
prediction,
target,
pred_std,
mask=self.interior_mask_bool,
grid_shape=self.grid_shape,
edge_index=self.edge_index,
),
dim=0,
) # (time_steps-1,)
Expand Down Expand Up @@ -444,7 +470,12 @@ def test_step(self, batch, batch_idx):

# Save per-sample spatial loss for specific times
spatial_loss = self.loss(
prediction, target, pred_std, average_grid=False
prediction,
target,
pred_std,
average_grid=False,
grid_shape=self.grid_shape,
edge_index=self.edge_index,
) # (B, pred_steps, num_grid_nodes)
log_spatial_losses = spatial_loss[
:, [step - 1 for step in self.args.val_steps_to_log]
Expand Down
6 changes: 6 additions & 0 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
else:
setattr(self, name, attr_value)

# Set edge_index for metrics (from finest level mesh graph)
if self.hierarchical:
self.edge_index = self.m2m_edge_index[0]
else:
self.edge_index = self.m2m_edge_index

# Specify dimensions of data
self.num_mesh_nodes, _ = self.get_num_mesh()
utils.log_on_rank_zero(
Expand Down
Loading