Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
9f8dd2e
adding index exclusion list support in dali loaders, making index han…
azrael417 Sep 17, 2025
d400d36
some small fixes for non combined dali loader and unannotated files
azrael417 Sep 17, 2025
a83e177
distringuishing between read_shape and return_shape. read_shape will …
azrael417 Sep 18, 2025
d332199
fixing a bug where samples per year was not computed correctly
azrael417 Sep 18, 2025
9877126
making nicer prints
azrael417 Sep 18, 2025
2502aa0
fixed dataloader indexing
azrael417 Sep 18, 2025
02b12f0
fixed history with noise
azrael417 Sep 18, 2025
bcaf66d
setting explicit noise seeds for ensemble serial settings. without th…
azrael417 Sep 19, 2025
981319d
adding accessor functions for some RNG features when noise is enabled
azrael417 Sep 19, 2025
a238877
removing unneccessary imports
azrael417 Sep 19, 2025
03b3e0b
removing more imports
azrael417 Sep 19, 2025
a2c157d
removing more imports
azrael417 Sep 19, 2025
db995f4
resolving rebase conflict
bonevbs Feb 19, 2026
b0b147c
some changes to the model and losses
bonevbs Oct 30, 2025
4353bc7
restoring logic from main regarding noise states
bonevbs Oct 30, 2025
26cbc07
fixing model package
bonevbs Nov 1, 2025
1ccf085
some changes to deterministic trainer
bonevbs Nov 12, 2025
846d3b6
removing the residual training optiona and instead extending loss han…
bonevbs Nov 24, 2025
67553a6
updating other configs
bonevbs Nov 24, 2025
74869b3
added routine to compute spherical bandlimit
bonevbs Nov 24, 2025
e3ebc5b
resolving rebase conflict
bonevbs Feb 19, 2026
bc7a5d2
adding energy score loss
bonevbs Dec 2, 2025
b603f8f
removing the random CRPS
bonevbs Dec 2, 2025
d46ec29
resolving rebase conflict
bonevbs Feb 19, 2026
5242d7e
more fixes for energy score loss
bonevbs Dec 2, 2025
354deb5
resolving rebase conflict
bonevbs Feb 19, 2026
465dca8
fixing another bug
bonevbs Dec 2, 2025
c9c295e
fixing energy score
bonevbs Dec 2, 2025
3e55c95
cleaning up some losses
bonevbs Dec 10, 2025
ae84efd
implemented improved Sobolev energy score
bonevbs Dec 21, 2025
61f5b99
resolving rebase conflict
bonevbs Feb 19, 2026
9cde914
fixing bugs in implementations of the energy score
bonevbs Jan 9, 2026
3f9ffa0
adding stochasticity to vit
bonevbs Jan 19, 2026
29f5bc4
adding spectral coherence loss
bonevbs Jan 20, 2026
6a1aa45
incorporaitng updates to FCN3 and energy score
bonevbs Jan 26, 2026
8e5a3d6
small refactor
azrael417 Jan 27, 2026
eaa1e41
fixing Pangu and adding tf32 disabling
bonevbs Jan 27, 2026
01115c3
adding default indexing mode to pangu meshgrid
bonevbs Jan 27, 2026
0690f01
maknig stochastic FCN3 an option
bonevbs Jan 27, 2026
052b28e
some bug fixes to spectral CRPS and associated loss tests
bonevbs Jan 28, 2026
286c954
updating lm calc
azrael417 Jan 28, 2026
cf5fb2d
making observations in CRPS test non-constant to mitigate SHT precisi…
azrael417 Jan 29, 2026
d123d44
streamlining tests
azrael417 Jan 29, 2026
33ce80c
adding energy score
azrael417 Jan 29, 2026
d28a1f7
adding disable_tf32
bonevbs Jan 29, 2026
28267ac
getting rid of debug print
bonevbs Jan 29, 2026
5f81793
test cleanup
azrael417 Jan 29, 2026
e4c6806
working losses
azrael417 Jan 29, 2026
1baee8c
moving lm weights computation to base loss
bonevbs Jan 29, 2026
24f7158
working sobolev test
azrael417 Jan 29, 2026
aa95add
fixing sobolev loss
bonevbs Jan 29, 2026
2e05e38
fixing sobolev loss
bonevbs Jan 29, 2026
a0f3f15
adding sobolev loss nan test and removing redundant code
azrael417 Jan 30, 2026
7d4a6d3
testing only even ensemble sizes
azrael417 Jan 30, 2026
7be36e9
adding relative weight to sobolev energy score loss term and adding l…
azrael417 Feb 2, 2026
30b38cf
full 0 masking in energy score losses
azrael417 Feb 2, 2026
71bfee0
working energy scores
azrael417 Feb 2, 2026
767855b
various cleanups. Factoring out SST imputation
bonevbs Feb 16, 2026
47d1a89
various fixes
bonevbs Feb 16, 2026
1d836a8
fixing imputation and masking in losses
bonevbs Feb 18, 2026
38d081f
compatibility with newest torch-harmonics
bonevbs Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions config/afnonet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ full_field: &BASE_CONFIG
nested_skip_fno: !!bool True # whether to nest the inner skip connection or have it be sequential, inside the AFNO block
verbose: False

