diff --git a/config/afnonet.yaml b/config/afnonet.yaml index 6eab5f81..02a1d88f 100644 --- a/config/afnonet.yaml +++ b/config/afnonet.yaml @@ -57,9 +57,6 @@ full_field: &BASE_CONFIG nested_skip_fno: !!bool True # whether to nest the inner skip connection or have it be sequential, inside the AFNO block verbose: False - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"] normalization: "zscore" #options zscore or minmax hard_thresholding_fraction: 1.0 diff --git a/config/debug.yaml b/config/debug.yaml index 84744732..cbd7b190 100644 --- a/config/debug.yaml +++ b/config/debug.yaml @@ -90,9 +90,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: false - # define channels to be read from data channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] # normalization mode zscore but for q diff --git a/config/fourcastnet3.yaml b/config/fourcastnet3.yaml index d27c629e..21925c94 100644 --- a/config/fourcastnet3.yaml +++ b/config/fourcastnet3.yaml @@ -87,9 +87,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: false - # define channels to be read from data. sp has been removed here channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] # normalization mode zscore but for q diff --git a/config/icml_models.yaml b/config/icml_models.yaml index e135a7f3..ddde08ec 100644 --- a/config/icml_models.yaml +++ b/config/icml_models.yaml @@ -66,9 +66,6 @@ base_config: &BASE_CONFIG N_grid_channels: 0 gridtype: "sinusoidal" #options "sinusoidal" or "linear" - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"] normalization: "zscore" #options zscore or minmax or none diff --git a/config/pangu.yaml b/config/pangu.yaml index c2852a14..7f49193e 100644 --- a/config/pangu.yaml +++ b/config/pangu.yaml @@ -78,9 +78,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: !!bool False - # define channels to be read from data channel_names: ["u10m", "v10m", "t2m", "msl", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" # options zscore or minmax or none @@ -131,10 +128,10 @@ base_onnx: &BASE_ONNX # ONNX wrapper related overwrite nettype: "/makani/makani/makani/models/networks/pangu_onnx.py:PanguOnnx" onnx_file: '/model/pangu_weather_6.onnx' - + amp_mode: "none" disable_ddp: True - + # Set Pangu ONNX channel order channel_names: ["msl", "u10m", "v10m", "t2m", "z1000", "z925", "z850", "z700", "z600", "z500", "z400", "z300", "z250", "z200", "z150", "z100", "z50", "q1000", "q925", "q850", "q700", "q600", "q500", "q400", "q300", "q250", "q200", "q150", "q100", "q50", "t1000", "t925", "t850", "t700", "t600", "t500", "t400", "t300", "t250", "t200", "t150", "t100", "t50", "u1000", "u925", "u850", "u700", "u600", "u500", "u400", "u300", "u250", "u200", "u150", "u100", "u50", "v1000", "v925", "v850", "v700", "v600", "v500", "v400", "v300", "v250", "v200", "v150", "v100", "v50"] # Remove input/output normalization diff --git a/config/sfnonet.yaml b/config/sfnonet.yaml index 271a0ca7..b70b1d4d 100644 --- a/config/sfnonet.yaml +++ b/config/sfnonet.yaml @@ -86,9 +86,6 @@ base_config: &BASE_CONFIG add_noise: !!bool False noise_std: 0. - target: "default" # options default, residual - normalize_residual: !!bool False - # define channels to be read from data channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" # options zscore or minmax or none diff --git a/config/vit.yaml b/config/vit.yaml index 58fcee09..fc2571c1 100644 --- a/config/vit.yaml +++ b/config/vit.yaml @@ -65,9 +65,6 @@ full_field: &BASE_CONFIG embed_dim: 384 normalization_layer: "layer_norm" - #options default, residual - target: "default" - channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"] normalization: "zscore" #options zscore or minmax hard_thresholding_fraction: 1.0 diff --git a/data_process/get_stats.py b/data_process/get_stats.py index 840b3f78..97e0f3ce 100644 --- a/data_process/get_stats.py +++ b/data_process/get_stats.py @@ -366,6 +366,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str, height, width = (data_shape[2], data_shape[3]) # quadrature: + # we normalize the quadrature rule to 4pi quadrature = GridQuadrature(quadrature_rule, (height, width), crop_shape=None, crop_offset=(0, 0), normalize=False, pole_mask=None).to(device) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9af8c19a..5e6c7c8b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -71,9 +71,9 @@ RUN pip install --ignore-installed "git+https://github.com/NVIDIA/mlperf-common. # torch-harmonics ENV FORCE_CUDA_EXTENSION 1 -ENV TORCH_CUDA_ARCH_LIST "8.0 8.6 9.0 10.0 12.0+PTX" -ENV HARMONICS_VERSION 0.8.0 -RUN cd /opt && git clone -b v0.8.0 https://github.com/NVIDIA/torch-harmonics.git && \ +ENV TORCH_CUDA_ARCH_LIST "8.0 8.6 8.7 9.0 10.0+PTX" +ENV HARMONICS_VERSION 0.8.1 +RUN cd /opt && git clone -b bbonev/disco-modal-normalization https://github.com/NVIDIA/torch-harmonics.git && \ cd torch-harmonics && \ pip install --no-build-isolation -e . diff --git a/makani/convert_checkpoint.py b/makani/convert_checkpoint.py index be4f95cb..19fed0ab 100644 --- a/makani/convert_checkpoint.py +++ b/makani/convert_checkpoint.py @@ -100,7 +100,7 @@ def consolidate_checkpoints(input_path, output_path, checkpoint_version=0): print(checkpoint_paths) # load the static data necessary for instantiating the preprocessor (required due to the way the registry works) - LocalPackage._load_static_data(input_path, params) + LocalPackage._load_static_data(LocalPackage(input_path), params) # get the model multistep = params.n_future > 0 diff --git a/makani/models/common/__init__.py b/makani/models/common/__init__.py index 0341114a..440a9130 100644 --- a/makani/models/common/__init__.py +++ b/makani/models/common/__init__.py @@ -16,5 +16,7 @@ from .activations import ComplexReLU, ComplexActivation from .layers import DropPath, LayerScale, PatchEmbed2D, PatchEmbed3D, PatchRecovery2D, PatchRecovery3D, EncoderDecoder, MLP, UpSample3D, DownSample3D, UpSample2D, DownSample2D from .fft import RealFFT1, InverseRealFFT1, RealFFT2, InverseRealFFT2, RealFFT3, InverseRealFFT3 +from .imputation import MLPImputation, ConstantImputation from .layer_norm import GeometricInstanceNormS2 from .spectral_convolution import SpectralConv, SpectralAttention +from .pos_embedding import LearnablePositionEmbedding diff --git a/makani/models/common/imputation.py b/makani/models/common/imputation.py new file mode 100644 index 00000000..4c3dcec3 --- /dev/null +++ b/makani/models/common/imputation.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from typing import Optional + +from makani.utils import comm +from .layers import EncoderDecoder + +# helper module to handle imputation of SST +class MLPImputation(nn.Module): + def __init__( + self, + inp_chans: int = 2, + inpute_chans: torch.Tensor = torch.tensor([0]), + mlp_ratio: float = 2.0, + activation_function: nn.Module = nn.GELU, + ): + super().__init__() + + self.inp_chans = inp_chans + self.inpute_chans = inpute_chans + self.out_chans = inpute_chans.shape[0] + + self.mlp = EncoderDecoder( + num_layers=1, + input_dim=self.inp_chans, + output_dim=self.out_chans, + hidden_dim=int(mlp_ratio * self.out_chans), + act_layer=activation_function, + input_format="nchw", + ) + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): + if mask is None: + mask = torch.isnan(x[..., self.inpute_chans, :, :]) + else: + mask = torch.logical_or(mask, torch.isnan(x[..., self.inpute_chans, :, :])) + + x[..., self.inpute_chans, :, :] = torch.where(mask, 0.0, x[..., self.inpute_chans, :, :]) + + # flatten extra batch dims for Conv2d compatibility + batch_shape = x.shape[:-3] + x_flat = x.reshape(-1, *x.shape[-3:]) + mlp_out = self.mlp(x_flat).reshape(*batch_shape, self.out_chans, *x_flat.shape[-2:]) + + x[..., self.inpute_chans, :, :] = torch.where(mask, mlp_out, x[..., self.inpute_chans, :, :]) + + return x + +class ConstantImputation(nn.Module): + def __init__( + self, + inp_chans: int = 2, + ): + super().__init__() + + self.weight = nn.Parameter(torch.randn(inp_chans, 1, 1)) + + if comm.get_size("spatial") > 1: + self.weight.is_shared_mp = ["spatial"] + self.weight.sharded_dims_mp = [None, None, None] + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): + if mask is None: + mask = torch.isnan(x) + else: + mask = torch.logical_or(mask, torch.isnan(x)) + return torch.where(mask, self.weight, x) \ No newline at end of file diff --git a/makani/models/common/layers.py b/makani/models/common/layers.py index 3791e636..0cb6a64c 100644 --- a/makani/models/common/layers.py +++ b/makani/models/common/layers.py @@ -628,5 +628,5 @@ def forward(self, x): x = self.norm(x) x = self.linear(x) - + return x diff --git a/makani/models/common/pos_embedding.py b/makani/models/common/pos_embedding.py new file mode 100644 index 00000000..0f65da5e --- /dev/null +++ b/makani/models/common/pos_embedding.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import torch +import torch.nn as nn + +from makani.utils import comm +from physicsnemo.distributed.utils import compute_split_shapes + +class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta): + """ + Abstract base class for position embeddings. + + This class defines the interface for position embedding modules + that add positional information to input tensors. + + Parameters + ---------- + img_shape : tuple, optional + Image shape (height, width), by default (480, 960) + grid : str, optional + Grid type, by default "equiangular" + num_chans : int, optional + Number of channels, by default 1 + """ + + def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1): + super().__init__() + + self.img_shape = img_shape + self.num_chans = num_chans + + def forward(self): + + return self.position_embeddings + +class LearnablePositionEmbedding(PositionEmbedding): + """ + Learnable position embeddings for spherical transformers. + + This module provides learnable position embeddings that can be either + latitude-only or full latitude-longitude embeddings. + + Parameters + ---------- + img_shape : tuple, optional + Image shape (height, width), by default (480, 960) + grid : str, optional + Grid type, by default "equiangular" + num_chans : int, optional + Number of channels, by default 1 + embed_type : str, optional + Embedding type ("lat" or "latlon"), by default "lat" + """ + + def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"): + super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans) + + # if distributed, make sure to split correctly across ranks: + # in case of model parallelism, we need to make sure that we use the correct shapes per rank + # for h + if comm.get_size("h") > 1: + self.local_shape_h = compute_split_shapes(img_shape[0], comm.get_size("h"))[comm.get_rank("h")] + else: + self.local_shape_h = img_shape[0] + + # for w + if comm.get_size("w") > 1: + self.local_shape_w = compute_split_shapes(img_shape[1], comm.get_size("w"))[comm.get_rank("w")] + else: + self.local_shape_w = img_shape[1] + + if embed_type == "latlon": + self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, self.local_shape_w)) + self.position_embeddings.is_shared_mp = [] + self.position_embeddings.sharded_dims_mp = [None, None, "h", "w"] + elif embed_type == "lat": + self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, 1)) + self.position_embeddings.is_shared_mp = ["w"] + self.position_embeddings.sharded_dims_mp = [None, None, "h", None] + else: + raise ValueError(f"Unknown learnable position embedding type {embed_type}") + + def forward(self): + return self.position_embeddings.expand(-1,-1,self.local_shape_h, self.local_shape_w) diff --git a/makani/models/model_package.py b/makani/models/model_package.py index 63248844..c7e04e08 100644 --- a/makani/models/model_package.py +++ b/makani/models/model_package.py @@ -37,7 +37,7 @@ class LocalPackage: """ - Implements the earth2mip/modulus Package interface. + Implements the modulus Package interface. """ # These define the model package in terms of where makani expects the files to be located @@ -48,7 +48,7 @@ class LocalPackage: MEANS_FILE = "global_means.npy" STDS_FILE = "global_stds.npy" OROGRAPHY_FILE = "orography.nc" - LANDMASK_FILE = "land_mask.nc" + LANDMASK_FILE = "land_sea_mask.nc" SOILTYPE_FILE = "soil_type.nc" def __init__(self, root): @@ -147,11 +147,11 @@ def timestep(self): def update_state(self, replace_state=True): self.model.preprocessor.update_internal_state(replace_state=replace_state) return - + def set_rng(self, reset=True, seed=333): self.model.preprocessor.set_rng(reset=reset, seed=seed) return - + def forward(self, x, time, normalized_data=True, replace_state=None): if not normalized_data: x = (x - self.in_bias) / self.in_scale diff --git a/makani/models/model_registry.py b/makani/models/model_registry.py index 44855a44..f97958a9 100644 --- a/makani/models/model_registry.py +++ b/makani/models/model_registry.py @@ -165,7 +165,27 @@ def get_model(params: ParamsBase, use_stochastic_interpolation: bool = False, mu if hasattr(model_handle, "load") and callable(model_handle.load): model_handle = model_handle.load() - model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict()) + model_kwargs = params.to_dict() + + # pass normalization statistics to the model + if params.get("normalization", "none") in ["zscore", "minmax"]: + try: + bias, scale = get_data_normalization(params) + # Slice the stats to match the model's output channels + # Assuming the model's output corresponds to params.out_channels + if hasattr(params, "out_channels"): + if bias is not None: + bias = bias.flatten()[params.out_channels] + if scale is not None: + scale = scale.flatten()[params.out_channels] + + if bias is not None and scale is not None: + model_kwargs["normalization_means"] = bias + model_kwargs["normalization_stds"] = scale + except Exception as e: + logging.warning(f"Could not load normalization stats. Error: {e}") + + model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **model_kwargs) else: raise KeyError(f"No model is registered under the name {params.nettype}") diff --git a/makani/models/networks/fourcastnet3.py b/makani/models/networks/fourcastnet3.py index 9244e145..0ab75733 100644 --- a/makani/models/networks/fourcastnet3.py +++ b/makani/models/networks/fourcastnet3.py @@ -14,26 +14,25 @@ # limitations under the License. import math +from functools import partial +from typing import Optional + import torch import torch.nn as nn -import torch.special as special import torch.amp as amp from torch.utils.checkpoint import checkpoint -from functools import partial -from itertools import groupby - # helpers -from makani.models.common import DropPath, LayerScale, MLP, EncoderDecoder, SpectralConv +from makani.models.common import DropPath, LayerScale, MLP, EncoderDecoder, SpectralConv, LearnablePositionEmbedding, ConstantImputation, MLPImputation from makani.utils.features import get_water_channels, get_channel_groups +from makani.utils.grids import compute_spherical_bandlimit # get spectral transforms and spherical convolutions from torch_harmonics import torch_harmonics as th import torch_harmonics.distributed as thd # get pre-formulated layers -#from makani.models.common import GeometricInstanceNormS2 -from makani.mpu.layers import DistributedMLP, DistributedEncoderDecoder +from makani.mpu.layers import DistributedMLP # more distributed stuff from makani.utils import comm @@ -44,10 +43,9 @@ from physicsnemo.models.meta import ModelMetaData # heuristic for finding theta_cutoff -def _compute_cutoff_radius(nlat, kernel_shape, basis_type): - theta_cutoff_factor = {"piecewise linear": 0.5, "morlet": 0.5, "zernike": math.sqrt(2.0)} - - return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1) +def _compute_cutoff_radius(lmax, kernel_shape, basis_type): + margin_factor = {"piecewise linear": 1.0, "morlet": 1.0, "harmonic": 1.0, "zernike": 1.0} + return margin_factor[basis_type] * (kernel_shape[0] + 0.25) * math.pi / float(lmax) # commenting out torch.compile due to long intiial compile times # @torch.compile @@ -57,6 +55,10 @@ def _soft_clamp(x: torch.Tensor, offset: float = 0.0): y = torch.where(x >= 0.5, x - 0.25, y) return y +# heper function to be able to pass Sin as an activation function +class Sin(nn.Module): + def forward(self, x): + return torch.sin(x) class DiscreteContinuousEncoder(nn.Module): def __init__( @@ -69,7 +71,8 @@ def __init__( out_chans=2, kernel_shape=(3,3), basis_type="morlet", - basis_norm_mode="mean", + basis_norm_mode="nodal", + lmax=240, use_mlp=False, mlp_ratio=2.0, activation_function=nn.GELU, @@ -79,7 +82,7 @@ def __init__( super().__init__() # heuristic for finding theta_cutoff - theta_cutoff = _compute_cutoff_radius(nlat=inp_shape[0], kernel_shape=kernel_shape, basis_type=basis_type) + theta_cutoff = _compute_cutoff_radius(lmax=lmax, kernel_shape=kernel_shape, basis_type=basis_type) # set up local convolution conv_handle = thd.DistributedDiscreteContinuousConvS2 if comm.get_size("spatial") > 1 else th.DiscreteContinuousConvS2 @@ -119,13 +122,16 @@ def __init__( input_format="nchw", ) - def forward(self, x): + def _conv_forward(self, x): dtype = x.dtype - with amp.autocast(device_type="cuda", enabled=False): x = x.float() x = self.conv(x) x = x.to(dtype=dtype) + return x + + def forward(self, x): + x = self._conv_forward(x) if hasattr(self, "act"): x = self.act(x) @@ -147,7 +153,9 @@ def __init__( out_chans=2, kernel_shape=(3, 3), basis_type="morlet", - basis_norm_mode="mean", + basis_norm_mode="nodal", + theta_cutoff_factor=1.0, + lmax=240, use_mlp=False, mlp_ratio=2.0, activation_function=nn.GELU, @@ -187,7 +195,7 @@ def __init__( # heuristic for finding theta_cutoff # nto entirely clear if out or in shape should be used here with a non-conv method for upsampling - theta_cutoff = _compute_cutoff_radius(nlat=out_shape[0], kernel_shape=kernel_shape, basis_type=basis_type) + theta_cutoff = theta_cutoff_factor * _compute_cutoff_radius(lmax=lmax, kernel_shape=kernel_shape, basis_type=basis_type) # set up DISCO convolution conv_handle = thd.DistributedDiscreteContinuousConvS2 if comm.get_size("spatial") > 1 else th.DiscreteContinuousConvS2 @@ -202,7 +210,7 @@ def __init__( grid_in=grid_out, grid_out=grid_out, groups=groups, - bias=False, + bias=bias, theta_cutoff=theta_cutoff, ) if comm.get_size("spatial") > 1: @@ -212,6 +220,17 @@ def __init__( self.conv.bias.is_shared_mp = ["spatial"] self.conv.bias.sharded_dims_mp = [None] + def _conv_forward(self, x): + dtype = x.dtype + + with amp.autocast(device_type="cuda", enabled=False): + x = x.float() + x = self.upsample(x) + x = self.conv(x) + x = x.to(dtype=dtype) + + return x + def forward(self, x): dtype = x.dtype @@ -221,11 +240,7 @@ def forward(self, x): if hasattr(self, "mlp"): x = self.mlp(x) - with amp.autocast(device_type="cuda", enabled=False): - x = x.float() - x = self.upsample(x) - x = self.conv(x) - x = x.to(dtype=dtype) + x = self._conv_forward(x) return x @@ -249,30 +264,37 @@ def __init__( use_mlp=False, kernel_shape=(3, 3), basis_type="morlet", - basis_norm_mode="mean", + basis_norm_mode="nodal", + lmax=240, checkpointing_level=0, bias=False, + stochastic_bias=False, + seed=333, ): super().__init__() + # generator objects: + seed = seed + comm.get_rank("model") + comm.get_size("model") * comm.get_rank("ensemble") + comm.get_size("model") * comm.get_size("ensemble") * comm.get_rank("batch") + self.set_rng(seed=seed) + # determine some shapes self.inp_shape = (forward_transform.nlat, forward_transform.nlon) self.out_shape = (inverse_transform.nlat, inverse_transform.nlon) self.out_chans = out_chans # gain factor for the convolution - gain_factor = 1.0 + gain_factor = 0.5 # disco convolution layer if conv_type == "local": # heuristic for finding theta_cutoff - theta_cutoff = 2 * _compute_cutoff_radius(nlat=self.inp_shape[0], kernel_shape=kernel_shape, basis_type=basis_type) + theta_cutoff = _compute_cutoff_radius(lmax=lmax, kernel_shape=kernel_shape, basis_type=basis_type) conv_handle = thd.DistributedDiscreteContinuousConvS2 if comm.get_size("spatial") > 1 else th.DiscreteContinuousConvS2 self.local_conv = conv_handle( inp_chans, - inp_chans, + inp_chans if use_mlp else out_chans, in_shape=self.inp_shape, out_shape=self.out_shape, kernel_shape=kernel_shape, @@ -300,18 +322,26 @@ def __init__( forward_transform, inverse_transform, inp_chans, - inp_chans, + inp_chans if use_mlp else out_chans, operator_type="dhconv", num_groups=num_groups, - bias=bias, + bias=False, gain=gain_factor, ) else: raise ValueError(f"Unknown convolution type {conv_type}") + # stochastic bias + if stochastic_bias: + self.bias_std = nn.Parameter(torch.zeros(inp_chans if use_mlp else out_chans, 1, 1)) + scale = math.sqrt(gain_factor / self.bias_std.shape[0] / 2) + nn.init.normal_(self.bias_std, mean=0.0, std=scale) + self.bias_std.is_shared_mp = ["spatial"] + # norm layer self.norm = norm_layer() + # MLP if use_mlp == True: MLPH = DistributedMLP if (comm.get_size("matmul") > 1) else MLP mlp_hidden_dim = int(inp_chans * mlp_ratio) @@ -353,15 +383,33 @@ def __init__( else: raise ValueError(f"Unknown skip connection type {skip}") + @torch.compiler.disable(recursive=False) + def set_rng(self, seed=333): + self.rng_cpu = torch.Generator(device=torch.device("cpu")) + self.rng_cpu.manual_seed(seed) + if torch.cuda.is_available(): + self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}")) + self.rng_gpu.manual_seed(seed) + + def _conv_forward(self, x): + if hasattr(self, "global_conv"): + dx, _ = self.global_conv(x) + elif hasattr(self, "local_conv"): + dx = self.local_conv(x) + + return dx + def forward(self, x): """ Updated NO block """ - if hasattr(self, "global_conv"): - dx, _ = self.global_conv(x) - elif hasattr(self, "local_conv"): - dx = self.local_conv(x) + dx = self._conv_forward(x) + + if hasattr(self, "bias_std"): + n = torch.zeros(*dx.shape[:-2], 1, 1, device=dx.device, dtype=dx.dtype) + n.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if n.is_cuda else self.rng_cpu) + dx = dx + self.bias_std * n if hasattr(self, "norm"): dx = self.norm(dx) @@ -400,15 +448,14 @@ def __init__( kernel_shape=(3, 3), filter_basis_type="morlet", filter_basis_norm_mode="mean", - scale_factor=8, encoder_mlp=False, upsample_sht=False, channel_names=["u500", "v500"], aux_channel_names=[], n_history=0, - atmo_embed_dim=8, - surf_embed_dim=8, + embed_dim=8, aux_embed_dim=8, + pos_embed_dim=0, num_layers=4, num_groups=1, use_mlp=True, @@ -419,43 +466,53 @@ def __init__( path_drop_rate=0.0, mlp_drop_rate=0.0, normalization_layer="none", - max_modes=None, - hard_thresholding_fraction=1.0, + hard_thresholding_fraction=0.25, + scale_factor=8, + lmax=None, sfno_block_frequency=2, big_skip=False, clamp_water=False, + encoder_bias=False, bias=False, checkpointing_level=0, freeze_encoder=False, freeze_processor=False, + normalization_means=None, + normalization_stds=None, + stochastic_bias=False, + seed=333, **kwargs, ): super().__init__() self.inp_shape = inp_shape self.out_shape = out_shape - self.atmo_embed_dim = atmo_embed_dim - self.surf_embed_dim = surf_embed_dim + self.embed_dim = embed_dim self.aux_embed_dim = aux_embed_dim + self.pos_embed_dim = pos_embed_dim self.big_skip = big_skip self.checkpointing_level = checkpointing_level - - # currently doesn't support neither history nor future: - assert n_history == 0 + self.n_history = n_history # compute the downscaled image size self.h = int(self.inp_shape[0] // scale_factor) self.w = int(self.inp_shape[1] // scale_factor) + if normalization_means is not None: + self.register_buffer("normalization_means", torch.as_tensor(normalization_means)) + if normalization_stds is not None: + self.register_buffer("normalization_stds", torch.as_tensor(normalization_stds)) + # initialize spectral transforms - self._init_spectral_transforms(model_grid_type, sht_grid_type, hard_thresholding_fraction, max_modes) + self._init_spectral_transforms(model_grid_type, sht_grid_type, hard_thresholding_fraction, lmax) # compute static permutations to extract - self._precompute_channel_groups(channel_names, aux_channel_names) + self._precompute_channel_groups(channel_names, aux_channel_names, n_history) # compute the total number of internal groups self.n_out_chans = self.n_atmo_groups * self.n_atmo_chans + self.n_surf_chans - self.total_embed_dim = self.n_atmo_groups * self.atmo_embed_dim + self.surf_embed_dim + self.n_in_chans = (self.n_atmo_groups * self.n_atmo_chans + self.n_surf_chans) * (self.n_history + 1) + self.total_aux_embed_dim = self.aux_embed_dim + self.pos_embed_dim # convert kernel shape to tuple kernel_shape = tuple(kernel_shape) @@ -467,115 +524,94 @@ def __init__( activation_function = nn.GELU elif activation_function == "silu": activation_function = nn.SiLU + elif activation_function == "sin": + activation_function = Sin else: raise ValueError(f"Unknown activation function {activation_function}") - # encoder for the atmospheric channels - # TODO: add the groups - self.atmo_encoder = DiscreteContinuousEncoder( + # sst imputation in the case of SST channels + if self.sst_channels_in.shape[0] > 0: + self.sst_imputation = MLPImputation( + inp_chans=self.n_in_chans + self.n_aux_chans, + inpute_chans=self.sst_channels_in, + mlp_ratio=mlp_ratio, + activation_function=activation_function, + ) + + # encoder for the atmospheric and surface channels + self.encoder = DiscreteContinuousEncoder( inp_shape=inp_shape, out_shape=(self.h, self.w), - inp_chans=self.n_atmo_chans, - out_chans=self.atmo_embed_dim, + inp_chans=self.n_in_chans, + out_chans=self.embed_dim, grid_in=model_grid_type, grid_out=sht_grid_type, kernel_shape=kernel_shape, basis_type=filter_basis_type, basis_norm_mode=filter_basis_norm_mode, + lmax=self.lmax, activation_function=activation_function, - groups=math.gcd(self.n_atmo_chans, self.atmo_embed_dim), - bias=bias, + groups=math.gcd(self.n_in_chans, self.embed_dim), + bias=encoder_bias, use_mlp=encoder_mlp, ) # encoder for the auxiliary channels - if self.n_surf_chans > 0: - self.surf_encoder = DiscreteContinuousEncoder( + if self.n_aux_chans > 0: + self.aux_encoder = DiscreteContinuousEncoder( inp_shape=inp_shape, out_shape=(self.h, self.w), - inp_chans=self.n_surf_chans, - out_chans=self.surf_embed_dim, + inp_chans=self.n_aux_chans, + out_chans=self.aux_embed_dim, grid_in=model_grid_type, grid_out=sht_grid_type, kernel_shape=kernel_shape, basis_type=filter_basis_type, basis_norm_mode=filter_basis_norm_mode, + lmax=self.lmax, activation_function=activation_function, - groups=math.gcd(self.n_surf_chans, self.surf_embed_dim), - bias=bias, + groups=math.gcd(self.n_aux_chans, self.aux_embed_dim), + bias=encoder_bias, use_mlp=encoder_mlp, ) - # decoder for the atmospheric variables - self.atmo_decoder = DiscreteContinuousDecoder( + + # decoder for the atmospheric and surface variables + self.decoder = DiscreteContinuousDecoder( inp_shape=(self.h, self.w), out_shape=out_shape, - inp_chans=self.atmo_embed_dim, - out_chans=self.n_atmo_chans, + inp_chans=self.embed_dim, + out_chans=self.n_out_chans, grid_in=sht_grid_type, grid_out=model_grid_type, kernel_shape=kernel_shape, basis_type=filter_basis_type, basis_norm_mode=filter_basis_norm_mode, + lmax=self.lmax, activation_function=activation_function, - groups=math.gcd(self.n_atmo_chans, self.atmo_embed_dim), - bias=bias, + groups=math.gcd(self.n_out_chans, self.embed_dim), + bias=encoder_bias, use_mlp=encoder_mlp, upsample_sht=upsample_sht, ) - # decoder for the surface variables - if self.n_surf_chans > 0: - self.surf_decoder = DiscreteContinuousDecoder( - inp_shape=(self.h, self.w), - out_shape=out_shape, - inp_chans=self.surf_embed_dim, - out_chans=self.n_surf_chans, - grid_in=sht_grid_type, - grid_out=model_grid_type, - kernel_shape=kernel_shape, - basis_type=filter_basis_type, - basis_norm_mode=filter_basis_norm_mode, - activation_function=activation_function, - groups=math.gcd(self.n_surf_chans, self.surf_embed_dim), - bias=bias, - use_mlp=encoder_mlp, - upsample_sht=upsample_sht, - ) - - # encoder for the auxiliary channels - if self.n_aux_chans > 0: - self.aux_encoder = DiscreteContinuousEncoder( - inp_shape=inp_shape, - out_shape=(self.h, self.w), - inp_chans=self.n_aux_chans, - out_chans=self.aux_embed_dim, - grid_in=model_grid_type, - grid_out=sht_grid_type, - kernel_shape=kernel_shape, - basis_type=filter_basis_type, - basis_norm_mode=filter_basis_norm_mode, - activation_function=activation_function, - groups=math.gcd(self.n_aux_chans, self.aux_embed_dim), - bias=bias, - use_mlp=encoder_mlp, - ) + # position embedding + if self.pos_embed_dim > 0: + self.pos_embed = LearnablePositionEmbedding(img_shape=(self.h, self.w), grid=sht_grid_type, num_chans=self.pos_embed_dim, embed_type="lat") # dropout self.pos_drop = nn.Dropout(p=pos_drop_rate) if pos_drop_rate > 0.0 else nn.Identity() dpr = [x.item() for x in torch.linspace(0, path_drop_rate, num_layers)] # get the handle for the normalization layer - norm_layer = self._get_norm_layer_handle(self.h, self.w, self.total_embed_dim, normalization_layer=normalization_layer, sht_grid_type=sht_grid_type) + norm_layer = self._get_norm_layer_handle(self.h, self.w, self.embed_dim + self.total_aux_embed_dim, normalization_layer=normalization_layer, sht_grid_type=sht_grid_type) # Internal NO blocks self.blocks = nn.ModuleList([]) for i in range(num_layers): - first_layer = i == 0 - last_layer = i == num_layers - 1 - if i % sfno_block_frequency == 0: - # if True: + # determine the convolution type + if (sfno_block_frequency > 0) and (i % sfno_block_frequency == 0): conv_type = "global" else: conv_type = "local" @@ -583,8 +619,8 @@ def __init__( block = NeuralOperatorBlock( self.sht, self.isht, - self.total_embed_dim + (self.n_aux_chans > 0) * self.aux_embed_dim, - self.total_embed_dim, + self.embed_dim + self.total_aux_embed_dim, + self.embed_dim, conv_type=conv_type, mlp_ratio=mlp_ratio, mlp_drop_rate=mlp_drop_rate, @@ -597,23 +633,15 @@ def __init__( kernel_shape=kernel_shape, basis_type=filter_basis_type, basis_norm_mode=filter_basis_norm_mode, - bias=bias, + lmax=self.lmax, checkpointing_level=checkpointing_level, + bias=bias, + stochastic_bias=stochastic_bias, + seed=seed, ) self.blocks.append(block) - # residual prediction - if self.big_skip: - self.residual_transform = nn.Conv2d(self.n_out_chans, self.n_out_chans, 1, bias=False) - self.residual_transform.weight.is_shared_mp = ["spatial"] - self.residual_transform.weight.sharded_dims_mp = [None, None, None, None] - if self.residual_transform.bias is not None: - self.residual_transform.bias.is_shared_mp = ["spatial"] - self.residual_transform.bias.sharded_dims_mp = [None] - scale = math.sqrt(0.5 / self.n_out_chans) - nn.init.normal_(self.residual_transform.weight, mean=0.0, std=scale) - # controlled output normalization of q and tcwv if clamp_water: water_chans = get_water_channels(channel_names) @@ -622,9 +650,7 @@ def __init__( # freeze the encoder/decoder if freeze_encoder: - frozen_params = list(self.atmo_encoder.parameters()) + list(self.atmo_decoder.parameters()) - if hasattr(self, "surf_encoder"): - frozen_params += list(self.surf_encoder.parameters()) + list(self.surf_decoder.parameters()) + frozen_params = list(self.encoder.parameters()) + list(self.decoder.parameters()) if hasattr(self, "aux_encoder"): frozen_params += list(self.aux_encoder.parameters()) if self.big_skip: @@ -645,7 +671,7 @@ def _init_spectral_transforms( model_grid_type="equiangular", sht_grid_type="legendre-gauss", hard_thresholding_fraction=1.0, - max_modes=None, + lmax=None, ): """ Initialize the spectral transforms based on the maximum number of modes to keep. Handles the computation @@ -653,11 +679,19 @@ def _init_spectral_transforms( """ # precompute the cutoff frequency on the sphere - if max_modes is not None: - modes_lat, modes_lon = max_modes - else: - modes_lat = int(self.h * hard_thresholding_fraction) - modes_lon = int((self.w // 2 + 1) * hard_thresholding_fraction) + if lmax is None: + lmax = compute_spherical_bandlimit(self.inp_shape, model_grid_type) + lmax = int(lmax * hard_thresholding_fraction) + self.lmax = lmax + + # if sht_grid_type == "equiangular": + # self.h = self.lmax + 1 + # self.w = 2 * self.lmax + # elif sht_grid_type == "legendre-gauss": + # self.h = self.lmax + 1 + # self.w = 2 * self.lmax + # else: + # raise ValueError(f"Unknown SHT grid type {sht_grid_type}") sht_handle = th.RealSHT isht_handle = th.InverseRealSHT @@ -671,8 +705,8 @@ def _init_spectral_transforms( isht_handle = thd.DistributedInverseRealSHT # set up - self.sht = sht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type).float() - self.isht = isht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid=sht_grid_type).float() + self.sht = sht_handle(self.h, self.w, lmax=self.lmax, mmax=self.lmax, grid=sht_grid_type).float() + self.isht = isht_handle(self.h, self.w, lmax=self.lmax, mmax=self.lmax, grid=sht_grid_type).float() @torch.compiler.disable(recursive=True) def _get_norm_layer_handle( @@ -726,60 +760,101 @@ def _precompute_channel_groups( self, channel_names=[], aux_channel_names=[], + n_history=0, ): """ group the channels appropriately into atmospheric pressure levels and surface variables """ - atmo_chans, surf_chans, aux_chans, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + sst_chans = [channel_names.index("sst")] if "sst" in channel_names else [] + lsml_chans = [len(channel_names) + aux_channel_names.index("xlsml")] if "xlsml" in aux_channel_names else [] # compute how many channel groups will be kept internally self.n_atmo_groups = len(pressure_lvls) self.n_atmo_chans = len(atmo_chans) // self.n_atmo_groups + self.n_surf_chans = len(surf_chans) + self.n_dyn_aux_chans = len(dyn_aux_chans) + self.n_stat_aux_chans= len(stat_aux_chans) + self.n_aux_chans = self.n_dyn_aux_chans * (n_history + 1) + self.n_stat_aux_chans # make sure they are divisible. Attention! This does not guarantee that the grrouping is correct if len(atmo_chans) % self.n_atmo_groups: raise ValueError(f"Expected number of atmospheric variables to be divisible by number of atmospheric groups but got {len(atmo_chans)} and {self.n_atmo_groups}") - self.register_buffer("atmo_channels", torch.LongTensor(atmo_chans), persistent=False) - self.register_buffer("surf_channels", torch.LongTensor(surf_chans), persistent=False) - self.register_buffer("aux_channels", torch.LongTensor(aux_chans), persistent=False) - - self.n_surf_chans = self.surf_channels.shape[0] - self.n_aux_chans = self.aux_channels.shape[0] + # if history is included, adapt the channel lists to include the offsets + self.n_atmo_groups = self.n_atmo_groups + n_dyn_chans = len(atmo_chans) + len(surf_chans) + len(dyn_aux_chans) + atmo_chans_in = atmo_chans.copy() + surf_chans_in = surf_chans.copy() + sst_chans_in = sst_chans.copy() + for ih in range(1, n_history+1): + atmo_chans_in += [(c + ih*n_dyn_chans) for c in atmo_chans] + surf_chans_in += [(c + ih*n_dyn_chans) for c in surf_chans] + sst_chans_in += [(c + ih*n_dyn_chans) for c in sst_chans] + dyn_aux_chans += [(c + ih*n_dyn_chans) for c in dyn_aux_chans] + # account for the history offset in the static aux channels + stat_aux_chans = [c + n_history*n_dyn_chans for c in stat_aux_chans] + + self.register_buffer("atmo_channels_in", torch.LongTensor(atmo_chans_in), persistent=False) + self.register_buffer("atmo_channels_out", torch.LongTensor(atmo_chans), persistent=False) + self.register_buffer("surf_channels_in", torch.LongTensor(surf_chans_in), persistent=False) + self.register_buffer("surf_channels_out", torch.LongTensor(surf_chans), persistent=False) + self.register_buffer("sst_channels_in", torch.LongTensor(sst_chans_in), persistent=False) + self.register_buffer("sst_channels_out", torch.LongTensor(sst_chans), persistent=False) + self.register_buffer("dyn_aux_channels", torch.LongTensor(dyn_aux_chans), persistent=False) + self.register_buffer("stat_aux_channels", torch.LongTensor(stat_aux_chans), persistent=False) + self.register_buffer("land_mask_channels", torch.LongTensor(lsml_chans), persistent=False) + self.register_buffer("in_channels", torch.LongTensor(surf_chans_in + atmo_chans_in), persistent=False) + self.register_buffer("aux_channels", torch.LongTensor(dyn_aux_chans + stat_aux_chans), persistent=False) + self.register_buffer("pred_channels", torch.LongTensor(surf_chans + atmo_chans), persistent=False) return - def encode(self, x): + def impute_sst_channels(self, x): """ - forward pass for the encoder + Impute the SST channels if applicable """ - batchdims = x.shape[:-3] - # for atmospheric channels the same encoder is applied to each atmospheric level - x_atmo = x[..., self.atmo_channels, :, :].contiguous().reshape(-1, self.n_atmo_chans, *x.shape[-2:]) - x_out = self.atmo_encoder(x_atmo) - x_out = x_out.reshape(*batchdims, self.n_atmo_groups * self.atmo_embed_dim, *x_out.shape[-2:]) + # start by imputing the SST channels if applicable + if hasattr(self, "sst_imputation"): + if self.land_mask_channels.nelement() > 0: + # get a land mask that is broadcastable to the input shape + mask = x[..., self.land_mask_channels, :, :] + else: + mask = None + x = self.sst_imputation(x, mask=mask).clone() - if hasattr(self, "surf_encoder"): - x_surf = x[..., self.surf_channels, :, :].contiguous() - x_surf = self.surf_encoder(x_surf) - x_out = torch.cat((x_out, x_surf), dim=-3) + return x - x_out = x_out.reshape(*batchdims, self.total_embed_dim, *x_out.shape[-2:]) + def encode(self, x): + """ + forward pass for the encoder + """ + + x = x[..., self.in_channels, :, :].contiguous() + x = self.encoder(x) - return x_out + return x def encode_auxiliary_channels(self, x): """ returns the embedded auxiliary channels """ - batchdims = x.shape[:-3] + + aux_tensors = [] if hasattr(self, "aux_encoder"): - x_aux = x[..., self.aux_channels, :, :] + x_aux = x[..., self.aux_channels, :, :].contiguous() x_aux = self.aux_encoder(x_aux) - x_aux = x_aux.reshape(*batchdims, self.aux_embed_dim, *x_aux.shape[-2:]) + aux_tensors.append(x_aux) + + if hasattr(self, "pos_embed"): + x_pos = self.pos_embed() + aux_tensors.append(x_pos) + + if len(aux_tensors) > 0: + x_aux = torch.cat(aux_tensors, dim=-3) else: x_aux = None @@ -790,19 +865,10 @@ def decode(self, x): forward pass for the decoder """ - batchdims = x.shape[:-3] - - x_atmo = x[..., : (self.n_atmo_groups * self.atmo_embed_dim), :, :].reshape(-1, self.atmo_embed_dim, *x.shape[-2:]) - x_atmo = self.atmo_decoder(x_atmo) - x_out = torch.zeros(*batchdims, self.n_out_chans, *x_atmo.shape[-2:], dtype=x.dtype, device=x.device) - x_out[..., self.atmo_channels, :, :] = x_atmo.reshape(*batchdims, -1, *x_atmo.shape[-2:]) + x = x[..., : self.embed_dim, :, :] + x = self.decoder(x) - if hasattr(self, "surf_decoder"): - x_surf = x[..., -self.surf_embed_dim :, :, :] - x_surf = self.surf_decoder(x_surf) - x_out[..., self.surf_channels, :, :] = x_surf.reshape(*batchdims, -1, *x_surf.shape[-2:]) - - return x_out + return x def processor_blocks(self, x, x_aux): # maybe clean the padding just in case @@ -823,9 +889,18 @@ def processor_blocks(self, x, x_aux): return x def clamp_water_channels(self, x): - """clamp water channes with a smooth, positive activation function""" + """ + clamp water channes with a smooth, positive activation function + """ + if hasattr(self, "water_channels"): - w = _soft_clamp(x[..., self.water_channels, :, :]) + if hasattr(self, "normalization_means") and hasattr(self, "normalization_stds"): + means = self.normalization_means[self.water_channels].view(1, -1, 1, 1) + stds = self.normalization_stds[self.water_channels].view(1, -1, 1, 1) + offset = means / stds + w = _soft_clamp(x[..., self.water_channels, :, :], offset=offset) - offset + else: + w = _soft_clamp(x[..., self.water_channels, :, :]) # the following eventually leads to spectral instability # w = nn.functional.softplus(x[..., self.water_channels, :, :], beta=5, threshold=5) x[..., self.water_channels, :, :] = w @@ -834,9 +909,12 @@ def clamp_water_channels(self, x): def forward(self, x): + # sst imputation + x = self.impute_sst_channels(x) + # save big skip if self.big_skip: - residual = x[..., : self.n_out_chans, :, :].contiguous() + residual = x[..., self.pred_channels, :, :].contiguous() # extract embeddings for the auxiliary embeddings x_aux = self.encode_auxiliary_channels(x) @@ -857,7 +935,7 @@ def forward(self, x): x = self.decode(x) if self.big_skip: - x = x + self.residual_transform(residual) + x[..., self.pred_channels, :, :] = x + residual # apply output transform x = self.clamp_water_channels(x) @@ -875,4 +953,4 @@ class AtmoSphericNeuralOperatorNetMetaData(ModelMetaData): amp_gpu: bool = True -FCN3 = physicsnemo.Module.from_torch(AtmoSphericNeuralOperatorNet, AtmoSphericNeuralOperatorNetMetaData()) +FCN3 = physicsnemo.Module.from_torch(AtmoSphericNeuralOperatorNet, AtmoSphericNeuralOperatorNetMetaData()) \ No newline at end of file diff --git a/makani/models/networks/pangu.py b/makani/models/networks/pangu.py index b5b8e550..4f31ef69 100644 --- a/makani/models/networks/pangu.py +++ b/makani/models/networks/pangu.py @@ -71,11 +71,11 @@ def get_earth_position_index(window_size, ndim=3): # Change the order of the index to calculate the index in total if ndim == 3: - coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w])) - coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w])) + coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w], indexing="ij")) + coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w], indexing="ij")) elif ndim == 2: - coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w])) - coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w])) + coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w], indexing="ij")) + coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w], indexing="ij")) coords_flatten_1 = torch.flatten(coords_1, 1) coords_flatten_2 = torch.flatten(coords_2, 1) coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :] @@ -452,7 +452,7 @@ def forward(self, x: torch.Tensor, mask=None): x: input features with shape of (B * num_lon, num_pl*num_lat, N, C) mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon) """ - + B_, nW_, N, C = x.shape qkv = ( self.qkv(x) @@ -478,7 +478,7 @@ def forward(self, x: torch.Tensor, mask=None): attn = self.attn_drop_fn(attn) x = self.apply_attention(attn, v, B_, nW_, N, C) - + else: if mask is not None: bias = mask.unsqueeze(1).unsqueeze(0) + earth_position_bias.unsqueeze(0).unsqueeze(0) @@ -486,10 +486,10 @@ def forward(self, x: torch.Tensor, mask=None): #bias = bias.squeeze(2) else: bias = earth_position_bias.unsqueeze(0) - + # extract batch size for q,k,v nLon = self.num_lon - q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) + q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4]) k = k.view(B_ // nLon, nLon, k.shape[1], k.shape[2], k.shape[3], k.shape[4]) v = v.view(B_ // nLon, nLon, v.shape[1], v.shape[2], v.shape[3], v.shape[4]) #### @@ -736,7 +736,7 @@ class Pangu(nn.Module): - https://arxiv.org/abs/2211.02556 """ - def __init__(self, + def __init__(self, inp_shape=(721,1440), out_shape=(721,1440), grid_in="equiangular", @@ -773,14 +773,14 @@ def __init__(self, self.checkpointing_level = checkpointing_level drop_path = np.linspace(0, drop_path_rate, 8).tolist() - + # Add static channels to surface self.num_aux = len(self.aux_channel_names) N_total_surface = self.num_aux + self.num_surface # compute static permutations to extract self._precompute_channel_groups(self.channel_names, self.aux_channel_names) - + # Patch embeddings are 2D or 3D convolutions, mapping the data to the required patches self.patchembed2d = PatchEmbed2D( img_size=self.inp_shape, @@ -791,7 +791,7 @@ def __init__(self, flatten=False, norm_layer=None, ) - + self.patchembed3d = PatchEmbed3D( img_size=(num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size=patch_size, @@ -870,7 +870,7 @@ def __init__(self, self.patchrecovery3d = PatchRecovery3D( (num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size, 2 * embed_dim, num_atmospheric ) - + def _precompute_channel_groups( self, channel_names=[], @@ -880,7 +880,8 @@ def _precompute_channel_groups( Group the channels appropriately into atmospheric pressure levels and surface variables """ - atmo_chans, surf_chans, aux_chans, pressure_lvls = features.get_channel_groups(channel_names, aux_channel_names) + atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, pressure_lvls = features.get_channel_groups(channel_names, aux_channel_names) + aux_chans = dyn_aux_chans + stat_aux_chans # compute how many channel groups will be kept internally self.n_atmo_groups = len(pressure_lvls) @@ -901,7 +902,7 @@ def _precompute_channel_groups( def prepare_input(self, input): """ - Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, + Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric, and reshaping the atmospheric variables into the required format. """ @@ -932,23 +933,23 @@ def prepare_output(self, output_surface, output_atmospheric): level_dict = {level: [idx for idx, value in enumerate(self.channel_names) if value[1:] == level] for level in levels} reordered_ids = [idx for level in levels for idx in level_dict[level]] check_reorder = [f'{level}_{idx}' for level in levels for idx in level_dict[level]] - + # Flatten & reorder the output atmospheric to original order (doublechecked that this is working correctly!) flattened_atmospheric = output_atmospheric.reshape(output_atmospheric.shape[0], -1, output_atmospheric.shape[3], output_atmospheric.shape[4]) reordered_atmospheric = torch.cat([torch.zeros_like(output_surface), torch.zeros_like(flattened_atmospheric)], dim=1) for i in range(len(reordered_ids)): reordered_atmospheric[:, reordered_ids[i], :, :] = flattened_atmospheric[:, i, :, :] - + # Append the surface output, this has not been reordered. if output_surface is not None: - _, surf_chans, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names) + _, surf_chans, _, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names) reordered_atmospheric[:, surf_chans, :, :] = output_surface output = reordered_atmospheric else: output = reordered_atmospheric return output - + def forward(self, input): # Prep the input by splitting into surface and atmospheric variables @@ -959,7 +960,7 @@ def forward(self, input): surface = checkpoint(self.patchembed2d, surface_aux, use_reentrant=False) atmospheric = checkpoint(self.patchembed3d, atmospheric, use_reentrant=False) else: - surface = self.patchembed2d(surface_aux) + surface = self.patchembed2d(surface_aux) atmospheric = self.patchembed3d(atmospheric) if surface.shape[1] == 0: @@ -1011,11 +1012,5 @@ def forward(self, input): output_atmospheric = self.patchrecovery3d(output_atmospheric) output = self.prepare_output(output_surface, output_atmospheric) - - return output - - - - - + return output diff --git a/makani/models/networks/pangu_onnx.py b/makani/models/networks/pangu_onnx.py index 0805badb..bf3a0065 100644 --- a/makani/models/networks/pangu_onnx.py +++ b/makani/models/networks/pangu_onnx.py @@ -38,7 +38,7 @@ class PanguOnnx(OnnxWrapper): channel_order_PL: List containing the names of the pressure levels with the ordering that the ONNX model expects onnx_file: Path to the ONNX file containing the model ''' - def __init__(self, + def __init__(self, channel_names=[], aux_channel_names=[], onnx_file=None, @@ -58,7 +58,7 @@ def _precompute_channel_groups( group the channels appropriately into atmospheric pressure levels and surface variables """ - atmo_chans, surf_chans, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) + atmo_chans, surf_chans, _, _, pressure_lvls = get_channel_groups(channel_names, aux_channel_names) # compute how many channel groups will be kept internally self.n_atmo_groups = len(pressure_lvls) @@ -78,12 +78,12 @@ def prepare_input(self, input): B,V,Lat,Long=input.shape if B>1: - raise NotImplementedError("Not implemented yet for batch size greater than 1") + raise NotImplementedError("Not implemented yet for batch size greater than 1") input=input.squeeze(0) surface_aux_inp=input[self.surf_channels] atmospheric_inp=input[self.atmo_channels].reshape(self.n_atmo_groups,self.n_atmo_chans,Lat,Long).transpose(1,0) - + return surface_aux_inp, atmospheric_inp def prepare_output(self, output_surface, output_atmospheric): @@ -99,9 +99,9 @@ def prepare_output(self, output_surface, output_atmospheric): return output.unsqueeze(0) - + def forward(self, input): - + surface, atmospheric = self.prepare_input(input) @@ -109,5 +109,5 @@ def forward(self, input): output = self.prepare_output(output_surface, output) - + return output diff --git a/makani/models/networks/vit.py b/makani/models/networks/vit.py index 874c21ab..75a9bf65 100644 --- a/makani/models/networks/vit.py +++ b/makani/models/networks/vit.py @@ -84,6 +84,7 @@ def __init__( norm_layer=nn.LayerNorm, comm_inp_name="fin", comm_hidden_name="fout", + seed=333, ): super().__init__() @@ -108,6 +109,16 @@ def __init__( self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) + # generator objects: + seed = seed + comm.get_rank("model") + comm.get_size("model") * comm.get_rank("ensemble") + comm.get_size("model") * comm.get_size("ensemble") * comm.get_rank("batch") + self.set_rng(seed=seed) + + # stochastic bias + self.bias_std = nn.Parameter(torch.zeros(1, 1, dim)) + scale = math.sqrt(2.0 / self.bias_std.shape[-1] / 2) + nn.init.normal_(self.bias_std, mean=0.0, std=scale) + self.bias_std.is_shared_mp = ["spatial"] + mlp_hidden_dim = int(dim * mlp_ratio) # distribute MLP for model parallelism @@ -125,11 +136,26 @@ def __init__( else: self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop_rate=mlp_drop_rate, input_format="traditional") + @torch.compiler.disable(recursive=False) + def set_rng(self, seed=333): + self.rng_cpu = torch.Generator(device=torch.device("cpu")) + self.rng_cpu.manual_seed(seed) + if torch.cuda.is_available(): + self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}")) + self.rng_gpu.manual_seed(seed) + def forward(self, x): # flatten transpose: y = self.attn(self.norm1(x)) x = x + self.drop_path(y) x = self.norm2(x) + + if hasattr(self, "bias_std"): + with torch.no_grad(): + n = torch.zeros_like(self.bias_std) + n.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if n.is_cuda else self.rng_cpu) + x = x + self.bias_std * n + x = x + self.drop_path(self.mlp(x)) return x @@ -153,6 +179,7 @@ def __init__( norm_layer="layer_norm", comm_inp_name="fin", comm_hidden_name="fout", + seed = 333, **kwargs, ): super().__init__() @@ -243,4 +270,4 @@ def forward(self, x): x = self.norm(x) x = self.forward_head(x) - return x + return x \ No newline at end of file diff --git a/makani/models/noise.py b/makani/models/noise.py index 56e3232b..0d384fc1 100644 --- a/makani/models/noise.py +++ b/makani/models/noise.py @@ -100,7 +100,7 @@ def reset(self, batch_size=None): # this routine generates a noise sample for a single time step and updates the state accordingly, by appending the last time step def update(self, replace_state=False, batch_size=None): - # Update should always create a new state, so + # Update should always create a new state, so # we don't need to check for replace_state # create single occurence with torch.no_grad(): @@ -159,7 +159,7 @@ def __init__( grid_type="equiangular", seed=333, reflect=False, - learnable =False, + learnable=False, **kwargs, ): r""" @@ -193,21 +193,28 @@ def __init__( alpha = float(alpha) # Compute ls, angular power spectrum and sigma_l: - ls = torch.arange(self.lmax) + ls = torch.arange(self.lmax).reshape(-1 ,1) + ms = torch.arange(self.mmax) power_spectrum = torch.pow(2 * ls + 1, -alpha) norm_factor = torch.sum((2 * ls + 1) * power_spectrum / 4.0 / math.pi) sigma_l = sigma * torch.sqrt(power_spectrum / norm_factor) + sigma_l = torch.where(ms <= ls, sigma_l, 0.0) # the new shape is B, T, C, L, M - sigma_l = sigma_l.reshape((1, 1, 1, self.lmax, 1)).to(dtype=torch.float32) + sigma_l = sigma_l.reshape((1, 1, 1, self.lmax, self.mmax)).to(dtype=torch.float32) # split tensor if comm.get_size("h") > 1: sigma_l = split_tensor_along_dim(sigma_l, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + # split tensor + if comm.get_size("w") > 1: + sigma_l = split_tensor_along_dim(sigma_l, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + # register buffer if learnable: self.register_parameter("sigma_l", nn.Parameter(sigma_l)) + self.sigma_l.sharded_dims_mp = [None, None, None, "h", "w"] else: self.register_buffer("sigma_l", sigma_l, persistent=False) diff --git a/makani/models/preprocessor.py b/makani/models/preprocessor.py index 59dbc2fb..071de927 100644 --- a/makani/models/preprocessor.py +++ b/makani/models/preprocessor.py @@ -53,13 +53,7 @@ def __init__(self, params): self.history_eps = 1e-6 # residual normalization - self.learn_residual = params.target == "residual" - if self.learn_residual and (params.normalize_residual): - with torch.no_grad(): - residual_scale = torch.as_tensor(np.load(params.time_diff_stds_path)).to(torch.float32) - self.register_buffer("residual_scale", residual_scale, persistent=False) - else: - self.residual_scale = None + self.residual_scale = None # image shape self.img_shape = [params.img_shape_x, params.img_shape_y] @@ -178,20 +172,6 @@ def expand_history(self, x, nhist): x = torch.reshape(x, (b_, nhist, ct_ // nhist, h_, w_)) return x - def add_residual(self, x, dx): - if self.learn_residual: - if self.residual_scale is not None: - dx = dx * self.residual_scale - - # add residual: deal with history - x = self.expand_history(x, nhist=self.n_history + 1) - x[:, -1, ...] = x[:, -1, ...] + dx - x = self.flatten_history(x) - else: - x = dx - - return x - def add_static_features(self, x): if self.do_add_static_features: # we need to replicate the grid for each batch: diff --git a/makani/models/stepper.py b/makani/models/stepper.py index f04590ba..f7edfea0 100644 --- a/makani/models/stepper.py +++ b/makani/models/stepper.py @@ -49,9 +49,6 @@ def forward(self, inp, update_state=True, replace_state=True): # undo normalization y = self.preprocessor.history_denormalize(yn, target=True) - # add residual (for residual learning, no-op for direct learning - y = self.preprocessor.add_residual(inp, y) - return y @@ -60,7 +57,6 @@ def __init__(self, params, model_handle): super().__init__() self.preprocessor = Preprocessor2D(params) self.model = model_handle() - self.residual_mode = True if (params.target == "target") else False self.push_forward_mode = params.get("multistep_push_forward", False) # collect parameters for history @@ -102,9 +98,6 @@ def _forward_train(self, inp, update_state=True, replace_state=True): # will have been updated later: pred = self.preprocessor.history_denormalize(predn, target=True) - # add residual (for residual learning, no-op for direct learning - pred = self.preprocessor.add_residual(inpt, pred) - # append output result.append(pred) @@ -148,15 +141,12 @@ def _forward_eval(self, inp, update_state=True, replace_state=True): # because otherwise normalization stats are already outdated y = self.preprocessor.history_denormalize(yn, target=True) - # add residual (for residual learning, no-op for direct learning - y = self.preprocessor.add_residual(inp, y) - return y def forward(self, inp, update_state=True, replace_state=True): # decide which routine to call if self.training: - y = self._forward_train(inp, update_state=True, replace_state=replace_state) + y = self._forward_train(inp, update_state=update_state, replace_state=replace_state) else: y = self._forward_eval(inp, update_state=update_state, replace_state=replace_state) diff --git a/makani/mpu/layers.py b/makani/mpu/layers.py index 18f470b3..dd67a969 100644 --- a/makani/mpu/layers.py +++ b/makani/mpu/layers.py @@ -266,6 +266,134 @@ def forward(self, x): else: return self.fwd(x) +# Stochastic MLP needs comm datastructure +class StochasticMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + output_bias=True, + input_format="nchw", + drop_rate=0.0, + drop_type="iid", + checkpointing=False, + gain=1.0, + seed=333, + **kwargs, + ): + super().__init__() + + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # generator objects: + self.set_rng(seed=seed) + + # First fully connected layer + if input_format == "nchw": + self.fc1_weight_std = nn.Parameter(torch.zeros(hidden_features, in_features, 1, 1)) + self.fc1_weight_mean = nn.Parameter(torch.zeros(hidden_features, in_features, 1, 1)) + self.fc1_bias = nn.Parameter(torch.zeros(hidden_features)) + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # sharing settings + self.fc1_weight_std.is_shared_mp = ["spatial"] + self.fc1_weight_mean.is_shared_mp = ["spatial"] + self.fc1_bias.is_shared_mp = ["spatial"] + + # initialize the weights correctly + scale = math.sqrt(1.0 / in_features) + nn.init.normal_(self.fc1_weight_std, mean=0.0, std=scale) + nn.init.normal_(self.fc1_weight_mean, mean=0.0, std=scale) + + # activation + self.act = act_layer() + + # sanity checks + if (input_format == "traditional") and (drop_type == "features"): + raise NotImplementedError(f"Error, traditional input format and feature dropout cannot be selected simultaneously") + + # output layer + if input_format == "nchw": + self.fc2_weight_std = nn.Parameter(torch.zeros(out_features, hidden_features, 1, 1)) + self.fc2_weight_mean = nn.Parameter(torch.zeros(out_features, hidden_features, 1, 1)) + self.fc2_bias = nn.Parameter(torch.zeros(out_features)) if output_bias else None + else: + raise NotImplementedError(f"Error, input format {input_format} not supported.") + + # sharing settings + self.fc2_weight_std.is_shared_mp = ["spatial"] + self.fc2_weight_mean.is_shared_mp = ["spatial"] + if self.fc2_bias is not None: + self.fc2_bias.is_shared_mp = ["spatial"] + + # gain factor for the output determines the scaling of the output init + scale = math.sqrt(gain / hidden_features / 2) + nn.init.normal_(self.fc2_weight_std, mean=0.0, std=scale) + nn.init.normal_(self.fc2_weight_mean, mean=0.0, std=scale) + if self.fc2_bias is not None: + nn.init.constant_(self.fc2_bias, 0.0) + + if drop_rate > 0.0: + if drop_type == "iid": + self.drop = nn.Dropout(drop_rate) + elif drop_type == "features": + self.drop = nn.Dropout2d(drop_rate) + else: + raise NotImplementedError(f"Error, drop_type {drop_type} not supported") + else: + self.drop = nn.Identity() + + @torch.compiler.disable(recursive=False) + def set_rng(self, seed=333): + self.rng_cpu = torch.Generator(device=torch.device("cpu")) + self.rng_cpu.manual_seed(seed) + if torch.cuda.is_available(): + self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}")) + self.rng_gpu.manual_seed(seed) + + @torch.compiler.disable(recursive=False) + def checkpoint_forward(self, x): + return checkpoint(self.fwd, x, use_reentrant=False) + + def fwd(self, x): + + # generate weight1 + weight1 = torch.empty_like(self.fc1_weight_mean) + weight1.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if weight1.is_cuda else self.rng_cpu) + weight1 = self.fc1_weight_std * weight1 + self.fc1_weight_mean + + # fully connected 1 + x = nn.functional.conv2d(x, weight1, bias=self.fc1_bias) + + # activation + x = self.act(x) + + # dropout + x = self.drop(x) + + # generate weight1 + weight2 = torch.empty_like(self.fc2_weight_mean) + weight2.normal_(mean=0.0, std=1.0, generator=self.rng_gpu if weight2.is_cuda else self.rng_cpu) + weight2 = self.fc2_weight_std * weight2 + self.fc2_weight_mean + + # fully connected 2 + x = nn.functional.conv2d(x, weight2, bias=self.fc2_bias) + + # dropout + x = self.drop(x) + + return x + + def forward(self, x): + if self.checkpointing: + return self.checkpoint_forward(x) + else: + return self.fwd(x) class DistributedPatchEmbed(nn.Module): def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768, input_is_matmul_parallel=False, output_is_matmul_parallel=True): diff --git a/makani/mpu/mappings.py b/makani/mpu/mappings.py index ecfbbcf3..7989ec5d 100644 --- a/makani/mpu/mappings.py +++ b/makani/mpu/mappings.py @@ -58,6 +58,24 @@ def backward(ctx, go): return gi, None, None, None +class gradient_print_wrapper(torch.autograd.Function): + + @staticmethod + @custom_fwd(device_type="cuda") + def forward(ctx, x, msg=""): + ctx.msg = msg + return x + + @staticmethod + @custom_bwd(device_type="cuda") + def backward(ctx, go): + + msg = ctx.msg + print(f"Gradient stats for {msg}: min: {go.min()}, max: {go.max()}, mean: {go.mean()}") + + return go, None + + # handler for additional gradient reductions # helper for gradient reduction across channel parallel ranks def init_gradient_reduction_hooks(model, device, reduction_buffer_count=1, broadcast_buffers=True, find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=False, verbose=False): diff --git a/makani/utils/dataloaders/dali_es_helper_2d.py b/makani/utils/dataloaders/dali_es_helper_2d.py index 5a8b2527..bd6627db 100644 --- a/makani/utils/dataloaders/dali_es_helper_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_2d.py @@ -183,7 +183,7 @@ def _generate_indexlist(self, timestamp_boundary_list): if timestamp_boundary_list: #compute list of allowed timestamps timestamp_boundary_list = [get_date_from_string(timestamp_string) for timestamp_string in timestamp_boundary_list] - + # now, based on dt, dh, n_history and n_future, we can build regions where no data is allowed timestamp_exclusion_list = get_date_ranges(timestamp_boundary_list, lookback_hours = dt_total * (self.n_future + 1), lookahead_hours = dt_total * self.n_history) @@ -521,7 +521,7 @@ def _compute_zenith_angle(self, inp_times, tar_times): # nvtx range torch.cuda.nvtx.range_pop() - return cos_zenith_inp, cos_zenith_tar + return cos_zenith_inp, cos_zenith_tar def __getstate__(self): del self.aws_connector diff --git a/makani/utils/dataloaders/dali_es_helper_concat_2d.py b/makani/utils/dataloaders/dali_es_helper_concat_2d.py index 21f8b6fe..c2e30a80 100644 --- a/makani/utils/dataloaders/dali_es_helper_concat_2d.py +++ b/makani/utils/dataloaders/dali_es_helper_concat_2d.py @@ -159,7 +159,7 @@ def _generate_indexlist(self, timestamp_boundary_list): if timestamp_boundary_list: #compute list of allowed timestamps timestamp_boundary_list = [get_date_from_string(timestamp_string) for timestamp_string in timestamp_boundary_list] - + # now, based on dt, dh, n_history and n_future, we can build regions where no data is allowed timestamp_exclusion_list = get_date_ranges(timestamp_boundary_list, lookback_hours = dt_total * (self.n_future + 1), lookahead_hours = dt_total * self.n_history) diff --git a/makani/utils/dataloaders/data_helpers.py b/makani/utils/dataloaders/data_helpers.py index 753e3985..c4d2fe06 100644 --- a/makani/utils/dataloaders/data_helpers.py +++ b/makani/utils/dataloaders/data_helpers.py @@ -58,6 +58,36 @@ def get_data_normalization(params): return bias, scale +def get_time_diff_stds(params): + + time_diff_stds = None + + if hasattr(params, "time_diff_stds_path"): + time_diff_stds = np.load(params.time_diff_stds_path) + else: + raise ValueError(f"time_diff_std_path not defined.") + + return time_diff_stds + + +def get_psd_stats(params): + + psd_means = None + psd_stds = None + + if hasattr(params, "psd_means_path") and hasattr(params, "psd_stds_path"): + psd_means = np.load(params.psd_means_path) + psd_stds = np.load(params.psd_stds_path) + + # filter channels if requested + if hasattr(params, "out_channels"): + psd_means = psd_means[..., params.out_channels, :] + psd_stds = psd_stds[..., params.out_channels, :] + else: + raise ValueError(f"psd_means_path or psd_stds_path not defined.") + + return psd_means, psd_stds + def get_climatology(params): """ diff --git a/makani/utils/driver.py b/makani/utils/driver.py index 6e507c81..da36841d 100644 --- a/makani/utils/driver.py +++ b/makani/utils/driver.py @@ -632,15 +632,15 @@ def get_optimizer(self, model, params): if params.optimizer_type == "Adam": if self.log_to_screen: self.logger.info("using Adam optimizer") - optimizer = optim.Adam(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True) + optimizer = optim.Adam(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True) elif params.optimizer_type == "AdamW": if self.log_to_screen: self.logger.info("using AdamW optimizer") - optimizer = optim.AdamW(all_parameters, betas=betas, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), foreach=True) + optimizer = optim.AdamW(all_parameters, lr=params.get("lr", 1e-3), betas=betas, eps=params.get("optimizer_eps", 1e-8), weight_decay=params.get("weight_decay", 0), foreach=True) elif params.optimizer_type == "SGD": if self.log_to_screen: self.logger.info("using SGD optimizer") - optimizer = optim.SGD(all_parameters, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), momentum=params.get("momentum", 0), foreach=True) + optimizer = optim.SGD(all_parameters, lr=params.get("lr", 1e-3), weight_decay=params.get("weight_decay", 0), momentum=params.get("momentum", 0), nesterov=params.get("nesterov", True), foreach=True) elif params.optimizer_type == "SIRFShampoo": if self.log_to_screen: self.logger.info("using SIRFShampoo optimizer") diff --git a/makani/utils/features.py b/makani/utils/features.py index 9b177500..cab61c37 100644 --- a/makani/utils/features.py +++ b/makani/utils/features.py @@ -88,7 +88,7 @@ def get_wind_channels(channel_names): wind_chans = [] for c, ch in enumerate(channel_names): - if ch[0] == "u": + if ch[0] == "u" and ("v" + ch[1:]) in channel_names: vc = channel_names.index("v" + ch[1:]) wind_chans = wind_chans + [c, vc] @@ -97,13 +97,15 @@ def get_wind_channels(channel_names): def get_channel_groups(channel_names, aux_channel_names=[]): """ - Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups + Helper routine to extract indices of atmospheric, surface and auxiliary variables and group them into their respective groups. + The resulting numbering does NOT respect history. """ atmo_groups = OrderedDict() atmo_chans = [] surf_chans = [] - aux_chans = [] + dyn_aux_chans = [] + stat_aux_chans = [] # parse channel names and group variables by pressure level/surface variables for idx, chn in enumerate(channel_names): @@ -127,6 +129,10 @@ def get_channel_groups(channel_names, aux_channel_names=[]): atmo_chans += idx # append the auxiliary variable to the surface channels - aux_chans = [idx + len(channel_names) for idx in range(len(aux_channel_names))] + for idx, chn in enumerate(aux_channel_names): + if chn in ["xoro", "xlsml", "xlsms"]: + stat_aux_chans.append(idx + len(channel_names)) + else: + dyn_aux_chans.append(idx + len(channel_names)) - return atmo_chans, surf_chans, aux_chans, atmo_groups.keys() + return atmo_chans, surf_chans, dyn_aux_chans, stat_aux_chans, atmo_groups.keys() diff --git a/makani/utils/grids.py b/makani/utils/grids.py index 52085296..281fb65e 100644 --- a/makani/utils/grids.py +++ b/makani/utils/grids.py @@ -16,10 +16,12 @@ import numpy as np import torch -from torch_harmonics.quadrature import legendre_gauss_weights, clenshaw_curtiss_weights +import torch.amp as amp + +from torch_harmonics.quadrature import legendre_gauss_weights, clenshaw_curtiss_weights, precompute_latitudes from makani.utils import comm -from physicsnemo.distributed.utils import compute_split_shapes +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim from physicsnemo.distributed.mappings import reduce_from_parallel_region @@ -33,9 +35,23 @@ def grid_to_quadrature_rule(grid_type): return grid_to_quad_dict[grid_type] +def compute_spherical_bandlimit(img_shape, grid_type): + + if grid_type == "equiangular": + lmax = (img_shape[0] - 1) // 2 + mmax = img_shape[1] // 2 + return min(lmax, mmax) + elif grid_type == "legendre-gauss": + lmax = img_shape[0] - 1 + mmax = img_shape[1] // 2 + return min(lmax, mmax) + else: + raise NotImplementedError(f"Unknown type {grid_type} not implemented") + + class GridConverter(torch.nn.Module): def __init__(self, src_grid, dst_grid, lat_rad, lon_rad): - super(GridConverter, self).__init__() + super().__init__() self.src = src_grid self.dst = dst_grid self.src_lat = lat_rad @@ -44,7 +60,7 @@ def __init__(self, src_grid, dst_grid, lat_rad, lon_rad): if self.src != self.dst: if self.dst == "legendre-gauss": cost_lg, _ = legendre_gauss_weights(lat_rad.shape[0], -1, 1) - tq = torch.arccos(torch.from_numpy(cost_lg)) - torch.pi / 2.0 + tq = torch.arccos(cost_lg) - torch.pi / 2.0 self.dst_lat = tq.to(lat_rad.device) self.dst_lon = lon_rad @@ -123,7 +139,7 @@ def __init__(self, quadrature_rule, img_shape, crop_shape=None, crop_offset=(0, # apply pole mask if (pole_mask is not None) and (pole_mask > 0): quad_weight[:pole_mask, :] = 0.0 - quad_weight[sizes[0] - pole_mask :, :] = 0.0 + quad_weight[img_shape[0] - pole_mask :, :] = 0.0 # if distributed, make sure to split correctly across ranks: # in case of model parallelism, we need to make sure that we use the correct shapes per rank @@ -165,3 +181,70 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: quad = reduce_from_parallel_region(quad.contiguous(), "spatial") return quad + + +class BandLimitMask(torch.nn.Module): + def __init__(self, img_shape, grid_type, lmax = None, type="sht"): + super().__init__() + self.img_shape = img_shape + self.grid_type = grid_type + self.lmax = lmax if lmax is not None else compute_spherical_bandlimit(img_shape, grid_type) + self.type = type + + if self.type == "sht": + # SHT for the computation of SH coefficients + if (comm.get_size("spatial") > 1): + from torch_harmonics.distributed import DistributedRealSHT, DistributedInverseRealSHT + import torch_harmonics.distributed as thd + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.forward_transform = DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + self.inverse_transform = DistributedInverseRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + else: + from torch_harmonics import RealSHT, InverseRealSHT + + self.forward_transform = RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + self.inverse_transform = InverseRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + + elif self.type == "fft": + + # get the cutoff frequency in m for each latitude + lats, _ = precompute_latitudes(self.img_shape[0], grid=self.grid_type) + # get the grid spacing at the equator + delta_equator = 2 * torch.pi / (self.lmax-1) + mlim = torch.ceil(2 * torch.pi * torch.sin(lats) / delta_equator).reshape(self.img_shape[0], 1) + ms = torch.arange(self.lmax).reshape(1, -1) + mask = (ms <= mlim) + mask = split_tensor_along_dim(mask, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + mask = split_tensor_along_dim(mask, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + self.register_buffer("mask", mask, persistent=False) + + if (comm.get_size("spatial") > 1): + from makani.mpu.fft import DistributedRealFFT1, DistributedInverseRealFFT1 + self.forward_transform = DistributedRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + self.inverse_transform = DistributedInverseRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + else: + from makani.models.common.fft import RealFFT1, InverseRealFFT1 + self.forward_transform = RealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + self.inverse_transform = InverseRealFFT1(img_shape[1], lmax=lmax, mmax=lmax).float() + else: + raise ValueError(f"Unknown truncation type {self.type}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + with amp.autocast(device_type="cuda", enabled=False): + dtype = x.dtype + x = x.float() + + x = self.forward_transform(x) + + if hasattr(self, "mask"): + x = torch.where(self.mask, x, torch.zeros_like(x)) + + x = self.inverse_transform(x) + + x = x.to(dtype=dtype) + + return x \ No newline at end of file diff --git a/makani/utils/inference/inferencer.py b/makani/utils/inference/inferencer.py index ea416345..3828d103 100644 --- a/makani/utils/inference/inferencer.py +++ b/makani/utils/inference/inferencer.py @@ -441,10 +441,11 @@ def inference_indexlist( return logs - def _initialize_noise_states(self): + def _initialize_noise_states(self, seed_offset=666): noise_states = [] - for _ in range(self.params.local_ensemble_size): - self.preprocessor.update_internal_state(replace_state=True) + for ide in range(self.params.local_ensemble_size): + member_seed = seed_offset + self.preprocessor.get_base_seed(default=333) * ide + self.preprocessor.set_rng(seed=member_seed, reset=True) noise_states.append(self.preprocessor.get_internal_state(tensor=True)) return noise_states @@ -510,7 +511,7 @@ def _inference_indexlist( climatology_iterator = iter(self.climatology_dataloader) # create loader for the full epoch - noise_states = [] + noise_states = self._initialize_noise_states() inptlist = None idt = 0 with torch.inference_mode(): @@ -567,7 +568,7 @@ def _inference_indexlist( self.preprocessor.update_internal_state(replace_state=True, batch_size=inp.shape[0]) # reset noise states and input list - noise_states = self._initialize_noise_states() + noise_states = self._initialize_noise_states(seed_offset=idt) inptlist = [inp.clone() for _ in range(self.params.local_ensemble_size)] if rollout_buffer is not None: @@ -597,9 +598,8 @@ def _inference_indexlist( # retrieve input inpt = inptlist[e] - # this is different, depending on local ensemble size + # restore noise state if (self.params.local_ensemble_size > 1): - # restore noise belonging to this ensemble member self.preprocessor.set_internal_state(noise_states[e]) # forward pass: never replace state since we do that manually diff --git a/makani/utils/loss.py b/makani/utils/loss.py index 4b356535..77bbc8bf 100644 --- a/makani/utils/loss.py +++ b/makani/utils/loss.py @@ -23,12 +23,17 @@ from torch import nn from makani.utils import comm -from makani.utils.dataloaders.data_helpers import get_data_normalization +from makani.utils.grids import GridQuadrature, BandLimitMask +from makani.utils.dataloaders.data_helpers import get_data_normalization, get_time_diff_stds, get_psd_stats +from physicsnemo.distributed.utils import compute_split_shapes from physicsnemo.distributed.mappings import gather_from_parallel_region, reduce_from_parallel_region -from .losses import LossType, GeometricLpLoss, SpectralH1Loss, SpectralAMSELoss, HydrostaticBalanceLoss -from .losses import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss, EnsembleNLLLoss, EnsembleMMDLoss -from .losses import DriftRegularization +from .losses import LossType, GeometricLpLoss, SpectralLpLoss, SpectralH1Loss, SpectralAMSELoss +from .losses import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss, KernelScoreLoss +from .losses import L2EnergyScoreLoss, SobolevEnergyScoreLoss, SpectralL2EnergyScoreLoss, SpectralCoherenceLoss +from .losses import GaussianMMDLoss +from .losses import EnsembleNLLLoss +from .losses import DriftRegularization, HydrostaticBalanceLoss, SpectralRegularization class LossHandler(nn.Module): @@ -42,6 +47,7 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps self.rank = comm.get_rank("matmul") self.n_future = params.n_future + self.n_history = params.n_history self.spatial_distributed = comm.is_distributed("spatial") and (comm.get_size("spatial") > 1) self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) @@ -52,11 +58,12 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps # check whether dynamic loss weighting is required self.uncertainty_weighting = params.get("uncertainty_weighting", False) + self.balanced_weighting = params.get("balanced_weighting", False) self.randomized_loss_weights = params.get("randomized_loss_weights", False) self.random_slice_loss = params.get("random_slice_loss", False) # whether to keep running stats - self.track_running_stats = track_running_stats or self.uncertainty_weighting + self.track_running_stats = track_running_stats or self.uncertainty_weighting or self.balanced_weighting self.eps = eps n_channels = len(params.channel_names) @@ -81,14 +88,29 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps else: scale = torch.ones((1, len(params.out_channels), 1, 1), dtype=torch.float32) + # load PSD stats + try: + psd_means, psd_stds = get_psd_stats(params) + if psd_means is not None: + psd_means = torch.from_numpy(psd_means).to(torch.float32) + if psd_stds is not None: + psd_stds = torch.from_numpy(psd_stds).to(torch.float32) + except ValueError: + psd_means = None + psd_stds = None + # create module list self.loss_fn = nn.ModuleList([]) + self.loss_requires_input = [] # track which losses need input state channel_weights = [] for loss in losses: loss_type = loss["type"] + # check if this is a tendency loss (from explicit field, not string parsing) + requires_input = loss.get("tendency", False) + # get pole mask if it was specified pole_mask = loss.get("pole_mask", 0) @@ -105,6 +127,8 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps pole_mask=pole_mask, bias=bias, scale=scale, + psd_means=psd_means, + psd_stds=psd_stds, grid_type=params.model_grid_type, spatial_distributed=self.spatial_distributed, ensemble_distributed=self.ensemble_distributed, @@ -112,42 +136,34 @@ def __init__(self, params, track_running_stats: bool = False, seed: int = 0, eps ) # append to dict and compile before: - # TODO: fix the compile issue - # self.loss_fn[loss_type] = torch.compile(loss_fn) self.loss_fn.append(loss_fn) + self.loss_requires_input.append(requires_input) + # TODO: the entire channel weighting logic should be moved to the loss function base class # determine channel weighting if "channel_weights" not in loss.keys(): channel_weight_type = "constant" else: channel_weight_type = loss["channel_weights"] + # check if time difference weighting is required + if loss.get("temp_diff_normalization", False): + time_diff_scale = get_time_diff_stds(params).flatten() + time_diff_scale = torch.clamp(torch.from_numpy(time_diff_scale[params.out_channels]), min=1e-4) + time_diff_scale = scale.flatten() / time_diff_scale + else: + time_diff_scale = None + + # get channel weights either directly or through the compute routine if isinstance(channel_weight_type, List): - chw = torch.tensor(channel_weight_type, dtype=torch.float32).reshape(1, -1) + chw = torch.tensor(channel_weight_type, dtype=torch.float32) + if time_diff_scale is not None: + chw = chw * time_diff_scale assert chw.shape[1] == loss_fn.n_channels else: - chw = loss_fn.compute_channel_weighting(channel_weight_type) - - # the option to normalize outputs with stds of the time difference rather than th - if ("temp_diff_normalization" in loss.keys()) and loss["temp_diff_normalization"]: - - # extract relevant stds - time_diff_stds = torch.from_numpy(np.load(params.time_diff_stds_path)).reshape(1, -1)[:, params.out_channels] - # the time differences are computed between two consecutive datapoints, - # so we need to account for the number of timesteps used in the prediction - # this is now commebnted out as we expect the stats to be computed with the correct dt - # time_diff_stds *= np.sqrt(params.dt) - - # to avoid division by very small numbers, we clamp the time differences from below - time_diff_stds = torch.clamp(time_diff_stds, min=1e-4) - - time_var_weights = scale.reshape(1, -1) / time_diff_stds - - if hasattr(loss_fn, "squared") and loss_fn.squared: - time_var_weights = time_var_weights**2 - - chw = chw * time_var_weights + chw = loss_fn.compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale) + # reshape channel weights for propewr broadcasting chw = chw.reshape(1, -1) # check for a relative weight that weights the loss relative to other losses @@ -195,6 +211,7 @@ def _compute_multistep_weight(self, multistep_weight_type: str) -> torch.Tensor: # linear weighting factor for the case of multistep training multistep_weight = torch.arange(1, self.n_future + 2, dtype=torch.float32) / float(self.n_future + 1) elif multistep_weight_type == "last-n-1": + print(f"using last n-1") # weighting factor for the last n steps, with the first step weighted 0 multistep_weight = torch.ones(self.n_future + 1, dtype=torch.float32) / float(self.n_future) multistep_weight[0] = 0.0 @@ -215,13 +232,18 @@ def _parse_loss_type(self, loss_type: str): loss_type = set(loss_type.split()) + # this can probably all be moved to the loss function itself relative = "relative" in loss_type squared = "squared" in loss_type jacobian = "s2" if "geometric" in loss_type else "flat" # decide which loss to use - if "l2" in loss_type: + if "spectral" in loss_type and "l2" in loss_type: + loss_handle = partial(SpectralLpLoss, p=2, relative=relative, squared=squared) + elif "spectral" in loss_type and "l1" in loss_type: + loss_handle = partial(SpectralLpLoss, p=1, relative=relative, squared=squared) + elif "l2" in loss_type: loss_handle = partial(GeometricLpLoss, p=2, relative=relative, squared=squared, jacobian=jacobian) elif "l1" in loss_type: loss_handle = partial(GeometricLpLoss, p=1, relative=relative, squared=squared, jacobian=jacobian) @@ -242,17 +264,31 @@ def _parse_loss_type(self, loss_type: str): p_max = int(x.replace("p_max=", "")) loss_handle = partial(HydrostaticBalanceLoss, p_min=p_min, p_max=p_max, use_moist_air_formula=use_moist_air_formula) elif "ensemble_crps" in loss_type: - loss_handle = partial(EnsembleCRPSLoss, crps_type="cdf") + loss_handle = partial(CRPSLoss) elif "ensemble_spectral_crps" in loss_type: - loss_handle = partial(EnsembleSpectralCRPSLoss, crps_type="cdf") - elif "gauss_crps" in loss_type: - loss_handle = partial(EnsembleCRPSLoss, crps_type="gauss") + loss_handle = partial(SpectralCRPSLoss) + elif "ensemble_vort_div_crps" in loss_type: + loss_handle = partial(VortDivCRPSLoss) + elif "ensemble_gradient_crps" in loss_type: + loss_handle = partial(GradientCRPSLoss) + elif "ensemble_kernel_score" in loss_type: + loss_handle = partial(KernelScoreLoss) elif "ensemble_nll" in loss_type: loss_handle = EnsembleNLLLoss - elif "ensemble_mmd" in loss_type: - loss_handle = EnsembleMMDLoss + elif "gaussian_mmd" in loss_type: + loss_handle = GaussianMMDLoss + elif "l2_energy_score" in loss_type: + loss_handle = partial(L2EnergyScoreLoss) + elif "sobolev_energy_score" in loss_type: + loss_handle = partial(SobolevEnergyScoreLoss) + elif "spectral_l2_energy_score" in loss_type: + loss_handle = partial(SpectralL2EnergyScoreLoss) + elif "spectral_coherence_loss" in loss_type: + loss_handle = partial(SpectralCoherenceLoss) elif "drift_regularization" in loss_type: loss_handle = DriftRegularization + elif "spectral_regularization" in loss_type: + loss_handle = SpectralRegularization else: raise NotImplementedError(f"Unknown loss function: {loss_type}") @@ -309,7 +345,23 @@ def reset_running_stats(self): self.running_var.fill_(1) self.num_batches_tracked.zero_() - def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): + def _extract_input_state(self, inp: torch.Tensor) -> torch.Tensor: + """ + Extract last timestep from flattened history input. + + Args: + inp: Input tensor with shape (B, (n_history+1)*C, H, W) + + Returns: + Last timestep with shape (B, C, H, W) + """ + # inp shape: (B, (n_history+1)*C, H, W) + # we want: (B, C, H, W) - the last timestep + n_channels_per_step = inp.shape[1] // (self.n_history + 1) + inp_last = inp[..., -n_channels_per_step:, :, :] + return inp_last + + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None, inp: Optional[torch.Tensor] = None): # we assume the following: # if prd is 5D, we assume that the dims are # batch, ensemble, channel, h, w @@ -347,13 +399,48 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens else: prdm = prd + # transform to tendency space if any loss requires it + if inp is not None and any(self.loss_requires_input): + inp_state = self._extract_input_state(inp) + + # validate channel counts for single-step predictions + if self.n_future == 0: + n_pred_channels = prdm.shape[1] + n_inp_channels = inp_state.shape[1] + assert n_pred_channels == n_inp_channels, \ + f"Channel mismatch: prediction has {n_pred_channels} channels but input has {n_inp_channels} channels" + + # transform predictions and targets to tendency space + # this allows ANY loss function to compute tendency-based metrics + prdm_tendency = prdm - inp_state + tar_tendency = tar - inp_state + + # also transform ensemble predictions if present + if prd.dim() == 5: + # expand inp_state to match ensemble dim + inp_state_expanded = inp_state.unsqueeze(1) + prd_tendency = prd - inp_state_expanded + else: + prd_tendency = prdm_tendency + else: + prdm_tendency = prdm + tar_tendency = tar + prd_tendency = prd + # compute loss contributions from each loss loss_vals = [] - for lfn in self.loss_fn: + for lfn, requires_inp in zip(self.loss_fn, self.loss_requires_input): if lfn.type == LossType.Deterministic: - loss_vals.append(lfn(prdm, tar, wgt)) + if requires_inp: + loss_vals.append(lfn(prdm_tendency, tar_tendency, wgt)) + else: + loss_vals.append(lfn(prdm, tar, wgt)) else: - loss_vals.append(lfn(prd, tar, wgt)) + # probabilistic losses: use tendency-transformed ensemble if needed + if requires_inp: + loss_vals.append(lfn(prd_tendency, tar_tendency, wgt)) + else: + loss_vals.append(lfn(prd, tar, wgt)) all_losses = torch.cat(loss_vals, dim=-1) if self.training and self.track_running_stats: @@ -363,7 +450,14 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens chw = self.channel_weights if self.uncertainty_weighting and self.training: var, _ = self.get_running_stats() + if self.num_batches_tracked.item() <= 100: + var = torch.ones_like(var) chw = chw / (torch.sqrt(2 * var) + self.eps) + elif self.balanced_weighting and self.training: + _, mean = self.get_running_stats() + if self.num_batches_tracked.item() <= 100: + mean = torch.ones_like(mean) + chw = chw / (mean + self.eps) if self.randomized_loss_weights: rmask = torch.zeros_like(chw) diff --git a/makani/utils/losses/__init__.py b/makani/utils/losses/__init__.py index 48d63f3a..44cbe295 100644 --- a/makani/utils/losses/__init__.py +++ b/makani/utils/losses/__init__.py @@ -15,10 +15,11 @@ from .base_loss import LossType from .h1_loss import SpectralH1Loss -from .lp_loss import GeometricLpLoss, SpectralL2Loss +from .lp_loss import GeometricLpLoss, SpectralLpLoss from .amse_loss import SpectralAMSELoss from .hydrostatic_loss import HydrostaticBalanceLoss -from .crps_loss import EnsembleCRPSLoss, EnsembleSpectralCRPSLoss -from .mmd_loss import EnsembleMMDLoss +from .crps_loss import CRPSLoss, SpectralCRPSLoss, GradientCRPSLoss, VortDivCRPSLoss, KernelScoreLoss +from .energy_score import L2EnergyScoreLoss, SobolevEnergyScoreLoss, SpectralL2EnergyScoreLoss, SpectralCoherenceLoss +from .mmd_loss import GaussianMMDLoss from .likelihood_loss import EnsembleNLLLoss -from .drift_regularization import DriftRegularization +from .regularization import DriftRegularization, SpectralRegularization diff --git a/makani/utils/losses/amse_loss.py b/makani/utils/losses/amse_loss.py index 89db8fb6..8ae392ca 100644 --- a/makani/utils/losses/amse_loss.py +++ b/makani/utils/losses/amse_loss.py @@ -61,7 +61,7 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens with amp.autocast(device_type="cuda", enabled=False): xcoeffs = self.sht(prd) ycoeffs = self.sht(tar) - + # compute the SHT: xcoeffssq = torch.square(torch.abs(xcoeffs)) ycoeffssq = torch.square(torch.abs(ycoeffs)) @@ -100,5 +100,5 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens loss = torch.sum(loss, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): loss = reduce_from_parallel_region(loss, "h") - + return loss diff --git a/makani/utils/losses/base_loss.py b/makani/utils/losses/base_loss.py index b71006cd..f6460b1c 100644 --- a/makani/utils/losses/base_loss.py +++ b/makani/utils/losses/base_loss.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from abc import ABCMeta, abstractmethod +import math import torch import torch.nn as nn @@ -26,9 +27,11 @@ from makani.utils.grids import grid_to_quadrature_rule, GridQuadrature from makani.utils import comm +from makani.utils.features import get_wind_channels +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim -def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_type: str) -> torch.Tensor: +def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: """ auxiliary routine for predetermining channel weighting """ @@ -43,7 +46,7 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t elif channel_weight_type == "auto": for c, chn in enumerate(channel_names): - if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv"]: + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: channel_weights[c] = 0.1 elif chn in ["t2m", "2d"]: channel_weights[c] = 1.0 @@ -53,6 +56,19 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t else: channel_weights[c] = 0.01 + elif channel_weight_type == "new auto": + + for c, chn in enumerate(channel_names): + if chn in ["u10m", "v10m", "u100m", "v100m", "tp", "sp", "msl", "tcwv", "sst"]: + channel_weights[c] = 0.1 + elif chn in ["t2m", "2d"]: + channel_weights[c] = 2.0 + elif chn[0] in ["z", "u", "v", "t", "r", "q"]: + pressure_level = float(chn[1:]) + channel_weights[c] = max(0.3, 0.001 * pressure_level) + else: + channel_weights[c] = 0.01 + elif channel_weight_type == "custom": weight_dict = { @@ -216,6 +232,10 @@ def _compute_channel_weighting_helper(channel_names: List[str], channel_weight_t # normalize channel_weights = channel_weights / torch.sum(channel_weights) + # get the time differences and weigh them additionally + if time_diff_scale is not None: + channel_weights = channel_weights * time_diff_scale + return channel_weights @@ -250,9 +270,8 @@ def __init__( self.pole_mask = pole_mask self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + # get the quadrature rule for the corresponding grid quadrature_rule = grid_to_quadrature_rule(grid_type) - - # get the quadrature self.quadrature = GridQuadrature( quadrature_rule, img_shape=self.img_shape, @@ -272,8 +291,8 @@ def n_channels(self): return len(self.channel_names) @torch.compiler.disable(recursive=False) - def compute_channel_weighting(self, channel_weight_type: str) -> torch.Tensor: - return _compute_channel_weighting_helper(self.channel_names, channel_weight_type) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) @abstractmethod def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -292,6 +311,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + lmax: Optional[int] = None, spatial_distributed: Optional[bool] = False, ): super().__init__() @@ -302,14 +322,36 @@ def __init__( self.channel_names = channel_names self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + # SHT for the computation of SH coefficients if self.spatial_distributed and (comm.get_size("spatial") > 1): if not thd.is_initialized(): polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") thd.init(polar_group, azimuth_group) - self.sht = thd.DistributedRealSHT(*img_shape, grid=grid_type) + self.sht = thd.DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) else: - self.sht = th.RealSHT(*img_shape, grid=grid_type).float() + self.sht = th.RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type).float() + + # get the local l weights + l_weights = torch.ones(self.sht.lmax, dtype=torch.float32) + m_weights = 2 * torch.ones(self.sht.mmax, dtype=torch.float32) + m_weights[0] = 1.0 + + # get meshgrid of weights: + l_weights, m_weights = torch.meshgrid(l_weights, m_weights, indexing="ij") + + # use the product weights + lm_weights = l_weights * m_weights + + # split the tensors along all dimensions: + if self.spatial_distributed and comm.get_size("h") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + if self.spatial_distributed and comm.get_size("w") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + lm_weights = lm_weights.contiguous() + + # register + self.register_buffer("lm_weights", lm_weights, persistent=False) @property def type(self): @@ -320,8 +362,150 @@ def n_channels(self): return len(self.n_channels) @torch.compiler.disable(recursive=False) - def compute_channel_weighting(self, channel_weight_type: str) -> torch.Tensor: - return _compute_channel_weighting_helper(self.channel_names, channel_weight_type) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + + @abstractmethod + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: + pass + + +class VortDivBaseLoss(nn.Module, metaclass=ABCMeta): + """ + Geometric base loss class used by all geometric losses + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ): + super().__init__() + + self.img_shape = img_shape + self.crop_shape = crop_shape + self.crop_offset = crop_offset + self.channel_names = channel_names + self.pole_mask = pole_mask + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + + # get the wind channels + wind_chans = get_wind_channels(self.channel_names) + self.register_buffer("wind_chans", torch.LongTensor(wind_chans)) + + if self.spatial_distributed and (comm.get_size("spatial") > 1): + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.vsht = thd.DistributedRealVectorSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.isht = thd.DistributedInverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + else: + self.vsht = th.RealVectorSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.isht = th.InverseRealVectorSHT(nlat=self.vsht.nlat, nlon=self.vsht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + + # get the quadrature rule for the corresponding grid + quadrature_rule = grid_to_quadrature_rule(grid_type) + self.quadrature = GridQuadrature( + quadrature_rule, + img_shape=self.img_shape, + crop_shape=self.crop_shape, + crop_offset=self.crop_offset, + normalize=True, + pole_mask=self.pole_mask, + distributed=self.spatial_distributed, + ) + + @property + def type(self): + return LossType.Deterministic + + @property + def n_channels(self): + return len(self.n_channels) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + + chw = _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + chw = chw[self.wind_chans.to(chw.device)] + + # average u and v component weightings to weight vort and div equally + chw[1::2] = (chw[1::2] + chw[0::2]) / 2 + chw[0::2] = chw[1::2] + + return chw + + @abstractmethod + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: + pass + +class GradientBaseLoss(nn.Module, metaclass=ABCMeta): + """ + Gradient base loss class used by all gradient based losses + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ): + super().__init__() + + self.img_shape = img_shape + self.crop_shape = crop_shape + self.crop_offset = crop_offset + self.channel_names = channel_names + self.pole_mask = pole_mask + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + + if self.spatial_distributed and (comm.get_size("spatial") > 1): + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + self.sht = thd.DistributedRealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.ivsht = thd.DistributedInverseRealVectorSHT(nlat=self.sht.nlat, nlon=self.sht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + else: + self.sht = th.RealSHT(*img_shape, lmax=lmax, mmax=lmax, grid=grid_type) + self.ivsht = th.InverseRealVectorSHT(nlat=self.sht.nlat, nlon=self.sht.nlon, lmax=lmax, mmax=lmax, grid=grid_type) + + # get the quadrature rule for the corresponding grid + quadrature_rule = grid_to_quadrature_rule(grid_type) + self.quadrature = GridQuadrature( + quadrature_rule, + img_shape=self.img_shape, + crop_shape=self.crop_shape, + crop_offset=self.crop_offset, + normalize=True, + pole_mask=self.pole_mask, + distributed=self.spatial_distributed, + ) + + @property + def type(self): + return LossType.Deterministic + + @property + def n_channels(self): + return len(self.n_channels) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + return _compute_channel_weighting_helper(self.channel_names, channel_weight_type, time_diff_scale=time_diff_scale) + @abstractmethod def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None) -> torch.Tensor: diff --git a/makani/utils/losses/crps_loss.py b/makani/utils/losses/crps_loss.py index 61e89ce0..e4cf06d5 100644 --- a/makani/utils/losses/crps_loss.py +++ b/makani/utils/losses/crps_loss.py @@ -22,7 +22,7 @@ import torch.nn as nn from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils.losses.base_loss import LossType, GeometricBaseLoss, SpectralBaseLoss, VortDivBaseLoss, GradientBaseLoss from makani.utils import comm # distributed stuff @@ -30,6 +30,10 @@ from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region, copy_to_parallel_region from makani.mpu.mappings import distributed_transpose +# torch-harmonics for convolutions +import torch_harmonics as th +import torch_harmonics.distributed as thd + def rankdata(x: torch.Tensor, dim: int) -> torch.Tensor: """ @@ -49,6 +53,9 @@ def _crps_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor, we CRPS ensemble score from integrating the PDF piecewise compare https://github.com/properscoring/properscoring/blob/master/properscoring/_gufuncs.py#L7 disabling torch compile for the moment due to very long startup times when training large ensembles with ensemble parallelism + + forecasts: [ensemble, ...], observation: [...], weights: [ensemble, ...] + Assumes forecasts are sorted along ensemble dimension 0. """ # beware: forecasts are assumed sorted in sorted order @@ -110,14 +117,17 @@ def _crps_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor, we def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, alpha: float) -> torch.Tensor: """ - alternative CRPS variant that uses spread and skill + fair CRPS variant that uses spread and skill. Assumes pre-sorted ensemble """ observation = observation.unsqueeze(0) - # get nanmask - nanmasks = torch.logical_or(torch.isnan(forecasts), torch.isnan(weights)) - nanmask = torch.sum(nanmasks, dim=0).bool() + # get nanmask from observations and forecasts + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) + nanmask_bool = nanmasks.sum(dim=0) != 0 + + # impute NaN before computation to avoid 0 * NaN = NaN in backward pass + observation = torch.where(torch.isnan(observation), 0.0, observation) # compute total weights nweights = torch.where(nanmasks, 0.0, weights) @@ -133,11 +143,79 @@ def _crps_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, espread = 2 * torch.mean((2 * rank - num_ensemble - 1) * forecasts, dim=0) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * (num_ensemble - 1)) eskill = (observation - forecasts).abs().mean(dim=0) - # crps = torch.where(nanmasks.sum(dim=0) != 0, torch.nan, eskill - 0.5 * espread) + crps = torch.where(nanmask_bool, 0.0, eskill - 0.5 * espread) + + return crps + + +def _crps_probability_weighted_moment_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """ + CRPS estimator based on the probability weighted moment. see [1]. + + [1] Michael Zamo, Phillippe Naveau. Estimation of the Continuous Ranked Probability Score with Limited Information and Applications to Ensemble Weather Forecasts. Mathematical Geosciences. Volume 50 pp. 209-234. 2018. + """ + + observation = observation.unsqueeze(0) + + # get nanmask from observations and forecasts + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) + nanmask_bool = nanmasks.sum(dim=0) != 0 + + # impute NaN before computation to avoid 0 * NaN = NaN in backward pass + observation = torch.where(torch.isnan(observation), 0.0, observation) + + # compute total weights + nweights = torch.where(nanmasks, 0.0, weights) + total_weight = torch.sum(nweights, dim=0, keepdim=True) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get the ranks for the pwm computation + rank = torch.arange(num_ensemble, device=forecasts.device).reshape((num_ensemble,) + (1,) * (forecasts.dim() - 1)) + + # get the ensemble spread (total_weight is ensemble size here) + beta0 = forecasts.mean(dim=0) + beta1 = (rank * forecasts).sum(dim=0) / float(num_ensemble * (num_ensemble - 1)) + eskill = (observation - forecasts).abs().mean(dim=0) + + crps = eskill + beta0 - 2 * beta1 + + # zero out masked positions + crps = torch.where(nanmask_bool, 0.0, crps) + + return crps + + +def _crps_naive_skillspread_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weights: torch.Tensor, alpha: float) -> torch.Tensor: + """ + alternative fair CRPS variant that uses spread and skill. Uses naive computation which is O(N^2) in the number of ensemble members. Useful for complex + """ + + observation = observation.unsqueeze(0) + + # get nanmask from observations and forecasts + nanmasks = torch.logical_or(torch.isnan(observation), torch.isnan(weights)) + nanmask_bool = nanmasks.sum(dim=0) != 0 + + # impute NaN before computation to avoid 0 * NaN = NaN in backward pass + observation = torch.where(torch.isnan(observation), 0.0, observation) + + # compute total weights + nweights = torch.where(nanmasks, 0.0, weights) + total_weight = torch.sum(nweights, dim=0, keepdim=True) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # use broadcasting semantics to compute spread and skill + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = (observation - forecasts).abs().mean(dim=0) + crps = eskill - 0.5 * espread - # set to nan for first forecasts nan - crps = torch.where(nanmask, torch.nan, crps) + # zero out masked positions + crps = torch.where(nanmask_bool, 0.0, crps) return crps @@ -173,7 +251,7 @@ def _crps_gauss_kernel(observation: torch.Tensor, forecasts: torch.Tensor, weigh return crps -class EnsembleCRPSLoss(GeometricBaseLoss): +class CRPSLoss(GeometricBaseLoss): def __init__( self, @@ -278,6 +356,17 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # compute score crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "probability weighted moment": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_probability_weighted_moment_kernel(observations, forecasts, ensemble_weights) elif self.crps_type == "skillspread": if self.ensemble_weights is not None: raise NotImplementedError("currently only constant ensemble weights are supported") @@ -286,6 +375,14 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # compute score crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "naive skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_naive_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) elif self.crps_type == "gauss": if self.ensemble_weights is not None: ensemble_weights = self.ensemble_weights[idx] @@ -316,7 +413,7 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w return crps -class EnsembleSpectralCRPSLoss(SpectralBaseLoss): +class SpectralCRPSLoss(SpectralBaseLoss): def __init__( self, @@ -325,6 +422,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + lmax: Optional[int] = None, crps_type: str = "skillspread", spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, @@ -341,6 +439,7 @@ def __init__( crop_offset=crop_offset, channel_names=channel_names, grid_type=grid_type, + lmax=lmax, spatial_distributed=spatial_distributed, ) @@ -361,32 +460,6 @@ def __init__( # if absolute is true, the loss is computed only on the absolute value of the spectral coefficient self.absolute = absolute - # get the local l weights - lmax = self.sht.lmax - # l_weights = 1 / (2*ls+1) - l_weights = torch.ones(lmax) - - # get the local m weights - mmax = self.sht.mmax - m_weights = 2 * torch.ones(mmax)#.reshape(1, -1) - m_weights[0] = 1.0 - - # get meshgrid of weights: - l_weights, m_weights = torch.meshgrid(l_weights, m_weights, indexing="ij") - - # use the product weights - lm_weights = l_weights * m_weights - - # split the tensors along all dimensions: - lm_weights = l_weights * m_weights - if spatial_distributed and comm.get_size("h") > 1: - lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] - if spatial_distributed and comm.get_size("w") > 1: - lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] - - # register - self.register_buffer("lm_weights", lm_weights, persistent=False) - @property def type(self): return LossType.Probabilistic @@ -409,20 +482,15 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ forecasts = forecasts.float() observations = observations.float() with amp.autocast(device_type="cuda", enabled=False): - forecasts = self.sht(forecasts) / 4.0 / math.pi - observations = self.sht(observations) / 4.0 / math.pi + forecasts = self.sht(forecasts) / math.sqrt(4.0 * math.pi) + observations = self.sht(observations) / math.sqrt(4.0 * math.pi) if self.absolute: forecasts = torch.abs(forecasts).to(dtype) observations = torch.abs(observations).to(dtype) else: - forecasts = torch.view_as_real(forecasts).to(dtype) - observations = torch.view_as_real(observations).to(dtype) - - # merge complex dimension after channel dimension and flatten - # this needs to be undone at the end - forecasts = torch.movedim(forecasts, 5, 3).flatten(2, 3) - observations = torch.movedim(observations, 4, 2).flatten(1, 2) + # since the other kernels require sorting, this approach only works with the naive CRPS kernel + assert self.crps_type == "skillspread" # we assume the following shapes: # forecasts: batch, ensemble, channels, mmax, lmax @@ -472,6 +540,17 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ # compute score crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "probability weighted moment": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_probability_weighted_moment_kernel(observations, forecasts, ensemble_weights) elif self.crps_type == "skillspread": if self.ensemble_weights is not None: raise NotImplementedError("currently only constant ensemble weights are supported") @@ -479,7 +558,10 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) # compute score - crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + if self.absolute: + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + else: + crps = _crps_naive_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) elif self.crps_type == "gauss": if self.ensemble_weights is not None: ensemble_weights = self.ensemble_weights[idx] @@ -500,9 +582,572 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spectral_ if self.ensemble_distributed: crps = reduce_from_parallel_region(crps, "ensemble") - # finally undo the folding of the complex dimension into the channel dimension - if not self.absolute: - crps = crps.reshape(B, -1, 2).sum(dim=-1) + # the resulting tensor should have dimension B, C, which is what we return + return crps + +class GradientCRPSLoss(GradientBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + lmax: Optional[int] = None, + crps_type: str = "skillspread", + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + absolute: Optional[bool] = True, + alpha: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + lmax=lmax, + spatial_distributed=spatial_distributed, + ) + + # if absolute is true, the loss is computed only on the absolute value of the gradient + self.absolute = absolute + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.crps_type = crps_type + self.alpha = alpha + self.eps = eps + + if (self.crps_type != "skillspread") and (self.alpha < 1.0): + raise NotImplementedError("The alpha parameter (almost fair CRPS factor) is only supported for the skillspread kernel.") + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + if self.absolute: + return len(self.channel_names) + else: + return 2 * len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: torch.Tensor = None) -> torch.Tensor: + chw = super().compute_channel_weighting(channel_weight_type, time_diff_scale=time_diff_scale) + + if self.absolute: + return chw + else: + return [weight for weight in chw for _ in range(2)] + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, C, H, W = forecasts.shape + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + + # compute the SH coefficients of the forecasts and observations + forecasts = self.sht(forecasts.float()).unsqueeze(-3) + observations = self.sht(observations.float()).unsqueeze(-3) + + # append zeros, so that we can use the inverse vector SHT + forecasts = torch.cat([forecasts, torch.zeros_like(forecasts)], dim=-3) + observations = torch.cat([observations, torch.zeros_like(observations)], dim=-3) + + forecasts = self.ivsht(forecasts) + observations = self.ivsht(observations) + + forecasts = forecasts.to(dtype) + observations = observations.to(dtype) + + if self.absolute: + forecasts = forecasts.pow(2).sum(dim=-3).sqrt() + observations = observations.pow(2).sum(dim=-3).sqrt() + else: + C = 2 * C + + forecasts = forecasts.reshape(B, E, C, H, W) + observations = observations.reshape(B, C, H, W) + + # if ensemble dim is one dimensional then computing the score is quick: + if (not self.ensemble_distributed) and (E == 1): + # in this case, CRPS is straightforward + crps = torch.abs(observations - forecasts.squeeze(1)).reshape(B, C, H * W) + else: + # transpose forecasts: ensemble, batch, channels, lat, lon + forecasts = torch.moveaxis(forecasts, 1, 0) + + # now we need to transpose the forecasts into ensemble direction. + # ideally we split spatial dims + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # observations does not need a transpose, but just a split + observations = observations.reshape(B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + # run appropriate crps kernel to compute it pointwise + if self.crps_type == "cdf": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "gauss": + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_gauss_kernel(observations, forecasts, ensemble_weights, self.eps) + else: + raise ValueError(f"Unknown CRPS crps_type {self.crps_type}") + + # perform ensemble and spatial average of crps score + if spatial_weights is not None: + crps = torch.sum(crps * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + crps = torch.sum(crps * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + crps = reduce_from_parallel_region(crps, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling he quadrature forward function + if self.spatial_distributed: + crps = reduce_from_parallel_region(crps, "spatial") # the resulting tensor should have dimension B, C, which is what we return return crps + +class VortDivCRPSLoss(VortDivBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + crps_type: str = "skillspread", + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.crps_type = crps_type + self.alpha = alpha + self.eps = eps + + if (self.crps_type != "skillspread") and (self.alpha < 1.0): + raise NotImplementedError("The alpha parameter (almost fair CRPS factor) is only supported for the skillspread kernel.") + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, _, H, W = forecasts.shape + C = self.wind_chans.shape[0] + + # extract wind channels + forecasts = forecasts[..., self.wind_chans, :, :].reshape(B, E, C//2, 2, H, W) + observations = observations[..., self.wind_chans, :, :].reshape(B, C//2, 2, H, W) + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + forecasts = self.isht(self.vsht(forecasts.float())) + observations = self.isht(self.vsht(observations.float())) + + # extract wind channels + forecasts = forecasts.reshape(B, E, C, H, W) + observations = observations.reshape(B, C, H, W) + + # if ensemble dim is one dimensional then computing the score is quick: + if (not self.ensemble_distributed) and (E == 1): + # in this case, CRPS is straightforward + crps = torch.abs(observations - forecasts.squeeze(1)).reshape(B, C, H * W) + else: + # transpose forecasts: ensemble, batch, channels, lat, lon + forecasts = torch.moveaxis(forecasts, 1, 0) + + # now we need to transpose the forecasts into ensemble direction. + # ideally we split spatial dims + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # observations does not need a transpose, but just a split + observations = observations.reshape(B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + # run appropriate crps kernel to compute it pointwise + if self.crps_type == "cdf": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "gauss": + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_gauss_kernel(observations, forecasts, ensemble_weights, self.eps) + else: + raise ValueError(f"Unknown CRPS crps_type {self.crps_type}") + + # perform ensemble and spatial average of crps score + if spatial_weights is not None: + crps = torch.sum(crps * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + crps = torch.sum(crps * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + crps = reduce_from_parallel_region(crps, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling he quadrature forward function + if self.spatial_distributed: + crps = reduce_from_parallel_region(crps, "spatial") + + # the resulting tensor should have dimension B, C, which is what we return + return crps + +class KernelScoreLoss(GeometricBaseLoss): + """ + Computes the kernel score defined in Gneiting and Raftery (2007) with kernels + defined by the discrete-continuous convolutions. + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + crps_type: str = "skillspread", + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + kernel_basis_type: str = "harmonic", + kernel_basis_norm_mode: str = "nodal", + kernel_shape: Tuple[int, int] = (3, 3), + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.crps_type = crps_type + self.alpha = alpha + self.eps = eps + + if (self.crps_type != "skillspread") and (self.alpha < 1.0): + raise NotImplementedError("The alpha parameter (almost fair CRPS factor) is only supported for the skillspread kernel.") + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + # init distributed torch-harmonics if needed + if self.spatial_distributed and (comm.get_size("spatial") > 1): + if not thd.is_initialized(): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + + # set up DISCO convolution (one per kernel) + conv_handle = thd.DistributedDiscreteContinuousConvS2 if self.spatial_distributed else th.DiscreteContinuousConvS2 + + fb = th.filter_basis.get_filter_basis(tuple(kernel_shape), kernel_basis_type) + self.kernel_basis_size = fb.kernel_size + + self.conv = conv_handle( + self.n_channels, + self.n_channels * self.kernel_basis_size, + in_shape=img_shape, + out_shape=img_shape, + kernel_shape=tuple(kernel_shape), + basis_type=kernel_basis_type, + basis_norm_mode=kernel_basis_norm_mode, + grid_in=grid_type, + grid_out=grid_type, + groups=self.n_channels, + bias=False, + theta_cutoff=2 * kernel_shape[0] * math.pi / float(img_shape[0] - 1), + ) + + # initialize the weight to identity + weight = torch.zeros_like(self.conv.weight.data) + for i in range(self.n_channels): + for k in range(self.kernel_basis_size): + weight[i*k, 0, k] = 1.0 + + # convert weight to buffer to avoid issues with distributed training + delattr(self.conv, "weight") + self.conv.register_buffer("weight", weight) + + @property + def type(self): + return LossType.Probabilistic + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, _, H, W = forecasts.shape + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + forecasts = self.conv(forecasts.float().reshape(B*E, -1, H, W)) + observations = self.conv(observations.float()) + + forecasts = forecasts.reshape(B, E, -1, H, W) + C = forecasts.shape[2] + + # if ensemble dim is one dimensional then computing the score is quick: + if (not self.ensemble_distributed) and (E == 1): + # in this case, CRPS is straightforward + crps = torch.abs(observations - forecasts.squeeze(1)).reshape(B, C, H * W) + else: + # transpose forecasts: ensemble, batch, channels, lat, lon + forecasts = torch.moveaxis(forecasts, 1, 0) + + # now we need to transpose the forecasts into ensemble direction. + # ideally we split spatial dims + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + # observations does not need a transpose, but just a split + observations = observations.reshape(B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + # run appropriate crps kernel to compute it pointwise + if self.crps_type == "cdf": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) # how does the sorting work out if it is batched + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_ensemble_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "probability weighted moment": + # now, E dimension is local and spatial dim is split further + # we need to sort the forecasts now + forecasts, idx = torch.sort(forecasts, dim=0) + if self.ensemble_weights is not None: + ensemble_weights = self.ensemble_weights[idx] + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_probability_weighted_moment_kernel(observations, forecasts, ensemble_weights) + elif self.crps_type == "skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + elif self.crps_type == "naive skillspread": + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # compute score + crps = _crps_naive_skillspread_kernel(observations, forecasts, ensemble_weights, self.alpha) + else: + raise ValueError(f"Unknown CRPS crps_type {self.crps_type}") + + # perform ensemble and spatial average of crps score + if spatial_weights is not None: + crps = torch.sum(crps * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + crps = torch.sum(crps * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + crps = reduce_from_parallel_region(crps, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling he quadrature forward function + if self.spatial_distributed: + crps = reduce_from_parallel_region(crps, "spatial") + + # reduce the kernel dimensions + crps = crps.reshape(B, self.n_channels, self.kernel_basis_size).sum(dim=-1) + + # the resulting tensor should have dimension B, C, which is what we return + return crps \ No newline at end of file diff --git a/makani/utils/losses/drift_regularization.py b/makani/utils/losses/drift_regularization.py deleted file mode 100644 index bc8b8442..00000000 --- a/makani/utils/losses/drift_regularization.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Tuple, List - -import torch -import torch.nn as nn - -from makani.utils.losses.base_loss import GeometricBaseLoss, LossType - -from makani.utils import comm -from physicsnemo.distributed.mappings import reduce_from_parallel_region - - -class DriftRegularization(GeometricBaseLoss): - """ - Computes the Lp loss on the sphere. - """ - - def __init__( - self, - img_shape: Tuple[int, int], - crop_shape: Tuple[int, int], - crop_offset: Tuple[int, int], - channel_names: List[str], - p: Optional[float] = 1.0, - pole_mask: Optional[int] = 0, - grid_type: Optional[str] = "equiangular", - spatial_distributed: Optional[bool] = False, - ensemble_distributed: Optional[bool] = False, - **kwargs, - ): - super().__init__( - img_shape=img_shape, - crop_shape=crop_shape, - crop_offset=crop_offset, - channel_names=channel_names, - grid_type=grid_type, - pole_mask=pole_mask, - spatial_distributed=spatial_distributed, - ) - - self.p = p - self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed - self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) - - @property - def type(self): - return LossType.Probabilistic - - def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): - - if prd.dim() > tar.dim(): - tar = tar.unsqueeze(1) - - # compute difference between the means output has dims - loss = torch.abs(self.quadrature(prd) - self.quadrature(tar)).pow(self.p) - - # if ensemble - if prd.dim() == 5: - loss = torch.mean(loss, dim=1) - if self.ensemble_distributed: - loss = reduce_from_parallel_region(loss, "ensemble") / float(comm.get_size("ensemble")) - - return loss diff --git a/makani/utils/losses/energy_score.py b/makani/utils/losses/energy_score.py new file mode 100644 index 00000000..1f1b22fb --- /dev/null +++ b/makani/utils/losses/energy_score.py @@ -0,0 +1,747 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import math +import torch +import torch.nn as nn +from torch import amp + +from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils import comm + +# distributed stuff +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim +from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region +from makani.mpu.mappings import distributed_transpose + + +class L2EnergyScoreLoss(GeometricBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + pole_mask: int, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + channel_reduction: Optional[bool] = True, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-5, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.channel_reduction = channel_reduction + self.alpha = alpha + self.beta = beta + self.eps = eps + + # we also need a variant of the weights split in ensemble direction: + quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) + if self.ensemble_distributed: + quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] + quad_weight_split = quad_weight_split.contiguous() + self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 if self.channel_reduction else len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + if self.channel_reduction: + chw = torch.ones(1) + else: + chw = super().compute_channel_weighting(channel_weight_type, time_diff_scale) + return chw + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # we assume that spatial_weights have NO ensemble dim + if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, lat, lon + # observations: batch, channels, lat, lon + B, E, C, H, W = forecasts.shape + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") + else: + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from observations and forecasts + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(ensemble_weights)) + nanmask_bool = nanmasks.sum(dim=0) != 0 + + # impute NaN before computation to avoid 0 * NaN = NaN in backward pass + observations = torch.where(torch.isnan(observations), 0.0, observations) + forecasts = torch.where(torch.isnan(forecasts), 0.0, forecasts) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().square() + eskill = (observations - forecasts).abs().square() + + # zero out masked positions + espread = torch.where(nanmask_bool, 0.0, espread) + eskill = torch.where(nanmask_bool, 0.0, eskill) + + # do the spatial reduction + if spatial_weights is not None: + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) + else: + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + if self.channel_reduction: + espread = espread.sum(dim=-1, keepdim=True) + eskill = eskill.sum(dim=-1, keepdim=True) + + # just to be sure, mask the diagonal of espread with self.eps + #espread = torch.where(torch.eye(num_ensemble, device=espread.device).bool().reshape(num_ensemble, num_ensemble, 1, 1), self.eps, espread) + # get the masks + espread_mask = torch.where(espread < self.eps, True, False) + eskill_mask = torch.where(eskill < self.eps, True, False) + + # mask the data + espread = torch.where(espread_mask, self.eps, espread) + eskill = torch.where(eskill_mask, self.eps, eskill) + + with amp.autocast(device_type="cuda", enabled=False): + + espread = espread.float() + eskill = eskill.float() + + # This is according to the definition in Gneiting et al. 2005 + espread = torch.sqrt(espread).pow(self.beta) + eskill = torch.sqrt(eskill).pow(self.beta) + + # mask espread and sum + espread = torch.where(espread_mask, 0.0, espread) + eskill = torch.where(eskill_mask, 0.0, eskill) + #espread = torch.where(torch.eye(num_ensemble, device=espread.device).bool().reshape(num_ensemble, num_ensemble, 1, 1), 0.0, espread) + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + + # sum over ensemble + eskill = eskill.sum(dim=0) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + loss = eskill - 0.5 * espread + + return loss + + +class SobolevEnergyScoreLoss(SpectralBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + channel_reduction: Optional[bool] = True, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 1.0, + offset: Optional[float] = 1.0, + fraction: Optional[float] = 1.0, + relative_weight: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-6, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + lmax=lmax, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = spatial_distributed and comm.is_distributed("spatial") + self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) + self.channel_reduction = channel_reduction + self.alpha = alpha + self.beta = beta + self.fraction = fraction + self.offset = offset + self.relative_weight = relative_weight + self.eps = eps + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + # get the local l weights + l_weights = torch.arange(self.sht.lmax, dtype=torch.float32) + m_weights = 2 * torch.ones(self.sht.mmax, dtype=torch.float32) + m_weights[0] = 1.0 + # get meshgrid of weights: + l_weights, m_weights = torch.meshgrid(l_weights, m_weights, indexing="ij") + + # use the product weights + lm_weights = (self.offset + self.relative_weight * l_weights * (l_weights + 1)).pow(self.fraction) * m_weights + + # split the tensors along all dimensions: + if self.spatial_distributed and comm.get_size("h") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + if self.spatial_distributed and comm.get_size("w") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + lm_weights = lm_weights.contiguous() + + self.register_buffer("lm_weights", lm_weights, persistent=False) + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 if self.channel_reduction else len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + if self.channel_reduction: + chw = torch.ones(1) + else: + chw = super().compute_channel_weighting(channel_weight_type, time_diff_scale) + return chw + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + # TODO: check 4 pi normalization + forecasts = self.sht(forecasts.float()) / math.sqrt(4 * math.pi) + observations = self.sht(observations.float()) / math.sqrt(4 * math.pi) + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, mmax, lmax + # observations: batch, channels, mmax, lmax + B, E, C, H, W = forecasts.shape + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # for correct spatial reduction we need to do the same with spatial weights + + lm_weights_split = self.lm_weights.flatten(start_dim=-2, end_dim=-1) + if self.ensemble_distributed: + lm_weights_split = scatter_to_parallel_region(lm_weights_split, -1, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + num_ensemble = forecasts.shape[0] + + # get nanmask from observations and forecasts + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(forecasts)) + nanmask_bool = nanmasks.sum(dim=0) != 0 + + # impute NaN before computation to avoid 0 * NaN = NaN in backward pass + observations = torch.where(torch.isnan(observations), 0.0, observations) + forecasts = torch.where(torch.isnan(forecasts), 0.0, forecasts) + + # compute the individual distances + espread = lm_weights_split * (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().square() + eskill = lm_weights_split * (observations - forecasts).abs().square() + + # zero out masked positions + espread = torch.where(nanmask_bool, 0.0, espread) + eskill = torch.where(nanmask_bool, 0.0, eskill) + + # do the channel reduction first + if self.channel_reduction: + espread = espread.sum(dim=-2, keepdim=True) + eskill = eskill.sum(dim=-2, keepdim=True) + + # do the spatial reduction + espread = espread.sum(dim=-1, keepdim=False) + eskill = eskill.sum(dim=-1, keepdim=False) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # just to be sure, mask the diagonal of espread with self.eps + #espread = torch.where(torch.eye(num_ensemble, device=espread.device).bool().reshape(num_ensemble, num_ensemble, 1, 1), self.eps, espread) + # get the masks + espread_mask = torch.where(espread < self.eps, True, False) + eskill_mask = torch.where(eskill < self.eps, True, False) + + # mask the data + espread = torch.where(espread_mask, self.eps, espread) + eskill = torch.where(eskill_mask, self.eps, eskill) + + with amp.autocast(device_type="cuda", enabled=False): + + espread = espread.float() + eskill = eskill.float() + + # This is according to the definition in Gneiting et al. 2005 + espread = torch.sqrt(espread).pow(self.beta) + eskill = torch.sqrt(eskill).pow(self.beta) + + # mask espread and sum + espread = torch.where(espread_mask, 0.0, espread) + eskill = torch.where(eskill_mask, 0.0, eskill) + #espread = torch.where(torch.eye(num_ensemble, device=espread.device).bool().reshape(num_ensemble, num_ensemble, 1, 1), 0.0, espread) + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + + # compute the skill term + eskill = eskill.sum(dim=0) / float(num_ensemble) + + return (eskill - 0.5 * espread) + + +class SpectralL2EnergyScoreLoss(SpectralBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + channel_reduction: Optional[bool] = True, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-3, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + lmax=lmax, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = spatial_distributed and comm.is_distributed("spatial") + self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) + self.channel_reduction = channel_reduction + self.alpha = alpha + self.beta = beta + self.eps = eps + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 if self.channel_reduction else len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + if self.channel_reduction: + chw = torch.ones(1) + else: + chw = super().compute_channel_weighting(channel_weight_type, time_diff_scale) + return chw + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + forecasts = forecasts.float() + observations = observations.float() + with amp.autocast(device_type="cuda", enabled=False): + # TODO: check 4 pi normalization + forecasts = self.sht(forecasts) / math.sqrt(4.0 * math.pi) + observations = self.sht(observations) / math.sqrt(4.0 * math.pi) + + forecasts = forecasts.to(dtype) + observations = observations.to(dtype) + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, mmax, lmax + # observations: batch, channels, mmax, lmax + B, E, C, H, W = forecasts.shape + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # for correct spatial reduction we need to do the same with spatial weights + + lm_weights_split = self.lm_weights + if self.ensemble_distributed: + lm_weights_split = scatter_to_parallel_region(lm_weights_split, -1, "ensemble") + + # observations does not need a transpose, but just a split + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + num_ensemble = forecasts.shape[0] + + # get nanmask from observations and forecasts + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(forecasts)) + nanmask_bool = nanmasks.sum(dim=0) != 0 + + # impute NaN before computation to avoid 0 * NaN = NaN in backward pass + observations = torch.where(torch.isnan(observations), 0.0, observations) + forecasts = torch.where(torch.isnan(forecasts), 0.0, forecasts) + + espread = lm_weights_split * (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().square() + eskill = lm_weights_split * (observations - forecasts).abs().square() + + # zero out masked positions + espread = torch.where(nanmask_bool, 0.0, espread) + eskill = torch.where(nanmask_bool, 0.0, eskill) + + # do the channel reduction first + if self.channel_reduction: + espread = espread.sum(dim=-3, keepdim=True) + eskill = eskill.sum(dim=-3, keepdim=True) + + # do the spatial m reduction + espread = espread.sum(dim=-1, keepdim=False) + eskill = eskill.sum(dim=-1, keepdim=False) + + # since we split m dim into ensemble dim, we need to do an ensemble sum as well + if self.ensemble_distributed: + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + + # we need to do the spatial averaging manually since + if self.spatial_distributed: + espread = reduce_from_parallel_region(espread, "w") + eskill = reduce_from_parallel_region(eskill, "w") + + # get the masks + espread_mask = torch.where(espread < self.eps, True, False) + eskill_mask = torch.where(eskill < self.eps, True, False) + + # mask the data + espread = torch.where(espread_mask, self.eps, espread) + eskill = torch.where(eskill_mask, self.eps, eskill) + + with amp.autocast(device_type="cuda", enabled=False): + + espread = espread.float() + eskill = eskill.float() + + # This is according to the definition in Gneiting et al. 2005 + espread = torch.sqrt(espread).pow(self.beta) + eskill = torch.sqrt(eskill).pow(self.beta) + + # mask espread and sum + espread = torch.where(espread_mask, 0.0, espread) + eskill = torch.where(eskill_mask, 0.0, eskill) + + # now we have reduced everything and need to sum appropriately (B, C, H) + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + # we now have the loss per wavenumber, which we can normalize + loss = (eskill - 0.5 * espread) + + # we need to do the spatial averaging manually since + loss = loss.sum(dim=-1) + if self.spatial_distributed: + loss = reduce_from_parallel_region(loss, "h") + + return loss + + +class SpectralCoherenceLoss(SpectralBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + lmax: Optional[int] = None, + relative: Optional[bool] = False, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + alpha: Optional[float] = 1.0, + eps: Optional[float] = 1.0e-6, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + lmax=lmax, + spatial_distributed=spatial_distributed, + ) + + self.relative = relative + self.spatial_distributed = spatial_distributed and comm.is_distributed("spatial") + self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) + self.eps = eps + + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + + # prep ls and ms for broadcasting + ls = torch.arange(self.sht.lmax).reshape(-1, 1) + ms = torch.arange(self.sht.mmax).reshape(1, -1) + + lm_weights = torch.ones((self.sht.lmax, self.sht.mmax)) + lm_weights[:, 1:] *= 2.0 + lm_weights = torch.where(ms > ls, 0.0, lm_weights) + if comm.get_size("h") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + if comm.get_size("w") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + lm_weights = lm_weights.contiguous() + + self.register_buffer("lm_weights", lm_weights, persistent=False) + + @property + def type(self): + return LossType.Probabilistic + + @property + def n_channels(self): + return 1 if self.channel_reduction else len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + if self.channel_reduction: + chw = torch.ones(1) + else: + chw = super().compute_channel_weighting(channel_weight_type, time_diff_scale) + return chw + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() != 5: + raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + + # get the data type before stripping amp types + dtype = forecasts.dtype + + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + # TODO: check 4 pi normalization + forecasts = self.sht(forecasts.float()) / math.sqrt(4.0 * math.pi) + observations = self.sht(observations.float()) / math.sqrt(4.0 * math.pi) + + # we assume the following shapes: + # forecasts: batch, ensemble, channels, mmax, lmax + # observations: batch, channels, mmax, lmax + B, E, C, H, W = forecasts.shape + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") # for correct spatial reduction we need to do the same with spatial weights + + lm_weights_split = self.lm_weights + if self.ensemble_distributed: + lm_weights_split = scatter_to_parallel_region(lm_weights_split, -1, "ensemble") + + # observations does not need a transpose, but just a split and broadcast to ensemble dimension + observations = observations.unsqueeze(0) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + num_ensemble = forecasts.shape[0] + + # compute power spectral densities of forecasts and observations + psd_forecasts = (lm_weights_split * forecasts.abs().square()).sum(dim=-1) + psd_observations = (lm_weights_split * observations.abs().square()).sum(dim=-1) + + # reduce over ensemble parallel region and m spatial dimensions + if self.ensemble_distributed: + psd_forecasts = reduce_from_parallel_region(psd_forecasts, "ensemble") + psd_observations = reduce_from_parallel_region(psd_observations, "ensemble") + + if self.spatial_distributed: + psd_forecasts = reduce_from_parallel_region(psd_forecasts, "w") + psd_observations = reduce_from_parallel_region(psd_observations, "w") + + + # compute coherence between forecasts and observations + coherence_forecasts = (lm_weights_split * (forecasts.unsqueeze(0).conj() * forecasts.unsqueeze(1)).real).sum(dim=-1) + coherence_observations = (lm_weights_split * (forecasts.conj() * observations).real).sum(dim=-1) + + # reduce over ensemble parallel region and m spatial dimensions + if self.ensemble_distributed: + coherence_forecasts = reduce_from_parallel_region(coherence_forecasts, "ensemble") + coherence_observations = reduce_from_parallel_region(coherence_observations, "ensemble") + + if self.spatial_distributed: + coherence_forecasts = reduce_from_parallel_region(coherence_forecasts, "w") + coherence_observations = reduce_from_parallel_region(coherence_observations, "w") + + # divide the coherence by the product of the norms (with epsilon for numerical stability) + coherence_observations = coherence_observations / torch.sqrt(psd_forecasts * psd_observations + self.eps) + coherence_forecasts = coherence_forecasts / torch.sqrt(psd_forecasts.unsqueeze(0) * psd_forecasts.unsqueeze(1) + self.eps) + + # compute the error in the power spectral density + psd_skill = (psd_forecasts - psd_observations).square() + if self.relative: + psd_skill = psd_skill / (psd_observations + self.eps) + psd_skill = psd_skill.sum(dim=0) / float(num_ensemble) + + # compute the coherence skill and spread + coherence_skill = (1.0 - coherence_observations).sum(dim=0) / float(num_ensemble) + + # mask the diagonal of coherence_spread with 0.0 + coherence_spread = torch.where(torch.eye(num_ensemble, device=coherence_forecasts.device).bool().reshape(num_ensemble, num_ensemble, 1, 1, 1), 0.0, 1.0 - coherence_forecasts) + coherence_spread = coherence_spread.sum(dim=(0, 1)) / float(num_ensemble * (num_ensemble - 1)) + + # compute the loss + if self.relative: + loss = psd_skill + 2.0 * (coherence_skill - 0.5 * coherence_spread) + else: + loss = psd_skill + 2.0 * psd_observations.squeeze(0) * (coherence_skill - 0.5 * coherence_spread) + + # reduce the loss over the l dimensions + loss = loss.sum(dim=-1) + if self.spatial_distributed: + loss = reduce_from_parallel_region(loss, "h") + + # reduce over the channel dimension + loss = loss.sum(dim=-1) + + return loss \ No newline at end of file diff --git a/makani/utils/losses/h1_loss.py b/makani/utils/losses/h1_loss.py index 0cac68b4..75acff49 100644 --- a/makani/utils/losses/h1_loss.py +++ b/makani/utils/losses/h1_loss.py @@ -134,7 +134,7 @@ def rel(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] tar_norm2 = 2 * torch.sum(tar_coeffssq, dim=-1) if self.spatial_distributed and (comm.get_size("w") > 1): tar_norm2 = reduce_from_parallel_region(tar_norm2, "w") - + # compute target norms tar_norm2 = tar_norm2.reshape(B, C, -1) tar_h1_norm2 = torch.sum(tar_norm2 * self.h1_weights, dim=-1) diff --git a/makani/utils/losses/lp_loss.py b/makani/utils/losses/lp_loss.py index 2c31b2fc..21c6ed48 100644 --- a/makani/utils/losses/lp_loss.py +++ b/makani/utils/losses/lp_loss.py @@ -24,6 +24,8 @@ from makani.utils import comm +from physicsnemo.distributed.mappings import reduce_from_parallel_region + class GeometricLpLoss(GeometricBaseLoss): """ @@ -114,9 +116,9 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens return loss -class SpectralL2Loss(SpectralBaseLoss): +class SpectralLpLoss(SpectralBaseLoss): """ - Computes the geometric L2 loss but using the spherical Harmonic transform + Computes the Lp loss in spectral (SH coefficients) space """ def __init__( @@ -126,6 +128,7 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, + p: Optional[float] = 2.0, relative: Optional[bool] = False, squared: Optional[bool] = False, spatial_distributed: Optional[bool] = False, @@ -140,6 +143,7 @@ def __init__( spatial_distributed=spatial_distributed, ) + self.p = p self.relative = relative self.squared = squared self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed @@ -147,80 +151,95 @@ def __init__( def abs(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): B, C, H, W = prd.shape - coeffssq = torch.square(torch.abs(self.sht(prd - tar))) / torch.pi / 4.0 + # compute SH coefficients of the difference + coeffs = self.sht(prd - tar) + + # compute |coeffs|^p (orthonormal convention) + coeffsp = torch.abs(coeffs) ** self.p if wgt is not None: - coeffssq = coeffssq * wgt + coeffsp = coeffsp * wgt + # sum over m: m=0 contributes once, m!=0 contribute twice (due to conjugate symmetry) if comm.get_rank("w") == 0: - norm2 = coeffssq[..., 0] + 2 * torch.sum(coeffssq[..., 1:], dim=-1) + normp = coeffsp[..., 0] + 2 * torch.sum(coeffsp[..., 1:], dim=-1) else: - norm2 = 2 * torch.sum(coeffssq, dim=-1) + normp = 2 * torch.sum(coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - norm2 = reduce_from_parallel_region(norm2, "w") + normp = reduce_from_parallel_region(normp, "w") - # compute norms - norm2 = norm2.reshape(B, C, -1) - norm2 = torch.sum(norm2, dim=-1) + # sum over l (degrees) + normp = normp.reshape(B, C, -1) + normp = torch.sum(normp, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - norm2 = reduce_from_parallel_region(norm2, "h") + normp = reduce_from_parallel_region(normp, "h") + # take p-th root unless squared is True if not self.squared: - norm2 = torch.sqrt(norm2) + normp = normp ** (1.0 / self.p) - return norm2 + return normp def rel(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): B, C, H, W = prd.shape - coeffssq = torch.square(torch.abs(self.sht(prd - tar))) / torch.pi / 4.0 + # compute SH coefficients of the difference + coeffs = self.sht(prd - tar) + coeffsp = torch.abs(coeffs) ** self.p if wgt is not None: - coeffssq = coeffssq * wgt + coeffsp = coeffsp * wgt - # sum m != 0 coeffs: + # sum m != 0 coeffs for numerator if comm.get_rank("w") == 0: - norm2 = coeffssq[..., 0] + 2 * torch.sum(coeffssq[..., 1:], dim=-1) + normp = coeffsp[..., 0] + 2 * torch.sum(coeffsp[..., 1:], dim=-1) else: - norm2 = 2 * torch.sum(coeffssq, dim=-1) + normp = 2 * torch.sum(coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - norm2 = reduce_from_parallel_region(norm2, "w") + normp = reduce_from_parallel_region(normp, "w") + + # sum over l + normp = normp.reshape(B, C, -1) + normp = torch.sum(normp, dim=-1) - # compute norms - norm2 = norm2.reshape(B, C, -1) - norm2 = torch.sum(norm2, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - norm2 = reduce_from_parallel_region(norm2, "h") + normp = reduce_from_parallel_region(normp, "h") - # target - tar_coeffssq = torch.square(torch.abs(self.sht(tar))) / torch.pi / 4.0 + # compute target norm + tar_coeffs = self.sht(tar) + tar_coeffsp = torch.abs(tar_coeffs) ** self.p if wgt is not None: - tar_coeffssq = tar_coeffssq * wgt + tar_coeffsp = tar_coeffsp * wgt - # sum m != 0 coeffs: + # sum m != 0 coeffs for denominator if comm.get_rank("w") == 0: - tar_norm2 = tar_coeffssq[..., 0] + 2 * torch.sum(tar_coeffssq[..., 1:], dim=-1) + tar_normp = tar_coeffsp[..., 0] + 2 * torch.sum(tar_coeffsp[..., 1:], dim=-1) else: - tar_norm2 = 2 * torch.sum(tar_coeffssq, dim=-1) + tar_normp = 2 * torch.sum(tar_coeffsp, dim=-1) + if self.spatial_distributed and (comm.get_size("w") > 1): - tar_norm2 = reduce_from_parallel_region(tar_norm2, "w") + tar_normp = reduce_from_parallel_region(tar_normp, "w") + + # sum over l + tar_normp = tar_normp.reshape(B, C, -1) + tar_normp = torch.sum(tar_normp, dim=-1) - # compute target norms - tar_norm2 = tar_norm2.reshape(B, C, -1) - tar_norm2 = torch.sum(tar_norm2, dim=-1) if self.spatial_distributed and (comm.get_size("h") > 1): - tar_norm2 = reduce_from_parallel_region(tar_norm2, "h") + tar_normp = reduce_from_parallel_region(tar_normp, "h") + # take p-th root unless squared is True if not self.squared: - diff_norms = torch.sqrt(norm2) - tar_norms = torch.sqrt(tar_norm2) + diff_norms = normp ** (1.0 / self.p) + tar_norms = tar_normp ** (1.0 / self.p) else: - diff_norms = norm2 - tar_norms = tar_norm2 + diff_norms = normp + tar_norms = tar_normp - # setup return value + # compute relative error retval = diff_norms / tar_norms return retval @@ -233,3 +252,4 @@ def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tens loss = self.abs(prd, tar, wgt) return loss + diff --git a/makani/utils/losses/mmd_loss.py b/makani/utils/losses/mmd_loss.py index 21d2d8d5..dd192a29 100644 --- a/makani/utils/losses/mmd_loss.py +++ b/makani/utils/losses/mmd_loss.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,58 +15,24 @@ from typing import Optional, Tuple, List -import numpy as np - +import math import torch import torch.nn as nn -from torch.cuda import amp +from torch import amp -from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils.losses.base_loss import GeometricBaseLoss, GradientBaseLoss, LossType from makani.utils import comm +import torch_harmonics as th +import torch_harmonics.distributed as thd + # distributed stuff from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region from makani.mpu.mappings import distributed_transpose -# @torch.compile -# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor): -# return torch.abs(x - y) - - -@torch.compile -def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor, bandwidth: float = 1.0): - return torch.exp(-0.5 * torch.square(torch.abs(x - y)) / bandwidth) - - -# Computes the squared maximum mean discrepancy -# @torch.compile -def _mmd2_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor) -> torch.Tensor: - - # initial values - spread_term = torch.zeros_like(observation) - disc_term = torch.zeros_like(observation) - - num_forecasts = forecasts.shape[0] - - for m in range(num_forecasts): - # get the forecast - ym = forecasts[m] - - # account for contributions on the off-diasgonal assuming that the kernel is symmetric - spread_term = spread_term + 2.0 * torch.sum(_mmd_rbf_kernel(ym, forecasts[m:]), dim=0) - - # contributions to the discrepancy term - disc_term = disc_term + _mmd_rbf_kernel(ym, observation) - - # compute the squared mmd - mmd2 = spread_term / (num_forecasts - 1) / num_forecasts - 2.0 * disc_term / num_forecasts - - return mmd2 - - -class EnsembleMMDLoss(GeometricBaseLoss): +class GaussianMMDLoss(GeometricBaseLoss): r""" Computes the maximum mean discrepancy loss for a specific kernel. For details see [1] @@ -80,10 +46,15 @@ def __init__( crop_offset: Tuple[int, int], channel_names: List[str], grid_type: str, - squared: Optional[bool] = False, - pole_mask: Optional[int] = 0, + pole_mask: int, spatial_distributed: Optional[bool] = False, ensemble_distributed: Optional[bool] = False, + ensemble_weights: Optional[torch.Tensor] = None, + sigma: Optional[float] = 1.0, + alpha: Optional[float] = 1.0, + beta: Optional[float] = 2.0, + eps: Optional[float] = 1.0e-5, + channel_reduction: Optional[bool] = False, **kwargs, ): @@ -97,10 +68,13 @@ def __init__( spatial_distributed=spatial_distributed, ) - self.squared = squared - self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + self.alpha = alpha + self.beta = beta + self.eps = eps + self.channel_reduction = channel_reduction + self.sigma = sigma # we also need a variant of the weights split in ensemble direction: quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) @@ -109,10 +83,26 @@ def __init__( quad_weight_split = quad_weight_split.contiguous() self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + if ensemble_weights is not None: + self.register_buffer("ensemble_weights", ensemble_weights, persistent=False) + else: + self.ensemble_weights = ensemble_weights + @property def type(self): return LossType.Probabilistic + @property + def n_channels(self): + if self.channel_reduction: + return 1 + else: + return len(self.channel_names) + + @torch.compiler.disable(recursive=False) + def compute_channel_weighting(self, channel_weight_type: str, time_diff_scale: str) -> torch.Tensor: + return torch.ones(1) + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -121,51 +111,236 @@ def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_w # we assume that spatial_weights have NO ensemble dim if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): - raise ValueError("the weights have to have the same number of dimensions as observations") + spdim = spatial_weights.dim() + odim = observations.dim() + raise ValueError(f"the weights have to have the same number of dimensions (found {spdim}) as observations (found {odim}).") # we assume the following shapes: # forecasts: batch, ensemble, channels, lat, lon # observations: batch, channels, lat, lon B, E, C, H, W = forecasts.shape - # if ensemble dim is one dimensional then computing the score is quick: - if (not self.ensemble_distributed) and (forecasts.shape[1] == 1): - # in this case, CRPS is straightforward - mmd = _mmd_rbf_kernel(observations, forecasts.squeeze(1)).reshape(B, C, H * W) + # get the data type before stripping amp types + dtype = forecasts.dtype + + # transpose the forecasts to ensemble, batch, channels, lat, lon and then do distributed transpose into ensemble direction. + # ideally we split spatial dims + forecasts = torch.moveaxis(forecasts, 1, 0) + forecasts = forecasts.reshape(E, B, C, H * W) + if self.ensemble_distributed: + ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] + forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") + + # observations does not need a transpose, but just a split + observations = observations.reshape(1, B, C, H * W) + if self.ensemble_distributed: + observations = scatter_to_parallel_region(observations, -1, "ensemble") + + # for correct spatial reduction we need to do the same with spatial weights + if spatial_weights is not None: + spatial_weights_split = spatial_weights.flatten(start_dim=-2, end_dim=-1) + spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + + if self.ensemble_weights is not None: + raise NotImplementedError("currently only constant ensemble weights are supported") else: - # transpose forecasts: ensemble, batch, channels, lat, lon - forecasts = torch.moveaxis(forecasts, 1, 0) - - # now we need to transpose the forecasts into ensemble direction. - # ideally we split spatial dims - forecasts = forecasts.reshape(E, B, C, H * W) - if self.ensemble_distributed: - ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] - forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") - # observations does not need a transpose, but just a split - observations = observations.reshape(B, C, H * W) - if self.ensemble_distributed: - observations = scatter_to_parallel_region(observations, -1, "ensemble") - if spatial_weights is not None: - spatial_weights_split = spatial_weights.flatten(-2, -1) - spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") - - # now, E dimension is local and spatial dim is split further. Compute the mmd - mmd = _mmd2_ensemble_kernel(observations, forecasts) - - # perform spatial average of crps score + ensemble_weights = torch.ones_like(forecasts, device=forecasts.device) + + # ensemble size + num_ensemble = forecasts.shape[0] + + # get nanmask from observations and forecasts + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(forecasts)) + nanmask_bool = nanmasks.sum(dim=0) != 0 + + # impute NaN before computation to avoid 0 * NaN = NaN in backward pass + observations = torch.where(torch.isnan(observations), 0.0, observations) + forecasts = torch.where(torch.isnan(forecasts), 0.0, forecasts) + + # use broadcasting semantics to compute spread and skill and sum over channels (vector norm) + espread = (forecasts.unsqueeze(1) - forecasts.unsqueeze(0)).abs().pow(self.beta) + eskill = (observations - forecasts).abs().pow(self.beta) + + # zero out masked positions + espread = torch.where(nanmask_bool, 0.0, espread) + eskill = torch.where(nanmask_bool, 0.0, eskill) + + # do the spatial reduction if spatial_weights is not None: - mmd = torch.sum(mmd * self.quad_weight_split * spatial_weights_split, dim=-1) + espread = torch.sum(espread * self.quad_weight_split * spatial_weights_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split * spatial_weights_split, dim=-1) else: - mmd = torch.sum(mmd * self.quad_weight_split, dim=-1) + espread = torch.sum(espread * self.quad_weight_split, dim=-1) + eskill = torch.sum(eskill * self.quad_weight_split, dim=-1) + + # since we split spatial dim into ensemble dim, we need to do an ensemble sum as well if self.ensemble_distributed: - mmd = reduce_from_parallel_region(mmd, "ensemble") + espread = reduce_from_parallel_region(espread, "ensemble") + eskill = reduce_from_parallel_region(eskill, "ensemble") + # we need to do the spatial averaging manually since + # we are not calling the quadrature forward function if self.spatial_distributed: - mmd = reduce_from_parallel_region(mmd, "spatial") + espread = reduce_from_parallel_region(espread, "spatial") + eskill = reduce_from_parallel_region(eskill, "spatial") + + # do the channel reduction while ignoring NaNs + # if channel weights are required they should be added here to the reduction + if self.channel_reduction: + espread = espread.sum(dim=-2, keepdim=True) + eskill = eskill.sum(dim=-2, keepdim=True) + + # apply the Gaussian kernel + espread = torch.exp(-0.5 * torch.square(espread) / self.sigma) + eskill = torch.exp(-0.5 * torch.square(eskill) / self.sigma) + + # mask out the diagonal elements in the spread term + espread = torch.where(torch.eye(num_ensemble, device=espread.device).bool().reshape(num_ensemble, num_ensemble, 1, 1), 0.0, espread) + + # now we have reduced everything and need to sum appropriately + espread = espread.sum(dim=(0,1)) * (float(num_ensemble) - 1.0 + self.alpha) / float(num_ensemble * num_ensemble * (num_ensemble - 1)) + eskill = eskill.sum(dim=0) / float(num_ensemble) + + # the resulting tensor should have dimension B, C which is what we return + return eskill - 0.5 * espread - if not self.squared: - mmd = torch.sqrt(mmd) +# @torch.compile +# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor): +# return torch.abs(x - y) + + +# @torch.compile +# def _mmd_rbf_kernel(x: torch.Tensor, y: torch.Tensor, bandwidth: float = 1.0): +# return torch.exp(-0.5 * torch.square(torch.abs(x - y)) / bandwidth) - # the resulting tensor should have dimension B, C, which is what we return - return mmd + +# # Computes the squared maximum mean discrepancy +# # @torch.compile +# def _mmd2_ensemble_kernel(observation: torch.Tensor, forecasts: torch.Tensor) -> torch.Tensor: + +# # initial values +# spread_term = torch.zeros_like(observation) +# disc_term = torch.zeros_like(observation) + +# num_forecasts = forecasts.shape[0] + +# for m in range(num_forecasts): + +# # get the forecast +# ym = forecasts[m] + +# # account for contributions on the off-diasgonal assuming that the kernel is symmetric +# spread_term = spread_term + 2.0 * torch.sum(_mmd_rbf_kernel(ym, forecasts[m:]), dim=0) + +# # contributions to the discrepancy term +# disc_term = disc_term + _mmd_rbf_kernel(ym, observation) + +# # compute the squared mmd +# mmd2 = spread_term / (num_forecasts - 1) / num_forecasts - 2.0 * disc_term / num_forecasts + +# return mmd2 + + +# class EnsembleMMDLoss(GeometricBaseLoss): +# r""" +# Computes the maximum mean discrepancy loss for a specific kernel. For details see [1] + +# [1] Dziugaite, Gintare Karolina; Roy, Daniel M.; Ghahramani, Zhoubin; Training generative neural networks via Maximum Mean Discrepancy optimization; arXiv:1505.03906 +# """ + +# def __init__( +# self, +# img_shape: Tuple[int, int], +# crop_shape: Tuple[int, int], +# crop_offset: Tuple[int, int], +# channel_names: List[str], +# grid_type: str, +# squared: Optional[bool] = False, +# pole_mask: Optional[int] = 0, +# spatial_distributed: Optional[bool] = False, +# ensemble_distributed: Optional[bool] = False, +# **kwargs, +# ): + +# super().__init__( +# img_shape=img_shape, +# crop_shape=crop_shape, +# crop_offset=crop_offset, +# channel_names=channel_names, +# grid_type=grid_type, +# pole_mask=pole_mask, +# spatial_distributed=spatial_distributed, +# ) + +# self.squared = squared + +# self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed +# self.ensemble_distributed = comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) and ensemble_distributed + +# # we also need a variant of the weights split in ensemble direction: +# quad_weight_split = self.quadrature.quad_weight.reshape(1, 1, -1) +# if self.ensemble_distributed: +# quad_weight_split = split_tensor_along_dim(quad_weight_split, dim=-1, num_chunks=comm.get_size("ensemble"))[comm.get_rank("ensemble")] +# quad_weight_split = quad_weight_split.contiguous() +# self.register_buffer("quad_weight_split", quad_weight_split, persistent=False) + +# @property +# def type(self): +# return LossType.Probabilistic + +# def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, spatial_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + +# # sanity checks +# if forecasts.dim() != 5: +# raise ValueError(f"Error, forecasts tensor expected to have 5 dimensions but found {forecasts.dim()}.") + +# # we assume that spatial_weights have NO ensemble dim +# if (spatial_weights is not None) and (spatial_weights.dim() != observations.dim()): +# raise ValueError("the weights have to have the same number of dimensions as observations") + +# # we assume the following shapes: +# # forecasts: batch, ensemble, channels, lat, lon +# # observations: batch, channels, lat, lon +# B, E, C, H, W = forecasts.shape + +# # if ensemble dim is one dimensional then computing the score is quick: +# if (not self.ensemble_distributed) and (forecasts.shape[1] == 1): +# # in this case, CRPS is straightforward +# mmd = _mmd_rbf_kernel(observations, forecasts.squeeze(1)).reshape(B, C, H * W) +# else: +# # transpose forecasts: ensemble, batch, channels, lat, lon +# forecasts = torch.moveaxis(forecasts, 1, 0) + +# # now we need to transpose the forecasts into ensemble direction. +# # ideally we split spatial dims +# forecasts = forecasts.reshape(E, B, C, H * W) +# if self.ensemble_distributed: +# ensemble_shapes = [forecasts.shape[0] for _ in range(comm.get_size("ensemble"))] +# forecasts = distributed_transpose.apply(forecasts, (-1, 0), ensemble_shapes, "ensemble") +# # observations does not need a transpose, but just a split +# observations = observations.reshape(B, C, H * W) +# if self.ensemble_distributed: +# observations = scatter_to_parallel_region(observations, -1, "ensemble") +# if spatial_weights is not None: +# spatial_weights_split = spatial_weights.flatten(-2, -1) +# spatial_weights_split = scatter_to_parallel_region(spatial_weights_split, -1, "ensemble") + +# # now, E dimension is local and spatial dim is split further. Compute the mmd +# mmd = _mmd2_ensemble_kernel(observations, forecasts) + +# # perform spatial average of crps score +# if spatial_weights is not None: +# mmd = torch.sum(mmd * self.quad_weight_split * spatial_weights_split, dim=-1) +# else: +# mmd = torch.sum(mmd * self.quad_weight_split, dim=-1) +# if self.ensemble_distributed: +# mmd = reduce_from_parallel_region(mmd, "ensemble") + +# if self.spatial_distributed: +# mmd = reduce_from_parallel_region(mmd, "spatial") + +# if not self.squared: +# mmd = torch.sqrt(mmd) + +# # the resulting tensor should have dimension B, C, which is what we return +# return mmd diff --git a/makani/utils/losses/regularization.py b/makani/utils/losses/regularization.py new file mode 100644 index 00000000..42fb07c6 --- /dev/null +++ b/makani/utils/losses/regularization.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, List + +import math + +import torch +import torch.nn as nn +from torch import amp + +from makani.utils.losses.base_loss import GeometricBaseLoss, SpectralBaseLoss, LossType +from makani.utils import comm + +# distributed stuff +from physicsnemo.distributed.utils import compute_split_shapes, split_tensor_along_dim +from physicsnemo.distributed.mappings import scatter_to_parallel_region, reduce_from_parallel_region + + +class DriftRegularization(GeometricBaseLoss): + """ + Computes the Lp loss on the sphere. + """ + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + p: Optional[float] = 1.0, + pole_mask: Optional[int] = 0, + grid_type: Optional[str] = "equiangular", + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + **kwargs, + ): + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + pole_mask=pole_mask, + spatial_distributed=spatial_distributed, + ) + + self.p = p + self.spatial_distributed = comm.is_distributed("spatial") and spatial_distributed + self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) + + @property + def type(self): + return LossType.Probabilistic + + def forward(self, prd: torch.Tensor, tar: torch.Tensor, wgt: Optional[torch.Tensor] = None): + + if prd.dim() > tar.dim(): + tar = tar.unsqueeze(1) + + # compute difference between the means output has dims + loss = torch.abs(self.quadrature(prd) - self.quadrature(tar)).pow(self.p) + + # if ensemble + if prd.dim() == 5: + loss = torch.mean(loss, dim=1) + if self.ensemble_distributed: + loss = reduce_from_parallel_region(loss, "ensemble") / float(comm.get_size("ensemble")) + + return loss + +class SpectralRegularization(SpectralBaseLoss): + + def __init__( + self, + img_shape: Tuple[int, int], + crop_shape: Tuple[int, int], + crop_offset: Tuple[int, int], + channel_names: List[str], + grid_type: str, + lmax: Optional[int] = None, + spatial_distributed: Optional[bool] = False, + ensemble_distributed: Optional[bool] = False, + eps: Optional[float] = 1.0e-10, + logarithmic: Optional[bool] = False, + **kwargs, + ): + + super().__init__( + img_shape=img_shape, + crop_shape=crop_shape, + crop_offset=crop_offset, + channel_names=channel_names, + grid_type=grid_type, + lmax=lmax, + spatial_distributed=spatial_distributed, + ) + + self.spatial_distributed = spatial_distributed and comm.is_distributed("spatial") + self.ensemble_distributed = ensemble_distributed and comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1) + self.eps = eps + self.logarithmic = logarithmic + + # prep ls and ms for broadcasting + ls = torch.arange(self.sht.lmax).reshape(-1, 1) + ms = torch.arange(self.sht.mmax).reshape(1, -1) + + lm_weights = torch.ones((self.sht.lmax, self.sht.mmax)) + lm_weights[:, 1:] *= 2.0 + lm_weights = torch.where(ms > ls, 0.0, lm_weights) + + if comm.get_size("h") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-2, num_chunks=comm.get_size("h"))[comm.get_rank("h")] + + if comm.get_size("w") > 1: + lm_weights = split_tensor_along_dim(lm_weights, dim=-1, num_chunks=comm.get_size("w"))[comm.get_rank("w")] + + self.register_buffer("lm_weights", lm_weights, persistent=False) + + @property + def type(self): + return LossType.Probabilistic + + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, ensemble_weights: Optional[torch.Tensor] = None) -> torch.Tensor: + + # sanity checks + if forecasts.dim() == 5: + B, E = forecasts.shape[0:2] + observations = observations.unsqueeze(1) + elif forecasts.dim() == 4: + B = forecasts.shape[0] + E = -1 + forecasts = forecasts.unsqueeze(1) + observations = observations.unsqueeze(1) + else: + raise ValueError(f"Error, forecasts tensor expected to have 4 or 5 dimensions but found {forecasts.dim()}.") + + # get the data type before stripping amp types + dtype = forecasts.dtype + + # before anything else compute the transform + # as the CDF definition doesn't generalize well to more than one-dimensional variables, we treat complex and imaginary part as the same + with amp.autocast(device_type="cuda", enabled=False): + # TODO: check 4 pi normalization + forecasts = self.sht(forecasts.float()).abs().pow(2) / (4.0 * math.pi) + observations = self.sht(observations.float()).abs().pow(2) / (4.0 * math.pi) + + # we assume the following shapes: + # B, E, C, H, W (where H, W are spectral dims now) + C, H, W = forecasts.shape[-3:] + + # get nanmask from the observarions + nanmasks = torch.logical_or(torch.isnan(observations), torch.isnan(forecasts)) + + # do the summation over the ms first to obtain the PSDs + forecasts = (self.lm_weights * forecasts).sum(dim=-1) + observations = (self.lm_weights * observations).sum(dim=-1) + + if self.spatial_distributed: + forecasts = reduce_from_parallel_region(forecasts, "w") + observations = reduce_from_parallel_region(observations, "w") + + if self.logarithmic: + forecasts = torch.log(forecasts) + observations = torch.log(observations) + + diff = (forecasts - observations).abs() + + if E > 0: + diff = diff.sum(dim=1) / float(E) + if self.ensemble_distributed: + diff = reduce_from_parallel_region(diff, "ensemble") / float(comm.get_size("ensemble")) + + # do the l reduction + diff = diff.sum(dim=-1) + if self.spatial_distributed: + diff = reduce_from_parallel_region(diff, "h") + + return diff / float(self.sht.lmax) \ No newline at end of file diff --git a/makani/utils/metric.py b/makani/utils/metric.py index 8e06260c..fd8118e0 100644 --- a/makani/utils/metric.py +++ b/makani/utils/metric.py @@ -73,14 +73,14 @@ def __init__(self, metric_name, metric_channels, metric_handle, channel_names, n # CPU buffers pin_memory = self.device.type == "cuda" - + if self.aux_shape_finalized is None: data_shape_finalized = (self.num_rollout_steps, self.num_channels) integral_shape = (self.num_channels) else: data_shape_finalized = (self.num_rollout_steps, self.num_channels, *self.aux_shape_finalized) integral_shape = (self.num_channels, *self.aux_shape_finalized) - + self.rollout_curve_cpu = torch.zeros(data_shape_finalized, dtype=torch.float32, device="cpu", pin_memory=pin_memory) if self.integrate: @@ -213,12 +213,12 @@ def __init__( climatology, num_rollout_steps, device, - l1_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - rmse_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - acc_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - crps_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - spread_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], - ssr_var_names=["u10m", "t2m", "u500", "z500", "q500", "sp"], + l1_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + rmse_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + acc_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + crps_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + spread_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], + ssr_var_names=["u10m", "t2m", "sp", "sst", "u500", "z500", "q500", "q50"], rh_var_names=[], wb2_compatible=False, ): diff --git a/makani/utils/metrics/functions.py b/makani/utils/metrics/functions.py index ea4ae515..d99038c5 100644 --- a/makani/utils/metrics/functions.py +++ b/makani/utils/metrics/functions.py @@ -22,7 +22,7 @@ from physicsnemo.distributed.utils import split_tensor_along_dim from makani.mpu.mappings import distributed_transpose -from makani.utils.losses import EnsembleCRPSLoss, LossType +from makani.utils.losses import CRPSLoss, LossType from makani.utils.metrics.base_metric import _sanitize_shapes, _welford_reduction_helper, GeometricBaseMetric class GeometricL1(GeometricBaseMetric): @@ -197,7 +197,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tenso # stack along dim -1: # we form the ratio in the finalization step acc = torch.stack([cov_xy, var_x, var_y], dim=-1) - + # reduce if self.channel_reduction == "mean": acc = torch.mean(acc, dim=1) @@ -252,12 +252,12 @@ def __init__( def combine(self, vals, counts, dim=0): # sanitize shapes vals, counts = _sanitize_shapes(vals, counts, dim=dim) - + # extract parameters covs = vals[..., 0].unsqueeze(-1) m2s = vals[..., 1:3] means = vals[..., 3:5] - + # counts are: n = sum_k n_k counts_agg = torch.sum(counts, dim=0) # means are: mu = sum_i n_i * mu_i / n @@ -280,7 +280,7 @@ def finalize(self, vals, counts): return vals[..., 0] / torch.sqrt(vals[..., 1] * vals[..., 2]) def forward(self, x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: - + if hasattr(self, "bias"): x = x - self.bias y = y - self.bias @@ -349,7 +349,7 @@ def __init__( @property def type(self): return LossType.Probabilistic - + def forward(self, forecasts: torch.Tensor, observations: torch.Tensor, weight: Optional[torch.Tensor] = None) -> torch.Tensor: # sanity checks @@ -498,7 +498,7 @@ def __init__( ): super().__init__() - self.metric_func = EnsembleCRPSLoss( + self.metric_func = CRPSLoss( img_shape=img_shape, crop_shape=crop_shape, crop_offset=crop_offset, diff --git a/makani/utils/training/autoencoder_trainer.py b/makani/utils/training/autoencoder_trainer.py index e2fc07fa..3b427a75 100644 --- a/makani/utils/training/autoencoder_trainer.py +++ b/makani/utils/training/autoencoder_trainer.py @@ -482,7 +482,8 @@ def train_one_epoch(self, profiler=None): train_steps = 0 train_start = time.perf_counter_ns() self.model_train.zero_grad(set_to_none=True) - for data in tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: train_steps += 1 self.iters += 1 @@ -519,6 +520,9 @@ def train_one_epoch(self, profiler=None): accumulated_loss[0] += loss.detach().clone() * inp.shape[0] accumulated_loss[1] += inp.shape[0] + # log the loss + pbar_postfix = {"loss": loss.item()} + # perform weight update if do_update: if self.max_grad_norm > 0.0: @@ -526,6 +530,7 @@ def train_one_epoch(self, profiler=None): grad_norm = clip_grads(self.model_train, self.max_grad_norm) accumulated_grad_norm[0] += grad_norm.detach() accumulated_grad_norm[1] += 1.0 + pbar_postfix["grad norm"] = grad_norm.item() self.gscaler.step(self.optimizer) self.gscaler.update() @@ -548,6 +553,9 @@ def train_one_epoch(self, profiler=None): self.logger.info(f"Dumping weights and gradients to {weights_and_grads_path}") self.dump_weights_and_grads(weights_and_grads_path, self.model, step=(self.epoch * self.params.num_samples_per_epoch + self.iters)) + # set progress bar prefix + progress_bar.set_postfix(**pbar_postfix) + if profiler is not None: profiler.step() @@ -606,7 +614,8 @@ def validate_one_epoch(self, epoch, profiler=None): with torch.inference_mode(): with torch.no_grad(): eval_steps = 0 - for data in tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: eval_steps += 1 # map to gpu @@ -661,6 +670,9 @@ def validate_one_epoch(self, epoch, profiler=None): self.visualizer.add(tag, pred_cpu, targ_cpu) + # log the loss + progress_bar.set_postfix({"loss": loss.item()}) + # put in the metrics handler self.metrics.update(pred, inpt, loss, 0) diff --git a/makani/utils/training/deterministic_trainer.py b/makani/utils/training/deterministic_trainer.py index c8b1a456..0e0b8938 100644 --- a/makani/utils/training/deterministic_trainer.py +++ b/makani/utils/training/deterministic_trainer.py @@ -171,11 +171,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = self.loss_obj = self.loss_obj.to(self.device) self.timers["loss handler init"] = timer.time - # channel weights: - if self.log_to_screen: - chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() - chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} - self.logger.info(f"Channel weights: {chw_output}") + # # channel weights: + # if self.log_to_screen: + # chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() + # chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} + # self.logger.info(f"Channel weights: {chw_output}") # optimizer and scheduler setup with Timer() as timer: @@ -227,7 +227,12 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_list = [{"name": "windspeed_uv10", "functor": "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))", "diverging": False}] + plot_channel = "z500" + # plot_channel = "q50" + # plot_index = self.params.channel_names.index(plot_channel) + plot_index = 0 + print(self.params.channel_names) + plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() self.visualizer = visualize.VisualizationWrapper( self.params.log_to_wandb, @@ -481,7 +486,8 @@ def train_one_epoch(self, profiler=None): train_steps = 0 train_start = time.perf_counter_ns() self.model_train.zero_grad(set_to_none=True) - for data in tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: train_steps += 1 self.iters += 1 @@ -512,12 +518,12 @@ def train_one_epoch(self, profiler=None): if do_update: # regular forward pass including DDP pred = self.model_train(inp) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) else: # disable sync step with self.model_train.no_sync(): pred = self.model_train(inp) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) loss = loss * loss_scaling_fact # backward pass @@ -527,6 +533,9 @@ def train_one_epoch(self, profiler=None): accumulated_loss[0] += loss.detach().clone() * inp.shape[0] accumulated_loss[1] += inp.shape[0] + # log the loss + pbar_postfix = {"loss": loss.item()} + # perform weight update if requested: we do not need to add 1 here because we already do that before the step if do_update: if self.max_grad_norm > 0.0: @@ -534,6 +543,7 @@ def train_one_epoch(self, profiler=None): grad_norm = clip_grads(self.model_train, self.max_grad_norm) accumulated_grad_norm[0] += grad_norm.detach() accumulated_grad_norm[1] += 1.0 + pbar_postfix["grad norm"] = grad_norm.item() self.gscaler.step(self.optimizer) self.gscaler.update() @@ -556,6 +566,9 @@ def train_one_epoch(self, profiler=None): self.logger.info(f"Dumping weights and gradients to {weights_and_grads_path}") self.dump_weights_and_grads(weights_and_grads_path, self.model, step=(self.epoch * self.params.num_samples_per_epoch + self.iters)) + # set progress bar prefix + progress_bar.set_postfix(**pbar_postfix) + if torch.cuda.is_available(): torch.cuda.nvtx.range_pop() @@ -613,7 +626,8 @@ def validate_one_epoch(self, epoch, profiler=None): with torch.inference_mode(): with torch.no_grad(): eval_steps = 0 - for data in tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: eval_steps += 1 if torch.cuda.is_available(): @@ -664,6 +678,9 @@ def validate_one_epoch(self, epoch, profiler=None): tag = f"step{eval_steps}_time{str(idt).zfill(3)}" self.visualizer.add(tag, pred_cpu, targ_cpu) + # log the loss + progress_bar.set_postfix({"loss": loss.item()}) + # put in the metrics handler self.metrics.update(pred, targ, loss, idt) diff --git a/makani/utils/training/ensemble_trainer.py b/makani/utils/training/ensemble_trainer.py index 3c232fd0..c2c61dcf 100644 --- a/makani/utils/training/ensemble_trainer.py +++ b/makani/utils/training/ensemble_trainer.py @@ -171,11 +171,11 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = self.loss_obj = self.loss_obj.to(self.device) self.timers["loss handler init"] = timer.time - # channel weights: - if self.log_to_screen: - chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() - chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} - self.logger.info(f"Channel weights: {chw_output}") + # # channel weights: + # if self.log_to_screen: + # chw_weights = self.loss_obj.channel_weights.squeeze().cpu().numpy().tolist() + # chw_output = {k: v for k,v in zip(self.params.channel_names, chw_weights)} + # self.logger.info(f"Channel weights: {chw_output}") # optimizer and scheduler setup # model @@ -231,7 +231,10 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] = # visualization wrapper: with Timer() as timer: - plot_list = [{"name": "windspeed_uv10", "functor": "lambda x: np.sqrt(np.square(x[0, ...]) + np.square(x[1, ...]))", "diverging": False}] + plot_channel = "sst" + plot_index = self.params.channel_names.index(plot_channel) + # plot_index = 0 + plot_list = [{"name": plot_channel, "functor": f"lambda x: x[{plot_index}, ...]", "diverging": False}] out_bias, out_scale = self.train_dataloader.get_output_normalization() self.visualizer = visualize.VisualizationWrapper( self.params.log_to_wandb, @@ -476,7 +479,7 @@ def _ensemble_step(self, inp: torch.Tensor, tar: torch.Tensor): # stack predictions along new dim (ensemble dim): pred = torch.stack(predlist, dim=1) # compute loss - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) return pred, loss @@ -496,7 +499,8 @@ def train_one_epoch(self, profiler=None): train_steps = 0 train_start = time.perf_counter_ns() self.model_train.zero_grad(set_to_none=True) - for data in tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: train_steps += 1 self.iters += 1 @@ -538,13 +542,17 @@ def train_one_epoch(self, profiler=None): accumulated_loss[0] += loss.detach().clone() * inp.shape[0] accumulated_loss[1] += inp.shape[0] + # log the loss + pbar_postfix = {"loss": loss.item()} + # perform weight update if requested if do_update: if self.max_grad_norm > 0.0: self.gscaler.unscale_(self.model_optimizer) - grad_norm = clip_grads(self.model_train, self.max_grad_norm) + grad_norm = clip_grads(self.model_train, self.max_grad_norm, verbose=self.log_to_screen) accumulated_grad_norm[0] += grad_norm.detach() accumulated_grad_norm[1] += 1.0 + pbar_postfix["grad norm"] = grad_norm.item() self.gscaler.step(self.model_optimizer) self.gscaler.update() @@ -567,6 +575,9 @@ def train_one_epoch(self, profiler=None): self.logger.info(f"Dumping weights and gradients to {weights_and_grads_path}") self.dump_weights_and_grads(weights_and_grads_path, self.model, step=(self.epoch * self.params.num_samples_per_epoch + self.iters)) + # set progress bar prefix + progress_bar.set_postfix(**pbar_postfix) + torch.cuda.nvtx.range_pop() # profiler step @@ -633,7 +644,8 @@ def validate_one_epoch(self, epoch, profiler=None): with torch.no_grad(): eval_steps = 0 - for data in tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: eval_steps += 1 if torch.cuda.is_available(): @@ -695,6 +707,9 @@ def validate_one_epoch(self, epoch, profiler=None): pred = torch.stack(predlist, dim=1) loss = self.loss_obj(pred, targ) + # log the loss + progress_bar.set_postfix({"loss": loss.item()}) + # TODO: move all of this into the visualization handler if (eval_steps <= 1) and visualize: # create average prediction for deterministic metrics diff --git a/makani/utils/training/stochastic_trainer.py b/makani/utils/training/stochastic_trainer.py index 3b263a1d..aa57f6d4 100644 --- a/makani/utils/training/stochastic_trainer.py +++ b/makani/utils/training/stochastic_trainer.py @@ -463,7 +463,8 @@ def train_one_epoch(self): train_steps = 0 train_start = time.perf_counter_ns() self.model_train.zero_grad(set_to_none=True) - for data in tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.train_dataloader, desc=f"Training progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: train_steps += 1 self.iters += 1 @@ -489,11 +490,11 @@ def train_one_epoch(self): with amp.autocast(device_type="cuda", enabled=self.amp_enabled, dtype=self.amp_dtype): if do_update: pred, tar = self.model_train(inp, tar, n_samples=self.params.stochastic_size) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) else: with self.model_train.no_sync(): pred, tar = self.model_train(inp, tar, n_samples=self.params.stochastic_size) - loss = self.loss_obj(pred, tar) + loss = self.loss_obj(pred, tar, inp=inp) loss = loss * loss_scaling_fact self.gscaler.scale(loss).backward() @@ -502,6 +503,9 @@ def train_one_epoch(self): accumulated_loss[0] += loss.detach().clone() * inp.shape[0] accumulated_loss[1] += inp.shape[0] + # log the loss + pbar_postfix = {"loss": loss.item()} + # gradient clipping if do_update: if self.max_grad_norm > 0.0: @@ -509,6 +513,7 @@ def train_one_epoch(self): grad_norm = clip_grads(self.model_train, self.max_grad_norm) accumulated_grad_norm[0] += grad_norm.detach() accumulated_grad_norm[1] += 1.0 + pbar_postfix["grad norm"] = grad_norm.item() # perform weight update self.gscaler.step(self.optimizer) @@ -537,6 +542,9 @@ def train_one_epoch(self): self.logger.info(f"Dumping weights and gradients to {weights_and_grads_path}") self.dump_weights_and_grads(weights_and_grads_path, self.model, step=(self.epoch * self.params.num_samples_per_epoch + self.iters)) + # set progress bar prefix + progress_bar.set_postfix(**pbar_postfix) + # average the loss over ranks and steps if dist.is_initialized(): dist.all_reduce(accumulated_loss, op=dist.ReduceOp.SUM, group=comm.get_group("data")) @@ -593,7 +601,8 @@ def validate_one_epoch(self, epoch): normalize_weights(self.model, eps=1e-4) eval_steps = 0 - for data in tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen): + progress_bar = tqdm(self.valid_dataloader, desc=f"Validation progress epoch {self.epoch}", disable=not self.log_to_screen) + for data in progress_bar: eval_steps += 1 # map to gpu @@ -659,6 +668,9 @@ def validate_one_epoch(self, epoch): tag = f"step{eval_steps}_time{str(idt).zfill(3)}" self.visualizer.add(tag, pred_cpu, targ_cpu) + # log the loss + progress_bar.set_postfix({"loss": loss.item()}) + # update metrics self.metrics.update(pred, targ, loss, idt) diff --git a/makani/utils/training/training_helpers.py b/makani/utils/training/training_helpers.py index f64b621d..e7e61b30 100644 --- a/makani/utils/training/training_helpers.py +++ b/makani/utils/training/training_helpers.py @@ -53,10 +53,10 @@ def normalize_weights(model, eps=1e-5): return -def _compute_total_grad_norm(model, norm_type=2.0): +def _compute_total_grad_norm(model, norm_type=2.0, verbose=False): # iterate over parameters gnorms = [] - for param in model.parameters(): + for name, param in model.named_parameters(): if param.grad is None: continue @@ -79,6 +79,11 @@ def _compute_total_grad_norm(model, norm_type=2.0): gnorms.append(gnorm) + if verbose: + for gnorm in gnorms: + if torch.any(torch.isnan(gnorm)): + print(f"Gradient norm is NaN for parameter {name}") + # compute total norm if gnorms: total_gnorm = torch.sum(torch.stack(gnorms)) @@ -92,11 +97,11 @@ def _compute_total_grad_norm(model, norm_type=2.0): return total_gnorm -def clip_grads(model, max_grad_norm, norm_type=2.0): +def clip_grads(model, max_grad_norm, norm_type=2.0, verbose=False): # iterate over parameters with torch.no_grad(): - total_gnorm = _compute_total_grad_norm(model, norm_type) + total_gnorm = _compute_total_grad_norm(model, norm_type=norm_type, verbose=verbose) clip_factor = max_grad_norm / (total_gnorm + 1e-6) # add small epsilon to avoid division by zero clip_factor = torch.clamp(clip_factor, max=1.0) diff --git a/tests/distributed/distributed_helpers.py b/tests/distributed/distributed_helpers.py index 794f657d..85b2f41b 100644 --- a/tests/distributed/distributed_helpers.py +++ b/tests/distributed/distributed_helpers.py @@ -13,10 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import torch.distributed as dist +import torch_harmonics.distributed as thd + +from makani.utils import comm + from physicsnemo.distributed.utils import split_tensor_along_dim from physicsnemo.distributed.mappings import gather_from_parallel_region, scatter_to_parallel_region, \ reduce_from_parallel_region @@ -60,7 +65,6 @@ def get_default_parameters(): params.N_in_channels = len(params.in_channels) params.N_out_channels = len(params.out_channels) - params.target = "default" params.batch_size = 1 params.valid_autoreg_steps = 0 params.num_data_workers = 1 @@ -82,6 +86,52 @@ def get_default_parameters(): return params +def init_grid(cls): + # set up distributed + cls.grid_size_h = int(os.getenv("GRID_H", 1)) + cls.grid_size_w = int(os.getenv("GRID_W", 1)) + cls.grid_size_e = int(os.getenv("GRID_E", 1)) + cls.world_size = cls.grid_size_h * cls.grid_size_w * cls.grid_size_e + + # init groups + comm.init( + model_parallel_sizes=[cls.grid_size_h, cls.grid_size_w, 1, 1], + model_parallel_names=["h", "w", "fin", "fout"], + data_parallel_sizes=[cls.grid_size_e, -1], + data_parallel_names=["ensemble", "batch"], + ) + cls.world_rank = comm.get_world_rank() + + if torch.cuda.is_available(): + if cls.world_rank == 0: + print("Running test on GPU") + local_rank = comm.get_local_rank() + cls.device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(cls.device) + torch.cuda.manual_seed(333) + else: + if cls.world_rank == 0: + print("Running test on CPU") + cls.device = torch.device("cpu") + torch.manual_seed(333) + + # store comm group parameters + cls.wrank = comm.get_rank("w") + cls.hrank = comm.get_rank("h") + cls.erank = comm.get_rank("ensemble") + cls.w_group = comm.get_group("w") + cls.h_group = comm.get_group("h") + cls.e_group = comm.get_group("ensemble") + + # initializing sht process groups just to be sure + thd.init(cls.h_group, cls.w_group) + + if cls.world_rank == 0: + print(f"Running distributed tests on grid H x W x E = {cls.grid_size_h} x {cls.grid_size_w} x {cls.grid_size_e}") + + return + + def split_helper(tensor, dim=None, group=None): with torch.no_grad(): if (dim is not None) and dist.get_world_size(group=group): @@ -92,7 +142,7 @@ def split_helper(tensor, dim=None, group=None): tensor_local = tensor_list_local[grank] else: tensor_local = tensor.clone() - + return tensor_local @@ -116,5 +166,5 @@ def gather_helper(tensor, dim=None, group=None): tensor_gather = torch.cat(tens_gather, dim=dim) else: tensor_gather = tensor.clone() - + return tensor_gather diff --git a/tests/distributed/tests_distributed_fft.py b/tests/distributed/tests_distributed_fft.py index 26599fd7..d5d91014 100644 --- a/tests/distributed/tests_distributed_fft.py +++ b/tests/distributed/tests_distributed_fft.py @@ -19,58 +19,20 @@ from parameterized import parameterized import torch -import torch_harmonics.distributed as thd - from makani.models.common import RealFFT1, InverseRealFFT1, RealFFT2, InverseRealFFT2, RealFFT3, InverseRealFFT3 - -from makani.utils import comm from makani.mpu.fft import DistributedRealFFT1, DistributedInverseRealFFT1, DistributedRealFFT2, DistributedInverseRealFFT2, DistributedRealFFT3, DistributedInverseRealFFT3 sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -from .distributed_helpers import split_helper, gather_helper +from .distributed_helpers import init_grid, split_helper, gather_helper from ..testutils import compare_tensors class TestDistributedRealFFT(unittest.TestCase): @classmethod def setUpClass(cls): + init_grid(cls) - # set up distributed - cls.grid_size_h = int(os.getenv('GRID_H', 1)) - cls.grid_size_w = int(os.getenv('GRID_W', 1)) - cls.world_size = cls.grid_size_h * cls.grid_size_w - - # init groups - comm.init(model_parallel_sizes=[cls.grid_size_h, cls.grid_size_w, 1, 1], - model_parallel_names=["h", "w", "fin", "fout"]) - cls.world_rank = comm.get_world_rank() - - if torch.cuda.is_available(): - if cls.world_rank == 0: - print("Running test on GPU") - local_rank = comm.get_local_rank() - cls.device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(cls.device) - torch.cuda.manual_seed(333) - else: - if cls.world_rank == 0: - print("Running test on CPU") - cls.device = torch.device('cpu') - torch.manual_seed(333) - # store comm group parameters - cls.wrank = comm.get_rank("w") - cls.hrank = comm.get_rank("h") - cls.w_group = comm.get_group("w") - cls.h_group = comm.get_group("h") - - # initializing sht process groups - thd.init(cls.h_group, cls.w_group) - - if cls.world_rank == 0: - print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}") - - def _split_helper(self, tensor): tensor_local = split_helper(tensor, dim=-1, group=self.w_group) tensor_local = split_helper(tensor_local, dim=-2, group=self.h_group) diff --git a/tests/distributed/tests_distributed_layers.py b/tests/distributed/tests_distributed_layers.py index ab3afe95..6953c68b 100644 --- a/tests/distributed/tests_distributed_layers.py +++ b/tests/distributed/tests_distributed_layers.py @@ -21,15 +21,11 @@ import torch import torch.nn as nn -import torch.nn.functional as F import torch.distributed as dist import torch_harmonics as th import torch_harmonics.distributed as thd -from makani.utils import comm -from makani.utils import functions as fn - from makani.mpu.mappings import init_gradient_reduction_hooks # layer norm imports @@ -37,48 +33,14 @@ from makani.mpu.layer_norm import DistributedGeometricInstanceNormS2, DistributedInstanceNorm2d sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -from .distributed_helpers import split_helper, gather_helper -from ..testutils import compare_tensors +from .distributed_helpers import init_grid, split_helper, gather_helper +from ..testutils import compare_tensors, disable_tf32 class TestDistributedLayers(unittest.TestCase): @classmethod def setUpClass(cls): - - # set up distributed - cls.grid_size_h = int(os.getenv('GRID_H', 1)) - cls.grid_size_w = int(os.getenv('GRID_W', 1)) - cls.world_size = cls.grid_size_h * cls.grid_size_w - - # init groups - comm.init(model_parallel_sizes=[cls.grid_size_h, cls.grid_size_w, 1, 1, 1], - model_parallel_names=["h", "w", "fin", "fout", "batch"]) - cls.world_rank = comm.get_world_rank() - - torch.manual_seed(333) - if torch.cuda.is_available(): - if cls.world_rank == 0: - print("Running test on GPU") - local_rank = comm.get_local_rank() - cls.device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(cls.device) - torch.cuda.manual_seed(333) - else: - if cls.world_rank == 0: - print("Running test on CPU") - cls.device = torch.device('cpu') - - # store comm group parameters - cls.wrank = comm.get_rank("w") - cls.hrank = comm.get_rank("h") - cls.w_group = comm.get_group("w") - cls.h_group = comm.get_group("h") - - # initializing sht process groups - thd.init(cls.h_group, cls.w_group) - - if cls.world_rank == 0: - print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}") + init_grid(cls) def _init_seed(self, seed): @@ -87,13 +49,13 @@ def _init_seed(self, seed): torch.cuda.manual_seed(seed) return - + def _split_helper(self, tensor, hdim=-2, wdim=-1): tensor_local = split_helper(tensor, dim=hdim, group=self.h_group) tensor_local = split_helper(tensor_local, dim=wdim, group=self.w_group) return tensor_local - - + + def _gather_helper(self, tensor, hdim=-2, wdim=-1): tensor_gather = gather_helper(tensor, dim=hdim, group=self.h_group) tensor_gather = gather_helper(tensor_gather, dim=wdim, group=self.w_group) @@ -113,6 +75,10 @@ def _gather_helper(self, tensor, hdim=-2, wdim=-1): skip_on_empty=True, ) def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, tol, verbose=True): + + # disable tf32 + disable_tf32() + B, C, Hi, Wi, Ho, Wo = batch_size, num_chan, nlat_in, nlon_in, nlat_out, nlon_out from makani.models.common import SpectralConv @@ -232,14 +198,18 @@ def test_distributed_spectral_conv(self, nlat_in, nlon_in, nlat_out, nlon_out, b @parameterized.expand( [ - [256, 512, 32, 8, True, 1e-5], - [181, 360, 1, 10, True, 1e-5], - [256, 512, 32, 8, False, 1e-5], - [181, 360, 1, 10, False, 1e-5], + [256, 512, 32, 8, True, 1e-4], + [181, 360, 1, 10, True, 1e-4], + [256, 512, 32, 8, False, 1e-4], + [181, 360, 1, 10, False, 1e-4], ], skip_on_empty=True, ) def test_distributed_instance_norm_2d(self, nlat, nlon, batch_size, num_chan, affine, tol, verbose=True): + + # disable tf32 + disable_tf32() + B, C, H, W = batch_size, num_chan, nlat, nlon self._init_seed(333) @@ -360,14 +330,18 @@ def test_distributed_instance_norm_2d(self, nlat, nlon, batch_size, num_chan, af @parameterized.expand( [ - [181, 360, 1, 4, "equiangular", True, 1e-5], - [181, 360, 1, 4, "equiangular", False, 1e-5], - [180, 360, 1, 10, "legendre-gauss", True, 1e-5], - [180, 360, 1, 10, "legendre-gauss", False, 1e-5], + [181, 360, 1, 4, "equiangular", True, 1e-4], + [181, 360, 1, 4, "equiangular", False, 1e-4], + [180, 360, 1, 10, "legendre-gauss", True, 1e-4], + [180, 360, 1, 10, "legendre-gauss", False, 1e-4], ], skip_on_empty=True, ) def test_distributed_geometric_instance_norm_s2(self, nlat, nlon, batch_size, num_chan, grid_type, affine, tol, verbose=True): + + # disable tf32 + disable_tf32() + B, C, H, W = batch_size, num_chan, nlat, nlon # set up layer norm parameters diff --git a/tests/distributed/tests_distributed_losses.py b/tests/distributed/tests_distributed_losses.py index cd4213e3..4f2f9233 100644 --- a/tests/distributed/tests_distributed_losses.py +++ b/tests/distributed/tests_distributed_losses.py @@ -20,68 +20,30 @@ from parameterized import parameterized import torch -import torch.nn.functional as F import torch.distributed as dist -import torch_harmonics.distributed as thd - from makani.utils import comm -from makani.utils import functions as fn from makani.utils.grids import GridQuadrature -from makani.utils.losses import EnsembleCRPSLoss, EnsembleNLLLoss, EnsembleSpectralCRPSLoss +from makani.utils.losses import ( + CRPSLoss, + EnsembleNLLLoss, + SpectralCRPSLoss, + L2EnergyScoreLoss, + SpectralL2EnergyScoreLoss, + SobolevEnergyScoreLoss, +) # Add parent directory to path for testutils import sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -from .distributed_helpers import split_helper, gather_helper -from ..testutils import compare_tensors +from tests.distributed.distributed_helpers import init_grid, split_helper, gather_helper +from tests.testutils import compare_tensors, disable_tf32 class TestDistributedLoss(unittest.TestCase): @classmethod def setUpClass(cls): - - # set up distributed - cls.grid_size_h = int(os.getenv("GRID_H", 1)) - cls.grid_size_w = int(os.getenv("GRID_W", 1)) - cls.grid_size_e = int(os.getenv("GRID_E", 1)) - cls.world_size = cls.grid_size_h * cls.grid_size_w * cls.grid_size_e - - # init groups - comm.init( - model_parallel_sizes=[cls.grid_size_h, cls.grid_size_w, 1, 1], - model_parallel_names=["h", "w", "fin", "fout"], - data_parallel_sizes=[cls.grid_size_e, -1], - data_parallel_names=["ensemble", "batch"], - ) - cls.world_rank = comm.get_world_rank() - - if torch.cuda.is_available(): - if cls.world_rank == 0: - print("Running test on GPU") - local_rank = comm.get_local_rank() - cls.device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(cls.device) - torch.cuda.manual_seed(333) - else: - if cls.world_rank == 0: - print("Running test on CPU") - cls.device = torch.device("cpu") - torch.manual_seed(333) - - # store comm group parameters - cls.wrank = comm.get_rank("w") - cls.hrank = comm.get_rank("h") - cls.erank = comm.get_rank("ensemble") - cls.w_group = comm.get_group("w") - cls.h_group = comm.get_group("h") - cls.e_group = comm.get_group("ensemble") - - # initializing sht process groups just to be sure - thd.init(cls.h_group, cls.w_group) - - if cls.world_rank == 0: - print(f"Running distributed tests on grid H x W x E = {cls.grid_size_h} x {cls.grid_size_w} x {cls.grid_size_e}") + init_grid(cls) def _split_helper(self, tensor): with torch.no_grad(): @@ -117,7 +79,7 @@ def _gather_helper_bwd(self, tensor, ensemble=False): return tensor_gather - + @parameterized.expand( [ [128, 256, 32, 8, "naive", False, 1e-6], @@ -132,6 +94,10 @@ def _gather_helper_bwd(self, tensor, ensemble=False): ], skip_on_empty=True ) def test_distributed_quadrature(self, nlat, nlon, batch_size, num_chan, quad_rule, normalize, tol, verbose=False): + + # disable tf32 for deterministic comparison# disable tf32# disable tf32 for deterministic comparison# disable tf32 + disable_tf32() + B, C, H, W = batch_size, num_chan, nlat, nlon quad_local = GridQuadrature(quadrature_rule=quad_rule, img_shape=(H, W), normalize=normalize, distributed=False).to(self.device) @@ -149,7 +115,7 @@ def test_distributed_quadrature(self, nlat, nlon, batch_size, num_chan, quad_rul ograd_full = torch.randn_like(out_full) out_full.backward(ograd_full) igrad_full = inp_full.grad.clone() - + # distributed out_local = quad_dist(inp_local) out_local.backward(ograd_full) @@ -171,19 +137,23 @@ def test_distributed_quadrature(self, nlat, nlon, batch_size, num_chan, quad_rul @parameterized.expand( [ - [128, 256, 32, 8, 4, "ensemble_crps", 1e-5], - [129, 256, 1, 10, 4, "ensemble_crps", 1e-5], - [128, 256, 32, 8, 4, "ensemble_crps", 1e-5], - [129, 256, 1, 10, 4, "ensemble_crps", 1e-5], - [128, 256, 32, 8, 4, "skillspread_crps", 1e-5], - [129, 256, 1, 10, 4, "skillspread_crps", 1e-5], - [128, 256, 32, 8, 4, "gauss_crps", 1e-5], - [129, 256, 1, 10, 4, "gauss_crps", 1e-5], - [128, 256, 32, 8, 4, "ensemble_nll", 1e-5], - [129, 256, 1, 10, 4, "ensemble_nll", 1e-5], + [128, 256, 32, 8, 4, "cdf", 1e-5], + [129, 256, 1, 10, 4, "cdf", 1e-5], + [128, 256, 32, 8, 4, "cdf", 1e-5], + [129, 256, 1, 10, 4, "cdf", 1e-5], + [128, 256, 32, 8, 4, "skillspread", 1e-5], + [129, 256, 1, 10, 4, "skillspread", 1e-5], + [128, 256, 32, 8, 4, "gauss", 1e-5], + [129, 256, 1, 10, 4, "gauss", 1e-5], + [128, 256, 32, 8, 4, "nll", 1e-5], + [129, 256, 1, 10, 4, "nll", 1e-5], ], skip_on_empty=True ) def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, loss_type, tol, verbose=False): + + # disable tf32 for deterministic comparison# disable tf32 + disable_tf32() + B, E, C, H, W = batch_size, ens_size, num_chan, nlat, nlon # generate gauss random distributed around 1, with sigma=2 @@ -191,91 +161,37 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, loss inp_full = torch.randn((B, E, C, H, W), dtype=torch.float32, device=self.device) * sigma + mean obs_full = torch.full((B, C, H, W), fill_value=mean, dtype=torch.float32, device=self.device) - if loss_type == "ensemble_crps": + if loss_type != "nll": # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_local = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), channel_names=(), grid_type="equiangular", pole_mask=0, - crps_type="cdf", + crps_type=loss_type, + eps=1.0e-5, spatial_distributed=False, ensemble_distributed=False, ensemble_weights=None, ).to(self.device) # distributed loss - loss_fn_dist = EnsembleCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - pole_mask=0, - crps_type="cdf", - spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), - ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), - ensemble_weights=None, - ).to(self.device) - elif loss_type == "gauss_crps": - # local loss - loss_fn_local = EnsembleCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - pole_mask=0, - crps_type="gauss", - spatial_distributed=False, - ensemble_distributed=False, - eps=1.0e-5, - ).to(self.device) - - # distributed loss - loss_fn_dist = EnsembleCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - pole_mask=0, - crps_type="gauss", - spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), - ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), - eps=1.0e-5, - ).to(self.device) - elif loss_type == "skillspread_crps": - # local loss - loss_fn_local = EnsembleCRPSLoss( + loss_fn_dist = CRPSLoss( img_shape=(H, W), crop_shape=None, crop_offset=(0, 0), channel_names=(), grid_type="equiangular", pole_mask=0, - crps_type="skillspread", - spatial_distributed=False, - ensemble_distributed=False, + crps_type=loss_type, eps=1.0e-5, - ).to(self.device) - - # distributed loss - loss_fn_dist = EnsembleCRPSLoss( - img_shape=(H, W), - crop_shape=(H, W), - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - pole_mask=0, - crps_type="skillspread", spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), - eps=1.0e-5, + ensemble_weights=None, ).to(self.device) - elif loss_type == "ensemble_nll": + else: # local loss loss_fn_local = EnsembleNLLLoss( img_shape=(H, W), @@ -358,107 +274,56 @@ def test_distributed_crps(self, nlat, nlon, batch_size, num_chan, ens_size, loss @parameterized.expand( [ - [128, 256, 32, 8, 4, "ensemble_crps", False, 1e-4], - [129, 256, 1, 10, 4, "ensemble_crps", False, 1e-4], - [128, 256, 32, 8, 4, "ensemble_crps", True, 1e-4], - [128, 256, 32, 8, 4, "skillspread_crps", False, 1e-4], - [129, 256, 1, 10, 4, "skillspread_crps", False, 1e-4], - [128, 256, 32, 8, 4, "skillspread_crps", True, 1e-4], - [129, 256, 1, 10, 4, "skillspread_crps", True, 1e-4], + [128, 256, 32, 8, 4, "cdf", True, 1e-4], + [129, 256, 1, 10, 4, "cdf", True, 1e-4], + [128, 256, 32, 8, 4, "skillspread", False, 1e-4], + [129, 256, 1, 10, 4, "skillspread", False, 1e-4], + [128, 256, 32, 8, 4, "skillspread", True, 1e-4], + [129, 256, 1, 10, 4, "skillspread", True, 1e-4], ], skip_on_empty=True ) def test_distributed_spectral_crps(self, nlat, nlon, batch_size, num_chan, ens_size, loss_type, absolute, tol, verbose=True): + + # disable tf32 + disable_tf32() + + # extract shapes B, E, C, H, W = batch_size, ens_size, num_chan, nlat, nlon # generate gauss random distributed around 1, with sigma=2 mean, sigma = (1.0, 2.0) inp_full = torch.randn((B, E, C, H, W), dtype=torch.float32, device=self.device) * sigma + mean - obs_full = torch.full((B, C, H, W), fill_value=mean, dtype=torch.float32, device=self.device) - - if loss_type == "ensemble_crps": - # local loss - loss_fn_local = EnsembleSpectralCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - crps_type="cdf", - spatial_distributed=False, - ensemble_distributed=False, - ensemble_weights=None, - absolute=absolute, - ).to(self.device) + obs_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) * sigma * 0.01 + mean - # distributed loss - loss_fn_dist = EnsembleSpectralCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - crps_type="cdf", - spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), - ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), - ensemble_weights=None, - absolute=absolute, - ).to(self.device) - elif loss_type == "gauss_crps": - # local loss - loss_fn_local = EnsembleSpectralCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - crps_type="gauss", - spatial_distributed=False, - ensemble_distributed=False, - eps=1.0e-5, - absolute=absolute, - ).to(self.device) - - # distributed loss - loss_fn_dist = EnsembleSpectralCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - crps_type="gauss", - spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), - ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), - eps=1.0e-5, - absolute=absolute, - ).to(self.device) - elif loss_type == "skillspread_crps": - # local loss - loss_fn_local = EnsembleSpectralCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - crps_type="skillspread", - spatial_distributed=False, - ensemble_distributed=False, - eps=1.0e-5, - absolute=absolute, - ).to(self.device) + # local loss + loss_fn_local = SpectralCRPSLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + crps_type=loss_type, + spatial_distributed=False, + ensemble_distributed=False, + ensemble_weights=None, + eps=1.0e-5, + absolute=absolute, + ).to(self.device) - # distributed loss - loss_fn_dist = EnsembleSpectralCRPSLoss( - img_shape=(H, W), - crop_shape=None, - crop_offset=(0, 0), - channel_names=(), - grid_type="equiangular", - crps_type="skillspread", - spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), - ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), - eps=1.0e-5, - absolute=absolute, - ).to(self.device) + # distributed loss + loss_fn_dist = SpectralCRPSLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + crps_type=loss_type, + spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), + ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), + ensemble_weights=None, + eps=1.0e-5, + absolute=absolute, + ).to(self.device) ############################################################# # local loss @@ -510,10 +375,297 @@ def test_distributed_spectral_crps(self, nlat, nlon, batch_size, num_chan, ens_s self.assertTrue(compare_tensors("forecast gradients", igrad_gather_full, igrad_full, tol, tol, verbose=verbose)) # observation grads + with self.subTest(desc="observation gradients"): + obsgrad_gather_full = self._gather_helper_bwd(obsgrad_local, False) + if self.world_rank == 0: + print("obsgrad_gather_full", obsgrad_gather_full[0, 0, ...], "obsgrad_full", obsgrad_full[0, 0, ...]) + self.assertTrue(compare_tensors("observation gradients", obsgrad_gather_full, obsgrad_full, tol, tol, verbose=verbose)) + + + @parameterized.expand( + [ + [128, 256, 8, 3, 4, 1e-4], + [129, 256, 2, 5, 4, 1e-4], + ], skip_on_empty=True + ) + def test_distributed_l2_energy_score(self, nlat, nlon, batch_size, num_chan, ens_size, tol, verbose=False): + + # disable tf32 for deterministic comparison# disable tf32 + disable_tf32() + + B, E, C, H, W = batch_size, ens_size, num_chan, nlat, nlon + + # inputs + mean, sigma = (1.0, 2.0) + forecasts_full = torch.randn((B, E, C, H, W), dtype=torch.float32, device=self.device) * sigma + mean + obs_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) * sigma * 0.01 + mean + + # local loss + loss_fn_local = L2EnergyScoreLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + pole_mask=0, + alpha=1.0, + beta=1.0, + eps=1.0e-5, + spatial_distributed=False, + ensemble_distributed=False, + ensemble_weights=None, + ).to(self.device) + + # distributed loss + loss_fn_dist = L2EnergyScoreLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + pole_mask=0, + alpha=1.0, + beta=1.0, + eps=1.0e-5, + spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), + ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), + ensemble_weights=None, + ).to(self.device) + + ############################################################# + # local loss + ############################################################# + forecasts_full.requires_grad = True + obs_full.requires_grad = True + loss_full = loss_fn_local(forecasts_full, obs_full) + + with torch.no_grad(): + ograd_full = torch.randn_like(loss_full) + ograd_local = ograd_full.clone() + + loss_full.backward(ograd_full) + fgrad_full = forecasts_full.grad.clone() + obsgrad_full = obs_full.grad.clone() + + ############################################################# + # distributed loss + ############################################################# + forecasts_local = self._split_helper(forecasts_full.clone()) + obs_local = self._split_helper(obs_full.clone()) + forecasts_local.requires_grad = True + obs_local.requires_grad = True + + loss_local = loss_fn_dist(forecasts_local, obs_local) + loss_local.backward(ograd_local) + fgrad_local = forecasts_local.grad.clone() + obsgrad_local = obs_local.grad.clone() + + ############################################################# + # evaluate FWD pass + ############################################################# + with self.subTest(desc="outputs"): + self.assertTrue(compare_tensors("outputs", loss_local, loss_full, tol, tol, verbose=verbose)) + + ############################################################# + # evaluate BWD pass + ############################################################# + with self.subTest(desc="forecast gradients"): + fgrad_gather_full = self._gather_helper_bwd(fgrad_local, True) + self.assertTrue(compare_tensors("forecast gradients", fgrad_gather_full, fgrad_full, tol, tol, verbose=verbose)) + + with self.subTest(desc="observation gradients"): + obsgrad_gather_full = self._gather_helper_bwd(obsgrad_local, False) + self.assertTrue(compare_tensors("observation gradients", obsgrad_gather_full, obsgrad_full, tol, tol, verbose=verbose)) + + + @parameterized.expand( + [ + [128, 256, 8, 13, 4, 1e-4], + [129, 256, 2, 12, 4, 1e-4], + ], skip_on_empty=True + ) + def test_distributed_spectral_l2_energy_score(self, nlat, nlon, batch_size, num_chan, ens_size, tol, verbose=True): + + # disable tf32 for deterministic comparison + disable_tf32() + + # shapes + B, E, C, H, W = batch_size, ens_size, num_chan, nlat, nlon + + mean, sigma = (1.0, 2.0) + forecasts_full = torch.randn((B, E, C, H, W), dtype=torch.float32, device=self.device) * sigma + mean + obs_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) * sigma * 0.01 + mean + + # local loss + loss_fn_local = SpectralL2EnergyScoreLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + alpha=1.0, + eps=1.0e-3, + spatial_distributed=False, + ensemble_distributed=False, + ensemble_weights=None, + ).to(self.device) + + # distributed loss + loss_fn_dist = SpectralL2EnergyScoreLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + alpha=1.0, + eps=1.0e-3, + spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), + ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), + ensemble_weights=None, + ).to(self.device) + + ############################################################# + # local loss + ############################################################# + forecasts_full.requires_grad = True + obs_full.requires_grad = True + loss_full = loss_fn_local(forecasts_full, obs_full) + + with torch.no_grad(): + ograd_full = torch.randn_like(loss_full) + ograd_local = ograd_full.clone() + + loss_full.backward(ograd_full) + fgrad_full = forecasts_full.grad.clone() + obsgrad_full = obs_full.grad.clone() + + ############################################################# + # distributed loss + ############################################################# + forecasts_local = self._split_helper(forecasts_full.clone()) + obs_local = self._split_helper(obs_full.clone()) + forecasts_local.requires_grad = True + obs_local.requires_grad = True + + loss_local = loss_fn_dist(forecasts_local, obs_local) + loss_local.backward(ograd_local) + fgrad_local = forecasts_local.grad.clone() + obsgrad_local = obs_local.grad.clone() + + ############################################################# + # evaluate FWD pass + ############################################################# + with self.subTest(desc="outputs"): + self.assertTrue(compare_tensors("outputs", loss_local, loss_full, tol, tol, verbose=verbose)) + + ############################################################# + # evaluate BWD pass + ############################################################# + with self.subTest(desc="forecast gradients"): + fgrad_gather_full = self._gather_helper_bwd(fgrad_local, True) + self.assertTrue(compare_tensors("forecast gradients", fgrad_gather_full, fgrad_full, tol, tol, verbose=verbose)) + + with self.subTest(desc="observation gradients"): + obsgrad_gather_full = self._gather_helper_bwd(obsgrad_local, False) + self.assertTrue(compare_tensors("observation gradients", obsgrad_gather_full, obsgrad_full, tol, tol, verbose=verbose)) + + + @parameterized.expand( + [ + [128, 256, 8, 13, 4, 1.0, 1.0, 1.0, 1.0, 1e-5], + [129, 256, 2, 12, 4, 0.8, 1.2, 0.5, 0.7, 1e-5], + ], skip_on_empty=True + ) + def test_distributed_sobolev_energy_score(self, nlat, nlon, batch_size, num_chan, ens_size, alpha, beta, offset, fraction, tol, verbose=True): + + disable_tf32() + + B, E, C, H, W = batch_size, ens_size, num_chan, nlat, nlon + + mean, sigma = (0.3, 1.1) + forecasts_full = torch.randn((B, E, C, H, W), dtype=torch.float32, device=self.device) * sigma + mean + obs_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device) * sigma * 0.05 + mean + + # local loss + loss_fn_local = SobolevEnergyScoreLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + spatial_distributed=False, + ensemble_distributed=False, + ensemble_weights=None, + alpha=alpha, + beta=beta, + offset=offset, + fraction=fraction, + eps=1.0e-6, + ).to(self.device) + + # distributed loss + loss_fn_dist = SobolevEnergyScoreLoss( + img_shape=(H, W), + crop_shape=None, + crop_offset=(0, 0), + channel_names=(), + grid_type="equiangular", + spatial_distributed=(comm.is_distributed("spatial") and (comm.get_size("spatial") > 1)), + ensemble_distributed=(comm.is_distributed("ensemble") and (comm.get_size("ensemble") > 1)), + ensemble_weights=None, + alpha=alpha, + beta=beta, + offset=offset, + fraction=fraction, + eps=1.0e-6, + ).to(self.device) + + ############################################################# + # local loss + ############################################################# + forecasts_full.requires_grad = True + obs_full.requires_grad = True + loss_full = loss_fn_local(forecasts_full, obs_full) + + with torch.no_grad(): + ograd_full = torch.randn_like(loss_full) + ograd_local = ograd_full.clone() + + loss_full.backward(ograd_full) + fgrad_full = forecasts_full.grad.clone() + obsgrad_full = obs_full.grad.clone() + + ############################################################# + # distributed loss + ############################################################# + forecasts_local = self._split_helper(forecasts_full.clone()) + obs_local = self._split_helper(obs_full.clone()) + forecasts_local.requires_grad = True + obs_local.requires_grad = True + + loss_local = loss_fn_dist(forecasts_local, obs_local) + loss_local.backward(ograd_local) + fgrad_local = forecasts_local.grad.clone() + obsgrad_local = obs_local.grad.clone() + + ############################################################# + # evaluate FWD pass + ############################################################# + with self.subTest(desc="outputs"): + self.assertTrue(compare_tensors("outputs", loss_local, loss_full, tol, tol, verbose=verbose)) + + ############################################################# + # evaluate BWD pass + ############################################################# + with self.subTest(desc="forecast gradients"): + fgrad_gather_full = self._gather_helper_bwd(fgrad_local, True) + self.assertTrue(compare_tensors("forecast gradients", fgrad_gather_full, fgrad_full, tol, tol, verbose=verbose)) + with self.subTest(desc="observation gradients"): obsgrad_gather_full = self._gather_helper_bwd(obsgrad_local, False) self.assertTrue(compare_tensors("observation gradients", obsgrad_gather_full, obsgrad_full, tol, tol, verbose=verbose)) if __name__ == "__main__": + disable_tf32() unittest.main() diff --git a/tests/distributed/tests_distributed_metrics.py b/tests/distributed/tests_distributed_metrics.py index faed988e..ec19be04 100644 --- a/tests/distributed/tests_distributed_metrics.py +++ b/tests/distributed/tests_distributed_metrics.py @@ -33,7 +33,7 @@ from makani.utils import MetricsHandler sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) -from .distributed_helpers import split_helper, get_default_parameters +from .distributed_helpers import init_grid, split_helper, get_default_parameters from ..testutils import compare_arrays # because of physicsnemo/NCCL tear down issues, we can only run one test at a time @@ -72,6 +72,7 @@ def setUpClass(cls, path="/tmp"): @classmethod def tearDownClass(cls): cls.tmpdir.cleanup() + cls.mpi_comm.finalize() def _init_comms(self): diff --git a/tests/test_losses.py b/tests/test_losses.py index 437dd0fe..5f44756f 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -24,10 +24,11 @@ import torch from makani.utils import LossHandler -from makani.utils.losses import EnsembleCRPSLoss +from makani.utils.losses import CRPSLoss +from makani.utils.losses.energy_score import SobolevEnergyScoreLoss sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from .testutils import get_default_parameters, compare_tensors, compare_arrays +from .testutils import get_default_parameters, compare_tensors, compare_arrays, disable_tf32 from properscoring import crps_ensemble, crps_gaussian @@ -139,7 +140,7 @@ def test_loss_batchsize_independence(self, losses, uncertainty_weighting=False): # test initialization of loss object loss_obj = LossHandler(self.params) - + shape = (self.params.batch_size, self.params.N_out_channels, self.params.img_shape_x, self.params.img_shape_y) inp = torch.randn(*shape) @@ -158,7 +159,7 @@ def test_loss_weighted(self, losses, uncertainty_weighting=False): """ Tests initialization of loss, as well as the forward and backward pass """ - + self.params.losses = losses self.params.uncertainty_weighting = uncertainty_weighting @@ -180,7 +181,7 @@ def test_loss_weighted(self, losses, uncertainty_weighting=False): # compute weighted loss out_weighted = loss_obj(tar, inp, wgt) - + self.assertTrue(compare_tensors("loss", out, out_weighted)) @@ -252,7 +253,7 @@ def test_running_stats(self): self.assertTrue(compare_tensors("var", var, expected_var)) def test_ensemble_crps(self): - crps_func = EnsembleCRPSLoss( + crps_func = CRPSLoss( img_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_offset=(0, 0), @@ -264,7 +265,7 @@ def test_ensemble_crps(self): ensemble_distributed=False, ensemble_weights=None, ) - + for ensemble_size in [1, 10]: with self.subTest(desc=f"ensemble size {ensemble_size}"): # generate input tensor @@ -293,15 +294,15 @@ def test_ensemble_crps(self): result_proper = crps_ensemble(tar_arr, inp_arr, weights=None, issorted=False, axis=axis) quad_weight_arr = crps_func.quadrature.quad_weight.cpu().numpy() result_proper = np.sum(result_proper * quad_weight_arr, axis=(2, 3)) - + self.assertTrue(compare_arrays("output", result, result_proper)) def test_gauss_crps(self): - + # protext against sigma=0 eps = 1.0e-5 - - crps_func = EnsembleCRPSLoss( + + crps_func = CRPSLoss( img_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_shape=(self.params.img_shape_x, self.params.img_shape_y), crop_offset=(0, 0), @@ -313,34 +314,111 @@ def test_gauss_crps(self): ensemble_distributed=False, eps=eps, ) - + for ensemble_size in [1, 10]: with self.subTest(desc=f"ensemble size {ensemble_size}"): # generate input tensor inp = torch.empty((self.params.batch_size, ensemble_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) with torch.no_grad(): inp.normal_(1.0, 1.0) - + # target tensor tar = torch.ones((self.params.batch_size, self.params.N_in_channels, self.params.img_shape_x, self.params.img_shape_y), dtype=torch.float32) - + # torch result result = crps_func(inp, tar).cpu().numpy() - + # properscoring result tar_arr = tar.cpu().numpy() inp_arr = inp.cpu().numpy() - + # compute mu, sigma, guard against underflows mu = np.mean(inp_arr, axis=1) sigma = np.maximum(np.sqrt(np.var(inp_arr, axis=1)), eps) - + result_proper = crps_gaussian(tar_arr, mu, sigma, grad=False) quad_weight_arr = crps_func.quadrature.quad_weight.cpu().numpy() result_proper = np.sum(result_proper * quad_weight_arr, axis=(2, 3)) - + self.assertTrue(compare_arrays("output", result, result_proper)) + @parameterized.expand([ + # (beta, alpha, offset, fraction, channel_reduction) + (0.5, 1.0, 1.0, 1.0, True), + (1.0, 1.0, 1.0, 1.0, True), + (2.0, 1.0, 1.0, 1.0, True), + (1.0, 0.5, 1.0, 1.0, True), + (1.0, 2.0, 1.0, 1.0, True), + (1.0, 1.0, 0.5, 1.0, True), + (1.0, 1.0, 2.0, 1.0, True), + (1.0, 1.0, 1.0, 0.5, True), + (1.0, 1.0, 1.0, 2.0, True), + (1.0, 1.0, 1.0, 1.0, False), + (0.5, 0.5, 0.5, 0.5, True), + (2.0, 2.0, 2.0, 2.0, True), + ]) + def test_sobolev_energy_score(self, beta, alpha, offset, fraction, channel_reduction): + """ + Tests SobolevEnergyScoreLoss for different parameter combinations, + verifying that output and gradients are not NaN or inf. + """ + sobolev_loss = SobolevEnergyScoreLoss( + img_shape=(self.params.img_shape_x, self.params.img_shape_y), + crop_shape=(self.params.img_shape_x, self.params.img_shape_y), + crop_offset=(0, 0), + channel_names=self.params.channel_names, + grid_type=self.params.model_grid_type, + lmax=None, + spatial_distributed=False, + ensemble_distributed=False, + channel_reduction=channel_reduction, + alpha=alpha, + beta=beta, + offset=offset, + fraction=fraction, + ).to(self.device) + + for ensemble_size in [2, 6]: + with self.subTest(desc=f"beta={beta}, alpha={alpha}, offset={offset}, fraction={fraction}, channel_reduction={channel_reduction}, ensemble_size={ensemble_size}"): + # Generate forecast tensor: (batch, ensemble, channels, lat, lon) + forecasts = torch.randn( + self.params.batch_size, + ensemble_size, + self.params.N_in_channels, + self.params.img_shape_x, + self.params.img_shape_y, + device=self.device, + dtype=torch.float32, + requires_grad=True, + ) + + # Generate observation tensor: (batch, channels, lat, lon) + observations = torch.randn( + self.params.batch_size, + self.params.N_in_channels, + self.params.img_shape_x, + self.params.img_shape_y, + device=self.device, + dtype=torch.float32, + ) + + # Forward pass + result = sobolev_loss(forecasts, observations) + + # Check output is not NaN or inf + self.assertFalse(torch.isnan(result).any(), f"Output contains NaN values") + self.assertFalse(torch.isinf(result).any(), f"Output contains inf values") + + # Backward pass + loss = result.sum() + loss.backward() + + # Check gradients are not NaN or inf + self.assertIsNotNone(forecasts.grad, "Gradients are None") + self.assertFalse(torch.isnan(forecasts.grad).any(), f"Gradients contain NaN values") + self.assertFalse(torch.isinf(forecasts.grad).any(), f"Gradients contain inf values") + if __name__ == "__main__": + disable_tf32() unittest.main() diff --git a/tests/test_models.py b/tests/test_models.py index fca13aad..4234f077 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,7 +26,7 @@ from makani.utils import LossHandler sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from .testutils import get_default_parameters, compare_tensors +from .testutils import get_default_parameters, compare_tensors, disable_tf32 class TestModels(unittest.TestCase): @@ -190,4 +190,5 @@ def test_gradient_accumulation(self, nettype, atol, rtol, verbose=True): if __name__ == "__main__": + disable_tf32() unittest.main() diff --git a/tests/testutils.py b/tests/testutils.py index fdaa7dc5..c9bcedbb 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -17,6 +17,7 @@ import json import datetime as dt from typing import List, Optional +from packaging import version import numpy as np import h5py as h5 @@ -63,7 +64,6 @@ def get_default_parameters(): params.N_in_channels = len(params.in_channels) params.N_out_channels = len(params.out_channels) - params.target = "default" params.batch_size = 1 params.valid_autoreg_steps = 0 params.num_data_workers = 1 @@ -297,4 +297,17 @@ def compare_arrays(msg, array1, array2, atol=1e-8, rtol=1e-5, verbose=False): array2_abs_bad = np.abs(array2).flatten()[worst_diff].item() print(f"Worst allclose condition violation: {diff_bad} <= {atol} + {rtol} * {array2_abs_bad} = {atol + rtol * array2_abs_bad}") - return allclose \ No newline at end of file + return allclose + +def disable_tf32(): + # the api for this was changed lately in pytorch + if torch.cuda.is_available(): + if version.parse(torch.__version__) >= version.parse("2.9.0"): + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + else: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + return \ No newline at end of file