Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fme/ace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from fme.ace.registry.land_net import LandNetBuilder
from fme.ace.registry.m2lines import FloeNetBuilder, SamudraBuilder
from fme.ace.registry.sfno import SFNO_V0_1_0, SphericalFourierNeuralOperatorBuilder
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNO
from fme.ace.registry.stochastic_sfno import NoiseConditionedSFNOBuilder
from fme.ace.stepper import DerivedForcingsConfig, StepperOverrideConfig
from fme.ace.stepper.insolation.config import InsolationConfig, NameConfig, ValueConfig
from fme.ace.stepper.parameter_init import (
Expand Down
163 changes: 24 additions & 139 deletions fme/ace/registry/stochastic_sfno.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,10 @@
import dataclasses
import math
from collections.abc import Callable
from typing import Literal

import torch

from fme.ace.registry.registry import ModuleConfig, ModuleSelector
from fme.core.dataset_info import DatasetInfo
from fme.core.models.conditional_sfno.sfnonet import (
Context,
ContextConfig,
get_lat_lon_sfnonet,
)
from fme.core.models.conditional_sfno.sfnonet import (
SphericalFourierNeuralOperatorNet as ConditionalSFNO,
)


def isotropic_noise(
leading_shape: tuple[int, ...],
lmax: int, # length of the ℓ axis expected by isht
mmax: int, # length of the m axis expected by isht
isht: Callable[[torch.Tensor], torch.Tensor],
device: torch.device,
) -> torch.Tensor:
# --- draw independent N(0,1) parts --------------------------------------
coeff_shape = (*leading_shape, lmax, mmax)
real = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real

# m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1
sqrt2 = math.sqrt(2.0)
real[..., :, 1:] /= sqrt2
imag[..., :, 1:] /= sqrt2

# --- global scale that makes Var[T(θ,φ)] = 1 ---------------------------
scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax)
alm = (real + 1j * imag) * scale

return isht(alm)


class NoiseConditionedSFNO(torch.nn.Module):
def __init__(
self,
conditional_model: ConditionalSFNO,
img_shape: tuple[int, int],
noise_type: Literal["isotropic", "gaussian"] = "gaussian",
embed_dim_noise: int = 256,
embed_dim_pos: int = 0,
embed_dim_labels: int = 0,
):
super().__init__()
self.conditional_model = conditional_model
self.embed_dim = embed_dim_noise
self.noise_type = noise_type
self.label_pos_embed: torch.nn.Parameter | None = None
# register pos embed if pos_embed_dim != 0
if embed_dim_pos != 0:
self.pos_embed = torch.nn.Parameter(
torch.zeros(
1, embed_dim_pos, img_shape[0], img_shape[1], requires_grad=True
)
)
# initialize pos embed with std=0.02
torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
if embed_dim_labels > 0:
self.label_pos_embed = torch.nn.Parameter(
torch.zeros(
embed_dim_labels,
embed_dim_pos,
img_shape[0],
img_shape[1],
requires_grad=True,
)
)
torch.nn.init.trunc_normal_(self.label_pos_embed, std=0.02)
else:
self.pos_embed = None

def forward(
self, x: torch.Tensor, labels: torch.Tensor | None = None
) -> torch.Tensor:
x = x.reshape(-1, *x.shape[-3:])
if self.noise_type == "isotropic":
lmax = self.conditional_model.itrans_up.lmax
mmax = self.conditional_model.itrans_up.mmax
noise = isotropic_noise(
(x.shape[0], self.embed_dim),
lmax,
mmax,
self.conditional_model.itrans_up,
device=x.device,
)
elif self.noise_type == "gaussian":
noise = torch.randn(
[x.shape[0], self.embed_dim, *x.shape[-2:]],
device=x.device,
dtype=x.dtype,
)
else:
raise ValueError(f"Invalid noise type: {self.noise_type}")

if self.pos_embed is not None:
embedding_pos = self.pos_embed.repeat(noise.shape[0], 1, 1, 1)
if self.label_pos_embed is not None and labels is not None:
label_embedding_pos = torch.einsum(
"bl, lpxy -> bpxy", labels, self.label_pos_embed
)
embedding_pos = embedding_pos + label_embedding_pos
else:
embedding_pos = None

return self.conditional_model(
x,
Context(
embedding_scalar=None,
embedding_pos=embedding_pos,
labels=labels,
noise=noise,
),
)
from fme.core.models.conditional_sfno.v0.stochastic_sfno import build as build_v0
from fme.core.models.conditional_sfno.v1.stochastic_sfno import build as build_v1


# this is based on the call signature of SphericalFourierNeuralOperatorNet at
Expand All @@ -135,6 +18,7 @@ class NoiseConditionedSFNOBuilder(ModuleConfig):
Noise is provided as conditioning input to conditional layer normalization.

Attributes:
version: Version of the model.
spectral_transform: Type of spherical transform to use.
Kept for backwards compatibility.
filter_type: Type of filter to use.
Expand Down Expand Up @@ -186,6 +70,7 @@ class NoiseConditionedSFNOBuilder(ModuleConfig):
Defaults to spectral_lora_rank.
"""

version: Literal["v0", "v1", "latest"] = "v0"
spectral_transform: Literal["sht"] = "sht"
filter_type: Literal["linear", "makani-linear"] = "linear"
operator_type: Literal["dhconv"] = "dhconv"
Expand Down Expand Up @@ -236,30 +121,30 @@ def __post_init__(self):
"Only 'dhconv' operator_type is supported for "
"NoiseConditionedSFNO models."
)
if self.version == "latest":
# must replace as eventual newer versions break backwards compatibility
# v1 is not stable yet, keep using v0 as default for now
self.version = "v0"

def build(
self,
n_in_channels: int,
n_out_channels: int,
dataset_info: DatasetInfo,
):
sfno_net = get_lat_lon_sfnonet(
params=self,
in_chans=n_in_channels,
out_chans=n_out_channels,
img_shape=dataset_info.img_shape,
context_config=ContextConfig(
embed_dim_scalar=0,
embed_dim_pos=self.context_pos_embed_dim,
embed_dim_noise=self.noise_embed_dim,
embed_dim_labels=len(dataset_info.all_labels),
),
)
return NoiseConditionedSFNO(
sfno_net,
noise_type=self.noise_type,
embed_dim_noise=self.noise_embed_dim,
embed_dim_pos=self.context_pos_embed_dim,
embed_dim_labels=len(dataset_info.all_labels),
img_shape=dataset_info.img_shape,
)
if self.version == "v0":
return build_v0(
self,
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
dataset_info=dataset_info,
)
elif self.version == "v1":
return build_v1(
self,
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
dataset_info=dataset_info,
)
else:
raise ValueError(f"Unsupported version: {self.version}")
4 changes: 4 additions & 0 deletions fme/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import models as _ # to trigger registrations
from .atmosphere_data import AtmosphereData
from .device import get_device, using_gpu
from .gridded_ops import GriddedOperations
Expand All @@ -14,11 +15,14 @@
from .rand import set_seed
from .registry import Registry

del _

__all__ = [
"spherical_area_weights",
"weighted_mean",
"weighted_mean_bias",
"weighted_nanmean",
"weighted_sum",
"root_mean_squared_error",
"get_device",
"using_gpu",
Expand Down
1 change: 1 addition & 0 deletions fme/core/benchmark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
results
Loading
Loading