From 3840d2c7092ec6b647b09785f3e31cdb59317d12 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 24 Feb 2026 15:04:22 -0800 Subject: [PATCH 01/41] move noise related functions to separate module --- fme/downscaling/models.py | 47 +----------------------------------- fme/downscaling/noise.py | 50 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 46 deletions(-) create mode 100644 fme/downscaling/noise.py diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index bdf56f382..731fcf03b 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -11,11 +11,11 @@ from fme.core.normalizer import NormalizationConfig, StandardNormalizer from fme.core.optimization import NullOptimization, Optimization from fme.core.packer import Packer -from fme.core.rand import randn, randn_like from fme.core.typing_ import TensorDict, TensorMapping from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector +from fme.downscaling.noise import condition_with_noise_for_training from fme.downscaling.requirements import DataRequirements from fme.downscaling.samplers import stochastic_sampler as edm_sampler from fme.downscaling.typing_ import FineResCoarseResPair @@ -204,51 +204,6 @@ def _separate_interleaved_samples(tensor: torch.Tensor, n_samples: int) -> torch return tensor.reshape(n_batch, n_samples, *tensor.shape[1:]) -@dataclasses.dataclass -class ConditionedTarget: - """ - A class to hold the conditioned targets and the loss weighting. - - Attributes: - latents: The normalized targets with noise added. - sigma: The noise level. - weight: The loss weighting. - """ - - latents: torch.Tensor - sigma: torch.Tensor - weight: torch.Tensor - - -def condition_with_noise_for_training( - targets_norm: torch.Tensor, - p_std: float, - p_mean: float, - sigma_data: float, -) -> ConditionedTarget: - """ - Condition the targets with noise for training. - - Args: - targets_norm: The normalized targets. - p_std: The standard deviation of the noise distribution used during training. - p_mean: The mean of the noise distribution used during training. - sigma_data: The standard deviation of the data, - used to determine loss weighting. - - Returns: - The conditioned targets and the loss weighting. - """ - rnd_normal = randn([targets_norm.shape[0], 1, 1, 1], device=targets_norm.device) - # This is taken from EDM's original implementation in EDMLoss: - # https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L72-L80 # noqa: E501 - sigma = (rnd_normal * p_std + p_mean).exp() - weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 - noise = randn_like(targets_norm) * sigma - latents = targets_norm + noise - return ConditionedTarget(latents=latents, sigma=sigma, weight=weight) - - class DiffusionModel: def __init__( self, diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py new file mode 100644 index 000000000..853e38d0f --- /dev/null +++ b/fme/downscaling/noise.py @@ -0,0 +1,50 @@ +import dataclasses + +import torch + +from fme.core.rand import randn, randn_like + + +@dataclasses.dataclass +class ConditionedTarget: + """ + A class to hold the conditioned targets and the loss weighting. + + Attributes: + latents: The normalized targets with noise added. + sigma: The noise level. + weight: The loss weighting. + """ + + latents: torch.Tensor + sigma: torch.Tensor + weight: torch.Tensor + + +def condition_with_noise_for_training( + targets_norm: torch.Tensor, + p_std: float, + p_mean: float, + sigma_data: float, +) -> ConditionedTarget: + """ + Condition the targets with noise for training. + + Args: + targets_norm: The normalized targets. + p_std: The standard deviation of the noise distribution used during training. + p_mean: The mean of the noise distribution used during training. + sigma_data: The standard deviation of the data, + used to determine loss weighting. + + Returns: + The conditioned targets and the loss weighting. + """ + rnd_normal = randn([targets_norm.shape[0], 1, 1, 1], device=targets_norm.device) + # This is taken from EDM's original implementation in EDMLoss: + # https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L72-L80 # noqa: E501 + sigma = (rnd_normal * p_std + p_mean).exp() + weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + noise = randn_like(targets_norm) * sigma + latents = targets_norm + noise + return ConditionedTarget(latents=latents, sigma=sigma, weight=weight) From 26807aab018f3eef24b2a4e1eb68da7ff13800b3 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 25 Feb 2026 08:11:51 -0800 Subject: [PATCH 02/41] add classes for noise distribution --- fme/downscaling/noise.py | 44 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index 853e38d0f..73b257158 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -1,5 +1,7 @@ +import abc import dataclasses +import numpy as np import torch from fme.core.rand import randn, randn_like @@ -21,6 +23,48 @@ class ConditionedTarget: weight: torch.Tensor +class NoiseDistribution(abc.ABC): + @abc.abstractmethod + def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: + pass + + +@dataclasses.dataclass +class LogNormalNoiseDistribution: + p_mean: float + p_std: float + + def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: + rnd = randn([batch_size, 1, 1, 1], device=device) + # This is taken from EDM's original implementation in EDMLoss: + # https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L72-L80 # noqa: E501 + return (rnd * self.p_std + self.p_mean).exp() + + +@dataclasses.dataclass +class LogUniformNoiseDistribution: + p_min: float + p_max: float + + def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: + sigma = np.exp( + np.random.uniform(np.log(self.p_min), np.log(self.p_max), batch_size) + ) + return torch.tensor(sigma, device=device).reshape(batch_size, 1, 1, 1) + + +def condition_with_noise_for_training_from_distribution( + targets_norm: torch.Tensor, + noise_distribution: NoiseDistribution, + sigma_data: float, +) -> ConditionedTarget: + sigma = noise_distribution.sample(targets_norm.shape[0], targets_norm.device) + weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + noise = randn_like(targets_norm) * sigma + latents = targets_norm + noise + return ConditionedTarget(latents=latents, sigma=sigma, weight=weight) + + def condition_with_noise_for_training( targets_norm: torch.Tensor, p_std: float, From 7ebc3a2d6dc020fe62686cc245b94c50fc6575ce Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 25 Feb 2026 08:59:11 -0800 Subject: [PATCH 03/41] modify diffusionmodel to use noise distribution --- fme/downscaling/models.py | 40 +++++++++++++++++++++++++++++++++++---- fme/downscaling/noise.py | 27 +++++--------------------- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 731fcf03b..b38b2a807 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -1,4 +1,5 @@ import dataclasses +import warnings from collections.abc import Mapping from typing import Any @@ -15,7 +16,11 @@ from fme.downscaling.data import BatchData, PairedBatchData, StaticInputs from fme.downscaling.metrics_and_maths import filter_tensor_mapping, interpolate from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector -from fme.downscaling.noise import condition_with_noise_for_training +from fme.downscaling.noise import ( + LogNormalNoiseDistribution, + LogUniformNoiseDistribution, + condition_with_noise_for_training, +) from fme.downscaling.requirements import DataRequirements from fme.downscaling.samplers import stochastic_sampler as edm_sampler from fme.downscaling.typing_ import FineResCoarseResPair @@ -102,8 +107,6 @@ class DiffusionModelConfig: in_names: list[str] out_names: list[str] normalization: PairedNormalizationConfig - p_mean: float - p_std: float sigma_min: float sigma_max: float churn: float @@ -111,6 +114,11 @@ class DiffusionModelConfig: predict_residual: bool use_fine_topography: bool = False use_amp_bf16: bool = False + training_noise_distribution: ( + LogNormalNoiseDistribution | LogUniformNoiseDistribution | None + ) = None + p_mean: float | None = None + p_std: float | None = None def __post_init__(self): self._interpolate_input = self.module.expects_interpolated_input @@ -119,6 +127,30 @@ def __post_init__(self): "Fine topography can only be used when predicting on interpolated" " coarse input" ) + if self.p_mean is not None and self.p_std is not None: + if self.training_noise_distribution is None: + warnings.warn( + "p_mean and p_std are deprecated. " + f"Use training_noise_distribution field instead." + ) + else: + raise ValueError( + "Training noise should be specified in training_noise_distribution " + "field only. Both training_noise_distribution and p_mean, p_std " + "were specified. The latter two fields are deprecated." + ) + + @property + def noise(self) -> LogNormalNoiseDistribution | LogUniformNoiseDistribution: + if self.training_noise_distribution is not None: + return self.training_noise_distribution + elif self.p_mean is not None and self.p_std is not None: + return LogNormalNoiseDistribution(p_mean=self.p_mean, p_std=self.p_std) + else: + raise ValueError( + "Noise distribution must be specified in training_noise_distribution " + "or in p_mean and p_std fields." + ) def build( self, @@ -319,7 +351,7 @@ def train_on_batch( targets_norm = targets_norm - base_prediction conditioned_target = condition_with_noise_for_training( - targets_norm, self.config.p_std, self.config.p_mean, self.sigma_data + targets_norm, self.config.noise, self.sigma_data ) denoised_norm = self.module( diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index 73b257158..f3dde5c26 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -30,7 +30,7 @@ def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: @dataclasses.dataclass -class LogNormalNoiseDistribution: +class LogNormalNoiseDistribution(NoiseDistribution): p_mean: float p_std: float @@ -42,7 +42,7 @@ def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: @dataclasses.dataclass -class LogUniformNoiseDistribution: +class LogUniformNoiseDistribution(NoiseDistribution): p_min: float p_max: float @@ -53,22 +53,9 @@ def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.tensor(sigma, device=device).reshape(batch_size, 1, 1, 1) -def condition_with_noise_for_training_from_distribution( - targets_norm: torch.Tensor, - noise_distribution: NoiseDistribution, - sigma_data: float, -) -> ConditionedTarget: - sigma = noise_distribution.sample(targets_norm.shape[0], targets_norm.device) - weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 - noise = randn_like(targets_norm) * sigma - latents = targets_norm + noise - return ConditionedTarget(latents=latents, sigma=sigma, weight=weight) - - def condition_with_noise_for_training( targets_norm: torch.Tensor, - p_std: float, - p_mean: float, + noise_distribution: NoiseDistribution, sigma_data: float, ) -> ConditionedTarget: """ @@ -76,18 +63,14 @@ def condition_with_noise_for_training( Args: targets_norm: The normalized targets. - p_std: The standard deviation of the noise distribution used during training. - p_mean: The mean of the noise distribution used during training. + noise_distribution: The noise distribution to use for conditioning. sigma_data: The standard deviation of the data, used to determine loss weighting. Returns: The conditioned targets and the loss weighting. """ - rnd_normal = randn([targets_norm.shape[0], 1, 1, 1], device=targets_norm.device) - # This is taken from EDM's original implementation in EDMLoss: - # https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L72-L80 # noqa: E501 - sigma = (rnd_normal * p_std + p_mean).exp() + sigma = noise_distribution.sample(targets_norm.shape[0], targets_norm.device) weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 noise = randn_like(targets_norm) * sigma latents = targets_norm + noise From 7dd094792ce883d0ff636ba7377e4b2a823748d5 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Thu, 26 Feb 2026 08:18:46 -0800 Subject: [PATCH 04/41] fix usage in diffusion --- fme/diffusion/stepper.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/fme/diffusion/stepper.py b/fme/diffusion/stepper.py index c3a541fb0..4766cd5e0 100644 --- a/fme/diffusion/stepper.py +++ b/fme/diffusion/stepper.py @@ -47,7 +47,10 @@ from fme.core.weight_ops import strip_leading_module from fme.diffusion.loss import WeightedMappingLossConfig from fme.diffusion.registry import ModuleSelector -from fme.downscaling.models import condition_with_noise_for_training +from fme.downscaling.models import ( + LogNormalNoiseDistribution, + condition_with_noise_for_training, +) from fme.downscaling.modules.physicsnemo_unets_v1 import Linear, PositionalEmbedding DEFAULT_TIMESTEP = datetime.timedelta(hours=6) @@ -956,7 +959,11 @@ def _train_on_step( target_norm = self.normalizer.normalize(target) target_tensor = self.out_packer.pack(target_norm, axis=self.CHANNEL_DIM) conditioned = condition_with_noise_for_training( - target_tensor, self._config.p_std, self._config.p_mean, sigma_data=1.0 + target_tensor, + LogNormalNoiseDistribution( + p_std=self._config.p_std, p_mean=self._config.p_mean + ), + sigma_data=1.0, ) output_tensor = self.module( conditioned.latents, input_tensor, conditioned.sigma From ac62e90760c9205ef75e176fef84a4d91f661553 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Thu, 26 Feb 2026 08:39:01 -0800 Subject: [PATCH 05/41] exp config --- .../train-100-to-3km-prmsl-output.yaml | 144 ++++++++++++++++++ .../2026-02-10-downsc-add-pressfc/train.sh | 48 ++++++ 2 files changed, 192 insertions(+) create mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml create mode 100755 configs/experiments/2026-02-10-downsc-add-pressfc/train.sh diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml new file mode 100644 index 000000000..97ecd2e4b --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml @@ -0,0 +1,144 @@ +static_inputs: + HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr +model: + use_amp_bf16: true + out_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + in_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + loss: + type: MSE + module: + config: + model_channels: 128 + attn_resolutions: [] + num_blocks: 1 + channel_mult_emb: 6 + channel_mult: + - 1 + - 2 + - 2 + - 2 + - 2 + - 2 + - 2 + use_apex_gn: true + type: unet_diffusion_song_v2 + normalization: + coarse: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc + fine: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc + num_diffusion_generation_steps: 8 + churn: 0.0 + training_noise_distribution: + p_min: 0.002 + p_max: 150.0 + predict_residual: true + sigma_max: 150.0 + sigma_min: 0.002 + use_fine_topography: true +optimization: + lr: 0.0001 + optimizer_type: Adam +ema: + decay: 0.999 +validate_using_ema: true +train_data: + sample_with_replacement: 320 + batch_size: 40 + num_data_workers: 2 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + lat_extent: + start: -66.0 + stop: 70.0 + # lon_extent: + # start: 0 #230.0 + # stop: 16 #246.0 + strict_ensemble: false +validation_data: + batch_size: 36 + num_data_workers: 4 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 39 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 39 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 39 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 39 + lat_extent: + start: -20.0 + stop: 20.0 + strict_ensemble: false +coarse_patch_extent_lat: 16 +coarse_patch_extent_lon: 16 +max_epochs: 300 +validate_interval: 10 +experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 +save_checkpoints: false +logging: + project: multivariate-downscaling + entity: ai2cm + log_to_wandb: true +generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh new file mode 100755 index 000000000..50b288416 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# uses the augusta cluster which doesn't have weka access but has GCS access and is +# typically more available than cirrascale clusters + +set -e + +# recommended but not required to change this + +JOB_NAME="xshield-downscaling-100km-to-3km-prmsl-output-loguniform-noise" +CONFIG_FILENAME="train-100-to-3km-prmsl-output.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME +wandb_group="" + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) +N_GPUS=4 # TODO: change to 8 after testing + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) + +gantry run \ + --name $JOB_NAME \ + --description 'Run downscaling 100km to 3km multivar training' \ + --workspace ai2/downscaling \ + --priority low \ + --preemptible \ + --cluster ai2/jupiter \ + --cluster ai2/titan \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=training \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --weka climate-default:/climate-default \ + --gpus $N_GPUS \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $N_GPUS -m fme.downscaling.train $CONFIG_PATH \ No newline at end of file From 7154e61d54c48da69b9a514126d55b35c7f0327a Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 27 Feb 2026 11:03:57 -0800 Subject: [PATCH 06/41] eval configs --- .../eval-100-to-3km-prmsl-output.yaml | 204 ++++++++++++++++++ .../2026-02-10-downsc-add-pressfc/eval.sh | 48 +++++ 2 files changed, 252 insertions(+) create mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml create mode 100755 configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml new file mode 100644 index 000000000..7b6c766ae --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml @@ -0,0 +1,204 @@ +experiment_dir: /results +n_samples: 2 +patch: + divide_generation: true + composite_prediction: true + coarse_horizontal_overlap: 1 +model: + checkpoint_path: /checkpoints/best.ckpt + #model_updates: + # churn: 2.5 +data: + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start: 0 + stop: 2 + #start_time: '2023-01-01T00:00:00' + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + # CONUS + lat_extent: + start: 0 + stop: 16 + # start: 22.0 + # stop: 50.0 + lon_extent: + #start: 230.0 + #stop: 295.0 + start: 0 + stop: 16 + batch_size: 2 + num_data_workers: 2 + strict_ensemble: false +logging: + log_to_screen: true + log_to_wandb: true + log_to_file: true + project: multivariate-downscaling + entity: ai2cm +events: +- name: NE_US_Quebec_20230206 + date: 2023-02-06T00:00 + lat_extent: + start: 36 + stop: 52 + lon_extent: + start: 283 + stop: 299 + save_generated_samples: true + n_samples: 16 +- name: WA_AR_20230101 + date: 2023-01-01T06:00 + lat_extent: + start: 36 + stop: 52 + lon_extent: + start: 228 + stop: 244 + save_generated_samples: true + n_samples: 16 +- name: WPacific_hurricane_20230425 + date: 2023-04-25T18:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 130.0 + stop: 146.0 + save_generated_samples: true + n_samples: 16 +- name: WPacific_hurricane_landfall_china_20230510 + date: 2023-05-10T12:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 104 + stop: 120 + save_generated_samples: true + n_samples: 16 +- name: extratropical_cyclone_US_20230403 + date: 2023-04-03T12:00 + lat_extent: + start: 34 + stop: 50 + lon_extent: + start: 254 + stop: 270 + save_generated_samples: true + n_samples: 16 +- name: santa_ana_winds_20231221 + date: 2023-12-21T06:00 + lat_extent: + start: 26 + stop: 42 + lon_extent: + start: 234 + stop: 250 + save_generated_samples: true + n_samples: 16 +- name: alpine_foehn_20230330 + date: 2023-03-30T18:00 + lat_extent: + start: 37 + stop: 53 + lon_extent: + start: 2 + stop: 18 + save_generated_samples: true + n_samples: 16 +- name: hindu_kush_20230122 + date: 2023-01-22T06:00 + lat_extent: + start: 28 + stop: 44 + lon_extent: + start: 60 + stop: 76 + save_generated_samples: true + n_samples: 16 +- name: WPac_tc_20230426T06 + date: 2023-04-26T06:00 + lat_extent: + start: 8 + stop: 24 + lon_extent: + start: 130 + stop: 146 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20230514T06 + date: 2023-05-14T06:00 + lat_extent: + start: 4 + stop: 20 + lon_extent: + start: 117 + stop: 133 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20230517T18 + date: 2023-05-17T18:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 133 + stop: 149 + save_generated_samples: true + n_samples: 16 +- name: Taiwan_tc_landfall_20230707T18 + date: 2023-07-07T18:00 + lat_extent: + start: 14 + stop: 30 + lon_extent: + start: 115 + stop: 131 + save_generated_samples: true + n_samples: 16 +- name: Japan_tc_landfall_20230919T18 + date: 2023-09-19T18:00 + lat_extent: + start: 22 + stop: 38 + lon_extent: + start: 123 + stop: 139 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20231027T00 + date: 2023-10-27T00:00 + lat_extent: + start: 8 + stop: 24 + lon_extent: + start: 115 + stop: 131 + save_generated_samples: true + n_samples: 16 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh new file mode 100755 index 000000000..5105f5eea --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -e + +JOB_NAME="eval-xshield-amip-100km-to-3km-prmsl-loguni-events" +CONFIG_FILENAME="eval-100-to-3km-prmsl-output.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +N_NODES=1 +NGPU=2 + +IMAGE="$(cat latest_deps_only_image.txt)" + +EXISTING_RESULTS_DATASET=01KJG6E05ZVP82W8EMSV39SQA3 # best hist checkpoint from cont training job 01K51T9H7V9HGZR501XYN5VNGV +wandb_group="" + +gantry run \ + --name $JOB_NAME \ + --description 'Run 100km to 3km evaluation on coarsened X-SHiELD' \ + --workspace ai2/climate-titan \ + --priority urgent \ + --not-preemptible \ + --cluster ai2/titan \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=inference \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --dataset $EXISTING_RESULTS_DATASET:checkpoints:/checkpoints \ + --weka climate-default:/climate-default \ + --gpus $NGPU \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $NGPU -m fme.downscaling.evaluator $CONFIG_PATH \ No newline at end of file From c8ec02cb90f763a01a6ab062e99dc7ab14760530 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 27 Feb 2026 11:04:53 -0800 Subject: [PATCH 07/41] delete experiment configs --- .../eval-100-to-3km-prmsl-output.yaml | 204 ------------------ .../2026-02-10-downsc-add-pressfc/eval.sh | 48 ----- .../train-100-to-3km-prmsl-output.yaml | 144 ------------- .../2026-02-10-downsc-add-pressfc/train.sh | 48 ----- 4 files changed, 444 deletions(-) delete mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml delete mode 100755 configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh delete mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml delete mode 100755 configs/experiments/2026-02-10-downsc-add-pressfc/train.sh diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml deleted file mode 100644 index 7b6c766ae..000000000 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml +++ /dev/null @@ -1,204 +0,0 @@ -experiment_dir: /results -n_samples: 2 -patch: - divide_generation: true - composite_prediction: true - coarse_horizontal_overlap: 1 -model: - checkpoint_path: /checkpoints/best.ckpt - #model_updates: - # churn: 2.5 -data: - coarse: - - merge: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - file_pattern: 100km.zarr - engine: zarr - subset: - #start_time: '2023-01-01T00:00:00' - start: 0 - stop: 2 - - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data - file_pattern: pressfc_renamed_to_prmsl_100km.zarr - engine: zarr - subset: - start: 0 - stop: 2 - #start_time: '2023-01-01T00:00:00' - fine: - - merge: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - file_pattern: 3km.zarr - engine: zarr - subset: - #start_time: '2023-01-01T00:00:00' - start: 0 - stop: 2 - - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data - file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr - engine: zarr - subset: - #start_time: '2023-01-01T00:00:00' - start: 0 - stop: 2 - # CONUS - lat_extent: - start: 0 - stop: 16 - # start: 22.0 - # stop: 50.0 - lon_extent: - #start: 230.0 - #stop: 295.0 - start: 0 - stop: 16 - batch_size: 2 - num_data_workers: 2 - strict_ensemble: false -logging: - log_to_screen: true - log_to_wandb: true - log_to_file: true - project: multivariate-downscaling - entity: ai2cm -events: -- name: NE_US_Quebec_20230206 - date: 2023-02-06T00:00 - lat_extent: - start: 36 - stop: 52 - lon_extent: - start: 283 - stop: 299 - save_generated_samples: true - n_samples: 16 -- name: WA_AR_20230101 - date: 2023-01-01T06:00 - lat_extent: - start: 36 - stop: 52 - lon_extent: - start: 228 - stop: 244 - save_generated_samples: true - n_samples: 16 -- name: WPacific_hurricane_20230425 - date: 2023-04-25T18:00 - lat_extent: - start: 7 - stop: 23 - lon_extent: - start: 130.0 - stop: 146.0 - save_generated_samples: true - n_samples: 16 -- name: WPacific_hurricane_landfall_china_20230510 - date: 2023-05-10T12:00 - lat_extent: - start: 7 - stop: 23 - lon_extent: - start: 104 - stop: 120 - save_generated_samples: true - n_samples: 16 -- name: extratropical_cyclone_US_20230403 - date: 2023-04-03T12:00 - lat_extent: - start: 34 - stop: 50 - lon_extent: - start: 254 - stop: 270 - save_generated_samples: true - n_samples: 16 -- name: santa_ana_winds_20231221 - date: 2023-12-21T06:00 - lat_extent: - start: 26 - stop: 42 - lon_extent: - start: 234 - stop: 250 - save_generated_samples: true - n_samples: 16 -- name: alpine_foehn_20230330 - date: 2023-03-30T18:00 - lat_extent: - start: 37 - stop: 53 - lon_extent: - start: 2 - stop: 18 - save_generated_samples: true - n_samples: 16 -- name: hindu_kush_20230122 - date: 2023-01-22T06:00 - lat_extent: - start: 28 - stop: 44 - lon_extent: - start: 60 - stop: 76 - save_generated_samples: true - n_samples: 16 -- name: WPac_tc_20230426T06 - date: 2023-04-26T06:00 - lat_extent: - start: 8 - stop: 24 - lon_extent: - start: 130 - stop: 146 - save_generated_samples: true - n_samples: 16 -- name: Phl_tc_landfall_20230514T06 - date: 2023-05-14T06:00 - lat_extent: - start: 4 - stop: 20 - lon_extent: - start: 117 - stop: 133 - save_generated_samples: true - n_samples: 16 -- name: Phl_tc_landfall_20230517T18 - date: 2023-05-17T18:00 - lat_extent: - start: 7 - stop: 23 - lon_extent: - start: 133 - stop: 149 - save_generated_samples: true - n_samples: 16 -- name: Taiwan_tc_landfall_20230707T18 - date: 2023-07-07T18:00 - lat_extent: - start: 14 - stop: 30 - lon_extent: - start: 115 - stop: 131 - save_generated_samples: true - n_samples: 16 -- name: Japan_tc_landfall_20230919T18 - date: 2023-09-19T18:00 - lat_extent: - start: 22 - stop: 38 - lon_extent: - start: 123 - stop: 139 - save_generated_samples: true - n_samples: 16 -- name: Phl_tc_landfall_20231027T00 - date: 2023-10-27T00:00 - lat_extent: - start: 8 - stop: 24 - lon_extent: - start: 115 - stop: 131 - save_generated_samples: true - n_samples: 16 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh deleted file mode 100755 index 5105f5eea..000000000 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash - -set -e - -JOB_NAME="eval-xshield-amip-100km-to-3km-prmsl-loguni-events" -CONFIG_FILENAME="eval-100-to-3km-prmsl-output.yaml" - -SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') -CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME - - # since we use a service account API key for wandb, we use the beaker username to set the wandb username -BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') -REPO_ROOT=$(git rev-parse --show-toplevel) - -cd $REPO_ROOT # so config path is valid no matter where we are running this script - -N_NODES=1 -NGPU=2 - -IMAGE="$(cat latest_deps_only_image.txt)" - -EXISTING_RESULTS_DATASET=01KJG6E05ZVP82W8EMSV39SQA3 # best hist checkpoint from cont training job 01K51T9H7V9HGZR501XYN5VNGV -wandb_group="" - -gantry run \ - --name $JOB_NAME \ - --description 'Run 100km to 3km evaluation on coarsened X-SHiELD' \ - --workspace ai2/climate-titan \ - --priority urgent \ - --not-preemptible \ - --cluster ai2/titan \ - --beaker-image $IMAGE \ - --env WANDB_USERNAME=$BEAKER_USERNAME \ - --env WANDB_NAME=$JOB_NAME \ - --env WANDB_JOB_TYPE=inference \ - --env WANDB_RUN_GROUP=$wandb_group \ - --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ - --env-secret WANDB_API_KEY=wandb-api-key-annak \ - --dataset-secret google-credentials:/tmp/google_application_credentials.json \ - --dataset $EXISTING_RESULTS_DATASET:checkpoints:/checkpoints \ - --weka climate-default:/climate-default \ - --gpus $NGPU \ - --shared-memory 400GiB \ - --budget ai2/climate \ - --no-conda \ - --install "pip install --no-deps ." \ - --allow-dirty \ - -- torchrun --nproc_per_node $NGPU -m fme.downscaling.evaluator $CONFIG_PATH \ No newline at end of file diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml deleted file mode 100644 index 97ecd2e4b..000000000 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output.yaml +++ /dev/null @@ -1,144 +0,0 @@ -static_inputs: - HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr - land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr -model: - use_amp_bf16: true - out_names: - - PRATEsfc - - eastward_wind_at_ten_meters - - northward_wind_at_ten_meters - - PRMSL - in_names: - - PRATEsfc - - eastward_wind_at_ten_meters - - northward_wind_at_ten_meters - - PRMSL - loss: - type: MSE - module: - config: - model_channels: 128 - attn_resolutions: [] - num_blocks: 1 - channel_mult_emb: 6 - channel_mult: - - 1 - - 2 - - 2 - - 2 - - 2 - - 2 - - 2 - use_apex_gn: true - type: unet_diffusion_song_v2 - normalization: - coarse: - global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc - global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc - fine: - global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc - global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc - num_diffusion_generation_steps: 8 - churn: 0.0 - training_noise_distribution: - p_min: 0.002 - p_max: 150.0 - predict_residual: true - sigma_max: 150.0 - sigma_min: 0.002 - use_fine_topography: true -optimization: - lr: 0.0001 - optimizer_type: Adam -ema: - decay: 0.999 -validate_using_ema: true -train_data: - sample_with_replacement: 320 - batch_size: 40 - num_data_workers: 2 - fine: - - merge: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - file_pattern: 3km.zarr - engine: zarr - subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' - - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data - file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr - engine: zarr - subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' - coarse: - - merge: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - file_pattern: 100km.zarr - engine: zarr - subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' - - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data - file_pattern: pressfc_renamed_to_prmsl_100km.zarr - engine: zarr - subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' - lat_extent: - start: -66.0 - stop: 70.0 - # lon_extent: - # start: 0 #230.0 - # stop: 16 #246.0 - strict_ensemble: false -validation_data: - batch_size: 36 - num_data_workers: 4 - fine: - - merge: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - file_pattern: 3km.zarr - engine: zarr - subset: - start_time: '2023-01-01T00:00:00' - stop_time: '2024-01-01T00:00:00' - step: 39 - - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data - file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr - engine: zarr - subset: - start_time: '2023-01-01T00:00:00' - stop_time: '2024-01-01T00:00:00' - step: 39 - coarse: - - merge: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - file_pattern: 100km.zarr - engine: zarr - subset: - start_time: '2023-01-01T00:00:00' - stop_time: '2024-01-01T00:00:00' - step: 39 - - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data - file_pattern: pressfc_renamed_to_prmsl_100km.zarr - engine: zarr - subset: - start_time: '2023-01-01T00:00:00' - stop_time: '2024-01-01T00:00:00' - step: 39 - lat_extent: - start: -20.0 - stop: 20.0 - strict_ensemble: false -coarse_patch_extent_lat: 16 -coarse_patch_extent_lon: 16 -max_epochs: 300 -validate_interval: 10 -experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 -save_checkpoints: false -logging: - project: multivariate-downscaling - entity: ai2cm - log_to_wandb: true -generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh deleted file mode 100755 index 50b288416..000000000 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# uses the augusta cluster which doesn't have weka access but has GCS access and is -# typically more available than cirrascale clusters - -set -e - -# recommended but not required to change this - -JOB_NAME="xshield-downscaling-100km-to-3km-prmsl-output-loguniform-noise" -CONFIG_FILENAME="train-100-to-3km-prmsl-output.yaml" - -SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') -CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME -wandb_group="" - - # since we use a service account API key for wandb, we use the beaker username to set the wandb username -BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') -REPO_ROOT=$(git rev-parse --show-toplevel) -N_GPUS=4 # TODO: change to 8 after testing - -cd $REPO_ROOT # so config path is valid no matter where we are running this script - -IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) - -gantry run \ - --name $JOB_NAME \ - --description 'Run downscaling 100km to 3km multivar training' \ - --workspace ai2/downscaling \ - --priority low \ - --preemptible \ - --cluster ai2/jupiter \ - --cluster ai2/titan \ - --beaker-image $IMAGE \ - --env WANDB_USERNAME=$BEAKER_USERNAME \ - --env WANDB_NAME=$JOB_NAME \ - --env WANDB_JOB_TYPE=training \ - --env WANDB_RUN_GROUP=$wandb_group \ - --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ - --env-secret WANDB_API_KEY=wandb-api-key-annak \ - --dataset-secret google-credentials:/tmp/google_application_credentials.json \ - --weka climate-default:/climate-default \ - --gpus $N_GPUS \ - --shared-memory 400GiB \ - --budget ai2/climate \ - --no-conda \ - --install "pip install --no-deps ." \ - --allow-dirty \ - -- torchrun --nproc_per_node $N_GPUS -m fme.downscaling.train $CONFIG_PATH \ No newline at end of file From f21999ba6ed495346f82844c9a8692ed52c1c27b Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 27 Feb 2026 11:22:56 -0800 Subject: [PATCH 08/41] Fix event evaluator bug --- fme/downscaling/evaluator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fme/downscaling/evaluator.py b/fme/downscaling/evaluator.py index 5aa4e634f..5e789c80c 100644 --- a/fme/downscaling/evaluator.py +++ b/fme/downscaling/evaluator.py @@ -239,7 +239,9 @@ def _build_event_evaluator( evaluator_model: DiffusionModel | PatchPredictor dataset = event_config.get_paired_gridded_data( - base_data_config=self.data, requirements=self.model.data_requirements + base_data_config=self.data, + requirements=self.model.data_requirements, + static_inputs_from_checkpoint=model.static_inputs, ) if (dataset.coarse_shape[0] > model.coarse_shape[0]) or ( From 53906a5cf167ce03652af52d6c83009f711449e3 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 27 Feb 2026 12:49:42 -0800 Subject: [PATCH 09/41] update docstring --- fme/downscaling/models.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index b38b2a807..c38206331 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -91,8 +91,7 @@ class DiffusionModelConfig: in_names: The input variable names for the diffusion model. out_names: The output variable names for the diffusion model. normalization: The normalization configurations for the diffusion model. - p_mean: The mean of noise distribution used during training. - p_std: The std of the noise distribution used during training. + sigma_min: Min noise level for generation. sigma_max: Max noise level for generation. churn: The amount of stochasticity during generation. @@ -100,6 +99,13 @@ class DiffusionModelConfig: use_fine_topography: Whether to use fine topography in the model. use_amp_bf16: Whether to use automatic mixed precision (bfloat16) in the UNetDiffusionModule. + training_noise_distribution: Noise distribution to use during training. + p_mean: The mean of noise distribution used during training. + Deprecated. Use training_noise_distribution field instead. + This is kept for backwards compatibility. + p_std: The std of the noise distribution used during training. + Deprecated. Use training_noise_distribution field instead. + This is kept for backwards compatibility. """ module: DiffusionModuleRegistrySelector @@ -139,6 +145,13 @@ def __post_init__(self): "field only. Both training_noise_distribution and p_mean, p_std " "were specified. The latter two fields are deprecated." ) + if self.training_noise_distribution is None and ( + self.p_mean is None or self.p_std is None + ): + raise ValueError( + "Noise distribution must be specified in training_noise_distribution " + "field or in p_mean and p_std fields." + ) @property def noise(self) -> LogNormalNoiseDistribution | LogUniformNoiseDistribution: From c0ea35171a381fc34aaf2641e7378ea622d00783 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 27 Feb 2026 12:49:53 -0800 Subject: [PATCH 10/41] update test config --- fme/downscaling/test_train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fme/downscaling/test_train.py b/fme/downscaling/test_train.py index 3e65c8bff..5ed89d36d 100644 --- a/fme/downscaling/test_train.py +++ b/fme/downscaling/test_train.py @@ -137,8 +137,7 @@ def default_trainer_config( model_config_kwargs = { "num_diffusion_generation_steps": 2, "churn": 0.0, - "p_mean": -1.2, - "p_std": 1.2, + "training_noise_distribution": {"p_mean": -1.2, "p_std": 1.2}, "predict_residual": True, "sigma_max": 80.0, "sigma_min": 0.002, From 396aa5f36d9fb7a6d1d8e12e9748c00433d07541 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 2 Mar 2026 14:18:43 -0800 Subject: [PATCH 11/41] adds backwards compatibility test --- fme/downscaling/test_models.py | 40 ++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 4b5e74c23..2051ae6de 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -19,6 +19,7 @@ _separate_interleaved_samples, ) from fme.downscaling.modules.diffusion_registry import DiffusionModuleRegistrySelector +from fme.downscaling.noise import LogNormalNoiseDistribution from fme.downscaling.typing_ import FineResCoarseResPair @@ -403,3 +404,42 @@ def test_DiffusionModel_generate_on_batch_no_target_arbitrary_input_size(): n_ensemble, *fine_shape, ) + + +def test_lognorm_noise_backwards_compatibility(): + normalizer = PairedNormalizationConfig( + NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), + NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), + ) + + model_config = DiffusionModelConfig( + module=DiffusionModuleRegistrySelector( + "unet_diffusion_song", {"model_channels": 4} + ), + loss=LossConfig(type="MSE"), + in_names=["x"], + out_names=["x"], + normalization=normalizer, + p_mean=-1.0, + p_std=1.0, + sigma_min=0.1, + sigma_max=1.0, + churn=0.5, + num_diffusion_generation_steps=3, + training_noise_distribution=None, + use_fine_topography=False, + predict_residual=True, + ) + assert model_config.noise == LogNormalNoiseDistribution(p_mean=-1.0, p_std=1.0) + model = model_config.build( + (32, 32), + 2, + ) + state = model.get_state() + + # test from_state on checkpoints saved prior to noise distribution classes + del state["config"]["training_noise_distribution"] + model_from_state = DiffusionModel.from_state(state) + assert model_from_state.config.noise == LogNormalNoiseDistribution( + p_mean=-1.0, p_std=1.0 + ) From de17a12ef0611d719ffbc9fbe302f5c79a1dd6c1 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 3 Mar 2026 11:18:56 -0800 Subject: [PATCH 12/41] log loss vs noise --- fme/downscaling/aggregators/main.py | 77 ++++++++++++++++++- .../aggregators/test_aggregators.py | 38 +++++++++ fme/downscaling/models.py | 9 +++ 3 files changed, 123 insertions(+), 1 deletion(-) diff --git a/fme/downscaling/aggregators/main.py b/fme/downscaling/aggregators/main.py index 23b2a3ce9..152a4f9fe 100644 --- a/fme/downscaling/aggregators/main.py +++ b/fme/downscaling/aggregators/main.py @@ -18,7 +18,7 @@ from fme.core.device import get_device from fme.core.distributed import Distributed from fme.core.histogram import ComparedDynamicHistograms -from fme.core.typing_ import TensorMapping +from fme.core.typing_ import TensorDict, TensorMapping from fme.core.wandb import WandB from fme.downscaling.aggregators.adapters import ComparedDynamicHistogramsAdapter from fme.downscaling.data import PairedBatchData @@ -50,6 +50,78 @@ def _tensor_mapping_to_numpy(data: TensorMapping) -> TensorMapping: return {k: v.cpu().numpy() for k, v in data.items()} +class LossVsNoiseAggregator: + """ + Aggregates per-sample diffusion loss as a function of sampled noise level. + """ + + def __init__(self, name: str = "metrics/loss_vs_noise") -> None: + self._name = ensure_trailing_slash(name) + self._sigmas: list[torch.Tensor] = [] + self._per_sample_channel_losses: list[TensorDict] = [] + + @torch.no_grad() + def record_batch(self, outputs: ModelOutputs) -> None: + if outputs.sigma is None or not outputs.per_sample_channel_loss: + return + sigma = outputs.sigma.detach().flatten().cpu() + per_channel = { + name: loss.detach().flatten().cpu() + for name, loss in outputs.per_sample_channel_loss.items() + } + if any(loss.shape != sigma.shape for loss in per_channel.values()): + raise ValueError( + "Expected per-sample channel losses and sigma to share batch shape" + ) + self._sigmas.append(sigma) + self._per_sample_channel_losses.append(per_channel) + + def _plot_loss_vs_noise( + self, x_values: np.ndarray, y_values: np.ndarray, title: str + ) -> Any: + fig, ax = plt.subplots() + ax.scatter(x_values, y_values, s=4, alpha=0.25) + ax.set_xlabel("log10(sigma)") + ax.set_ylabel("weighted loss") + ax.set_title(title) + ax.grid(True, alpha=0.3) + plt.close(fig) + return fig + + def get_wandb(self, prefix: str = "") -> Mapping[str, Any]: + prefix = ensure_trailing_slash(prefix) + if not self._sigmas: + return {} + + sigma = torch.cat(self._sigmas) + if torch.any(sigma <= 0): + raise ValueError("Sigma must be strictly positive for log10 plotting") + + channel_names = sorted(self._per_sample_channel_losses[0].keys()) + channel_loss = { + name: torch.cat([batch[name] for batch in self._per_sample_channel_losses]) + for name in channel_names + } + total_loss = torch.stack([channel_loss[name] for name in channel_names], dim=-1) + total_loss = torch.mean(total_loss, dim=-1) + + x_values = np.log10(sigma.numpy()) + ret: dict[str, Any] = { + f"{prefix}{self._name}total": self._plot_loss_vs_noise( + x_values=x_values, + y_values=total_loss.numpy(), + title="Total weighted loss vs noise", + ) + } + for name in channel_names: + ret[f"{prefix}{self._name}{name}"] = self._plot_loss_vs_noise( + x_values=x_values, + y_values=channel_loss[name].numpy(), + title=f"{name} weighted loss vs noise", + ) + return ret + + def _get_spectrum_metrics( gen_spectrum: Mapping[str, np.ndarray], target_spectrum: Mapping[str, np.ndarray], @@ -798,6 +870,7 @@ def __init__( self.loss = Mean(torch.mean) self.channel_loss = Mean(torch.mean) + self.loss_vs_noise = LossVsNoiseAggregator() self._fine_latlon_coordinates: LatLonCoordinates | None = None @torch.no_grad() @@ -841,6 +914,7 @@ def weighted_rmse(truth, pred): self.loss.record_batch({"loss": outputs.loss}) if outputs.channel_losses: self.channel_loss.record_batch(outputs.channel_losses) + self.loss_vs_noise.record_batch(outputs) def get_wandb( self, @@ -855,6 +929,7 @@ def get_wandb( ret.update(self.loss.get_wandb(prefix)) if self.channel_loss._count > 0: ret.update(self.channel_loss.get_wandb(f"{prefix}channel_loss/")) + ret.update(self.loss_vs_noise.get_wandb(prefix)) for comparison in self._comparisons: ret.update(comparison.get_wandb(prefix)) for coarse_comparison in self._coarse_comparisons: diff --git a/fme/downscaling/aggregators/test_aggregators.py b/fme/downscaling/aggregators/test_aggregators.py index f24fdf7ba..be7c39485 100644 --- a/fme/downscaling/aggregators/test_aggregators.py +++ b/fme/downscaling/aggregators/test_aggregators.py @@ -13,6 +13,7 @@ from ..models import ModelOutputs from .generation import GenerationAggregator from .main import ( + LossVsNoiseAggregator, Mean, MeanComparison, MeanMapAggregator, @@ -200,6 +201,43 @@ def test_map_aggregator(n_steps: int): aggregator.get_wandb() +def test_loss_vs_noise_aggregator_get_wandb(): + aggregator = LossVsNoiseAggregator() + outputs_a = ModelOutputs( + prediction={}, + target={}, + latent_steps=[], + loss=torch.tensor(0.0), + sigma=torch.tensor([0.1, 1.0]), + per_sample_channel_loss={ + "x": torch.tensor([1.0, 2.0]), + "y": torch.tensor([2.0, 4.0]), + }, + ) + outputs_b = ModelOutputs( + prediction={}, + target={}, + latent_steps=[], + loss=torch.tensor(0.0), + sigma=torch.tensor([10.0]), + per_sample_channel_loss={ + "x": torch.tensor([3.0]), + "y": torch.tensor([6.0]), + }, + ) + aggregator.record_batch(outputs_a) + aggregator.record_batch(outputs_b) + + logs = aggregator.get_wandb(prefix="train") + assert set(logs.keys()) == { + "train/metrics/loss_vs_noise/total", + "train/metrics/loss_vs_noise/x", + "train/metrics/loss_vs_noise/y", + } + for value in logs.values(): + assert hasattr(value, "savefig") + + @pytest.mark.parametrize("n_latent_steps", [0, 2]) def test_aggregator_integration(n_latent_steps, percentiles=[99.999]): downscale_factor = 2 diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 05dc0b587..d84b2f2bc 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -28,6 +28,8 @@ class ModelOutputs: loss: torch.Tensor latent_steps: list[torch.Tensor] = dataclasses.field(default_factory=list) channel_losses: TensorDict = dataclasses.field(default_factory=dict) + sigma: torch.Tensor | None = None + per_sample_channel_loss: TensorDict = dataclasses.field(default_factory=dict) def _rename_normalizer( @@ -387,6 +389,11 @@ def train_on_batch( name: torch.mean(weighted_loss[:, i, :, :]) for i, name in enumerate(self.out_packer.names) } + per_sample_channel_loss = { + name: torch.mean(weighted_loss[:, i, :, :], dim=(-2, -1)).detach() + for i, name in enumerate(self.out_packer.names) + } + sigma = conditioned_target.sigma[:, 0, 0, 0].detach() if self.config.predict_residual: denoised_norm = denoised_norm + base_prediction @@ -400,6 +407,8 @@ def train_on_batch( target=target, loss=loss, channel_losses=channel_losses, + sigma=sigma, + per_sample_channel_loss=per_sample_channel_loss, latent_steps=[], ) From e8e332087f959aa75b0c5f337d81101db5f64aac Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 3 Mar 2026 12:50:40 -0800 Subject: [PATCH 13/41] bin losses by sigma --- fme/downscaling/aggregators/main.py | 127 ++++++++++++------ .../aggregators/test_aggregators.py | 18 ++- 2 files changed, 96 insertions(+), 49 deletions(-) diff --git a/fme/downscaling/aggregators/main.py b/fme/downscaling/aggregators/main.py index 152a4f9fe..953aa6520 100644 --- a/fme/downscaling/aggregators/main.py +++ b/fme/downscaling/aggregators/main.py @@ -52,37 +52,78 @@ def _tensor_mapping_to_numpy(data: TensorMapping) -> TensorMapping: class LossVsNoiseAggregator: """ - Aggregates per-sample diffusion loss as a function of sampled noise level. + Aggregates binned diffusion losses as a function of sampled noise level. """ - def __init__(self, name: str = "metrics/loss_vs_noise") -> None: + def __init__( + self, + name: str = "metrics/loss_vs_noise", + n_bins: int = 40, + log10_sigma_min: float = -4.0, + log10_sigma_max: float = 4.0, + ) -> None: + if n_bins < 1: + raise ValueError("n_bins must be >= 1") + if log10_sigma_min >= log10_sigma_max: + raise ValueError("log10_sigma_min must be less than log10_sigma_max") + self._name = ensure_trailing_slash(name) - self._sigmas: list[torch.Tensor] = [] - self._per_sample_channel_losses: list[TensorDict] = [] + self._n_bins = n_bins + edges = torch.linspace(log10_sigma_min, log10_sigma_max, n_bins + 1) + self._inner_edges = edges[1:-1] + self._bin_centers = ((edges[:-1] + edges[1:]) / 2).numpy() + + self._total_sum = torch.zeros(n_bins, dtype=torch.float64) + self._total_count = torch.zeros(n_bins, dtype=torch.int64) + self._channel_sum: dict[str, torch.Tensor] = {} + self._channel_count: dict[str, torch.Tensor] = {} + + def _accumulate( + self, values: torch.Tensor, bin_indices: torch.Tensor, name: str + ) -> None: + if name not in self._channel_sum: + self._channel_sum[name] = torch.zeros(self._n_bins, dtype=torch.float64) + self._channel_count[name] = torch.zeros(self._n_bins, dtype=torch.int64) + self._channel_sum[name].scatter_add_(0, bin_indices, values.to(torch.float64)) + self._channel_count[name].scatter_add_( + 0, bin_indices, torch.ones_like(bin_indices, dtype=torch.int64) + ) @torch.no_grad() def record_batch(self, outputs: ModelOutputs) -> None: if outputs.sigma is None or not outputs.per_sample_channel_loss: return + sigma = outputs.sigma.detach().flatten().cpu() - per_channel = { - name: loss.detach().flatten().cpu() - for name, loss in outputs.per_sample_channel_loss.items() - } - if any(loss.shape != sigma.shape for loss in per_channel.values()): - raise ValueError( - "Expected per-sample channel losses and sigma to share batch shape" - ) - self._sigmas.append(sigma) - self._per_sample_channel_losses.append(per_channel) + if torch.any(sigma <= 0): + raise ValueError("Sigma must be strictly positive for log10 binning") + log_sigma = torch.log10(sigma) + # Indices in [0, n_bins-1], with out-of-range values placed in edge bins. + bin_indices = torch.bucketize(log_sigma, self._inner_edges) + + per_channel: TensorDict = {} + for name, loss in outputs.per_sample_channel_loss.items(): + per_channel[name] = loss.detach().flatten().cpu() + if per_channel[name].shape != sigma.shape: + raise ValueError( + "Expected per-sample channel losses and sigma to share batch shape" + ) - def _plot_loss_vs_noise( - self, x_values: np.ndarray, y_values: np.ndarray, title: str - ) -> Any: + stacked = torch.stack([value for value in per_channel.values()], dim=-1) + total_loss = torch.mean(stacked, dim=-1).to(torch.float64) + self._total_sum.scatter_add_(0, bin_indices, total_loss) + self._total_count.scatter_add_( + 0, bin_indices, torch.ones_like(bin_indices, dtype=torch.int64) + ) + for name, values in per_channel.items(): + self._accumulate(values=values, bin_indices=bin_indices, name=name) + + def _plot_binned(self, y_values: np.ndarray, counts: np.ndarray, title: str) -> Any: fig, ax = plt.subplots() - ax.scatter(x_values, y_values, s=4, alpha=0.25) + mask = counts > 0 + ax.plot(self._bin_centers[mask], y_values[mask], marker="o", linewidth=1.0) ax.set_xlabel("log10(sigma)") - ax.set_ylabel("weighted loss") + ax.set_ylabel("mean weighted loss") ax.set_title(title) ax.grid(True, alpha=0.3) plt.close(fig) @@ -90,33 +131,33 @@ def _plot_loss_vs_noise( def get_wandb(self, prefix: str = "") -> Mapping[str, Any]: prefix = ensure_trailing_slash(prefix) - if not self._sigmas: + if torch.sum(self._total_count) == 0: return {} - sigma = torch.cat(self._sigmas) - if torch.any(sigma <= 0): - raise ValueError("Sigma must be strictly positive for log10 plotting") - - channel_names = sorted(self._per_sample_channel_losses[0].keys()) - channel_loss = { - name: torch.cat([batch[name] for batch in self._per_sample_channel_losses]) - for name in channel_names - } - total_loss = torch.stack([channel_loss[name] for name in channel_names], dim=-1) - total_loss = torch.mean(total_loss, dim=-1) - - x_values = np.log10(sigma.numpy()) - ret: dict[str, Any] = { - f"{prefix}{self._name}total": self._plot_loss_vs_noise( - x_values=x_values, - y_values=total_loss.numpy(), - title="Total weighted loss vs noise", + ret: dict[str, Any] = {} + total_count = self._total_count.numpy() + total_mean = np.divide( + self._total_sum.numpy(), + total_count, + out=np.zeros_like(self._total_sum.numpy()), + where=total_count > 0, + ) + ret[f"{prefix}{self._name}total"] = self._plot_binned( + y_values=total_mean, + counts=total_count, + title="Total weighted loss vs noise", + ) + for name in sorted(self._channel_sum): + count = self._channel_count[name].numpy() + mean = np.divide( + self._channel_sum[name].numpy(), + count, + out=np.zeros_like(self._channel_sum[name].numpy()), + where=count > 0, ) - } - for name in channel_names: - ret[f"{prefix}{self._name}{name}"] = self._plot_loss_vs_noise( - x_values=x_values, - y_values=channel_loss[name].numpy(), + ret[f"{prefix}{self._name}{name}"] = self._plot_binned( + y_values=mean, + counts=count, title=f"{name} weighted loss vs noise", ) return ret diff --git a/fme/downscaling/aggregators/test_aggregators.py b/fme/downscaling/aggregators/test_aggregators.py index be7c39485..e6a301fb8 100644 --- a/fme/downscaling/aggregators/test_aggregators.py +++ b/fme/downscaling/aggregators/test_aggregators.py @@ -201,8 +201,9 @@ def test_map_aggregator(n_steps: int): aggregator.get_wandb() -def test_loss_vs_noise_aggregator_get_wandb(): - aggregator = LossVsNoiseAggregator() +@pytest.mark.parametrize("prefix", ["train", "validation"]) +def test_loss_vs_noise_aggregator_get_wandb(prefix: str): + aggregator = LossVsNoiseAggregator(n_bins=8) outputs_a = ModelOutputs( prediction={}, target={}, @@ -228,11 +229,16 @@ def test_loss_vs_noise_aggregator_get_wandb(): aggregator.record_batch(outputs_a) aggregator.record_batch(outputs_b) - logs = aggregator.get_wandb(prefix="train") + # Binning happens in record_batch, not get_wandb. + assert int(aggregator._total_count.sum().item()) == 3 + assert int(aggregator._channel_count["x"].sum().item()) == 3 + assert int(aggregator._channel_count["y"].sum().item()) == 3 + + logs = aggregator.get_wandb(prefix=prefix) assert set(logs.keys()) == { - "train/metrics/loss_vs_noise/total", - "train/metrics/loss_vs_noise/x", - "train/metrics/loss_vs_noise/y", + f"{prefix}/metrics/loss_vs_noise/total", + f"{prefix}/metrics/loss_vs_noise/x", + f"{prefix}/metrics/loss_vs_noise/y", } for value in logs.values(): assert hasattr(value, "savefig") From 51d67961c9edc49bf1856788d382a0102c341b61 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 3 Mar 2026 13:01:39 -0800 Subject: [PATCH 14/41] reduce sigma range in agg --- fme/downscaling/aggregators/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fme/downscaling/aggregators/main.py b/fme/downscaling/aggregators/main.py index 953aa6520..9b23a8d6c 100644 --- a/fme/downscaling/aggregators/main.py +++ b/fme/downscaling/aggregators/main.py @@ -59,8 +59,8 @@ def __init__( self, name: str = "metrics/loss_vs_noise", n_bins: int = 40, - log10_sigma_min: float = -4.0, - log10_sigma_max: float = 4.0, + log10_sigma_min: float = -3.0, + log10_sigma_max: float = 3.0, ) -> None: if n_bins < 1: raise ValueError("n_bins must be >= 1") From 8fd871035b1ccd916bb165fdd995fbbcf9dbd6be Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Thu, 5 Mar 2026 18:58:10 +0000 Subject: [PATCH 15/41] respond to PR comments --- fme/core/rand.py | 12 +++++++++++- fme/downscaling/models.py | 8 +++++--- fme/downscaling/noise.py | 16 +++++++--------- fme/downscaling/test_models.py | 4 ++-- fme/downscaling/test_noise.py | 14 ++++++++++++++ 5 files changed, 39 insertions(+), 15 deletions(-) create mode 100644 fme/downscaling/test_noise.py diff --git a/fme/core/rand.py b/fme/core/rand.py index 896f1b9ff..efc399c22 100644 --- a/fme/core/rand.py +++ b/fme/core/rand.py @@ -37,7 +37,7 @@ def randn_like(x: torch.Tensor, **kwargs): return torch.randn_like(x, **kwargs) -def randn(shape: torch.Size, **kwargs): +def randn(shape: torch.Size, **kwargs) -> torch.Tensor: if USE_CPU_RANDN: device = kwargs.pop("device", None) return torch.randn(shape, device="cpu", **kwargs).to(device) @@ -45,6 +45,16 @@ def randn(shape: torch.Size, **kwargs): return torch.randn(shape, **kwargs) +def log_normal_sample(p_mean: float, p_std: float, shape: torch.Size, **randn_kwargs) -> torch.Tensor: + rnd = randn(shape, **randn_kwargs) + return (rnd * p_std + p_mean).exp() + + +def log_uniform_sample(p_min: float, p_max: float, shape: torch.Size) -> torch.Tensor: + return torch.exp( + torch.empty(shape).uniform_(np.log(p_min), np.log(p_max)) + ) + @contextlib.contextmanager def use_cpu_randn(): """ diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index c38206331..455376026 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -91,7 +91,6 @@ class DiffusionModelConfig: in_names: The input variable names for the diffusion model. out_names: The output variable names for the diffusion model. normalization: The normalization configurations for the diffusion model. - sigma_min: Min noise level for generation. sigma_max: Max noise level for generation. churn: The amount of stochasticity during generation. @@ -154,7 +153,10 @@ def __post_init__(self): ) @property - def noise(self) -> LogNormalNoiseDistribution | LogUniformNoiseDistribution: + def noise_distribution(self) -> LogNormalNoiseDistribution | LogUniformNoiseDistribution: + """ + Returns NoiseDistribution object to use for sampling noise in training. + """ if self.training_noise_distribution is not None: return self.training_noise_distribution elif self.p_mean is not None and self.p_std is not None: @@ -364,7 +366,7 @@ def train_on_batch( targets_norm = targets_norm - base_prediction conditioned_target = condition_with_noise_for_training( - targets_norm, self.config.noise, self.sigma_data + targets_norm, self.config.noise_distribution, self.sigma_data ) denoised_norm = self.module( diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index f3dde5c26..bb1d9bcf7 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -4,7 +4,7 @@ import numpy as np import torch -from fme.core.rand import randn, randn_like +from fme.core.rand import randn_like, log_normal_sample, log_uniform_sample @dataclasses.dataclass @@ -35,10 +35,9 @@ class LogNormalNoiseDistribution(NoiseDistribution): p_std: float def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: - rnd = randn([batch_size, 1, 1, 1], device=device) - # This is taken from EDM's original implementation in EDMLoss: - # https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L72-L80 # noqa: E501 - return (rnd * self.p_std + self.p_mean).exp() + return log_normal_sample( + p_mean=self.p_mean, p_std=self.p_std, shape=(batch_size, 1, 1, 1) + ).to(device) @dataclasses.dataclass @@ -47,10 +46,9 @@ class LogUniformNoiseDistribution(NoiseDistribution): p_max: float def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: - sigma = np.exp( - np.random.uniform(np.log(self.p_min), np.log(self.p_max), batch_size) - ) - return torch.tensor(sigma, device=device).reshape(batch_size, 1, 1, 1) + return log_uniform_sample( + p_min=self.p_min, p_max=self.p_max, shape=(batch_size, 1, 1, 1) + ).to(device) def condition_with_noise_for_training( diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 2051ae6de..22276378b 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -430,7 +430,7 @@ def test_lognorm_noise_backwards_compatibility(): use_fine_topography=False, predict_residual=True, ) - assert model_config.noise == LogNormalNoiseDistribution(p_mean=-1.0, p_std=1.0) + assert model_config.noise_distribution == LogNormalNoiseDistribution(p_mean=-1.0, p_std=1.0) model = model_config.build( (32, 32), 2, @@ -440,6 +440,6 @@ def test_lognorm_noise_backwards_compatibility(): # test from_state on checkpoints saved prior to noise distribution classes del state["config"]["training_noise_distribution"] model_from_state = DiffusionModel.from_state(state) - assert model_from_state.config.noise == LogNormalNoiseDistribution( + assert model_from_state.config.noise_distribution == LogNormalNoiseDistribution( p_mean=-1.0, p_std=1.0 ) diff --git a/fme/downscaling/test_noise.py b/fme/downscaling/test_noise.py new file mode 100644 index 000000000..694981c70 --- /dev/null +++ b/fme/downscaling/test_noise.py @@ -0,0 +1,14 @@ +import pytest +from fme.downscaling.noise import LogNormalNoiseDistribution, LogUniformNoiseDistribution + + +@pytest.mark.parametrize( + "noise_distribution", + [ + LogNormalNoiseDistribution(p_mean=0.0, p_std=1.0), + LogUniformNoiseDistribution(p_min=0.01, p_max=100) + ] +) +def test_noise_distribution(noise_distribution): + batch_size = 10 + assert noise_distribution.sample(batch_size=batch_size, device="cpu").shape == (batch_size, 1, 1, 1) From c6094f0f8d90d1d0f60fd3aca0b4996114181d41 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Thu, 5 Mar 2026 18:58:45 +0000 Subject: [PATCH 16/41] reformat --- fme/core/rand.py | 9 +++++---- fme/downscaling/models.py | 4 +++- fme/downscaling/noise.py | 3 +-- fme/downscaling/test_models.py | 4 +++- fme/downscaling/test_noise.py | 19 ++++++++++++++----- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/fme/core/rand.py b/fme/core/rand.py index efc399c22..8f692c091 100644 --- a/fme/core/rand.py +++ b/fme/core/rand.py @@ -45,15 +45,16 @@ def randn(shape: torch.Size, **kwargs) -> torch.Tensor: return torch.randn(shape, **kwargs) -def log_normal_sample(p_mean: float, p_std: float, shape: torch.Size, **randn_kwargs) -> torch.Tensor: +def log_normal_sample( + p_mean: float, p_std: float, shape: torch.Size, **randn_kwargs +) -> torch.Tensor: rnd = randn(shape, **randn_kwargs) return (rnd * p_std + p_mean).exp() def log_uniform_sample(p_min: float, p_max: float, shape: torch.Size) -> torch.Tensor: - return torch.exp( - torch.empty(shape).uniform_(np.log(p_min), np.log(p_max)) - ) + return torch.exp(torch.empty(shape).uniform_(np.log(p_min), np.log(p_max))) + @contextlib.contextmanager def use_cpu_randn(): diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 455376026..297da14c9 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -153,7 +153,9 @@ def __post_init__(self): ) @property - def noise_distribution(self) -> LogNormalNoiseDistribution | LogUniformNoiseDistribution: + def noise_distribution( + self, + ) -> LogNormalNoiseDistribution | LogUniformNoiseDistribution: """ Returns NoiseDistribution object to use for sampling noise in training. """ diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index bb1d9bcf7..1cd473dab 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -1,10 +1,9 @@ import abc import dataclasses -import numpy as np import torch -from fme.core.rand import randn_like, log_normal_sample, log_uniform_sample +from fme.core.rand import log_normal_sample, log_uniform_sample, randn_like @dataclasses.dataclass diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index 22276378b..a5406dc39 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -430,7 +430,9 @@ def test_lognorm_noise_backwards_compatibility(): use_fine_topography=False, predict_residual=True, ) - assert model_config.noise_distribution == LogNormalNoiseDistribution(p_mean=-1.0, p_std=1.0) + assert model_config.noise_distribution == LogNormalNoiseDistribution( + p_mean=-1.0, p_std=1.0 + ) model = model_config.build( (32, 32), 2, diff --git a/fme/downscaling/test_noise.py b/fme/downscaling/test_noise.py index 694981c70..50d929ac5 100644 --- a/fme/downscaling/test_noise.py +++ b/fme/downscaling/test_noise.py @@ -1,14 +1,23 @@ import pytest -from fme.downscaling.noise import LogNormalNoiseDistribution, LogUniformNoiseDistribution + +from fme.downscaling.noise import ( + LogNormalNoiseDistribution, + LogUniformNoiseDistribution, +) @pytest.mark.parametrize( - "noise_distribution", + "noise_distribution", [ LogNormalNoiseDistribution(p_mean=0.0, p_std=1.0), - LogUniformNoiseDistribution(p_min=0.01, p_max=100) - ] + LogUniformNoiseDistribution(p_min=0.01, p_max=100), + ], ) def test_noise_distribution(noise_distribution): batch_size = 10 - assert noise_distribution.sample(batch_size=batch_size, device="cpu").shape == (batch_size, 1, 1, 1) + assert noise_distribution.sample(batch_size=batch_size, device="cpu").shape == ( + batch_size, + 1, + 1, + 1, + ) From 9a0aea1b750b2f3816ccf48c4b451c51e7f10d1f Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Thu, 5 Mar 2026 12:21:25 -0800 Subject: [PATCH 17/41] enforce float32 dtype --- fme/core/rand.py | 12 ++++++++---- fme/downscaling/noise.py | 10 ++++++++-- fme/downscaling/test_models.py | 27 +++++++++++++++++++++++++++ fme/downscaling/test_noise.py | 10 ++++------ 4 files changed, 47 insertions(+), 12 deletions(-) diff --git a/fme/core/rand.py b/fme/core/rand.py index 8f692c091..729b4b95d 100644 --- a/fme/core/rand.py +++ b/fme/core/rand.py @@ -46,14 +46,18 @@ def randn(shape: torch.Size, **kwargs) -> torch.Tensor: def log_normal_sample( - p_mean: float, p_std: float, shape: torch.Size, **randn_kwargs + p_mean: float, p_std: float, shape: torch.Size, dtype: torch.dtype ) -> torch.Tensor: - rnd = randn(shape, **randn_kwargs) + rnd = randn(shape, dtype=dtype) return (rnd * p_std + p_mean).exp() -def log_uniform_sample(p_min: float, p_max: float, shape: torch.Size) -> torch.Tensor: - return torch.exp(torch.empty(shape).uniform_(np.log(p_min), np.log(p_max))) +def log_uniform_sample( + p_min: float, p_max: float, shape: torch.Size, dtype: torch.dtype +) -> torch.Tensor: + return torch.exp( + torch.empty(shape, dtype=dtype).uniform_(np.log(p_min), np.log(p_max)) + ) @contextlib.contextmanager diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index 1cd473dab..48f5b5157 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -35,7 +35,10 @@ class LogNormalNoiseDistribution(NoiseDistribution): def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: return log_normal_sample( - p_mean=self.p_mean, p_std=self.p_std, shape=(batch_size, 1, 1, 1) + p_mean=self.p_mean, + p_std=self.p_std, + shape=(batch_size, 1, 1, 1), + dtype=torch.float32, ).to(device) @@ -46,7 +49,10 @@ class LogUniformNoiseDistribution(NoiseDistribution): def sample(self, batch_size: int, device: torch.device) -> torch.Tensor: return log_uniform_sample( - p_min=self.p_min, p_max=self.p_max, shape=(batch_size, 1, 1, 1) + p_min=self.p_min, + p_max=self.p_max, + shape=(batch_size, 1, 1, 1), + dtype=torch.float32, ).to(device) diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index a5406dc39..a367f9610 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -445,3 +445,30 @@ def test_lognorm_noise_backwards_compatibility(): assert model_from_state.config.noise_distribution == LogNormalNoiseDistribution( p_mean=-1.0, p_std=1.0 ) + + +def test_noise_config_error(): + normalizer = PairedNormalizationConfig( + NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), + NormalizationConfig(means={"x": 0.0}, stds={"x": 1.0}), + ) + + with pytest.raises(ValueError): + DiffusionModelConfig( + module=DiffusionModuleRegistrySelector( + "unet_diffusion_song", {"model_channels": 4} + ), + loss=LossConfig(type="MSE"), + in_names=["x"], + out_names=["x"], + normalization=normalizer, + p_mean=-1.0, + p_std=1.0, + sigma_min=0.1, + sigma_max=1.0, + churn=0.5, + num_diffusion_generation_steps=3, + training_noise_distribution=LogNormalNoiseDistribution(-1.0, 1.0), + use_fine_topography=False, + predict_residual=True, + ) diff --git a/fme/downscaling/test_noise.py b/fme/downscaling/test_noise.py index 50d929ac5..5b12e53f7 100644 --- a/fme/downscaling/test_noise.py +++ b/fme/downscaling/test_noise.py @@ -1,4 +1,5 @@ import pytest +import torch from fme.downscaling.noise import ( LogNormalNoiseDistribution, @@ -15,9 +16,6 @@ ) def test_noise_distribution(noise_distribution): batch_size = 10 - assert noise_distribution.sample(batch_size=batch_size, device="cpu").shape == ( - batch_size, - 1, - 1, - 1, - ) + noise = noise_distribution.sample(batch_size=batch_size, device="cpu") + assert noise.shape == (batch_size, 1, 1, 1) + assert noise.dtype == torch.float32 From 5dacf6bc51bf9f3a392c3b995bb96d5d996558e9 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 6 Mar 2026 15:35:02 -0800 Subject: [PATCH 18/41] add loss_weights config field --- fme/downscaling/models.py | 2 ++ fme/downscaling/test_models.py | 5 ++++- fme/downscaling/test_train.py | 20 +++++++++++++++++++ fme/downscaling/train.py | 36 ++++++++++++++++++++++++++++++++-- 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index ca73c792e..358574d63 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -352,6 +352,7 @@ def train_on_batch( batch: PairedBatchData, static_inputs: StaticInputs | None, optimizer: Optimization | NullOptimization, + loss_weights: torch.Tensor, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" coarse, fine = batch.coarse.data, batch.fine.data @@ -382,6 +383,7 @@ def train_on_batch( weighted_loss = conditioned_target.weight * self.loss( denoised_norm, targets_norm ) + weighted_loss = weighted_loss * loss_weights loss = torch.mean(weighted_loss) optimizer.accumulate_loss(loss) optimizer.step_weights() diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index adf015e10..4efcbd38a 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -224,7 +224,10 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph [batch_size, *coarse_shape], [batch_size, *fine_shape] ) optimization = OptimizationConfig().build(modules=[model.module], max_epochs=2) - train_outputs = model.train_on_batch(batch, static_inputs, optimization) + loss_weights = torch.ones(1, len(model.out_packer.names), 1, 1, device=get_device()) + train_outputs = model.train_on_batch( + batch, static_inputs, optimization, loss_weights=loss_weights + ) assert torch.allclose(train_outputs.target["x"], batch.fine.data["x"]) n_generated_samples = 2 diff --git a/fme/downscaling/test_train.py b/fme/downscaling/test_train.py index 5ed89d36d..ac31491e1 100644 --- a/fme/downscaling/test_train.py +++ b/fme/downscaling/test_train.py @@ -181,6 +181,26 @@ def test_train_main_only( main(config_path=config_path) +def test_train_main_with_loss_weights( + default_trainer_config, tmp_path, very_fast_only: bool +): + """Check that training loop runs with per-variable loss weighting.""" + if very_fast_only: + pytest.skip("Skipping non-fast tests") + + config = _update_in_out_names( + default_trainer_config, ["var0", "var1"], ["var0", "var1"] + ) + config["max_epochs"] = 1 + config["loss_weights"] = {"weights": [{"var0": 2.0}, {"var1": 0.5}]} + config_path = _store_config( + tmp_path, config, filename="train-config-loss-weights.yaml" + ) + + with mock_wandb(): + main(config_path=config_path) + + def test_train_main_logs(default_trainer_config, tmp_path, very_fast_only: bool): """Check that training loop records the appropriate logs.""" if very_fast_only: diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index 6c5a23881..fd41e2866 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -87,6 +87,20 @@ def restore_checkpoint(trainer: "Trainer") -> None: trainer.ema = EMATracker.from_state(ema_checkpoint["ema"], ema_model.modules) +@dataclasses.dataclass +class LossWeights: + weights: list[dict[str, float]] + + def get_weight_tensor( + self, variable_names: list[str], device: torch.device + ) -> torch.Tensor: + weight_map = {} + for mapping in self.weights: + weight_map.update(mapping) + weights = [weight_map.get(name, 1.0) for name in variable_names] + return torch.tensor(weights, device=device).reshape(1, -1, 1, 1) + + class Trainer: def __init__( self, @@ -108,6 +122,15 @@ def __init__( wandb.watch(self.model.modules) self.num_batches_seen = 0 self.config = config + if config.loss_weights is None: + self.loss_weight_tensor = torch.ones( + 1, len(self.model.out_packer.names), 1, 1, device=get_device() + ) + else: + self.loss_weight_tensor = config.loss_weights.get_weight_tensor( + variable_names=self.model.out_packer.names, + device=get_device(), + ) self.patch_data = ( True if (config.coarse_patch_extent_lat and config.coarse_patch_extent_lon) @@ -187,7 +210,12 @@ def train_one_epoch(self) -> None: self.num_batches_seen += 1 if i % 10 == 0: logging.info(f"Training on batch {i+1}") - outputs = self.model.train_on_batch(batch, static_inputs, self.optimization) + outputs = self.model.train_on_batch( + batch, + static_inputs, + self.optimization, + loss_weights=self.loss_weight_tensor, + ) self.ema(self.model.modules) with torch.no_grad(): train_aggregator.record_batch( @@ -261,7 +289,10 @@ def valid_one_epoch(self) -> dict[str, float]: ) for batch, static_inputs in validation_batch_generator: outputs = self.model.train_on_batch( - batch, static_inputs, self.null_optimization + batch, + static_inputs, + self.null_optimization, + loss_weights=self.loss_weight_tensor, ) validation_aggregator.record_batch( outputs=outputs, @@ -405,6 +436,7 @@ class TrainerConfig: experiment_dir: str save_checkpoints: bool logging: LoggingConfig + loss_weights: LossWeights | None = None static_inputs: dict[str, str] | None = None ema: EMAConfig = dataclasses.field(default_factory=EMAConfig) validate_using_ema: bool = False From 0ab8760e868ce80e16e132e53b6b20fff35e8157 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 6 Mar 2026 15:44:08 -0800 Subject: [PATCH 19/41] exp configs --- .../train-100-to-3km-prmsl-output-loguni.yaml | 148 ++++++++++++++++++ .../2026-02-10-downsc-add-pressfc/train.sh | 47 ++++++ 2 files changed, 195 insertions(+) create mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml create mode 100755 configs/experiments/2026-02-10-downsc-add-pressfc/train.sh diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml new file mode 100644 index 000000000..846b943f4 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml @@ -0,0 +1,148 @@ +static_inputs: + HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr +loss_weights: + weights: + - PRATEsfc: 0.0 +model: + use_amp_bf16: true + out_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + in_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + loss: + type: MSE + module: + config: + model_channels: 128 + attn_resolutions: [] + num_blocks: 1 + channel_mult_emb: 6 + channel_mult: + - 1 + - 2 + - 2 + - 2 + - 2 + - 2 + - 2 + use_apex_gn: true + type: unet_diffusion_song_v2 + normalization: + coarse: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc + fine: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc + num_diffusion_generation_steps: 18 + churn: 0.0 + training_noise_distribution: + p_min: 0.002 + p_max: 2000.0 + predict_residual: true + sigma_max: 2000.0 + sigma_min: 0.002 + use_fine_topography: true +optimization: + lr: 0.0001 + optimizer_type: Adam +ema: + decay: 0.999 +validate_using_ema: true +train_data: + sample_with_replacement: 640 + batch_size: 80 # 10 per gpu + num_data_workers: 2 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + lat_extent: + start: -66.0 + stop: 70.0 + # lon_extent: + # start: 0 #230.0 + # stop: 16 #246.0 + strict_ensemble: false +validation_data: + batch_size: 48 + num_data_workers: 4 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 39 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 39 + lat_extent: + start: -20.0 + stop: 20.0 + strict_ensemble: false + drop_last: true +coarse_patch_extent_lat: 16 +coarse_patch_extent_lon: 16 +max_epochs: 150 +validate_interval: 15 +experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 +save_checkpoints: false +logging: + project: multivariate-downscaling + entity: ai2cm + log_to_wandb: true +generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh new file mode 100755 index 000000000..5745da668 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -0,0 +1,47 @@ +#!/bin/bash +# uses the augusta cluster which doesn't have weka access but has GCS access and is +# typically more available than cirrascale clusters + +set -e + +# recommended but not required to change this + +JOB_NAME="xshield-downscaling-100km-to-3km-zeroweight-prate" +CONFIG_FILENAME="train-100-to-3km-prmsl-output-loguni.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME +wandb_group="" + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) +N_GPUS=4 # TODO: change to 8 after testing + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) + +gantry run \ + --name $JOB_NAME \ + --description 'Run downscaling 100km to 3km multivar training' \ + --workspace ai2/climate-titan \ + --priority urgent \ + --preemptible \ + --cluster ai2/titan \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=training \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --weka climate-default:/climate-default \ + --gpus $N_GPUS \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $N_GPUS -m fme.downscaling.train $CONFIG_PATH \ No newline at end of file From 33030be5b8161353800328f6f1e8cc4a6f529b40 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Fri, 6 Mar 2026 15:45:07 -0800 Subject: [PATCH 20/41] up gpus --- configs/experiments/2026-02-10-downsc-add-pressfc/train.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh index 5745da668..10b96ac45 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -6,7 +6,7 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-zeroweight-prate" +JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate" CONFIG_FILENAME="train-100-to-3km-prmsl-output-loguni.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') @@ -16,7 +16,7 @@ wandb_group="" # since we use a service account API key for wandb, we use the beaker username to set the wandb username BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') REPO_ROOT=$(git rev-parse --show-toplevel) -N_GPUS=4 # TODO: change to 8 after testing +N_GPUS=8 # TODO: change to 8 after testing cd $REPO_ROOT # so config path is valid no matter where we are running this script From f6fa1f39bd2aab6c3e98ce5272436f067e42be2f Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Sat, 7 Mar 2026 10:38:23 -0800 Subject: [PATCH 21/41] fix val step --- .../train-100-to-3km-prmsl-output-loguni.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml index 846b943f4..c60aa5c33 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml @@ -122,14 +122,14 @@ validation_data: subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 39 + step: 29 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: pressfc_renamed_to_prmsl_100km.zarr engine: zarr subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 39 + step: 29 lat_extent: start: -20.0 stop: 20.0 From 69c558147302843a9b73bf4a8dfd1d56e3bd58ed Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 9 Mar 2026 08:38:45 -0700 Subject: [PATCH 22/41] eval --- .../eval-100-to-3km-prmsl-output.yaml | 207 ++++++++++++++++++ .../2026-02-10-downsc-add-pressfc/eval.sh | 50 +++++ 2 files changed, 257 insertions(+) create mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml create mode 100755 configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml new file mode 100644 index 000000000..7aaed272d --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml @@ -0,0 +1,207 @@ +experiment_dir: /results +n_samples: 2 +patch: + divide_generation: true + composite_prediction: true + coarse_horizontal_overlap: 1 +model: + checkpoint_path: /checkpoints/ema_ckpt.tar #best.ckpt + model_updates: + num_diffusion_generation_steps: 18 + #sigma_max: 2000.0 + #sigma_min: 0.002 + #churn: 2.0 +data: + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start: 0 + stop: 2 + #start_time: '2023-01-01T00:00:00' + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + # CONUS + lat_extent: + start: 0 + stop: 16 + # start: 22.0 + # stop: 50.0 + lon_extent: + #start: 230.0 + #stop: 295.0 + start: 0 + stop: 16 + batch_size: 2 + num_data_workers: 2 + strict_ensemble: false +logging: + log_to_screen: true + log_to_wandb: true + log_to_file: true + project: multivariate-downscaling + entity: ai2cm +events: +- name: NE_US_Quebec_20230206 + date: 2023-02-06T00:00 + lat_extent: + start: 36 + stop: 52 + lon_extent: + start: 283 + stop: 299 + save_generated_samples: true + n_samples: 16 +- name: WA_AR_20230101 + date: 2023-01-01T06:00 + lat_extent: + start: 36 + stop: 52 + lon_extent: + start: 228 + stop: 244 + save_generated_samples: true + n_samples: 16 +- name: WPacific_hurricane_20230425 + date: 2023-04-25T18:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 130.0 + stop: 146.0 + save_generated_samples: true + n_samples: 16 +- name: WPacific_hurricane_landfall_china_20230510 + date: 2023-05-10T12:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 104 + stop: 120 + save_generated_samples: true + n_samples: 16 +- name: extratropical_cyclone_US_20230403 + date: 2023-04-03T12:00 + lat_extent: + start: 34 + stop: 50 + lon_extent: + start: 254 + stop: 270 + save_generated_samples: true + n_samples: 16 +- name: santa_ana_winds_20231221 + date: 2023-12-21T06:00 + lat_extent: + start: 26 + stop: 42 + lon_extent: + start: 234 + stop: 250 + save_generated_samples: true + n_samples: 16 +- name: alpine_foehn_20230330 + date: 2023-03-30T18:00 + lat_extent: + start: 37 + stop: 53 + lon_extent: + start: 2 + stop: 18 + save_generated_samples: true + n_samples: 16 +- name: hindu_kush_20230122 + date: 2023-01-22T06:00 + lat_extent: + start: 28 + stop: 44 + lon_extent: + start: 60 + stop: 76 + save_generated_samples: true + n_samples: 16 +- name: WPac_tc_20230426T06 + date: 2023-04-26T06:00 + lat_extent: + start: 8 + stop: 24 + lon_extent: + start: 130 + stop: 146 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20230514T06 + date: 2023-05-14T06:00 + lat_extent: + start: 4 + stop: 20 + lon_extent: + start: 117 + stop: 133 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20230517T18 + date: 2023-05-17T18:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 133 + stop: 149 + save_generated_samples: true + n_samples: 16 +- name: Taiwan_tc_landfall_20230707T18 + date: 2023-07-07T18:00 + lat_extent: + start: 14 + stop: 30 + lon_extent: + start: 115 + stop: 131 + save_generated_samples: true + n_samples: 16 +- name: Japan_tc_landfall_20230919T18 + date: 2023-09-19T18:00 + lat_extent: + start: 22 + stop: 38 + lon_extent: + start: 123 + stop: 139 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20231027T00 + date: 2023-10-27T00:00 + lat_extent: + start: 8 + stop: 24 + lon_extent: + start: 115 + stop: 131 + save_generated_samples: true + n_samples: 16 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh new file mode 100755 index 000000000..0d629d31b --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -e + +JOB_NAME="eval-xshield-amip-100km-to-3km-0weight-prate-events" +CONFIG_FILENAME="eval-100-to-3km-prmsl-output.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +N_NODES=1 +NGPU=2 + +IMAGE="$(cat latest_deps_only_image.txt)" + +EXISTING_RESULTS_DATASET=01KK4SG3X0VBJ7MG7JMT2J7TJX +wandb_group="" + +#--not-preemptible \ + +gantry run \ + --name $JOB_NAME \ + --description 'Run 100km to 3km evaluation on coarsened X-SHiELD' \ + --workspace ai2/climate-titan \ + --priority urgent \ + --cluster ai2/jupiter \ + --cluster ai2/titan \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=inference \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --dataset $EXISTING_RESULTS_DATASET:checkpoints:/checkpoints \ + --weka climate-default:/climate-default \ + --gpus $NGPU \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $NGPU -m fme.downscaling.evaluator $CONFIG_PATH \ No newline at end of file From 63a877411af4f41558901ce2fedd427f3f848330 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 9 Mar 2026 12:03:41 -0700 Subject: [PATCH 23/41] update config range --- .../2026-02-10-downsc-add-pressfc/resume.sh | 51 +++++++++++++++++++ .../train-100-to-3km-prmsl-output-loguni.yaml | 15 +++--- 2 files changed, 59 insertions(+), 7 deletions(-) create mode 100755 configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh new file mode 100755 index 000000000..4d0974203 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# uses the augusta cluster which doesn't have weka access but has GCS access and is +# typically more available than cirrascale clusters + +set -e + +# recommended but not required to change this + +JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate-resume" +CONFIG_FILENAME="train-100-to-3km-prmsl-output-loguni.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME +wandb_group="" + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) +N_GPUS=8 # TODO: change to 8 after testing + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) + +PREVIOUS_RESULTS_DATASET="01KK4SG3X0VBJ7MG7JMT2J7TJX" + + +gantry run \ + --name $JOB_NAME \ + --description 'Run downscaling 100km to 3km multivar training' \ + --workspace ai2/climate-titan \ + --priority urgent \ + --preemptible \ + --cluster ai2/titan \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=training \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset $PREVIOUS_RESULTS_DATASET:/previous_results \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --weka climate-default:/climate-default \ + --gpus $N_GPUS \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $N_GPUS -m fme.downscaling.train $CONFIG_PATH \ No newline at end of file diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml index c60aa5c33..6b5228363 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml @@ -1,3 +1,4 @@ +resume_results_dir: /previous_results static_inputs: HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr @@ -44,11 +45,11 @@ model: num_diffusion_generation_steps: 18 churn: 0.0 training_noise_distribution: - p_min: 0.002 - p_max: 2000.0 + p_min: 0.01 + p_max: 200.0 predict_residual: true - sigma_max: 2000.0 - sigma_min: 0.002 + sigma_max: 200.0 + sigma_min: 0.01 use_fine_topography: true optimization: lr: 0.0001 @@ -89,8 +90,8 @@ train_data: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' lat_extent: - start: -66.0 - stop: 70.0 + start: -30 #-66.0 + stop: 30 #70.0 # lon_extent: # start: 0 #230.0 # stop: 16 #246.0 @@ -137,7 +138,7 @@ validation_data: drop_last: true coarse_patch_extent_lat: 16 coarse_patch_extent_lon: 16 -max_epochs: 150 +max_epochs: 300 validate_interval: 15 experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 save_checkpoints: false From e8fd3f1aa53c95dbb156bae8d2a584548c59fc30 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 9 Mar 2026 13:11:33 -0700 Subject: [PATCH 24/41] clamp max noise --- ...in-100-to-3km-prmsl-clamp-loss-weight.yaml | 150 ++++++++++++++++++ .../2026-02-10-downsc-add-pressfc/train.sh | 4 +- fme/downscaling/models.py | 6 +- fme/downscaling/noise.py | 6 + fme/downscaling/train.py | 3 + 5 files changed, 166 insertions(+), 3 deletions(-) create mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml new file mode 100644 index 000000000..8d05f660b --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml @@ -0,0 +1,150 @@ +resume_results_dir: /previous_results +static_inputs: + HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr +loss_weights: + weights: + - PRATEsfc: 0.0 +max_loss_weight: 10.0 +model: + use_amp_bf16: true + out_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + in_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + loss: + type: MSE + module: + config: + model_channels: 128 + attn_resolutions: [] + num_blocks: 1 + channel_mult_emb: 6 + channel_mult: + - 1 + - 2 + - 2 + - 2 + - 2 + - 2 + - 2 + use_apex_gn: true + type: unet_diffusion_song_v2 + normalization: + coarse: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc + fine: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc + num_diffusion_generation_steps: 18 + churn: 0.0 + training_noise_distribution: + p_min: 0.01 + p_max: 200.0 + predict_residual: true + sigma_max: 200.0 + sigma_min: 0.01 + use_fine_topography: true +optimization: + lr: 0.0001 + optimizer_type: Adam +ema: + decay: 0.999 +validate_using_ema: true +train_data: + sample_with_replacement: 640 + batch_size: 80 # 10 per gpu + num_data_workers: 2 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + lat_extent: + start: -66.0 + stop: 70.0 + # lon_extent: + # start: 0 #230.0 + # stop: 16 #246.0 + strict_ensemble: false +validation_data: + batch_size: 48 + num_data_workers: 4 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + lat_extent: + start: -20.0 + stop: 20.0 + strict_ensemble: false + drop_last: true +coarse_patch_extent_lat: 16 +coarse_patch_extent_lon: 16 +max_epochs: 150 +validate_interval: 15 +experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 +save_checkpoints: false +logging: + project: multivariate-downscaling + entity: ai2cm + log_to_wandb: true +generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh index 10b96ac45..d1ff8b0b3 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -6,8 +6,8 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate" -CONFIG_FILENAME="train-100-to-3km-prmsl-output-loguni.yaml" +JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate-clamp-loss-weight" +CONFIG_FILENAME="train-100-to-3km-prmsl-clamp-loss-weight.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 358574d63..007d5484d 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -353,6 +353,7 @@ def train_on_batch( static_inputs: StaticInputs | None, optimizer: Optimization | NullOptimization, loss_weights: torch.Tensor, + max_loss_weight: float | None = None, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" coarse, fine = batch.coarse.data, batch.fine.data @@ -374,7 +375,10 @@ def train_on_batch( targets_norm = targets_norm - base_prediction conditioned_target = condition_with_noise_for_training( - targets_norm, self.config.noise_distribution, self.sigma_data + targets_norm, + self.config.noise_distribution, + self.sigma_data, + max_loss_weight=max_loss_weight, ) denoised_norm = self.module( diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index 48f5b5157..dfbc25775 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -60,6 +60,7 @@ def condition_with_noise_for_training( targets_norm: torch.Tensor, noise_distribution: NoiseDistribution, sigma_data: float, + max_loss_weight: float | None = None, ) -> ConditionedTarget: """ Condition the targets with noise for training. @@ -69,12 +70,17 @@ def condition_with_noise_for_training( noise_distribution: The noise distribution to use for conditioning. sigma_data: The standard deviation of the data, used to determine loss weighting. + max_loss_weight: Optional upper bound on the loss weight. Low sigma + values produce large weights (~1/sigma^2); this clamps the maximum + weight to prevent those samples from dominating the loss. Returns: The conditioned targets and the loss weighting. """ sigma = noise_distribution.sample(targets_norm.shape[0], targets_norm.device) weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + if max_loss_weight is not None: + weight = torch.clamp(weight, max=max_loss_weight) noise = randn_like(targets_norm) * sigma latents = targets_norm + noise return ConditionedTarget(latents=latents, sigma=sigma, weight=weight) diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index fd41e2866..bbd18e6ff 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -215,6 +215,7 @@ def train_one_epoch(self) -> None: static_inputs, self.optimization, loss_weights=self.loss_weight_tensor, + max_loss_weight=self.config.max_loss_weight, ) self.ema(self.model.modules) with torch.no_grad(): @@ -293,6 +294,7 @@ def valid_one_epoch(self) -> dict[str, float]: static_inputs, self.null_optimization, loss_weights=self.loss_weight_tensor, + max_loss_weight=self.config.max_loss_weight, ) validation_aggregator.record_batch( outputs=outputs, @@ -437,6 +439,7 @@ class TrainerConfig: save_checkpoints: bool logging: LoggingConfig loss_weights: LossWeights | None = None + max_loss_weight: float | None = None static_inputs: dict[str, str] | None = None ema: EMAConfig = dataclasses.field(default_factory=EMAConfig) validate_using_ema: bool = False From cd4ead7ef601021690d84ee884b09c92e55aa40c Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 9 Mar 2026 14:10:39 -0700 Subject: [PATCH 25/41] rm previous_results line --- .../train-100-to-3km-prmsl-clamp-loss-weight.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml index 8d05f660b..eaafd7d09 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml @@ -1,4 +1,4 @@ -resume_results_dir: /previous_results +#resume_results_dir: /previous_results static_inputs: HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr From d4fa1a4e3473c285af65f50e416aaed03bafc62e Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 9 Mar 2026 15:29:21 -0700 Subject: [PATCH 26/41] winds only training --- .../train-100-to-3km-prmsl-winds-only.yaml | 150 ++++++++++++++++++ .../2026-02-10-downsc-add-pressfc/train.sh | 4 +- 2 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml new file mode 100644 index 000000000..ab69a09f4 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml @@ -0,0 +1,150 @@ +#resume_results_dir: /previous_results +static_inputs: + HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr +#loss_weights: +# weights: +# - PRATEsfc: 0.0 +#max_loss_weight: 10.0 +model: + use_amp_bf16: true + out_names: + #- PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + in_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + loss: + type: MSE + module: + config: + model_channels: 128 + attn_resolutions: [] + num_blocks: 1 + channel_mult_emb: 6 + channel_mult: + - 1 + - 2 + - 2 + - 2 + - 2 + - 2 + - 2 + use_apex_gn: true + type: unet_diffusion_song_v2 + normalization: + coarse: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc + fine: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc + num_diffusion_generation_steps: 18 + churn: 0.0 + training_noise_distribution: + p_min: 0.01 + p_max: 200.0 + predict_residual: true + sigma_max: 200.0 + sigma_min: 0.01 + use_fine_topography: true +optimization: + lr: 0.0001 + optimizer_type: Adam +ema: + decay: 0.999 +validate_using_ema: true +train_data: + sample_with_replacement: 640 + batch_size: 80 # 10 per gpu + num_data_workers: 2 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + lat_extent: + start: -66.0 + stop: 70.0 + # lon_extent: + # start: 0 #230.0 + # stop: 16 #246.0 + strict_ensemble: false +validation_data: + batch_size: 48 + num_data_workers: 4 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 29 + lat_extent: + start: -20.0 + stop: 20.0 + strict_ensemble: false + drop_last: true +coarse_patch_extent_lat: 16 +coarse_patch_extent_lon: 16 +max_epochs: 150 +validate_interval: 15 +experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 +save_checkpoints: false +logging: + project: multivariate-downscaling + entity: ai2cm + log_to_wandb: true +generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh index d1ff8b0b3..e5143922e 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -6,8 +6,8 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate-clamp-loss-weight" -CONFIG_FILENAME="train-100-to-3km-prmsl-clamp-loss-weight.yaml" +JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only" +CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME From db83e316bc9d725feb5368fe854286f5fbb7f488 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 9 Mar 2026 16:57:21 -0700 Subject: [PATCH 27/41] tropics only, use 1/sigma weighting --- .../train-100-to-3km-prmsl-winds-only.yaml | 21 ++++++++++++------- fme/downscaling/models.py | 2 ++ fme/downscaling/noise.py | 13 +++++++++--- fme/downscaling/train.py | 3 +++ 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml index ab69a09f4..4d13a41a0 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml @@ -6,6 +6,7 @@ static_inputs: # weights: # - PRATEsfc: 0.0 #max_loss_weight: 10.0 +loss_weight_exponent: 0.5 model: use_amp_bf16: true out_names: @@ -91,11 +92,11 @@ train_data: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' lat_extent: - start: -66.0 - stop: 70.0 - # lon_extent: - # start: 0 #230.0 - # stop: 16 #246.0 + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 strict_ensemble: false validation_data: batch_size: 48 @@ -132,9 +133,15 @@ validation_data: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' step: 29 + # lat_extent: + # start: -20.0 + # stop: 20.0 lat_extent: - start: -20.0 - stop: 20.0 + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 strict_ensemble: false drop_last: true coarse_patch_extent_lat: 16 diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index 0c56d7dd2..bb5f48b06 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -356,6 +356,7 @@ def train_on_batch( optimizer: Optimization | NullOptimization, loss_weights: torch.Tensor, max_loss_weight: float | None = None, + loss_weight_exponent: float = 1.0, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" coarse, fine = batch.coarse.data, batch.fine.data @@ -381,6 +382,7 @@ def train_on_batch( self.config.noise_distribution, self.sigma_data, max_loss_weight=max_loss_weight, + loss_weight_exponent=loss_weight_exponent, ) denoised_norm = self.module( diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index dfbc25775..1355bee9b 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -61,6 +61,7 @@ def condition_with_noise_for_training( noise_distribution: NoiseDistribution, sigma_data: float, max_loss_weight: float | None = None, + loss_weight_exponent: float = 1.0, ) -> ConditionedTarget: """ Condition the targets with noise for training. @@ -71,14 +72,20 @@ def condition_with_noise_for_training( sigma_data: The standard deviation of the data, used to determine loss weighting. max_loss_weight: Optional upper bound on the loss weight. Low sigma - values produce large weights (~1/sigma^2); this clamps the maximum - weight to prevent those samples from dominating the loss. + values produce large weights; this clamps the maximum weight to + prevent those samples from dominating the loss. + loss_weight_exponent: Exponent applied to the base EDM loss weight + ``(sigma^2 + sigma_data^2) / (sigma * sigma_data)^2``. The default + of 1.0 gives the standard EDM weighting (~1/sigma^2 for small + sigma). Use 0.5 for ~1/sigma weighting (square root of EDM weight). Returns: The conditioned targets and the loss weighting. """ sigma = noise_distribution.sample(targets_norm.shape[0], targets_norm.device) - weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + weight = ( + (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + ) ** loss_weight_exponent if max_loss_weight is not None: weight = torch.clamp(weight, max=max_loss_weight) noise = randn_like(targets_norm) * sigma diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index bbd18e6ff..0331c15c5 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -216,6 +216,7 @@ def train_one_epoch(self) -> None: self.optimization, loss_weights=self.loss_weight_tensor, max_loss_weight=self.config.max_loss_weight, + loss_weight_exponent=self.config.loss_weight_exponent, ) self.ema(self.model.modules) with torch.no_grad(): @@ -295,6 +296,7 @@ def valid_one_epoch(self) -> dict[str, float]: self.null_optimization, loss_weights=self.loss_weight_tensor, max_loss_weight=self.config.max_loss_weight, + loss_weight_exponent=self.config.loss_weight_exponent, ) validation_aggregator.record_batch( outputs=outputs, @@ -440,6 +442,7 @@ class TrainerConfig: logging: LoggingConfig loss_weights: LossWeights | None = None max_loss_weight: float | None = None + loss_weight_exponent: float = 1.0 static_inputs: dict[str, str] | None = None ema: EMAConfig = dataclasses.field(default_factory=EMAConfig) validate_using_ema: bool = False From a6cfc8777c81847fbdcdd8aba20cf5c166d4eecc Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Mon, 9 Mar 2026 18:36:17 -0700 Subject: [PATCH 28/41] increase val set size --- .../experiments/2026-02-10-downsc-add-pressfc/resume.sh | 7 ++++--- .../train-100-to-3km-prmsl-winds-only.yaml | 8 ++++---- .../experiments/2026-02-10-downsc-add-pressfc/train.sh | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh index 4d0974203..2f705c153 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh @@ -6,8 +6,9 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate-resume" -CONFIG_FILENAME="train-100-to-3km-prmsl-output-loguni.yaml" +JOB_NAME=" +xshield-downscaling-100km-to-3km-winds-prmsl-only-0.5sigmaexp-tropics-resume" +CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME @@ -22,7 +23,7 @@ cd $REPO_ROOT # so config path is valid no matter where we are running this scr IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) -PREVIOUS_RESULTS_DATASET="01KK4SG3X0VBJ7MG7JMT2J7TJX" +PREVIOUS_RESULTS_DATASET="01KKAGKBKK3CPX9TACNA9WA9JM" gantry run \ diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml index 4d13a41a0..84575d74e 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml @@ -109,14 +109,14 @@ validation_data: subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr engine: zarr subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 coarse: - merge: - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling @@ -125,14 +125,14 @@ validation_data: subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: pressfc_renamed_to_prmsl_100km.zarr engine: zarr subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 # lat_extent: # start: -20.0 # stop: 20.0 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh index e5143922e..cb6101005 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -6,7 +6,7 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only" +JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-0.5sigmaexp-tropics" CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') From e1a63dccd71e66eb48e69a7ddea0645c57a0a41b Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 08:30:10 -0700 Subject: [PATCH 29/41] update scripts --- configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh | 4 ++-- configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh index 0d629d31b..923515e67 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -2,7 +2,7 @@ set -e -JOB_NAME="eval-xshield-amip-100km-to-3km-0weight-prate-events" +JOB_NAME="eval-xshield-amip-100km-to-3km-0.5sigmaexp-tropics-events" CONFIG_FILENAME="eval-100-to-3km-prmsl-output.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') @@ -19,7 +19,7 @@ NGPU=2 IMAGE="$(cat latest_deps_only_image.txt)" -EXISTING_RESULTS_DATASET=01KK4SG3X0VBJ7MG7JMT2J7TJX +EXISTING_RESULTS_DATASET=01KKAPG0J6DMHN471DH96FAY2V wandb_group="" #--not-preemptible \ diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh index 2f705c153..8dd8029bb 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh @@ -6,8 +6,7 @@ set -e # recommended but not required to change this -JOB_NAME=" -xshield-downscaling-100km-to-3km-winds-prmsl-only-0.5sigmaexp-tropics-resume" +JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-0.5sigmaexp-tropics-resume" CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') From 0eb82ae94e1d81eec22b69c8236122cb0e3a06ca Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 08:46:00 -0700 Subject: [PATCH 30/41] winds-only resume 0.75exp --- configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh | 4 ++-- .../train-100-to-3km-prmsl-winds-only.yaml | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh index 8dd8029bb..76672008e 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh @@ -6,7 +6,7 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-0.5sigmaexp-tropics-resume" +JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-0.75sigmaexp-tropics-resume" CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') @@ -22,7 +22,7 @@ cd $REPO_ROOT # so config path is valid no matter where we are running this scr IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) -PREVIOUS_RESULTS_DATASET="01KKAGKBKK3CPX9TACNA9WA9JM" +PREVIOUS_RESULTS_DATASET="01KKAPG0J6DMHN471DH96FAY2V" gantry run \ diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml index 84575d74e..d88bc208c 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml @@ -1,4 +1,4 @@ -#resume_results_dir: /previous_results +resume_results_dir: /previous_results static_inputs: HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr @@ -6,7 +6,7 @@ static_inputs: # weights: # - PRATEsfc: 0.0 #max_loss_weight: 10.0 -loss_weight_exponent: 0.5 +loss_weight_exponent: 0.75 model: use_amp_bf16: true out_names: @@ -146,7 +146,7 @@ validation_data: drop_last: true coarse_patch_extent_lat: 16 coarse_patch_extent_lon: 16 -max_epochs: 150 +max_epochs: 900 validate_interval: 15 experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 save_checkpoints: false From f403e07032c0400851caa3dff548634eedc67bf5 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 09:18:22 -0700 Subject: [PATCH 31/41] winds only control --- .../train-100-to-3km-prmsl-winds-only.yaml | 6 +++--- configs/experiments/2026-02-10-downsc-add-pressfc/train.sh | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml index d88bc208c..f22e6f758 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml @@ -1,4 +1,4 @@ -resume_results_dir: /previous_results +#resume_results_dir: /previous_results static_inputs: HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr @@ -6,7 +6,7 @@ static_inputs: # weights: # - PRATEsfc: 0.0 #max_loss_weight: 10.0 -loss_weight_exponent: 0.75 +#loss_weight_exponent: 0.75 model: use_amp_bf16: true out_names: @@ -146,7 +146,7 @@ validation_data: drop_last: true coarse_patch_extent_lat: 16 coarse_patch_extent_lon: 16 -max_epochs: 900 +max_epochs: 500 validate_interval: 15 experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 save_checkpoints: false diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh index cb6101005..f2a4cfe16 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -6,7 +6,7 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-0.5sigmaexp-tropics" +JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-tropics" CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') @@ -26,9 +26,10 @@ gantry run \ --name $JOB_NAME \ --description 'Run downscaling 100km to 3km multivar training' \ --workspace ai2/climate-titan \ - --priority urgent \ + --priority low \ --preemptible \ --cluster ai2/titan \ + --cluster ai2/jupiter \ --beaker-image $IMAGE \ --env WANDB_USERNAME=$BEAKER_USERNAME \ --env WANDB_NAME=$JOB_NAME \ From 62ef51296ab7d308c29a4f8200a0a31e766d848d Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 10:46:18 -0700 Subject: [PATCH 32/41] fix docstring --- scripts/downscaling/plot_events.py | 375 +++++++++++++++++++++++++++++ 1 file changed, 375 insertions(+) create mode 100644 scripts/downscaling/plot_events.py diff --git a/scripts/downscaling/plot_events.py b/scripts/downscaling/plot_events.py new file mode 100644 index 000000000..a1ca6711d --- /dev/null +++ b/scripts/downscaling/plot_events.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python +""" +Fetch netCDF event files from a beaker dataset and generate map and histogram +plots for each variable (coarse, target, and predicted ensemble samples). + +This works with saved event outputs from `fme.downscaling.evaluator` from a +beaker experiment. It downloads the experiment files to a temporary directory, +parses filenames for *YYYYMMDD*.nc event outputs, and merges in +coarse data for map comparison. If no local directory is provided for the coarse data, +it will be read from a hard coded GCS path. + +Usage: + python plot_events.py [--output-dir ] + [--coarse-data ] [--variables VAR1 VAR2 ...] + +Requires: + beaker CLI to be installed and authenticated (https://github.com/allenai/beaker). +""" + +import argparse +import math +import re +import subprocess +import tempfile +import warnings +from pathlib import Path + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import cftime +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from cartopy.feature import ShapelyFeature +from cartopy.io import shapereader +from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER + +warnings.filterwarnings("ignore") + +from plot_beaker_histograms import plot_histogram_lines + +TIME_SEL = slice(cftime.DatetimeJulian(2023, 1, 1), None) +# Matching for *YYYYMMDD*.nc (date can appear anywhere in the filename) +_EVENT_FILE_RE = re.compile(r"(.+?)[\._-]?(\d{8})[\._-]?(.*)\.nc$") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Generate map plots from beaker dataset event files" + ) + parser.add_argument( + "beaker_dataset_id", + help="The beaker dataset ID to fetch", + ) + parser.add_argument( + "--output-dir", + default="./event_maps", + help="Output directory for figures (default: ./event_maps)", + ) + parser.add_argument( + "--coarse-data", + default=None, + help="Path to coarse data (default: None)", + ) + parser.add_argument( + "--variables", + nargs="*", + default=None, + help="Filter to only these variables (default: all eligible variables)", + ) + return parser.parse_args() + + +# Create a STATES feature with no fill +# Read state geometries from shapefile +shpfilename = shapereader.natural_earth( + resolution="50m", category="cultural", name="admin_1_states_provinces" +) +reader = shapereader.Reader(shpfilename) +states_feature = ShapelyFeature( + reader.geometries(), + ccrs.PlateCarree(), + facecolor="none", + edgecolor="lightgrey", + linewidth=0.35, + linestyle="--", +) + + +def add_outer_latlon_grid(ax, *, show_left, show_bottom): + gl = ax.gridlines( + draw_labels=True, + linewidth=0.5, + color="gray", + alpha=0.5, + linestyle="--", + ) + + # Explicitly disable all labels first + gl.top_labels = False + gl.right_labels = False + gl.left_labels = False + gl.bottom_labels = False + + # Enable only outer ones + if show_left: + gl.left_labels = True + if show_bottom: + gl.bottom_labels = True + + gl.xlabel_style = {"size": 8} + gl.ylabel_style = {"size": 8} + gl.xformatter = LONGITUDE_FORMATTER + gl.yformatter = LATITUDE_FORMATTER + + # square grid cells + ax.set_aspect("equal", adjustable="box") + + +def get_coarse_data(path: str | None, time_sel: slice | None = TIME_SEL) -> xr.Dataset: + if path is not None: + return xr.open_zarr(path) + else: + winds = xr.open_zarr( + "gs://vcm-ml-raw-flexible-retention/2025-07-25-X-SHiELD-AMIP-FME/regridded-zarrs/gaussian_grid_180_by_360/control/instantaneous_physics_fields.zarr" + ).sel(time=time_sel)[ + ["eastward_wind_at_ten_meters", "northward_wind_at_ten_meters"] + ] + prate = xr.open_zarr( + "gs://vcm-ml-raw-flexible-retention/2025-07-25-X-SHiELD-AMIP-FME/regridded-zarrs/gaussian_grid_180_by_360/control/fluxes_2d.zarr" + ).sel(time=time_sel)["PRATEsfc"] + pres = xr.open_zarr( + "gs://vcm-ml-raw-flexible-retention/2025-07-25-X-SHiELD-AMIP-FME/regridded-zarrs/gaussian_grid_180_by_360/control/column_integrated_dynamical_fields.zarr" + ).sel(time=time_sel)["PRESsfc"] + # in training, PRESsfc is used as input for outputting PRMSL + prmsl = pres.rename("PRMSL") + return xr.merge([winds, prate, pres, prmsl]) + + +def bbox(lat, lon, width=2.0): + return { + "lat": slice(lat - width / 2.0, lat + width / 2.0), + "lon": slice(lon - width / 2.0, lon + width / 2.0), + } + + +def upsample_array(x: np.ndarray, upsample_factor: int = 32) -> np.ndarray: + # upsample coarse data for plotting with fine res + x = np.repeat(x, upsample_factor, axis=0) # repeat rows + x = np.repeat(x, upsample_factor, axis=1) # repeat columns + return x + + +def plot_event(ds, var_name, samples=None, sel=None, n_cols=5, **plot_kwargs): + if samples is None: + samples = list(range(ds.sample.size)) + N = 2 + len(samples) + n_rows = math.ceil(N / n_cols) + + fig, axes = plt.subplots( + n_rows, + n_cols, + figsize=(2 * n_cols, 2 * n_rows), + subplot_kw={"projection": ccrs.PlateCarree()}, + ) + + axes = axes.ravel() # 1D array, easy to index + # Use only the first N axes + for ax in axes[N:]: + ax.set_visible(False) + + suffixes = ["coarse", "target", "predicted"] + vars = [f"{var_name}_{suffix}" for suffix in suffixes] + ds_ = ds[vars] + + if sel: + ds_ = ds_.sel(sel) + + if len(samples) == 0: + samples = [0] + + if var_name == "PRMSL": + # fill PRMSL_coarse with nans + ds_["PRMSL_coarse"].values[:] = np.nan + + vmax = ds_.to_array().max() + if ds_.to_array().min() < -0.2: + plot_kwargs["cmap"] = "RdBu_r" + else: + plot_kwargs["cmap"] = "turbo" + plot_kwargs["vmin"] = min(0, ds_.to_array().min()) + if "vmax" not in plot_kwargs: + plot_kwargs["vmax"] = vmax + # coarse and target + for i, var in enumerate(vars[:2]): + ax = axes[i] + + da = ds_[var] + img = da.plot(ax=ax, add_colorbar=False, **plot_kwargs) + + ax.set_title(suffixes[i], fontsize=10) + ax.add_feature(states_feature) + ax.add_feature(cfeature.BORDERS, color="lightgrey") + ax.coastlines(color="lightgrey") + + row = i // n_cols + col = i % n_cols + + add_outer_latlon_grid( + ax, + show_left=(col == 0), + show_bottom=(row == n_rows - 1), + ) + + for i, s in enumerate(samples): + ax = axes[2 + i] + da = ds_[vars[-1]].isel(sample=s) + + img = da.plot(ax=ax, add_colorbar=False, **plot_kwargs) + ax.set_title(f"predicted {s}", fontsize=10) + ax.add_feature(states_feature) + ax.coastlines(color="lightgrey") + ax.add_feature(cfeature.BORDERS, linestyle="-", color="lightgrey") + row = (i + 2) // n_cols + col = (i + 2) % n_cols + add_outer_latlon_grid( + ax, + show_left=(col == 0), + show_bottom=(row == n_rows - 1), + ) + cbar_ax = fig.add_axes([0.99, 0.25, 0.01, 0.5]) # [left, bottom, width, height] + cbar = fig.colorbar(img, cax=cbar_ax) + cbar.set_label(f"{var_name} [m/s]") + # plt.tight_layout() + + return fig, axes + + +def fetch_beaker_dataset(dataset_id: str, target_dir: str) -> None: + """Fetch a beaker dataset to the specified directory.""" + subprocess.run( + ["beaker", "dataset", "fetch", dataset_id, "--output", target_dir], + check=True, + ) + + +def find_event_files(directory: str) -> dict[str, Path]: + """Find netCDF files matching the event naming pattern, keyed by event name.""" + event_files = {} + for p in sorted(Path(directory).glob("*.nc")): + # extract event name + matched = _EVENT_FILE_RE.match(p.name) + if matched: + prefix, date, suffix = matched.group(1), matched.group(2), matched.group(3) + parts = [s for s in (prefix, suffix) if s] + event_name = f"{'_'.join(parts)}_{date}" + event_files[event_name] = p + return event_files + + +def detect_variable_pairs(ds: xr.Dataset) -> list[str]: + """Detect variables that have both _predicted and _target versions.""" + predicted = { + v[: -len("_predicted")] for v in ds.data_vars if v.endswith("_predicted") + } + target = {v[: -len("_target")] for v in ds.data_vars if v.endswith("_target")} + return sorted(predicted & target) + + +def filename_to_datetime(filename: str) -> cftime.DatetimeJulian: + match = re.search(r"(\d{4})(\d{2})(\d{2})(?:T(\d{2}))?", filename) + if match is None: + raise ValueError(f"Could not parse date from filename: {filename}") + return cftime.DatetimeJulian( + int(match.group(1)), + int(match.group(2)), + int(match.group(3)), + int(match.group(4) or 12), + ) + + +def add_wind_speed(ds: xr.Dataset) -> xr.Dataset: + variables = detect_variable_pairs(ds) + if ( + "eastward_wind_at_ten_meters" in variables + and "northward_wind_at_ten_meters" in variables + ): + ds["wind_speed_target"] = np.sqrt( + ds.eastward_wind_at_ten_meters_target**2 + + ds.northward_wind_at_ten_meters_target**2 + ) + ds["wind_speed_predicted"] = np.sqrt( + ds.eastward_wind_at_ten_meters_predicted**2 + + ds.northward_wind_at_ten_meters_predicted**2 + ) + ds["wind_speed_coarse"] = np.sqrt( + ds.eastward_wind_at_ten_meters_coarse**2 + + ds.northward_wind_at_ten_meters_coarse**2 + ) + return ds + + +def merge_coarse( + event: xr.Dataset, coarse: xr.Dataset, datetime: cftime.DatetimeJulian +) -> xr.Dataset: + _coarse = coarse.sel( + time=datetime, + grid_yt=slice(event.lat.min(), event.lat.max()), + grid_xt=slice(event.lon.min(), event.lon.max()), + ) + for var in detect_variable_pairs(event): + event[f"{var}_coarse"] = xr.DataArray( + upsample_array(_coarse[var].values, 32), dims=["lat", "lon"] + ) + return event + + +def main(): + args = parse_args() + beaker_id = args.beaker_dataset_id + output_dir = Path(args.output_dir) + coarse = get_coarse_data(args.coarse_data, time_sel=TIME_SEL) + + print(f"Fetching beaker dataset: {beaker_id}") + + with tempfile.TemporaryDirectory() as temp_dir: + fetch_beaker_dataset(beaker_id, temp_dir) + + event_files = find_event_files(temp_dir) + if not event_files: + print(f"No event files found in dataset {beaker_id}") + return + + print(f"Found {len(event_files)} event file(s)") + + for event_name, nc_file in event_files.items(): + output_event_dir = output_dir / beaker_id / event_name + output_event_dir.mkdir(parents=True, exist_ok=True) + + print(f"Processing: {nc_file.name} -> {output_event_dir}") + + event = xr.open_dataset(nc_file) + event = merge_coarse( + event, coarse, datetime=filename_to_datetime(nc_file.name) + ) + event = add_wind_speed(event) + variables = detect_variable_pairs(event) + if args.variables is not None: + variables = [v for v in variables if v in args.variables] + + if not variables: + print(f" No variable pairs found in {nc_file.name}") + continue + for var in variables: + fig, axes = plot_event(event, var) + fig.savefig( + output_event_dir / f"{var}_generated_maps.png", + transparent=True, + dpi=300, + bbox_inches="tight", + ) + plt.close(fig) + plot_histogram_lines( + event, + var, + event_name, + save_path=output_event_dir / f"{var}_histogram.png", + ) + event.close() + + print("Done!") + + +if __name__ == "__main__": + main() From e654a3f0e3dbb9fc77455965eea311df93f70158 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 10:48:31 -0700 Subject: [PATCH 33/41] resume on low priority 0weight prate --- .../2026-02-10-downsc-add-pressfc/resume.sh | 9 +++--- ...in-100-to-3km-prmsl-clamp-loss-weight.yaml | 30 +++++++++++-------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh index 76672008e..7aa647448 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh @@ -6,8 +6,8 @@ set -e # recommended but not required to change this -JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-0.75sigmaexp-tropics-resume" -CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" +JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate-tropics-resume" +CONFIG_FILENAME="train-100-to-3km-prmsl-clamp-loss-weight.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME @@ -22,16 +22,17 @@ cd $REPO_ROOT # so config path is valid no matter where we are running this scr IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) -PREVIOUS_RESULTS_DATASET="01KKAPG0J6DMHN471DH96FAY2V" +PREVIOUS_RESULTS_DATASET="01KK9ZRC9M9MF7T45T4XAP7GF0" gantry run \ --name $JOB_NAME \ --description 'Run downscaling 100km to 3km multivar training' \ --workspace ai2/climate-titan \ - --priority urgent \ + --priority low \ --preemptible \ --cluster ai2/titan \ + --cluster ai2/jupiter \ --beaker-image $IMAGE \ --env WANDB_USERNAME=$BEAKER_USERNAME \ --env WANDB_NAME=$JOB_NAME \ diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml index eaafd7d09..81f5cea86 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml @@ -1,4 +1,4 @@ -#resume_results_dir: /previous_results +resume_results_dir: /previous_results static_inputs: HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr @@ -90,12 +90,15 @@ train_data: subset: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' + #lat_extent: + # start: -66.0 + # stop: 70.0 lat_extent: - start: -66.0 - stop: 70.0 - # lon_extent: - # start: 0 #230.0 - # stop: 16 #246.0 + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 strict_ensemble: false validation_data: batch_size: 48 @@ -108,14 +111,14 @@ validation_data: subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr engine: zarr subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 coarse: - merge: - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling @@ -124,17 +127,20 @@ validation_data: subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: pressfc_renamed_to_prmsl_100km.zarr engine: zarr subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 lat_extent: - start: -20.0 - stop: 20.0 + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 strict_ensemble: false drop_last: true coarse_patch_extent_lat: 16 From 4aedd3c4677b94de3c6f2ef5a36aeaeb8ddb4fb4 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 11:42:13 -0700 Subject: [PATCH 34/41] update to tropics region --- .../train-100-to-3km-prmsl-output-loguni.yaml | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml index 6b5228363..f06db261e 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml @@ -2,9 +2,9 @@ resume_results_dir: /previous_results static_inputs: HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr -loss_weights: - weights: - - PRATEsfc: 0.0 +#loss_weights: +# weights: +# - PRATEsfc: 0.0 model: use_amp_bf16: true out_names: @@ -45,11 +45,11 @@ model: num_diffusion_generation_steps: 18 churn: 0.0 training_noise_distribution: - p_min: 0.01 - p_max: 200.0 + p_min: 0.005 + p_max: 2000.0 predict_residual: true - sigma_max: 200.0 - sigma_min: 0.01 + sigma_max: 2000.0 + sigma_min: 0.005 use_fine_topography: true optimization: lr: 0.0001 @@ -90,8 +90,11 @@ train_data: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' lat_extent: - start: -30 #-66.0 - stop: 30 #70.0 + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 # lon_extent: # start: 0 #230.0 # stop: 16 #246.0 @@ -107,14 +110,14 @@ validation_data: subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr engine: zarr subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 coarse: - merge: - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling @@ -123,22 +126,25 @@ validation_data: subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: pressfc_renamed_to_prmsl_100km.zarr engine: zarr subset: start_time: '2023-01-01T00:00:00' stop_time: '2024-01-01T00:00:00' - step: 29 + step: 9 lat_extent: - start: -20.0 - stop: 20.0 + start: 0 #-20 + stop: 35 #20.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 strict_ensemble: false drop_last: true coarse_patch_extent_lat: 16 coarse_patch_extent_lon: 16 -max_epochs: 300 +max_epochs: 600 validate_interval: 15 experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 save_checkpoints: false From d29e7d19266836b98fec846ec48dd9d70b0e84b0 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 16:53:07 -0700 Subject: [PATCH 35/41] global eval --- .../eval-global.yaml | 36 +++++++++++++++++++ .../2026-02-10-downsc-add-pressfc/eval.sh | 9 +++-- 2 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml new file mode 100644 index 000000000..8922eddd1 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml @@ -0,0 +1,36 @@ +experiment_dir: /results +n_samples: 2 +patch: + divide_generation: true + composite_prediction: true + coarse_horizontal_overlap: 1 +model: + checkpoint_path: /checkpoints/best.ckpt +data: + #topography: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + coarse: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + engine: zarr + file_pattern: 100km.zarr + subset: + start_time: 2023-01-01T00:00 + step: 29 + fine: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + engine: zarr + file_pattern: 3km.zarr + subset: + start_time: 2023-01-01T00:00 + step: 29 + lat_extent: + start: -66 + stop: 70.0 + batch_size: 6 + num_data_workers: 2 + strict_ensemble: false +logging: + log_to_screen: true + log_to_wandb: true + log_to_file: true + project: andrep-downscaling + entity: ai2cm diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh index 923515e67..e5de793f7 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -2,8 +2,10 @@ set -e -JOB_NAME="eval-xshield-amip-100km-to-3km-0.5sigmaexp-tropics-events" -CONFIG_FILENAME="eval-100-to-3km-prmsl-output.yaml" +#JOB_NAME="eval-xshield-amip-100km-to-3km-0.5sigmaexp-tropics-events" +JOB_NAME="eval-xshield-amip-100km-to-3km-loguni-multivariate-global" + +CONFIG_FILENAME="eval-global.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME @@ -19,7 +21,7 @@ NGPU=2 IMAGE="$(cat latest_deps_only_image.txt)" -EXISTING_RESULTS_DATASET=01KKAPG0J6DMHN471DH96FAY2V +EXISTING_RESULTS_DATASET=01KKCHH6F91QR2XHCPEHN0B5VR wandb_group="" #--not-preemptible \ @@ -29,6 +31,7 @@ gantry run \ --description 'Run 100km to 3km evaluation on coarsened X-SHiELD' \ --workspace ai2/climate-titan \ --priority urgent \ + --not-preemptible \ --cluster ai2/jupiter \ --cluster ai2/titan \ --beaker-image $IMAGE \ From 3df1c4019fc71d381729ce353cc991f67057b34d Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Tue, 10 Mar 2026 17:59:20 -0700 Subject: [PATCH 36/41] use datasets with prmsl --- .../eval-global.yaml | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml index 8922eddd1..7d914ce78 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml @@ -8,20 +8,34 @@ model: checkpoint_path: /checkpoints/best.ckpt data: #topography: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr - coarse: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - engine: zarr - file_pattern: 100km.zarr - subset: - start_time: 2023-01-01T00:00 - step: 29 fine: - - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling - engine: zarr - file_pattern: 3km.zarr - subset: - start_time: 2023-01-01T00:00 - step: 29 + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' lat_extent: start: -66 stop: 70.0 From 9c863a67bb2c8d68ae9881a88e8de5fc577bc571 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 11 Mar 2026 08:21:01 -0700 Subject: [PATCH 37/41] step 29 --- .../2026-02-10-downsc-add-pressfc/eval-global.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml index 7d914ce78..8742dc528 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml @@ -16,12 +16,14 @@ data: subset: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' + step: 29 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr engine: zarr subset: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' + step: 29 coarse: - merge: - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling @@ -30,12 +32,14 @@ data: subset: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' + step: 29 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: pressfc_renamed_to_prmsl_100km.zarr engine: zarr subset: start_time: '2014-01-01T00:00:00' stop_time: '2022-12-31T23:59:00' + step: 29 lat_extent: start: -66 stop: 70.0 From 6f37b3179bd8c531ec53670ac41d50551c99b089 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 11 Mar 2026 08:23:23 -0700 Subject: [PATCH 38/41] eval with manuscript prate ckpt --- .../eval-100-to-3km-prmsl-output.yaml | 2 +- configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml index 7aaed272d..c5fd76ac7 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml @@ -5,7 +5,7 @@ patch: composite_prediction: true coarse_horizontal_overlap: 1 model: - checkpoint_path: /checkpoints/ema_ckpt.tar #best.ckpt + checkpoint_path: /checkpoints/best_histogram_tail.ckpt #ema_ckpt.tar #best.ckpt model_updates: num_diffusion_generation_steps: 18 #sigma_max: 2000.0 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh index e5de793f7..c4cbcaff9 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -3,9 +3,9 @@ set -e #JOB_NAME="eval-xshield-amip-100km-to-3km-0.5sigmaexp-tropics-events" -JOB_NAME="eval-xshield-amip-100km-to-3km-loguni-multivariate-global" +JOB_NAME="eval-xshield-amip-100km-to-3km-prate-only-events" -CONFIG_FILENAME="eval-global.yaml" +CONFIG_FILENAME="eval-100-to-3km-prmsl-output.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME @@ -21,7 +21,7 @@ NGPU=2 IMAGE="$(cat latest_deps_only_image.txt)" -EXISTING_RESULTS_DATASET=01KKCHH6F91QR2XHCPEHN0B5VR +EXISTING_RESULTS_DATASET=01K8XGEEAJRHN8JRZE4WXQRDVD wandb_group="" #--not-preemptible \ From 5e5c874db0415c61ed82d6cebbbed5f88a906097 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 11 Mar 2026 08:41:07 -0700 Subject: [PATCH 39/41] global eval --- .../experiments/2026-02-10-downsc-add-pressfc/eval.sh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh index c4cbcaff9..baaaac924 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -3,9 +3,9 @@ set -e #JOB_NAME="eval-xshield-amip-100km-to-3km-0.5sigmaexp-tropics-events" -JOB_NAME="eval-xshield-amip-100km-to-3km-prate-only-events" +JOB_NAME="eval-xshield-amip-100km-to-3km-loguni-multivariate-global" -CONFIG_FILENAME="eval-100-to-3km-prmsl-output.yaml" +CONFIG_FILENAME="eval-global.yaml" SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME @@ -17,11 +17,11 @@ REPO_ROOT=$(git rev-parse --show-toplevel) cd $REPO_ROOT # so config path is valid no matter where we are running this script N_NODES=1 -NGPU=2 +NGPU=4 IMAGE="$(cat latest_deps_only_image.txt)" -EXISTING_RESULTS_DATASET=01K8XGEEAJRHN8JRZE4WXQRDVD +EXISTING_RESULTS_DATASET=01KK1ZH1KFJSHGR24W531AGKVE wandb_group="" #--not-preemptible \ @@ -31,7 +31,6 @@ gantry run \ --description 'Run 100km to 3km evaluation on coarsened X-SHiELD' \ --workspace ai2/climate-titan \ --priority urgent \ - --not-preemptible \ --cluster ai2/jupiter \ --cluster ai2/titan \ --beaker-image $IMAGE \ @@ -50,4 +49,4 @@ gantry run \ --no-conda \ --install "pip install --no-deps ." \ --allow-dirty \ - -- torchrun --nproc_per_node $NGPU -m fme.downscaling.evaluator $CONFIG_PATH \ No newline at end of file + -- torchrun --nproc_per_node $NGPU -m fme.downscaling.evaluator $CONFIG_PATH From a9d1cfbd8ee087317f45eec6f0915ede4f1d95fd Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 11 Mar 2026 14:32:13 -0700 Subject: [PATCH 40/41] fix time range --- .../2026-02-10-downsc-add-pressfc/eval-global.yaml | 12 ++++-------- .../2026-02-10-downsc-add-pressfc/eval.sh | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml index 8742dc528..e9a38af11 100644 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml @@ -14,15 +14,13 @@ data: file_pattern: 3km.zarr engine: zarr subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' + start_time: '2023-01-01T00:00:00' step: 29 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr engine: zarr subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' + start_time: '2023-01-01T00:00:00' step: 29 coarse: - merge: @@ -30,15 +28,13 @@ data: file_pattern: 100km.zarr engine: zarr subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' + start_time: '2023-01-01T00:00:00' step: 29 - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data file_pattern: pressfc_renamed_to_prmsl_100km.zarr engine: zarr subset: - start_time: '2014-01-01T00:00:00' - stop_time: '2022-12-31T23:59:00' + start_time: '2023-01-01T00:00:00' step: 29 lat_extent: start: -66 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh index baaaac924..71afb3b46 100755 --- a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -21,7 +21,7 @@ NGPU=4 IMAGE="$(cat latest_deps_only_image.txt)" -EXISTING_RESULTS_DATASET=01KK1ZH1KFJSHGR24W531AGKVE +EXISTING_RESULTS_DATASET=01KK07E72Z2H4CSP1EWXCT25FB wandb_group="" #--not-preemptible \ From b5d127b54925c47ef148cc70308c7f9dec1250d8 Mon Sep 17 00:00:00 2001 From: Anna Kwa Date: Wed, 11 Mar 2026 14:46:44 -0700 Subject: [PATCH 41/41] create BatchData wrapper in CascadePredictor._generate, make _generate methods private --- fme/downscaling/models.py | 6 ++-- fme/downscaling/predictors/cascade.py | 40 ++++++++++++++++++---- fme/downscaling/predictors/test_cascade.py | 2 +- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index bb5f48b06..03fa9fa8d 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -425,7 +425,7 @@ def train_on_batch( ) @torch.no_grad() - def generate( + def _generate( self, coarse_data: TensorMapping, static_inputs: StaticInputs | None, @@ -483,7 +483,7 @@ def generate_on_batch_no_target( static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: - generated, _, _ = self.generate(batch.data, static_inputs, n_samples) + generated, _, _ = self._generate(batch.data, static_inputs, n_samples) return generated @torch.no_grad() @@ -494,7 +494,7 @@ def generate_on_batch( n_samples: int = 1, ) -> ModelOutputs: coarse, fine = batch.coarse.data, batch.fine.data - generated, generated_norm, latent_steps = self.generate( + generated, generated_norm, latent_steps = self._generate( coarse, static_inputs, n_samples ) diff --git a/fme/downscaling/predictors/cascade.py b/fme/downscaling/predictors/cascade.py index 0f64c73b5..3399c3db6 100644 --- a/fme/downscaling/predictors/cascade.py +++ b/fme/downscaling/predictors/cascade.py @@ -2,6 +2,7 @@ import math import torch +import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device @@ -15,6 +16,7 @@ adjust_fine_coord_range, scale_tuple, ) +from fme.downscaling.data.utils import BatchedLatLonCoordinates from fme.downscaling.metrics_and_maths import filter_tensor_mapping from fme.downscaling.models import CheckpointModelConfig, DiffusionModel, ModelOutputs from fme.downscaling.requirements import DataRequirements @@ -86,6 +88,26 @@ def _restore_batch_and_sample_dims(data: TensorMapping, n_samples: int): return unfold_ensemble_dim(squeezed, n_samples) +def _batch_data_with_unused_coords(data: TensorMapping) -> BatchData: + # wrapper function so that we can call each level's + # public generate_on_batch_no_target function using tensormapping + # from the previous step. + data_shape = next(iter(data.values())).shape + time = xr.DataArray( + [0 for _ in range(data_shape[0])], + dims=["time"], + ) + latlon_coordinates = BatchedLatLonCoordinates( + lat=torch.zeros((data_shape[0], data_shape[1]), device=get_device()), + lon=torch.zeros((data_shape[0], data_shape[2]), device=get_device()), + ) + return BatchData( + data=data, + time=time, + latlon_coordinates=latlon_coordinates, + ) + + class CascadePredictor: def __init__( self, models: list[DiffusionModel], static_inputs: list[StaticInputs | None] @@ -116,22 +138,26 @@ def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([model.modules for model in self.models]) @torch.no_grad() - def generate( + def _generate( self, coarse: TensorMapping, n_samples: int, static_inputs: list[StaticInputs | None], ): current_coarse = coarse - for i, (model, fine_topography) in enumerate(zip(self.models, static_inputs)): + for i, (model, step_static_inputs) in enumerate( + zip(self.models, static_inputs) + ): sample_data = next(iter(current_coarse.values())) batch_size = sample_data.shape[0] # n_samples are generated for the first step, and subsequent models # generate 1 sample n_samples_cascade_step = n_samples if i == 0 else 1 - generated, generated_norm, latent_steps = model.generate( - current_coarse, fine_topography, n_samples_cascade_step + generated = model.generate_on_batch_no_target( + _batch_data_with_unused_coords(current_coarse), + step_static_inputs, + n_samples_cascade_step, ) generated = { k: v.reshape(batch_size * n_samples_cascade_step, *v.shape[-2:]) @@ -139,7 +165,7 @@ def generate( } current_coarse = generated generated = _restore_batch_and_sample_dims(generated, n_samples) - return generated, generated_norm, latent_steps + return generated @torch.no_grad() def generate_on_batch_no_target( @@ -151,7 +177,7 @@ def generate_on_batch_no_target( subset_static_inputs = self._get_subset_static_inputs( coarse_coords=batch.latlon_coordinates[0] ) - generated, _, _ = self.generate(batch.data, n_samples, subset_static_inputs) + generated = self._generate(batch.data, n_samples, subset_static_inputs) return generated @torch.no_grad() @@ -164,7 +190,7 @@ def generate_on_batch( static_inputs = self._get_subset_static_inputs( coarse_coords=batch.coarse.latlon_coordinates[0] ) - generated, _, latent_steps = self.generate( + generated, _, latent_steps = self._generate( batch.coarse.data, n_samples, static_inputs ) targets = filter_tensor_mapping(batch.fine.data, set(self.out_packer.names)) diff --git a/fme/downscaling/predictors/test_cascade.py b/fme/downscaling/predictors/test_cascade.py index 99511d576..a2f3f5cdb 100644 --- a/fme/downscaling/predictors/test_cascade.py +++ b/fme/downscaling/predictors/test_cascade.py @@ -98,7 +98,7 @@ def test_CascadePredictor_generate(downscale_factors): dtype=torch.float32, ) } - generated, _, _ = cascade_predictor.generate( + generated = cascade_predictor._generate( coarse=coarse_input, n_samples=n_samples_generate, static_inputs=static_inputs_list,