#options default, residual
target: "default"

channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"]
normalization: "zscore" #options zscore or minmax
hard_thresholding_fraction: 1.0
Expand Down
3 changes: 0 additions & 3 deletions config/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: false

# define channels to be read from data
channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
# normalization mode zscore but for q
Expand Down
3 changes: 0 additions & 3 deletions config/fourcastnet3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: false

# define channels to be read from data. sp has been removed here
channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
# normalization mode zscore but for q
Expand Down
3 changes: 0 additions & 3 deletions config/icml_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ base_config: &BASE_CONFIG
N_grid_channels: 0
gridtype: "sinusoidal" #options "sinusoidal" or "linear"

#options default, residual
target: "default"

channel_names: ["u10m", "v10m", "t2m", "sp", "msl", "t850", "u1000", "v1000", "z1000", "u850", "v850", "z850", "u500", "v500", "z500", "t500", "z50", "r500", "r850", "tcwv", "u100m", "v100m", "u250", "v250", "z250", "t250", "u100", "v100", "z100", "t100", "u900", "v900", "z900", "t900"]
normalization: "zscore" #options zscore or minmax or none

Expand Down
7 changes: 2 additions & 5 deletions config/pangu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: !!bool False

# define channels to be read from data
channel_names: ["u10m", "v10m", "t2m", "msl", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
normalization: "zscore" # options zscore or minmax or none
Expand Down Expand Up @@ -131,10 +128,10 @@ base_onnx: &BASE_ONNX
# ONNX wrapper related overwrite
nettype: "/makani/makani/makani/models/networks/pangu_onnx.py:PanguOnnx"
onnx_file: '/model/pangu_weather_6.onnx'

amp_mode: "none"
disable_ddp: True

# Set Pangu ONNX channel order
channel_names: ["msl", "u10m", "v10m", "t2m", "z1000", "z925", "z850", "z700", "z600", "z500", "z400", "z300", "z250", "z200", "z150", "z100", "z50", "q1000", "q925", "q850", "q700", "q600", "q500", "q400", "q300", "q250", "q200", "q150", "q100", "q50", "t1000", "t925", "t850", "t700", "t600", "t500", "t400", "t300", "t250", "t200", "t150", "t100", "t50", "u1000", "u925", "u850", "u700", "u600", "u500", "u400", "u300", "u250", "u200", "u150", "u100", "u50", "v1000", "v925", "v850", "v700", "v600", "v500", "v400", "v300", "v250", "v200", "v150", "v100", "v50"]
# Remove input/output normalization
Expand Down
3 changes: 0 additions & 3 deletions config/sfnonet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ base_config: &BASE_CONFIG
add_noise: !!bool False
noise_std: 0.

target: "default" # options default, residual
normalize_residual: !!bool False

# define channels to be read from data
channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
normalization: "zscore" # options zscore or minmax or none
Expand Down
3 changes: 0 additions & 3 deletions config/vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ full_field: &BASE_CONFIG
embed_dim: 384
normalization_layer: "layer_norm"

#options default, residual
target: "default"

channel_names: ["u10m", "v10m", "u100m", "v100m", "t2m", "sp", "msl", "tcwv", "u50", "u100", "u150", "u200", "u250", "u300", "u400", "u500", "u600", "u700", "u850", "u925", "u1000", "v50", "v100", "v150", "v200", "v250", "v300", "v400", "v500", "v600", "v700", "v850", "v925", "v1000", "z50", "z100", "z150", "z200", "z250", "z300", "z400", "z500", "z600", "z700", "z850", "z925", "z1000", "t50", "t100", "t150", "t200", "t250", "t300", "t400", "t500", "t600", "t700", "t850", "t925", "t1000", "q50", "q100", "q150", "q200", "q250", "q300", "q400", "q500", "q600", "q700", "q850", "q925", "q1000"]
normalization: "zscore" #options zscore or minmax
hard_thresholding_fraction: 1.0
Expand Down
1 change: 1 addition & 0 deletions data_process/get_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def get_stats(input_path: str, output_path: str, metadata_file: str,
height, width = (data_shape[2], data_shape[3])

# quadrature:
# we normalize the quadrature rule to 4pi
quadrature = GridQuadrature(quadrature_rule, (height, width),
crop_shape=None, crop_offset=(0, 0),
normalize=False, pole_mask=None).to(device)
Expand Down
6 changes: 3 additions & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ RUN pip install --ignore-installed "git+https://github.com/NVIDIA/mlperf-common.

# torch-harmonics
ENV FORCE_CUDA_EXTENSION 1
ENV TORCH_CUDA_ARCH_LIST "8.0 8.6 9.0 10.0 12.0+PTX"
ENV HARMONICS_VERSION 0.8.0
RUN cd /opt && git clone -b v0.8.0 https://github.com/NVIDIA/torch-harmonics.git && \
ENV TORCH_CUDA_ARCH_LIST "8.0 8.6 8.7 9.0 10.0+PTX"
ENV HARMONICS_VERSION 0.8.1
RUN cd /opt && git clone -b bbonev/disco-modal-normalization https://github.com/NVIDIA/torch-harmonics.git && \
cd torch-harmonics && \
pip install --no-build-isolation -e .

Expand Down
2 changes: 1 addition & 1 deletion makani/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def consolidate_checkpoints(input_path, output_path, checkpoint_version=0):
print(checkpoint_paths)

# load the static data necessary for instantiating the preprocessor (required due to the way the registry works)
LocalPackage._load_static_data(input_path, params)
LocalPackage._load_static_data(LocalPackage(input_path), params)

# get the model
multistep = params.n_future > 0
Expand Down
2 changes: 2 additions & 0 deletions makani/models/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@
from .activations import ComplexReLU, ComplexActivation
from .layers import DropPath, LayerScale, PatchEmbed2D, PatchEmbed3D, PatchRecovery2D, PatchRecovery3D, EncoderDecoder, MLP, UpSample3D, DownSample3D, UpSample2D, DownSample2D
from .fft import RealFFT1, InverseRealFFT1, RealFFT2, InverseRealFFT2, RealFFT3, InverseRealFFT3
from .imputation import MLPImputation, ConstantImputation
from .layer_norm import GeometricInstanceNormS2
from .spectral_convolution import SpectralConv, SpectralAttention
from .pos_embedding import LearnablePositionEmbedding
82 changes: 82 additions & 0 deletions makani/models/common/imputation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from typing import Optional

from makani.utils import comm
from .layers import EncoderDecoder

# helper module to handle imputation of SST
class MLPImputation(nn.Module):
def __init__(
self,
inp_chans: int = 2,
inpute_chans: torch.Tensor = torch.tensor([0]),
mlp_ratio: float = 2.0,
activation_function: nn.Module = nn.GELU,
):
super().__init__()

self.inp_chans = inp_chans
self.inpute_chans = inpute_chans
self.out_chans = inpute_chans.shape[0]

self.mlp = EncoderDecoder(
num_layers=1,
input_dim=self.inp_chans,
output_dim=self.out_chans,
hidden_dim=int(mlp_ratio * self.out_chans),
act_layer=activation_function,
input_format="nchw",
)

def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
if mask is None:
mask = torch.isnan(x[..., self.inpute_chans, :, :])
else:
mask = torch.logical_or(mask, torch.isnan(x[..., self.inpute_chans, :, :]))

x[..., self.inpute_chans, :, :] = torch.where(mask, 0.0, x[..., self.inpute_chans, :, :])

# flatten extra batch dims for Conv2d compatibility
batch_shape = x.shape[:-3]
x_flat = x.reshape(-1, *x.shape[-3:])
mlp_out = self.mlp(x_flat).reshape(*batch_shape, self.out_chans, *x_flat.shape[-2:])

x[..., self.inpute_chans, :, :] = torch.where(mask, mlp_out, x[..., self.inpute_chans, :, :])

return x

class ConstantImputation(nn.Module):
def __init__(
self,
inp_chans: int = 2,
):
super().__init__()

self.weight = nn.Parameter(torch.randn(inp_chans, 1, 1))

if comm.get_size("spatial") > 1:
self.weight.is_shared_mp = ["spatial"]
self.weight.sharded_dims_mp = [None, None, None]

def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
if mask is None:
mask = torch.isnan(x)
else:
mask = torch.logical_or(mask, torch.isnan(x))
return torch.where(mask, self.weight, x)
2 changes: 1 addition & 1 deletion makani/models/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,5 +628,5 @@ def forward(self, x):

x = self.norm(x)
x = self.linear(x)

return x
99 changes: 99 additions & 0 deletions makani/models/common/pos_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc

import torch
import torch.nn as nn

from makani.utils import comm
from physicsnemo.distributed.utils import compute_split_shapes

class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
"""
Abstract base class for position embeddings.

This class defines the interface for position embedding modules
that add positional information to input tensors.

Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
"""

def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1):
super().__init__()

self.img_shape = img_shape
self.num_chans = num_chans

def forward(self):

return self.position_embeddings

class LearnablePositionEmbedding(PositionEmbedding):
"""
Learnable position embeddings for spherical transformers.

This module provides learnable position embeddings that can be either
latitude-only or full latitude-longitude embeddings.

Parameters
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
Grid type, by default "equiangular"
num_chans : int, optional
Number of channels, by default 1
embed_type : str, optional
Embedding type ("lat" or "latlon"), by default "lat"
"""

def __init__(self, img_shape=(480, 960), grid="equiangular", num_chans=1, embed_type="lat"):
super().__init__(img_shape=img_shape, grid=grid, num_chans=num_chans)

# if distributed, make sure to split correctly across ranks:
# in case of model parallelism, we need to make sure that we use the correct shapes per rank
# for h
if comm.get_size("h") > 1:
self.local_shape_h = compute_split_shapes(img_shape[0], comm.get_size("h"))[comm.get_rank("h")]
else:
self.local_shape_h = img_shape[0]

# for w
if comm.get_size("w") > 1:
self.local_shape_w = compute_split_shapes(img_shape[1], comm.get_size("w"))[comm.get_rank("w")]
else:
self.local_shape_w = img_shape[1]

if embed_type == "latlon":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, self.local_shape_w))
self.position_embeddings.is_shared_mp = []
self.position_embeddings.sharded_dims_mp = [None, None, "h", "w"]
elif embed_type == "lat":
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_chans, self.local_shape_h, 1))
self.position_embeddings.is_shared_mp = ["w"]
self.position_embeddings.sharded_dims_mp = [None, None, "h", None]
else:
raise ValueError(f"Unknown learnable position embedding type {embed_type}")

