diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml new file mode 100644 index 000000000..c5fd76ac7 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-100-to-3km-prmsl-output.yaml @@ -0,0 +1,207 @@ +experiment_dir: /results +n_samples: 2 +patch: + divide_generation: true + composite_prediction: true + coarse_horizontal_overlap: 1 +model: + checkpoint_path: /checkpoints/best_histogram_tail.ckpt #ema_ckpt.tar #best.ckpt + model_updates: + num_diffusion_generation_steps: 18 + #sigma_max: 2000.0 + #sigma_min: 0.002 + #churn: 2.0 +data: + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start: 0 + stop: 2 + #start_time: '2023-01-01T00:00:00' + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + #start_time: '2023-01-01T00:00:00' + start: 0 + stop: 2 + # CONUS + lat_extent: + start: 0 + stop: 16 + # start: 22.0 + # stop: 50.0 + lon_extent: + #start: 230.0 + #stop: 295.0 + start: 0 + stop: 16 + batch_size: 2 + num_data_workers: 2 + strict_ensemble: false +logging: + log_to_screen: true + log_to_wandb: true + log_to_file: true + project: multivariate-downscaling + entity: ai2cm +events: +- name: NE_US_Quebec_20230206 + date: 2023-02-06T00:00 + lat_extent: + start: 36 + stop: 52 + lon_extent: + start: 283 + stop: 299 + save_generated_samples: true + n_samples: 16 +- name: WA_AR_20230101 + date: 2023-01-01T06:00 + lat_extent: + start: 36 + stop: 52 + lon_extent: + start: 228 + stop: 244 + save_generated_samples: true + n_samples: 16 +- name: WPacific_hurricane_20230425 + date: 2023-04-25T18:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 130.0 + stop: 146.0 + save_generated_samples: true + n_samples: 16 +- name: WPacific_hurricane_landfall_china_20230510 + date: 2023-05-10T12:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 104 + stop: 120 + save_generated_samples: true + n_samples: 16 +- name: extratropical_cyclone_US_20230403 + date: 2023-04-03T12:00 + lat_extent: + start: 34 + stop: 50 + lon_extent: + start: 254 + stop: 270 + save_generated_samples: true + n_samples: 16 +- name: santa_ana_winds_20231221 + date: 2023-12-21T06:00 + lat_extent: + start: 26 + stop: 42 + lon_extent: + start: 234 + stop: 250 + save_generated_samples: true + n_samples: 16 +- name: alpine_foehn_20230330 + date: 2023-03-30T18:00 + lat_extent: + start: 37 + stop: 53 + lon_extent: + start: 2 + stop: 18 + save_generated_samples: true + n_samples: 16 +- name: hindu_kush_20230122 + date: 2023-01-22T06:00 + lat_extent: + start: 28 + stop: 44 + lon_extent: + start: 60 + stop: 76 + save_generated_samples: true + n_samples: 16 +- name: WPac_tc_20230426T06 + date: 2023-04-26T06:00 + lat_extent: + start: 8 + stop: 24 + lon_extent: + start: 130 + stop: 146 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20230514T06 + date: 2023-05-14T06:00 + lat_extent: + start: 4 + stop: 20 + lon_extent: + start: 117 + stop: 133 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20230517T18 + date: 2023-05-17T18:00 + lat_extent: + start: 7 + stop: 23 + lon_extent: + start: 133 + stop: 149 + save_generated_samples: true + n_samples: 16 +- name: Taiwan_tc_landfall_20230707T18 + date: 2023-07-07T18:00 + lat_extent: + start: 14 + stop: 30 + lon_extent: + start: 115 + stop: 131 + save_generated_samples: true + n_samples: 16 +- name: Japan_tc_landfall_20230919T18 + date: 2023-09-19T18:00 + lat_extent: + start: 22 + stop: 38 + lon_extent: + start: 123 + stop: 139 + save_generated_samples: true + n_samples: 16 +- name: Phl_tc_landfall_20231027T00 + date: 2023-10-27T00:00 + lat_extent: + start: 8 + stop: 24 + lon_extent: + start: 115 + stop: 131 + save_generated_samples: true + n_samples: 16 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml new file mode 100644 index 000000000..e9a38af11 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval-global.yaml @@ -0,0 +1,50 @@ +experiment_dir: /results +n_samples: 2 +patch: + divide_generation: true + composite_prediction: true + coarse_horizontal_overlap: 1 +model: + checkpoint_path: /checkpoints/best.ckpt +data: + #topography: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + step: 29 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + step: 29 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + step: 29 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + step: 29 + lat_extent: + start: -66 + stop: 70.0 + batch_size: 6 + num_data_workers: 2 + strict_ensemble: false +logging: + log_to_screen: true + log_to_wandb: true + log_to_file: true + project: andrep-downscaling + entity: ai2cm diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh new file mode 100755 index 000000000..71afb3b46 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/eval.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +set -e + +#JOB_NAME="eval-xshield-amip-100km-to-3km-0.5sigmaexp-tropics-events" +JOB_NAME="eval-xshield-amip-100km-to-3km-loguni-multivariate-global" + +CONFIG_FILENAME="eval-global.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +N_NODES=1 +NGPU=4 + +IMAGE="$(cat latest_deps_only_image.txt)" + +EXISTING_RESULTS_DATASET=01KK07E72Z2H4CSP1EWXCT25FB +wandb_group="" + +#--not-preemptible \ + +gantry run \ + --name $JOB_NAME \ + --description 'Run 100km to 3km evaluation on coarsened X-SHiELD' \ + --workspace ai2/climate-titan \ + --priority urgent \ + --cluster ai2/jupiter \ + --cluster ai2/titan \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=inference \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --dataset $EXISTING_RESULTS_DATASET:checkpoints:/checkpoints \ + --weka climate-default:/climate-default \ + --gpus $NGPU \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $NGPU -m fme.downscaling.evaluator $CONFIG_PATH diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh new file mode 100755 index 000000000..7aa647448 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/resume.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# uses the augusta cluster which doesn't have weka access but has GCS access and is +# typically more available than cirrascale clusters + +set -e + +# recommended but not required to change this + +JOB_NAME="xshield-downscaling-100km-to-3km-0weight-prate-tropics-resume" +CONFIG_FILENAME="train-100-to-3km-prmsl-clamp-loss-weight.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME +wandb_group="" + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) +N_GPUS=8 # TODO: change to 8 after testing + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) + +PREVIOUS_RESULTS_DATASET="01KK9ZRC9M9MF7T45T4XAP7GF0" + + +gantry run \ + --name $JOB_NAME \ + --description 'Run downscaling 100km to 3km multivar training' \ + --workspace ai2/climate-titan \ + --priority low \ + --preemptible \ + --cluster ai2/titan \ + --cluster ai2/jupiter \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=training \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset $PREVIOUS_RESULTS_DATASET:/previous_results \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --weka climate-default:/climate-default \ + --gpus $N_GPUS \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $N_GPUS -m fme.downscaling.train $CONFIG_PATH \ No newline at end of file diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml new file mode 100644 index 000000000..81f5cea86 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-clamp-loss-weight.yaml @@ -0,0 +1,156 @@ +resume_results_dir: /previous_results +static_inputs: + HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr +loss_weights: + weights: + - PRATEsfc: 0.0 +max_loss_weight: 10.0 +model: + use_amp_bf16: true + out_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + in_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + loss: + type: MSE + module: + config: + model_channels: 128 + attn_resolutions: [] + num_blocks: 1 + channel_mult_emb: 6 + channel_mult: + - 1 + - 2 + - 2 + - 2 + - 2 + - 2 + - 2 + use_apex_gn: true + type: unet_diffusion_song_v2 + normalization: + coarse: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc + fine: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc + num_diffusion_generation_steps: 18 + churn: 0.0 + training_noise_distribution: + p_min: 0.01 + p_max: 200.0 + predict_residual: true + sigma_max: 200.0 + sigma_min: 0.01 + use_fine_topography: true +optimization: + lr: 0.0001 + optimizer_type: Adam +ema: + decay: 0.999 +validate_using_ema: true +train_data: + sample_with_replacement: 640 + batch_size: 80 # 10 per gpu + num_data_workers: 2 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + #lat_extent: + # start: -66.0 + # stop: 70.0 + lat_extent: + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 + strict_ensemble: false +validation_data: + batch_size: 48 + num_data_workers: 4 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + lat_extent: + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 + strict_ensemble: false + drop_last: true +coarse_patch_extent_lat: 16 +coarse_patch_extent_lon: 16 +max_epochs: 150 +validate_interval: 15 +experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 +save_checkpoints: false +logging: + project: multivariate-downscaling + entity: ai2cm + log_to_wandb: true +generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml new file mode 100644 index 000000000..f06db261e --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-output-loguni.yaml @@ -0,0 +1,155 @@ +resume_results_dir: /previous_results +static_inputs: + HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr +#loss_weights: +# weights: +# - PRATEsfc: 0.0 +model: + use_amp_bf16: true + out_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + in_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + loss: + type: MSE + module: + config: + model_channels: 128 + attn_resolutions: [] + num_blocks: 1 + channel_mult_emb: 6 + channel_mult: + - 1 + - 2 + - 2 + - 2 + - 2 + - 2 + - 2 + use_apex_gn: true + type: unet_diffusion_song_v2 + normalization: + coarse: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc + fine: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc + num_diffusion_generation_steps: 18 + churn: 0.0 + training_noise_distribution: + p_min: 0.005 + p_max: 2000.0 + predict_residual: true + sigma_max: 2000.0 + sigma_min: 0.005 + use_fine_topography: true +optimization: + lr: 0.0001 + optimizer_type: Adam +ema: + decay: 0.999 +validate_using_ema: true +train_data: + sample_with_replacement: 640 + batch_size: 80 # 10 per gpu + num_data_workers: 2 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + lat_extent: + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 + # lon_extent: + # start: 0 #230.0 + # stop: 16 #246.0 + strict_ensemble: false +validation_data: + batch_size: 48 + num_data_workers: 4 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + lat_extent: + start: 0 #-20 + stop: 35 #20.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 + strict_ensemble: false + drop_last: true +coarse_patch_extent_lat: 16 +coarse_patch_extent_lon: 16 +max_epochs: 600 +validate_interval: 15 +experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 +save_checkpoints: false +logging: + project: multivariate-downscaling + entity: ai2cm + log_to_wandb: true +generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml new file mode 100644 index 000000000..f22e6f758 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train-100-to-3km-prmsl-winds-only.yaml @@ -0,0 +1,157 @@ +#resume_results_dir: /previous_results +static_inputs: + HGTsfc: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/3km.zarr + land_fraction: /climate-default/2025-11-19-X-SHiELD-AMIP-downscaling-land-fraction/3km/land_fraction.zarr +#loss_weights: +# weights: +# - PRATEsfc: 0.0 +#max_loss_weight: 10.0 +#loss_weight_exponent: 0.75 +model: + use_amp_bf16: true + out_names: + #- PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + in_names: + - PRATEsfc + - eastward_wind_at_ten_meters + - northward_wind_at_ten_meters + - PRMSL + loss: + type: MSE + module: + config: + model_channels: 128 + attn_resolutions: [] + num_blocks: 1 + channel_mult_emb: 6 + channel_mult: + - 1 + - 2 + - 2 + - 2 + - 2 + - 2 + - 2 + use_apex_gn: true + type: unet_diffusion_song_v2 + normalization: + coarse: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/centering-pressfc-cp-to-prmsl.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/100km/scaling-full-field-pressfc-cp-to-prmsl.nc + fine: + global_means_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/centering-20260206.nc + global_stds_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling/stats/3km/scaling-full-field-20260206.nc + num_diffusion_generation_steps: 18 + churn: 0.0 + training_noise_distribution: + p_min: 0.01 + p_max: 200.0 + predict_residual: true + sigma_max: 200.0 + sigma_min: 0.01 + use_fine_topography: true +optimization: + lr: 0.0001 + optimizer_type: Adam +ema: + decay: 0.999 +validate_using_ema: true +train_data: + sample_with_replacement: 640 + batch_size: 80 # 10 per gpu + num_data_workers: 2 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2014-01-01T00:00:00' + stop_time: '2022-12-31T23:59:00' + lat_extent: + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 + strict_ensemble: false +validation_data: + batch_size: 48 + num_data_workers: 4 + fine: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: instantaneous_surface_and_sea_level_pressure_3km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + coarse: + - merge: + - data_path: /climate-default/2025-09-25-downscaling-data-X-SHiELD-AMIP-downscaling + file_pattern: 100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + - data_path: /climate-default/2026-02-03-downscaling-data-X-SHiELD-AMIP-pressure-data + file_pattern: pressfc_renamed_to_prmsl_100km.zarr + engine: zarr + subset: + start_time: '2023-01-01T00:00:00' + stop_time: '2024-01-01T00:00:00' + step: 9 + # lat_extent: + # start: -20.0 + # stop: 20.0 + lat_extent: + start: 0 #-66.0 + stop: 35 #70.0 + lon_extent: + start: 75 #230.0 + stop: 195 #246.0 + strict_ensemble: false + drop_last: true +coarse_patch_extent_lat: 16 +coarse_patch_extent_lon: 16 +max_epochs: 500 +validate_interval: 15 +experiment_dir: /results #/climate-default/home/annak/scratch/2026-02-10-downsc-add-pressfc/3km_bf16 +save_checkpoints: false +logging: + project: multivariate-downscaling + entity: ai2cm + log_to_wandb: true +generate_n_samples: 2 diff --git a/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh new file mode 100755 index 000000000..f2a4cfe16 --- /dev/null +++ b/configs/experiments/2026-02-10-downsc-add-pressfc/train.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# uses the augusta cluster which doesn't have weka access but has GCS access and is +# typically more available than cirrascale clusters + +set -e + +# recommended but not required to change this + +JOB_NAME="xshield-downscaling-100km-to-3km-winds-prmsl-only-tropics" +CONFIG_FILENAME="train-100-to-3km-prmsl-winds-only.yaml" + +SCRIPT_PATH=$(echo "$(git rev-parse --show-prefix)" | sed 's:/*$::') +CONFIG_PATH=$SCRIPT_PATH/$CONFIG_FILENAME +wandb_group="" + + # since we use a service account API key for wandb, we use the beaker username to set the wandb username +BEAKER_USERNAME=$(beaker account whoami --format=json | jq -r '.[0].name') +REPO_ROOT=$(git rev-parse --show-toplevel) +N_GPUS=8 # TODO: change to 8 after testing + +cd $REPO_ROOT # so config path is valid no matter where we are running this script + +IMAGE=$(cat $REPO_ROOT/latest_deps_only_image.txt) + +gantry run \ + --name $JOB_NAME \ + --description 'Run downscaling 100km to 3km multivar training' \ + --workspace ai2/climate-titan \ + --priority low \ + --preemptible \ + --cluster ai2/titan \ + --cluster ai2/jupiter \ + --beaker-image $IMAGE \ + --env WANDB_USERNAME=$BEAKER_USERNAME \ + --env WANDB_NAME=$JOB_NAME \ + --env WANDB_JOB_TYPE=training \ + --env WANDB_RUN_GROUP=$wandb_group \ + --env GOOGLE_APPLICATION_CREDENTIALS=/tmp/google_application_credentials.json \ + --env-secret WANDB_API_KEY=wandb-api-key-annak \ + --dataset-secret google-credentials:/tmp/google_application_credentials.json \ + --weka climate-default:/climate-default \ + --gpus $N_GPUS \ + --shared-memory 400GiB \ + --budget ai2/climate \ + --no-conda \ + --install "pip install --no-deps ." \ + --allow-dirty \ + -- torchrun --nproc_per_node $N_GPUS -m fme.downscaling.train $CONFIG_PATH \ No newline at end of file diff --git a/fme/downscaling/aggregators/main.py b/fme/downscaling/aggregators/main.py index 23b2a3ce9..9b23a8d6c 100644 --- a/fme/downscaling/aggregators/main.py +++ b/fme/downscaling/aggregators/main.py @@ -18,7 +18,7 @@ from fme.core.device import get_device from fme.core.distributed import Distributed from fme.core.histogram import ComparedDynamicHistograms -from fme.core.typing_ import TensorMapping +from fme.core.typing_ import TensorDict, TensorMapping from fme.core.wandb import WandB from fme.downscaling.aggregators.adapters import ComparedDynamicHistogramsAdapter from fme.downscaling.data import PairedBatchData @@ -50,6 +50,119 @@ def _tensor_mapping_to_numpy(data: TensorMapping) -> TensorMapping: return {k: v.cpu().numpy() for k, v in data.items()} +class LossVsNoiseAggregator: + """ + Aggregates binned diffusion losses as a function of sampled noise level. + """ + + def __init__( + self, + name: str = "metrics/loss_vs_noise", + n_bins: int = 40, + log10_sigma_min: float = -3.0, + log10_sigma_max: float = 3.0, + ) -> None: + if n_bins < 1: + raise ValueError("n_bins must be >= 1") + if log10_sigma_min >= log10_sigma_max: + raise ValueError("log10_sigma_min must be less than log10_sigma_max") + + self._name = ensure_trailing_slash(name) + self._n_bins = n_bins + edges = torch.linspace(log10_sigma_min, log10_sigma_max, n_bins + 1) + self._inner_edges = edges[1:-1] + self._bin_centers = ((edges[:-1] + edges[1:]) / 2).numpy() + + self._total_sum = torch.zeros(n_bins, dtype=torch.float64) + self._total_count = torch.zeros(n_bins, dtype=torch.int64) + self._channel_sum: dict[str, torch.Tensor] = {} + self._channel_count: dict[str, torch.Tensor] = {} + + def _accumulate( + self, values: torch.Tensor, bin_indices: torch.Tensor, name: str + ) -> None: + if name not in self._channel_sum: + self._channel_sum[name] = torch.zeros(self._n_bins, dtype=torch.float64) + self._channel_count[name] = torch.zeros(self._n_bins, dtype=torch.int64) + self._channel_sum[name].scatter_add_(0, bin_indices, values.to(torch.float64)) + self._channel_count[name].scatter_add_( + 0, bin_indices, torch.ones_like(bin_indices, dtype=torch.int64) + ) + + @torch.no_grad() + def record_batch(self, outputs: ModelOutputs) -> None: + if outputs.sigma is None or not outputs.per_sample_channel_loss: + return + + sigma = outputs.sigma.detach().flatten().cpu() + if torch.any(sigma <= 0): + raise ValueError("Sigma must be strictly positive for log10 binning") + log_sigma = torch.log10(sigma) + # Indices in [0, n_bins-1], with out-of-range values placed in edge bins. + bin_indices = torch.bucketize(log_sigma, self._inner_edges) + + per_channel: TensorDict = {} + for name, loss in outputs.per_sample_channel_loss.items(): + per_channel[name] = loss.detach().flatten().cpu() + if per_channel[name].shape != sigma.shape: + raise ValueError( + "Expected per-sample channel losses and sigma to share batch shape" + ) + + stacked = torch.stack([value for value in per_channel.values()], dim=-1) + total_loss = torch.mean(stacked, dim=-1).to(torch.float64) + self._total_sum.scatter_add_(0, bin_indices, total_loss) + self._total_count.scatter_add_( + 0, bin_indices, torch.ones_like(bin_indices, dtype=torch.int64) + ) + for name, values in per_channel.items(): + self._accumulate(values=values, bin_indices=bin_indices, name=name) + + def _plot_binned(self, y_values: np.ndarray, counts: np.ndarray, title: str) -> Any: + fig, ax = plt.subplots() + mask = counts > 0 + ax.plot(self._bin_centers[mask], y_values[mask], marker="o", linewidth=1.0) + ax.set_xlabel("log10(sigma)") + ax.set_ylabel("mean weighted loss") + ax.set_title(title) + ax.grid(True, alpha=0.3) + plt.close(fig) + return fig + + def get_wandb(self, prefix: str = "") -> Mapping[str, Any]: + prefix = ensure_trailing_slash(prefix) + if torch.sum(self._total_count) == 0: + return {} + + ret: dict[str, Any] = {} + total_count = self._total_count.numpy() + total_mean = np.divide( + self._total_sum.numpy(), + total_count, + out=np.zeros_like(self._total_sum.numpy()), + where=total_count > 0, + ) + ret[f"{prefix}{self._name}total"] = self._plot_binned( + y_values=total_mean, + counts=total_count, + title="Total weighted loss vs noise", + ) + for name in sorted(self._channel_sum): + count = self._channel_count[name].numpy() + mean = np.divide( + self._channel_sum[name].numpy(), + count, + out=np.zeros_like(self._channel_sum[name].numpy()), + where=count > 0, + ) + ret[f"{prefix}{self._name}{name}"] = self._plot_binned( + y_values=mean, + counts=count, + title=f"{name} weighted loss vs noise", + ) + return ret + + def _get_spectrum_metrics( gen_spectrum: Mapping[str, np.ndarray], target_spectrum: Mapping[str, np.ndarray], @@ -798,6 +911,7 @@ def __init__( self.loss = Mean(torch.mean) self.channel_loss = Mean(torch.mean) + self.loss_vs_noise = LossVsNoiseAggregator() self._fine_latlon_coordinates: LatLonCoordinates | None = None @torch.no_grad() @@ -841,6 +955,7 @@ def weighted_rmse(truth, pred): self.loss.record_batch({"loss": outputs.loss}) if outputs.channel_losses: self.channel_loss.record_batch(outputs.channel_losses) + self.loss_vs_noise.record_batch(outputs) def get_wandb( self, @@ -855,6 +970,7 @@ def get_wandb( ret.update(self.loss.get_wandb(prefix)) if self.channel_loss._count > 0: ret.update(self.channel_loss.get_wandb(f"{prefix}channel_loss/")) + ret.update(self.loss_vs_noise.get_wandb(prefix)) for comparison in self._comparisons: ret.update(comparison.get_wandb(prefix)) for coarse_comparison in self._coarse_comparisons: diff --git a/fme/downscaling/aggregators/test_aggregators.py b/fme/downscaling/aggregators/test_aggregators.py index f24fdf7ba..e6a301fb8 100644 --- a/fme/downscaling/aggregators/test_aggregators.py +++ b/fme/downscaling/aggregators/test_aggregators.py @@ -13,6 +13,7 @@ from ..models import ModelOutputs from .generation import GenerationAggregator from .main import ( + LossVsNoiseAggregator, Mean, MeanComparison, MeanMapAggregator, @@ -200,6 +201,49 @@ def test_map_aggregator(n_steps: int): aggregator.get_wandb() +@pytest.mark.parametrize("prefix", ["train", "validation"]) +def test_loss_vs_noise_aggregator_get_wandb(prefix: str): + aggregator = LossVsNoiseAggregator(n_bins=8) + outputs_a = ModelOutputs( + prediction={}, + target={}, + latent_steps=[], + loss=torch.tensor(0.0), + sigma=torch.tensor([0.1, 1.0]), + per_sample_channel_loss={ + "x": torch.tensor([1.0, 2.0]), + "y": torch.tensor([2.0, 4.0]), + }, + ) + outputs_b = ModelOutputs( + prediction={}, + target={}, + latent_steps=[], + loss=torch.tensor(0.0), + sigma=torch.tensor([10.0]), + per_sample_channel_loss={ + "x": torch.tensor([3.0]), + "y": torch.tensor([6.0]), + }, + ) + aggregator.record_batch(outputs_a) + aggregator.record_batch(outputs_b) + + # Binning happens in record_batch, not get_wandb. + assert int(aggregator._total_count.sum().item()) == 3 + assert int(aggregator._channel_count["x"].sum().item()) == 3 + assert int(aggregator._channel_count["y"].sum().item()) == 3 + + logs = aggregator.get_wandb(prefix=prefix) + assert set(logs.keys()) == { + f"{prefix}/metrics/loss_vs_noise/total", + f"{prefix}/metrics/loss_vs_noise/x", + f"{prefix}/metrics/loss_vs_noise/y", + } + for value in logs.values(): + assert hasattr(value, "savefig") + + @pytest.mark.parametrize("n_latent_steps", [0, 2]) def test_aggregator_integration(n_latent_steps, percentiles=[99.999]): downscale_factor = 2 diff --git a/fme/downscaling/models.py b/fme/downscaling/models.py index ca73c792e..03fa9fa8d 100644 --- a/fme/downscaling/models.py +++ b/fme/downscaling/models.py @@ -33,6 +33,8 @@ class ModelOutputs: loss: torch.Tensor latent_steps: list[torch.Tensor] = dataclasses.field(default_factory=list) channel_losses: TensorDict = dataclasses.field(default_factory=dict) + sigma: torch.Tensor | None = None + per_sample_channel_loss: TensorDict = dataclasses.field(default_factory=dict) def _rename_normalizer( @@ -352,6 +354,9 @@ def train_on_batch( batch: PairedBatchData, static_inputs: StaticInputs | None, optimizer: Optimization | NullOptimization, + loss_weights: torch.Tensor, + max_loss_weight: float | None = None, + loss_weight_exponent: float = 1.0, ) -> ModelOutputs: """Performs a denoising training step on a batch of data.""" coarse, fine = batch.coarse.data, batch.fine.data @@ -373,7 +378,11 @@ def train_on_batch( targets_norm = targets_norm - base_prediction conditioned_target = condition_with_noise_for_training( - targets_norm, self.config.noise_distribution, self.sigma_data + targets_norm, + self.config.noise_distribution, + self.sigma_data, + max_loss_weight=max_loss_weight, + loss_weight_exponent=loss_weight_exponent, ) denoised_norm = self.module( @@ -382,6 +391,7 @@ def train_on_batch( weighted_loss = conditioned_target.weight * self.loss( denoised_norm, targets_norm ) + weighted_loss = weighted_loss * loss_weights loss = torch.mean(weighted_loss) optimizer.accumulate_loss(loss) optimizer.step_weights() @@ -391,6 +401,11 @@ def train_on_batch( name: torch.mean(weighted_loss[:, i, :, :]) for i, name in enumerate(self.out_packer.names) } + per_sample_channel_loss = { + name: torch.mean(weighted_loss[:, i, :, :], dim=(-2, -1)).detach() + for i, name in enumerate(self.out_packer.names) + } + sigma = conditioned_target.sigma[:, 0, 0, 0].detach() if self.config.predict_residual: denoised_norm = denoised_norm + base_prediction @@ -404,11 +419,13 @@ def train_on_batch( target=target, loss=loss, channel_losses=channel_losses, + sigma=sigma, + per_sample_channel_loss=per_sample_channel_loss, latent_steps=[], ) @torch.no_grad() - def generate( + def _generate( self, coarse_data: TensorMapping, static_inputs: StaticInputs | None, @@ -466,7 +483,7 @@ def generate_on_batch_no_target( static_inputs: StaticInputs | None, n_samples: int = 1, ) -> TensorDict: - generated, _, _ = self.generate(batch.data, static_inputs, n_samples) + generated, _, _ = self._generate(batch.data, static_inputs, n_samples) return generated @torch.no_grad() @@ -477,7 +494,7 @@ def generate_on_batch( n_samples: int = 1, ) -> ModelOutputs: coarse, fine = batch.coarse.data, batch.fine.data - generated, generated_norm, latent_steps = self.generate( + generated, generated_norm, latent_steps = self._generate( coarse, static_inputs, n_samples ) diff --git a/fme/downscaling/noise.py b/fme/downscaling/noise.py index 48f5b5157..1355bee9b 100644 --- a/fme/downscaling/noise.py +++ b/fme/downscaling/noise.py @@ -60,6 +60,8 @@ def condition_with_noise_for_training( targets_norm: torch.Tensor, noise_distribution: NoiseDistribution, sigma_data: float, + max_loss_weight: float | None = None, + loss_weight_exponent: float = 1.0, ) -> ConditionedTarget: """ Condition the targets with noise for training. @@ -69,12 +71,23 @@ def condition_with_noise_for_training( noise_distribution: The noise distribution to use for conditioning. sigma_data: The standard deviation of the data, used to determine loss weighting. + max_loss_weight: Optional upper bound on the loss weight. Low sigma + values produce large weights; this clamps the maximum weight to + prevent those samples from dominating the loss. + loss_weight_exponent: Exponent applied to the base EDM loss weight + ``(sigma^2 + sigma_data^2) / (sigma * sigma_data)^2``. The default + of 1.0 gives the standard EDM weighting (~1/sigma^2 for small + sigma). Use 0.5 for ~1/sigma weighting (square root of EDM weight). Returns: The conditioned targets and the loss weighting. """ sigma = noise_distribution.sample(targets_norm.shape[0], targets_norm.device) - weight = (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + weight = ( + (sigma**2 + sigma_data**2) / (sigma * sigma_data) ** 2 + ) ** loss_weight_exponent + if max_loss_weight is not None: + weight = torch.clamp(weight, max=max_loss_weight) noise = randn_like(targets_norm) * sigma latents = targets_norm + noise return ConditionedTarget(latents=latents, sigma=sigma, weight=weight) diff --git a/fme/downscaling/predictors/cascade.py b/fme/downscaling/predictors/cascade.py index 0f64c73b5..3399c3db6 100644 --- a/fme/downscaling/predictors/cascade.py +++ b/fme/downscaling/predictors/cascade.py @@ -2,6 +2,7 @@ import math import torch +import xarray as xr from fme.core.coordinates import LatLonCoordinates from fme.core.device import get_device @@ -15,6 +16,7 @@ adjust_fine_coord_range, scale_tuple, ) +from fme.downscaling.data.utils import BatchedLatLonCoordinates from fme.downscaling.metrics_and_maths import filter_tensor_mapping from fme.downscaling.models import CheckpointModelConfig, DiffusionModel, ModelOutputs from fme.downscaling.requirements import DataRequirements @@ -86,6 +88,26 @@ def _restore_batch_and_sample_dims(data: TensorMapping, n_samples: int): return unfold_ensemble_dim(squeezed, n_samples) +def _batch_data_with_unused_coords(data: TensorMapping) -> BatchData: + # wrapper function so that we can call each level's + # public generate_on_batch_no_target function using tensormapping + # from the previous step. + data_shape = next(iter(data.values())).shape + time = xr.DataArray( + [0 for _ in range(data_shape[0])], + dims=["time"], + ) + latlon_coordinates = BatchedLatLonCoordinates( + lat=torch.zeros((data_shape[0], data_shape[1]), device=get_device()), + lon=torch.zeros((data_shape[0], data_shape[2]), device=get_device()), + ) + return BatchData( + data=data, + time=time, + latlon_coordinates=latlon_coordinates, + ) + + class CascadePredictor: def __init__( self, models: list[DiffusionModel], static_inputs: list[StaticInputs | None] @@ -116,22 +138,26 @@ def modules(self) -> torch.nn.ModuleList: return torch.nn.ModuleList([model.modules for model in self.models]) @torch.no_grad() - def generate( + def _generate( self, coarse: TensorMapping, n_samples: int, static_inputs: list[StaticInputs | None], ): current_coarse = coarse - for i, (model, fine_topography) in enumerate(zip(self.models, static_inputs)): + for i, (model, step_static_inputs) in enumerate( + zip(self.models, static_inputs) + ): sample_data = next(iter(current_coarse.values())) batch_size = sample_data.shape[0] # n_samples are generated for the first step, and subsequent models # generate 1 sample n_samples_cascade_step = n_samples if i == 0 else 1 - generated, generated_norm, latent_steps = model.generate( - current_coarse, fine_topography, n_samples_cascade_step + generated = model.generate_on_batch_no_target( + _batch_data_with_unused_coords(current_coarse), + step_static_inputs, + n_samples_cascade_step, ) generated = { k: v.reshape(batch_size * n_samples_cascade_step, *v.shape[-2:]) @@ -139,7 +165,7 @@ def generate( } current_coarse = generated generated = _restore_batch_and_sample_dims(generated, n_samples) - return generated, generated_norm, latent_steps + return generated @torch.no_grad() def generate_on_batch_no_target( @@ -151,7 +177,7 @@ def generate_on_batch_no_target( subset_static_inputs = self._get_subset_static_inputs( coarse_coords=batch.latlon_coordinates[0] ) - generated, _, _ = self.generate(batch.data, n_samples, subset_static_inputs) + generated = self._generate(batch.data, n_samples, subset_static_inputs) return generated @torch.no_grad() @@ -164,7 +190,7 @@ def generate_on_batch( static_inputs = self._get_subset_static_inputs( coarse_coords=batch.coarse.latlon_coordinates[0] ) - generated, _, latent_steps = self.generate( + generated, _, latent_steps = self._generate( batch.coarse.data, n_samples, static_inputs ) targets = filter_tensor_mapping(batch.fine.data, set(self.out_packer.names)) diff --git a/fme/downscaling/predictors/test_cascade.py b/fme/downscaling/predictors/test_cascade.py index 99511d576..a2f3f5cdb 100644 --- a/fme/downscaling/predictors/test_cascade.py +++ b/fme/downscaling/predictors/test_cascade.py @@ -98,7 +98,7 @@ def test_CascadePredictor_generate(downscale_factors): dtype=torch.float32, ) } - generated, _, _ = cascade_predictor.generate( + generated = cascade_predictor._generate( coarse=coarse_input, n_samples=n_samples_generate, static_inputs=static_inputs_list, diff --git a/fme/downscaling/test_models.py b/fme/downscaling/test_models.py index adf015e10..4efcbd38a 100644 --- a/fme/downscaling/test_models.py +++ b/fme/downscaling/test_models.py @@ -224,7 +224,10 @@ def test_diffusion_model_train_and_generate(predict_residual, use_fine_topograph [batch_size, *coarse_shape], [batch_size, *fine_shape] ) optimization = OptimizationConfig().build(modules=[model.module], max_epochs=2) - train_outputs = model.train_on_batch(batch, static_inputs, optimization) + loss_weights = torch.ones(1, len(model.out_packer.names), 1, 1, device=get_device()) + train_outputs = model.train_on_batch( + batch, static_inputs, optimization, loss_weights=loss_weights + ) assert torch.allclose(train_outputs.target["x"], batch.fine.data["x"]) n_generated_samples = 2 diff --git a/fme/downscaling/test_train.py b/fme/downscaling/test_train.py index 5ed89d36d..ac31491e1 100644 --- a/fme/downscaling/test_train.py +++ b/fme/downscaling/test_train.py @@ -181,6 +181,26 @@ def test_train_main_only( main(config_path=config_path) +def test_train_main_with_loss_weights( + default_trainer_config, tmp_path, very_fast_only: bool +): + """Check that training loop runs with per-variable loss weighting.""" + if very_fast_only: + pytest.skip("Skipping non-fast tests") + + config = _update_in_out_names( + default_trainer_config, ["var0", "var1"], ["var0", "var1"] + ) + config["max_epochs"] = 1 + config["loss_weights"] = {"weights": [{"var0": 2.0}, {"var1": 0.5}]} + config_path = _store_config( + tmp_path, config, filename="train-config-loss-weights.yaml" + ) + + with mock_wandb(): + main(config_path=config_path) + + def test_train_main_logs(default_trainer_config, tmp_path, very_fast_only: bool): """Check that training loop records the appropriate logs.""" if very_fast_only: diff --git a/fme/downscaling/train.py b/fme/downscaling/train.py index 6c5a23881..0331c15c5 100755 --- a/fme/downscaling/train.py +++ b/fme/downscaling/train.py @@ -87,6 +87,20 @@ def restore_checkpoint(trainer: "Trainer") -> None: trainer.ema = EMATracker.from_state(ema_checkpoint["ema"], ema_model.modules) +@dataclasses.dataclass +class LossWeights: + weights: list[dict[str, float]] + + def get_weight_tensor( + self, variable_names: list[str], device: torch.device + ) -> torch.Tensor: + weight_map = {} + for mapping in self.weights: + weight_map.update(mapping) + weights = [weight_map.get(name, 1.0) for name in variable_names] + return torch.tensor(weights, device=device).reshape(1, -1, 1, 1) + + class Trainer: def __init__( self, @@ -108,6 +122,15 @@ def __init__( wandb.watch(self.model.modules) self.num_batches_seen = 0 self.config = config + if config.loss_weights is None: + self.loss_weight_tensor = torch.ones( + 1, len(self.model.out_packer.names), 1, 1, device=get_device() + ) + else: + self.loss_weight_tensor = config.loss_weights.get_weight_tensor( + variable_names=self.model.out_packer.names, + device=get_device(), + ) self.patch_data = ( True if (config.coarse_patch_extent_lat and config.coarse_patch_extent_lon) @@ -187,7 +210,14 @@ def train_one_epoch(self) -> None: self.num_batches_seen += 1 if i % 10 == 0: logging.info(f"Training on batch {i+1}") - outputs = self.model.train_on_batch(batch, static_inputs, self.optimization) + outputs = self.model.train_on_batch( + batch, + static_inputs, + self.optimization, + loss_weights=self.loss_weight_tensor, + max_loss_weight=self.config.max_loss_weight, + loss_weight_exponent=self.config.loss_weight_exponent, + ) self.ema(self.model.modules) with torch.no_grad(): train_aggregator.record_batch( @@ -261,7 +291,12 @@ def valid_one_epoch(self) -> dict[str, float]: ) for batch, static_inputs in validation_batch_generator: outputs = self.model.train_on_batch( - batch, static_inputs, self.null_optimization + batch, + static_inputs, + self.null_optimization, + loss_weights=self.loss_weight_tensor, + max_loss_weight=self.config.max_loss_weight, + loss_weight_exponent=self.config.loss_weight_exponent, ) validation_aggregator.record_batch( outputs=outputs, @@ -405,6 +440,9 @@ class TrainerConfig: experiment_dir: str save_checkpoints: bool logging: LoggingConfig + loss_weights: LossWeights | None = None + max_loss_weight: float | None = None + loss_weight_exponent: float = 1.0 static_inputs: dict[str, str] | None = None ema: EMAConfig = dataclasses.field(default_factory=EMAConfig) validate_using_ema: bool = False diff --git a/scripts/downscaling/plot_events.py b/scripts/downscaling/plot_events.py new file mode 100644 index 000000000..a1ca6711d --- /dev/null +++ b/scripts/downscaling/plot_events.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python +""" +Fetch netCDF event files from a beaker dataset and generate map and histogram +plots for each variable (coarse, target, and predicted ensemble samples). + +This works with saved event outputs from `fme.downscaling.evaluator` from a +beaker experiment. It downloads the experiment files to a temporary directory, +parses filenames for *YYYYMMDD*.nc event outputs, and merges in +coarse data for map comparison. If no local directory is provided for the coarse data, +it will be read from a hard coded GCS path. + +Usage: + python plot_events.py [--output-dir ] + [--coarse-data ] [--variables VAR1 VAR2 ...] + +Requires: + beaker CLI to be installed and authenticated (https://github.com/allenai/beaker). +""" + +import argparse +import math +import re +import subprocess +import tempfile +import warnings +from pathlib import Path + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import cftime +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from cartopy.feature import ShapelyFeature +from cartopy.io import shapereader +from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER + +warnings.filterwarnings("ignore") + +from plot_beaker_histograms import plot_histogram_lines + +TIME_SEL = slice(cftime.DatetimeJulian(2023, 1, 1), None) +# Matching for *YYYYMMDD*.nc (date can appear anywhere in the filename) +_EVENT_FILE_RE = re.compile(r"(.+?)[\._-]?(\d{8})[\._-]?(.*)\.nc$") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Generate map plots from beaker dataset event files" + ) + parser.add_argument( + "beaker_dataset_id", + help="The beaker dataset ID to fetch", + ) + parser.add_argument( + "--output-dir", + default="./event_maps", + help="Output directory for figures (default: ./event_maps)", + ) + parser.add_argument( + "--coarse-data", + default=None, + help="Path to coarse data (default: None)", + ) + parser.add_argument( + "--variables", + nargs="*", + default=None, + help="Filter to only these variables (default: all eligible variables)", + ) + return parser.parse_args() + + +# Create a STATES feature with no fill +# Read state geometries from shapefile +shpfilename = shapereader.natural_earth( + resolution="50m", category="cultural", name="admin_1_states_provinces" +) +reader = shapereader.Reader(shpfilename) +states_feature = ShapelyFeature( + reader.geometries(), + ccrs.PlateCarree(), + facecolor="none", + edgecolor="lightgrey", + linewidth=0.35, + linestyle="--", +) + + +def add_outer_latlon_grid(ax, *, show_left, show_bottom): + gl = ax.gridlines( + draw_labels=True, + linewidth=0.5, + color="gray", + alpha=0.5, + linestyle="--", + ) + + # Explicitly disable all labels first + gl.top_labels = False + gl.right_labels = False + gl.left_labels = False + gl.bottom_labels = False + + # Enable only outer ones + if show_left: + gl.left_labels = True + if show_bottom: + gl.bottom_labels = True + + gl.xlabel_style = {"size": 8} + gl.ylabel_style = {"size": 8} + gl.xformatter = LONGITUDE_FORMATTER + gl.yformatter = LATITUDE_FORMATTER + + # square grid cells + ax.set_aspect("equal", adjustable="box") + + +def get_coarse_data(path: str | None, time_sel: slice | None = TIME_SEL) -> xr.Dataset: + if path is not None: + return xr.open_zarr(path) + else: + winds = xr.open_zarr( + "gs://vcm-ml-raw-flexible-retention/2025-07-25-X-SHiELD-AMIP-FME/regridded-zarrs/gaussian_grid_180_by_360/control/instantaneous_physics_fields.zarr" + ).sel(time=time_sel)[ + ["eastward_wind_at_ten_meters", "northward_wind_at_ten_meters"] + ] + prate = xr.open_zarr( + "gs://vcm-ml-raw-flexible-retention/2025-07-25-X-SHiELD-AMIP-FME/regridded-zarrs/gaussian_grid_180_by_360/control/fluxes_2d.zarr" + ).sel(time=time_sel)["PRATEsfc"] + pres = xr.open_zarr( + "gs://vcm-ml-raw-flexible-retention/2025-07-25-X-SHiELD-AMIP-FME/regridded-zarrs/gaussian_grid_180_by_360/control/column_integrated_dynamical_fields.zarr" + ).sel(time=time_sel)["PRESsfc"] + # in training, PRESsfc is used as input for outputting PRMSL + prmsl = pres.rename("PRMSL") + return xr.merge([winds, prate, pres, prmsl]) + + +def bbox(lat, lon, width=2.0): + return { + "lat": slice(lat - width / 2.0, lat + width / 2.0), + "lon": slice(lon - width / 2.0, lon + width / 2.0), + } + + +def upsample_array(x: np.ndarray, upsample_factor: int = 32) -> np.ndarray: + # upsample coarse data for plotting with fine res + x = np.repeat(x, upsample_factor, axis=0) # repeat rows + x = np.repeat(x, upsample_factor, axis=1) # repeat columns + return x + + +def plot_event(ds, var_name, samples=None, sel=None, n_cols=5, **plot_kwargs): + if samples is None: + samples = list(range(ds.sample.size)) + N = 2 + len(samples) + n_rows = math.ceil(N / n_cols) + + fig, axes = plt.subplots( + n_rows, + n_cols, + figsize=(2 * n_cols, 2 * n_rows), + subplot_kw={"projection": ccrs.PlateCarree()}, + ) + + axes = axes.ravel() # 1D array, easy to index + # Use only the first N axes + for ax in axes[N:]: + ax.set_visible(False) + + suffixes = ["coarse", "target", "predicted"] + vars = [f"{var_name}_{suffix}" for suffix in suffixes] + ds_ = ds[vars] + + if sel: + ds_ = ds_.sel(sel) + + if len(samples) == 0: + samples = [0] + + if var_name == "PRMSL": + # fill PRMSL_coarse with nans + ds_["PRMSL_coarse"].values[:] = np.nan + + vmax = ds_.to_array().max() + if ds_.to_array().min() < -0.2: + plot_kwargs["cmap"] = "RdBu_r" + else: + plot_kwargs["cmap"] = "turbo" + plot_kwargs["vmin"] = min(0, ds_.to_array().min()) + if "vmax" not in plot_kwargs: + plot_kwargs["vmax"] = vmax + # coarse and target + for i, var in enumerate(vars[:2]): + ax = axes[i] + + da = ds_[var] + img = da.plot(ax=ax, add_colorbar=False, **plot_kwargs) + + ax.set_title(suffixes[i], fontsize=10) + ax.add_feature(states_feature) + ax.add_feature(cfeature.BORDERS, color="lightgrey") + ax.coastlines(color="lightgrey") + + row = i // n_cols + col = i % n_cols + + add_outer_latlon_grid( + ax, + show_left=(col == 0), + show_bottom=(row == n_rows - 1), + ) + + for i, s in enumerate(samples): + ax = axes[2 + i] + da = ds_[vars[-1]].isel(sample=s) + + img = da.plot(ax=ax, add_colorbar=False, **plot_kwargs) + ax.set_title(f"predicted {s}", fontsize=10) + ax.add_feature(states_feature) + ax.coastlines(color="lightgrey") + ax.add_feature(cfeature.BORDERS, linestyle="-", color="lightgrey") + row = (i + 2) // n_cols + col = (i + 2) % n_cols + add_outer_latlon_grid( + ax, + show_left=(col == 0), + show_bottom=(row == n_rows - 1), + ) + cbar_ax = fig.add_axes([0.99, 0.25, 0.01, 0.5]) # [left, bottom, width, height] + cbar = fig.colorbar(img, cax=cbar_ax) + cbar.set_label(f"{var_name} [m/s]") + # plt.tight_layout() + + return fig, axes + + +def fetch_beaker_dataset(dataset_id: str, target_dir: str) -> None: + """Fetch a beaker dataset to the specified directory.""" + subprocess.run( + ["beaker", "dataset", "fetch", dataset_id, "--output", target_dir], + check=True, + ) + + +def find_event_files(directory: str) -> dict[str, Path]: + """Find netCDF files matching the event naming pattern, keyed by event name.""" + event_files = {} + for p in sorted(Path(directory).glob("*.nc")): + # extract event name + matched = _EVENT_FILE_RE.match(p.name) + if matched: + prefix, date, suffix = matched.group(1), matched.group(2), matched.group(3) + parts = [s for s in (prefix, suffix) if s] + event_name = f"{'_'.join(parts)}_{date}" + event_files[event_name] = p + return event_files + + +def detect_variable_pairs(ds: xr.Dataset) -> list[str]: + """Detect variables that have both _predicted and _target versions.""" + predicted = { + v[: -len("_predicted")] for v in ds.data_vars if v.endswith("_predicted") + } + target = {v[: -len("_target")] for v in ds.data_vars if v.endswith("_target")} + return sorted(predicted & target) + + +def filename_to_datetime(filename: str) -> cftime.DatetimeJulian: + match = re.search(r"(\d{4})(\d{2})(\d{2})(?:T(\d{2}))?", filename) + if match is None: + raise ValueError(f"Could not parse date from filename: {filename}") + return cftime.DatetimeJulian( + int(match.group(1)), + int(match.group(2)), + int(match.group(3)), + int(match.group(4) or 12), + ) + + +def add_wind_speed(ds: xr.Dataset) -> xr.Dataset: + variables = detect_variable_pairs(ds) + if ( + "eastward_wind_at_ten_meters" in variables + and "northward_wind_at_ten_meters" in variables + ): + ds["wind_speed_target"] = np.sqrt( + ds.eastward_wind_at_ten_meters_target**2 + + ds.northward_wind_at_ten_meters_target**2 + ) + ds["wind_speed_predicted"] = np.sqrt( + ds.eastward_wind_at_ten_meters_predicted**2 + + ds.northward_wind_at_ten_meters_predicted**2 + ) + ds["wind_speed_coarse"] = np.sqrt( + ds.eastward_wind_at_ten_meters_coarse**2 + + ds.northward_wind_at_ten_meters_coarse**2 + ) + return ds + + +def merge_coarse( + event: xr.Dataset, coarse: xr.Dataset, datetime: cftime.DatetimeJulian +) -> xr.Dataset: + _coarse = coarse.sel( + time=datetime, + grid_yt=slice(event.lat.min(), event.lat.max()), + grid_xt=slice(event.lon.min(), event.lon.max()), + ) + for var in detect_variable_pairs(event): + event[f"{var}_coarse"] = xr.DataArray( + upsample_array(_coarse[var].values, 32), dims=["lat", "lon"] + ) + return event + + +def main(): + args = parse_args() + beaker_id = args.beaker_dataset_id + output_dir = Path(args.output_dir) + coarse = get_coarse_data(args.coarse_data, time_sel=TIME_SEL) + + print(f"Fetching beaker dataset: {beaker_id}") + + with tempfile.TemporaryDirectory() as temp_dir: + fetch_beaker_dataset(beaker_id, temp_dir) + + event_files = find_event_files(temp_dir) + if not event_files: + print(f"No event files found in dataset {beaker_id}") + return + + print(f"Found {len(event_files)} event file(s)") + + for event_name, nc_file in event_files.items(): + output_event_dir = output_dir / beaker_id / event_name + output_event_dir.mkdir(parents=True, exist_ok=True) + + print(f"Processing: {nc_file.name} -> {output_event_dir}") + + event = xr.open_dataset(nc_file) + event = merge_coarse( + event, coarse, datetime=filename_to_datetime(nc_file.name) + ) + event = add_wind_speed(event) + variables = detect_variable_pairs(event) + if args.variables is not None: + variables = [v for v in variables if v in args.variables] + + if not variables: + print(f" No variable pairs found in {nc_file.name}") + continue + for var in variables: + fig, axes = plot_event(event, var) + fig.savefig( + output_event_dir / f"{var}_generated_maps.png", + transparent=True, + dpi=300, + bbox_inches="tight", + ) + plt.close(fig) + plot_histogram_lines( + event, + var, + event_name, + save_path=output_event_dir / f"{var}_histogram.png", + ) + event.close() + + print("Done!") + + +if __name__ == "__main__": + main()