def forward(self):
return self.position_embeddings.expand(-1,-1,self.local_shape_h, self.local_shape_w)
8 changes: 4 additions & 4 deletions makani/models/model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

class LocalPackage:
"""
Implements the earth2mip/modulus Package interface.
Implements the modulus Package interface.
"""

# These define the model package in terms of where makani expects the files to be located
Expand All @@ -48,7 +48,7 @@ class LocalPackage:
MEANS_FILE = "global_means.npy"
STDS_FILE = "global_stds.npy"
OROGRAPHY_FILE = "orography.nc"
LANDMASK_FILE = "land_mask.nc"
LANDMASK_FILE = "land_sea_mask.nc"
SOILTYPE_FILE = "soil_type.nc"

def __init__(self, root):
Expand Down Expand Up @@ -147,11 +147,11 @@ def timestep(self):
def update_state(self, replace_state=True):
self.model.preprocessor.update_internal_state(replace_state=replace_state)
return

def set_rng(self, reset=True, seed=333):
self.model.preprocessor.set_rng(reset=reset, seed=seed)
return

def forward(self, x, time, normalized_data=True, replace_state=None):
if not normalized_data:
x = (x - self.in_bias) / self.in_scale
Expand Down
22 changes: 21 additions & 1 deletion makani/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,27 @@ def get_model(params: ParamsBase, use_stochastic_interpolation: bool = False, mu
if hasattr(model_handle, "load") and callable(model_handle.load):
model_handle = model_handle.load()

model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict())
model_kwargs = params.to_dict()

# pass normalization statistics to the model
if params.get("normalization", "none") in ["zscore", "minmax"]:
try:
bias, scale = get_data_normalization(params)
# Slice the stats to match the model's output channels
# Assuming the model's output corresponds to params.out_channels
if hasattr(params, "out_channels"):
if bias is not None:
bias = bias.flatten()[params.out_channels]
if scale is not None:
scale = scale.flatten()[params.out_channels]

if bias is not None and scale is not None:
model_kwargs["normalization_means"] = bias
model_kwargs["normalization_stds"] = scale
except Exception as e:
logging.warning(f"Could not load normalization stats. Error: {e}")

model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **model_kwargs)
else:
raise KeyError(f"No model is registered under the name {params.nettype}")

Expand Down
Loading
Loading