diff --git a/config/diff_config_forecast.yml b/config/diff_config_forecast.yml new file mode 100644 index 000000000..ebd8bebf1 --- /dev/null +++ b/config/diff_config_forecast.yml @@ -0,0 +1,159 @@ +streams_directory: "./config/streams/era5_1deg/" + +embed_orientation: "channels" +embed_local_coords: True +embed_centroids_local_coords: False +embed_size_centroids: 0 +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +target_cell_local_prediction: True + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 256 +ae_global_num_blocks: 8 +ae_global_num_heads: 16 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +ae_global_att_dense_rate: 0.2 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +forecast_offset : 0 +forecast_delta_hrs: 0 +forecast_steps: 1 +forecast_policy: "diffusion" +forecast_freeze_model: False +forecast_att_dense_rate: 0.2 +fe_global_block_factor: 32 +fe_local_num_queries: 1 +fe_num_blocks: 8 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diff_sigma_min: 0.02 +fe_diff_sigma_max: 88 +fe_diff_sigma_data: 1 #I think in gencast this is hard coded to 1... + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +loss_fcts_lat: + - + - "mse" + - 1.0 +loss_fcts_val: + - + - "mse" + - 1.0 + +batch_size_per_gpu: 1 +batch_size_validation_per_gpu: 1 + +# a regex that needs to fully match the name of the modules you want to freeze +# e.g. ".*ERA5" will match any module whose name ends in ERA5\ +# encoders and decoders that exist per stream have the stream name attached at the end +freeze_modules: "" + +# training mode: "forecast" or "masking" (masked token modeling) +# for "masking" to train with auto-encoder mode, forecast_offset should be 0 +training_mode: "forecast" +# masking rate when training mode is "masking"; ignored in foreacast mode +masking_rate: 0.6 +# sample the masking rate (with normal distribution centered at masking_rate) +# note that a sampled masking rate leads to varying requirements +masking_rate_sampling: True +# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream) +sampling_rate_target: 1.0 +# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" +masking_strategy: "random" +# masking_strategy_config is a dictionary of additional parameters for the masking strategy +# required for "healpix" and "channel" masking strategies +# "healpix": requires healpix mask level to be specified with `hl_mask` +# "channel": requires "mode" to be specified, "per_cell" or "global", +masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + "probabilities": [0.34, 0.33, 0.33], + "hl_mask": 3, "mode": "per_cell", + "same_strategy_per_batch": false + } + +num_epochs: 32 +samples_per_epoch: 4096 +samples_per_validation: 512 +shuffle: True + +lr_scaling_policy: "sqrt" +lr_start: 1e-6 +lr_max: 5e-5 +lr_final_decay: 1e-6 +lr_final: 0.0 +lr_steps_warmup: 512 +lr_steps_cooldown: 512 +lr_policy_warmup: "cosine" +lr_policy_decay: "linear" +lr_policy_cooldown: "linear" + +grad_clip: 1.0 +weight_decay: 0.1 +norm_type: "LayerNorm" +nn_module: "te" + +start_date: 197901010000 +end_date: 202012310000 +start_date_val: 202101010000 +end_date_val: 202201010000 +len_hrs: 6 +step_hrs: 6 +input_window_steps: 1 + +val_initial: False + +loader_num_workers: 8 +log_validation: 0 +analysis_streams_output: ["ERA5"] + +istep: 0 +run_history: [] + +desc: "" +data_loader_rng_seed: ??? +run_id: ??? + +# Parameters for logging/printing in the training loop +train_log: + # The period to log metrics (in number of batch steps) + log_interval: 20 diff --git a/config/diff_config_old.yml b/config/diff_config_old.yml new file mode 100644 index 000000000..cd6eb79da --- /dev/null +++ b/config/diff_config_old.yml @@ -0,0 +1,152 @@ +streams_directory: "./config/streams/era5_1deg/" + +embed_orientation: "channels" +embed_local_coords: True +embed_centroids_local_coords: False +embed_size_centroids: 0 +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +target_cell_local_prediction: True + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 256 +ae_global_num_blocks: 8 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +ae_global_att_dense_rate: 0.2 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +forecast_offset : 0 +forecast_delta_hrs: 0 +forecast_steps: 1 +forecast_policy: "diffusion" +forecast_freeze_model: False +forecast_att_dense_rate: 1.0 +fe_num_blocks: 2 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diff_sigma_min: 0.02 +fe_diff_sigma_max: 88 +fe_diff_sigma_data: 1 #I think in gencast this is hard coded to 1... + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +loss_fcts_lat: + - + - "mse" + - 1.0 + +loss_fcts: + - + - "mse" + - 1.0 +loss_fcts_val: + - + - "mse" + - 1.0 + +batch_size_per_gpu: 1 +batch_size_validation_per_gpu: 1 + +# a regex that needs to fully match the name of the modules you want to freeze +# e.g. ".*ERA5" will match any module whose name ends in ERA5\ +# encoders and decoders that exist per stream have the stream name attached at the end +freeze_modules: "" + +# training mode: "forecast" or "masking" (masked token modeling) +# for "masking" to train with auto-encoder mode, forecast_offset should be 0 +training_mode: "forecast" +# masking rate when training mode is "masking"; ignored in foreacast mode +masking_rate: 0.6 +# sample the masking rate (with normal distribution centered at masking_rate) +# note that a sampled masking rate leads to varying requirements +masking_rate_sampling: True +# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream) +sampling_rate_target: 1.0 +# include a masking strategy here, currently only supporting "random", "block", "healpix" and "channel" +masking_strategy: "random" +# masking_strategy_config is a dictionary of additional parameters for the masking strategy +# required for "healpix" and "channel" masking strategies +# "healpix": requires healpix mask level to be specified with `hl_mask` +# "channel": requires "mode" to be specified, "per_cell" or "global", +masking_strategy_config: {"hl_mask": 3} + +num_epochs: 32 +samples_per_epoch: 4096 +samples_per_validation: 512 +shuffle: True + +lr_scaling_policy: "sqrt" +lr_start: 1e-6 +lr_max: 5e-5 +lr_final_decay: 1e-6 +lr_final: 0.0 +lr_steps_warmup: 512 +lr_steps_cooldown: 512 +lr_policy_warmup: "cosine" +lr_policy_decay: "linear" +lr_policy_cooldown: "linear" + +grad_clip: 1.0 +weight_decay: 0.1 +norm_type: "LayerNorm" +nn_module: "te" + +start_date: 197901010000 +end_date: 202012310000 +start_date_val: 202101010000 +end_date_val: 202201010000 +len_hrs: 6 +step_hrs: 6 +input_window_steps: 1 + +val_initial: False + +loader_num_workers: 8 +log_validation: 0 +analysis_streams_output: ["ERA5"] + +istep: 0 +run_history: [] + +desc: "" +data_loader_rng_seed: ??? +run_id: ??? + +# Parameters for logging/printing in the training loop +train_log: + # The period to log metrics (in number of batch steps) + log_interval: 20 diff --git a/config/diff_config_pretrain.yml b/config/diff_config_pretrain.yml new file mode 100644 index 000000000..ea2f8c1ef --- /dev/null +++ b/config/diff_config_pretrain.yml @@ -0,0 +1,159 @@ +streams_directory: "./config/streams/era5_1deg/" + +embed_orientation: "channels" +embed_local_coords: True +embed_centroids_local_coords: False +embed_size_centroids: 0 +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +target_cell_local_prediction: True + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 256 +ae_global_num_blocks: 8 +ae_global_num_heads: 16 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +ae_global_att_dense_rate: 0.2 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +forecast_offset : 0 +forecast_delta_hrs: 0 +forecast_steps: 0 +forecast_policy: null +forecast_freeze_model: False +forecast_att_dense_rate: 1.0 +fe_num_blocks: 0 +fe_num_heads: 16 +fe_global_block_factor: 32 +fe_local_num_queries: 1 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_diff_sigma_min: 0.02 +fe_diff_sigma_max: 88 +fe_diff_sigma_data: 1 #I think in gencast this is hard coded to 1... + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: True +attention_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +loss_fcts: + - + - "mse" + - 1.0 +loss_fcts_val: + - + - "mse" + - 1.0 + +batch_size_per_gpu: 1 +batch_size_validation_per_gpu: 1 + +# a regex that needs to fully match the name of the modules you want to freeze +# e.g. ".*ERA5" will match any module whose name ends in ERA5\ +# encoders and decoders that exist per stream have the stream name attached at the end +freeze_modules: "" + +# training mode: "forecast" or "masking" (masked token modeling) +# for "masking" to train with auto-encoder mode, forecast_offset should be 0 +training_mode: "masking" +# masking rate when training mode is "masking"; ignored in foreacast mode +masking_rate: 0.6 +# sample the masking rate (with normal distribution centered at masking_rate) +# note that a sampled masking rate leads to varying requirements +masking_rate_sampling: True +# sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream) +sampling_rate_target: 1.0 +# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" +masking_strategy: "random" +# masking_strategy_config is a dictionary of additional parameters for the masking strategy +# required for "healpix" and "channel" masking strategies +# "healpix": requires healpix mask level to be specified with `hl_mask` +# "channel": requires "mode" to be specified, "per_cell" or "global", +masking_strategy_config: {"strategies": ["random", "healpix", "channel"], + "probabilities": [0.34, 0.33, 0.33], + "hl_mask": 3, "mode": "per_cell", + "same_strategy_per_batch": false + } + +num_epochs: 32 +samples_per_epoch: 4096 +samples_per_validation: 512 +shuffle: True + +lr_scaling_policy: "sqrt" +lr_start: 1e-6 +lr_max: 5e-5 +lr_final_decay: 1e-6 +lr_final: 0.0 +lr_steps_warmup: 512 +lr_steps_cooldown: 512 +lr_policy_warmup: "cosine" +lr_policy_decay: "linear" +lr_policy_cooldown: "linear" + +grad_clip: 1.0 +weight_decay: 0.1 +norm_type: "LayerNorm" +nn_module: "te" + +start_date: 197901010000 +end_date: 202012310000 +start_date_val: 202101010000 +end_date_val: 202201010000 +len_hrs: 6 +step_hrs: 6 +input_window_steps: 1 + +val_initial: False + +loader_num_workers: 8 +log_validation: 0 +analysis_streams_output: ["ERA5"] + +istep: 0 +run_history: [] + +desc: "" +data_loader_rng_seed: ??? +run_id: ??? + +# Parameters for logging/printing in the training loop +train_log: + # The period to log metrics (in number of batch steps) + log_interval: 20 diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 5561ef0c6..5c91448d1 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -34,4 +34,4 @@ ERA5 : # sampling_rate : 0.2 pred_head : ens_size : 1 - num_layers : 1 \ No newline at end of file + num_layers : 1 diff --git a/config/streams/streams_anemoi/era5.yml b/config/streams/streams_anemoi/era5.yml new file mode 100644 index 000000000..382783a72 --- /dev/null +++ b/config/streams/streams_anemoi/era5.yml @@ -0,0 +1,39 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # source : ['u_', 'v_', '10u', '10v'] + # target : ['10u', '10v'] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + loss_weight : 1. + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 256 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 256 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 \ No newline at end of file diff --git a/src/weathergen/datasets/data_reader_base.py b/src/weathergen/datasets/data_reader_base.py index 8e01b189f..9e843a104 100644 --- a/src/weathergen/datasets/data_reader_base.py +++ b/src/weathergen/datasets/data_reader_base.py @@ -492,11 +492,18 @@ def normalize_source_channels(self, source: NDArray[DType]) -> NDArray[DType]: ------- Normalized data """ - assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" - for i, ch in enumerate(self.source_idx): - source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch] - - return source + #TODO: KCT, do this properly, otherwise very dangerous + # assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" + if not (source.shape[-1] == len(self.source_idx)) and (source.shape[-1] == len(self.target_idx)): + # if the source is actually being called as a target (check by looking at the no of channels) + for i, ch in enumerate(self.target_idx): + source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch] + return source + else: + for i, ch in enumerate(self.source_idx): + source[..., i] = (source[..., i] - self.mean[ch]) / self.stdev[ch] + + return source def normalize_target_channels(self, target: NDArray[DType]) -> NDArray[DType]: """ diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index c41916f9d..d527fda10 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -31,6 +31,7 @@ from weathergen.datasets.utils import ( compute_idxs_predict, compute_offsets_scatter_embed, + compute_offsets_scatter_embed_target_srclk, compute_source_cell_lens, ) from weathergen.utils.logger import logger @@ -256,7 +257,7 @@ def reset(self): len_dt_samples = len(self) // self.batch_size if self.forecast_policy is None: self.perms_forecast_dt = np.zeros(len_dt_samples, dtype=np.int64) - elif self.forecast_policy == "fixed" or self.forecast_policy == "sequential": + elif self.forecast_policy == "fixed" or self.forecast_policy == "sequential" or self.forecast_policy == "diffusion": self.perms_forecast_dt = fsm * np.ones(len_dt_samples, dtype=np.int64) elif self.forecast_policy == "random" or self.forecast_policy == "sequential_random": # randint high=one-past @@ -367,6 +368,7 @@ def __iter__(self): if rdata.is_empty(): stream_data.add_empty_target(fstep) + stream_data.add_empty_target_srclk(fstep) else: (tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target( stream_info, @@ -379,7 +381,29 @@ def __iter__(self): ds, ) + target_raw_srclk = torch.from_numpy( + np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) + ) + (tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = ( + self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function + stream_info, + torch.from_numpy(rdata.coords), + torch.from_numpy(rdata.geoinfos), + torch.from_numpy(rdata.data), + rdata.datetimes, + (time_win2.start, time_win2.end), + ds, + ) + ) + stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t) + stream_data.add_target_srclk( + fstep, + target_raw_srclk, + tt_lens_srclk, + tt_cells_srclk, + tt_centroids_srclk, + ) # merge inputs for sources and targets for current stream stream_data.merge_inputs() @@ -398,6 +422,7 @@ def __iter__(self): # compute offsets for scatter computation after embedding batch = compute_offsets_scatter_embed(batch) + batch = compute_offsets_scatter_embed_target_srclk(batch) # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index a5f12327e..a074e7eb5 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -66,6 +66,16 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i self.source_idxs_embed = torch.tensor([]) self.source_idxs_embed_pe = torch.tensor([]) + # below are for targets which are tokenized like sources + self.target_srclk_raw = [[] for _ in range(forecast_steps + 1)] + self.target_srclk_tokens_lens = [[] for _ in range(forecast_steps + 1)] + self.target_srclk_tokens_cells = [[] for _ in range(forecast_steps + 1)] + self.target_srclk_centroids = [[] for _ in range(forecast_steps + 1)] + + self.target_srclk_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)] + self.target_srclk_idxs_embed_pe = [torch.tensor([]) for _ in range(forecast_steps + 1)] + + def to_device(self, device="cuda") -> None: """ Move data to GPU @@ -91,6 +101,24 @@ def to_device(self, device="cuda") -> None: self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True) self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True) + self.target_srclk_raw = [t.to(device, non_blocking=True) for t in self.target_srclk_raw] + self.target_srclk_tokens_lens = [ + t.to(device, non_blocking=True) for t in self.target_srclk_tokens_lens + ] + self.target_srclk_tokens_cells = [ + t.to(device, non_blocking=True) for t in self.target_srclk_tokens_cells + ] + self.target_srclk_centroids = [ + t.to(device, non_blocking=True) for t in self.target_srclk_centroids + ] + + self.target_srclk_idxs_embed = [ + t.to(device, non_blocking=True) for t in self.target_srclk_idxs_embed + ] + self.target_srclk_idxs_embed_pe = [ + t.to(device, non_blocking=True) for t in self.target_srclk_idxs_embed_pe + ] + return self def add_empty_source(self, source: IOReaderData) -> None: @@ -111,6 +139,22 @@ def add_empty_source(self, source: IOReaderData) -> None: self.source_tokens_cells += [torch.tensor([])] self.source_centroids += [torch.tensor([])] + def add_empty_target_srclk(self, fstep: int) -> None: + """ + Add an empty target for an input encoded like source. + Parameters + ---------- + None + Returns + ------- + None + """ + + self.target_srclk_raw[fstep] += [torch.tensor([])] + self.target_srclk_tokens_lens[fstep] += [torch.zeros([self.nhc_source], dtype=torch.int32)] + self.target_srclk_tokens_cells[fstep] += [torch.tensor([])] + self.target_srclk_centroids[fstep] += [torch.tensor([])] + def add_empty_target(self, fstep: int) -> None: """ Add an empty target for an input. @@ -159,6 +203,34 @@ def add_source( self.source_tokens_cells += [ss_cells] self.source_centroids += [ss_centroids] + def add_target_srclk( + self, + fstep: int, + tt_raw: torch.tensor, + tt_lens: torch.tensor, + tt_cells: list, + tt_centroids: list, + ) -> None: + """ + Add data for source for one input. + Parameters + ---------- + ss_raw : torch.tensor( number of data points in time window , number of channels ) + ss_lens : torch.tensor( number of healpix cells ) + ss_cells : list( number of healpix cells ) + [ torch.tensor( tokens per cell, token size, number of channels) ] + ss_centroids : list(number of healpix cells ) + [ torch.tensor( for source , 5) ] + Returns + ------- + None + """ + + self.target_srclk_raw[fstep] += [tt_raw] + self.target_srclk_tokens_lens[fstep] += [tt_lens] + self.target_srclk_tokens_cells[fstep] += [tt_cells] + self.target_srclk_centroids[fstep] += [tt_centroids] + def add_target( self, fstep: int, @@ -318,6 +390,37 @@ def merge_inputs(self) -> None: self.source_tokens_cells = torch.tensor([]) self.source_centroids = torch.tensor([]) + # >>>>>>> + # collect all source like tokens in current stream and add to batch sample list when non-empty + for fstep in range(len(self.target_srclk_tokens_cells)): + if torch.tensor([len(s) for s in self.target_srclk_tokens_cells[fstep]]).sum() > 0: + self.target_srclk_raw[fstep] = torch.cat(self.target_srclk_raw[fstep]) + + # collect by merging entries per cells, preserving cell structure + self.target_srclk_tokens_cells[fstep] = self._merge_cells( + self.target_srclk_tokens_cells[fstep], self.nhc_source + ) + self.target_srclk_centroids[fstep] = self._merge_cells( + self.target_srclk_centroids[fstep], self.nhc_source + ) + # lens can be stacked and summed + self.target_srclk_tokens_lens[fstep] = torch.stack( + self.target_srclk_tokens_lens[fstep] + ).sum(0) + + # remove NaNs + idx = torch.isnan(self.target_srclk_tokens_cells[fstep]) + self.target_srclk_tokens_cells[fstep][idx] = self.mask_value + idx = torch.isnan(self.target_srclk_centroids[fstep]) + self.target_srclk_centroids[fstep][idx] = self.mask_value + + else: + self.target_srclk_raw[fstep] = torch.tensor([]) + self.target_srclk_tokens_lens[fstep] = torch.zeros([self.nhc_source]) + self.target_srclk_tokens_cells[fstep] = torch.tensor([]) + self.target_srclk_centroids[fstep] = torch.tensor([]) + # <<<<< + # targets for fstep in range(len(self.target_coords)): # collect all targets in current stream and add to batch sample list when non-empty @@ -358,4 +461,4 @@ def merge_inputs(self) -> None: self.target_times_raw[fstep] = np.array([], dtype="datetime64[ns]") self.target_tokens[fstep] = torch.tensor([]) self.target_tokens_lens[fstep] = torch.tensor([0]) - self.target_coords_lens[fstep] = torch.tensor([]) + self.target_coords_lens[fstep] = torch.tensor([]) \ No newline at end of file diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py index 3b17fddb2..a02d52ef1 100644 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ b/src/weathergen/datasets/tokenizer_forecast.py @@ -141,6 +141,7 @@ def batchify_target( target_times = np.split(times_reordered_enc, ll) target_tokens_lens = torch.tensor([len(s) for s in target_tokens], dtype=torch.int32) + target_centroids = torch.tensor([]) # compute encoding of target coordinates used in prediction network if target_tokens_lens.sum() > 0: diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index b5d2279b8..1afb21bab 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -616,6 +616,7 @@ def get_target_coords_local_ffast( return a + def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to @@ -676,6 +677,76 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: return batch +def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: + """ + Compute auxiliary information for scatter operation that changes from stream-centric to + cell-centric computations + + Parameters + ---------- + batch : str + batch of stream data information for which offsets have to be computed + + Returns + ------- + StreamData + stream data with offsets added as members + """ + + # collect source_tokens_lens for all stream datas + target_srclk_tokens_lens = torch.stack( + [ + torch.stack( + [ + torch.stack( + [ + s.target_srclk_tokens_lens[fstep] + if len(s.target_srclk_tokens_lens[fstep]) > 0 + else torch.tensor([]) + for fstep in range(len(s.target_srclk_tokens_lens)) + ] + ) + for s in stl_b + ] + ) + for stl_b in batch + ] + ) + + # precompute index sets for scatter operation after embed + offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) + #take offset_base up to last col and append a 0 in the beginning per fstep + zeros_col = torch.zeros((offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device) + offsets = torch.cat([zeros_col, offsets_base[:,:-1]], dim=1) + offsets_pe = torch.zeros_like(offsets) + + for ib, sb in enumerate(batch): + for itype, s in enumerate(sb): + for fstep in range(offsets.shape[0]): + if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty + s.target_srclk_idxs_embed[fstep] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int64) + for offset, token_len in zip( + offsets[fstep], target_srclk_tokens_lens[ib, itype, fstep], strict=False + ) + ] + ) + s.target_srclk_idxs_embed_pe[fstep] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int32) + for offset, token_len in zip( + offsets_pe[fstep], target_srclk_tokens_lens[ib][itype][fstep], strict=False + ) + ] + ) + + # advance offsets + offsets[fstep] += target_srclk_tokens_lens[ib][itype][fstep] + offsets_pe[fstep] += target_srclk_tokens_lens[ib][itype][fstep] + + return batch + def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list: """ diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index 18cd1b763..a7f8eb275 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -15,7 +15,6 @@ from weathergen.model.norms import AdaLayerNorm, RMSNorm - class MultiSelfAttentionHeadVarlen(torch.nn.Module): def __init__( self, @@ -194,6 +193,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_noise_conditioning=False ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -239,11 +239,21 @@ def mask_block_local(batch, head, idx_q, idx_kv): # compile for efficiency self.flex_attention = torch.compile(flex_attention, dynamic=False) - def forward(self, x, ada_ln_aux=None): + self.noise_conditioning = None + if with_noise_conditioning: + self.noise_conditioning = LinearNormConditioning(dim_embed, dtype=self.dtype) + + + def forward(self, x, noise_embedding=None, ada_ln_aux=None): + if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + if self.noise_conditioning: + assert noise_embedding is not None, "Need noise embedding if using noise conditioning" + x = self.noise_conditioning(x, noise_embedding) + # project onto heads s = [x.shape[0], x.shape[1], self.num_heads, -1] qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) @@ -462,6 +472,36 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None): return outs +class LinearNormConditioning(torch.nn.Module): + """Module for norm conditioning. + + Conditions the normalization of `inputs` by applying a linear layer to the + `norm_conditioning` which produces the scale and offset for each channel. + """ + + def __init__(self, feature_size, dtype=torch.bfloat16): + super().__init__() + self.dtype = dtype + + self.conditional_linear_layer = torch.nn.Linear( + in_features=feature_size, + out_features=2 * feature_size, + ) + # Optional: initialize weights similar to TruncatedNormal(stddev=1e-8) + torch.nn.init.normal_(self.conditional_linear_layer.weight, std=1e-8) + torch.nn.init.zeros_(self.conditional_linear_layer.bias) + + def forward(self, inputs, norm_conditioning, dtype = None): + # norm_conditioning: [batch, feature_size] + # inputs: [batch, ..., feature_size] + conditional_scale_offset = self.conditional_linear_layer(norm_conditioning.to(self.dtype)) + scale_minus_one, offset = torch.chunk(conditional_scale_offset, 2, dim=-1) + scale = scale_minus_one + 1.0 + # Reshape scale and offset for broadcasting if needed + while scale.dim() < inputs.dim(): + scale = scale.unsqueeze(1) + offset = offset.unsqueeze(1) + return (inputs * scale + offset).to(self.dtype) #TODO: check if to(self.dtype) needed here class MultiSelfAttentionHead(torch.nn.Module): def __init__( @@ -478,6 +518,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, + with_noise_conditioning=False ): super(MultiSelfAttentionHead, self).__init__() @@ -517,12 +558,20 @@ def __init__( else: self.att = self.attention self.softmax = torch.nn.Softmax(dim=-1) + + self.noise_conditioning = None + if with_noise_conditioning: + self.noise_conditioning = LinearNormConditioning(dim_embed, dtype=self.dtype) - def forward(self, x, ada_ln_aux=None): + def forward(self, x, noise_embedding=None, ada_ln_aux=None): if self.with_residual: x_in = x x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux) + if self.noise_conditioning: + assert noise_embedding is not None, "Need noise embedding if using noise conditioning" + x = self.noise_conditioning(x, noise_embedding, dtype=self.dtype) + # project onto heads and q,k,v and # ensure these are 4D tensors as required for flash attention s = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1] diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 29bae5806..e208ba6b6 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -10,6 +10,8 @@ import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint +import numpy as np + from weathergen.common.config import Config from weathergen.model.attention import ( @@ -24,7 +26,7 @@ StreamEmbedLinear, StreamEmbedTransformer, ) -from weathergen.model.layers import MLP +from weathergen.model.layers import MLP, PositionalEmbedding, Linear from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -276,6 +278,17 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: self.cf = cf self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() + if self.cf.forecast_policy == "diffusion": + assert hasattr(self.cf, 'fe_diff_sigma_min') and self.cf.fe_diff_sigma_min is not None, "fe_diff_sigma_min must be set if forecast_policy is diffusion" + assert hasattr(self.cf, 'fe_diff_sigma_max') and self.cf.fe_diff_sigma_max is not None, "fe_diff_sigma_max must be set if forecast_policy is diffusion" + assert hasattr(self.cf, 'fe_diff_sigma_data') and self.cf.fe_diff_sigma_data is not None, "fe_diff_sigma_data must be set if forecast_policy is diffusion" + self.sigma_min = self.cf.fe_diff_sigma_min + self.sigma_max = self.cf.fe_diff_sigma_max if self.cf.fe_diff_sigma_max is not None else float('inf') + self.sigma_data = self.cf.fe_diff_sigma_data + self.map_noise = PositionalEmbedding(self.cf.ae_global_dim_embed, self.cf.fe_diff_sigma_data) + init = dict(init_mode='kaiming_uniform', init_weight=np.sqrt(1/3), init_bias=np.sqrt(1/3)) + self.map_layer0 = Linear(in_features=self.cf.ae_global_dim_embed, out_features=self.cf.ae_global_dim_embed, **init) + self.map_layer1 = Linear(in_features=self.cf.ae_global_dim_embed, out_features=self.cf.ae_global_dim_embed, **init) def create(self) -> torch.nn.ModuleList: """ @@ -284,6 +297,7 @@ def create(self) -> torch.nn.ModuleList: :return: torch.nn.ModuleList containing the forecasting blocks. """ global_rate = int(1 / self.cf.forecast_att_dense_rate) + diff_factor = 2 if self.cf.forecast_policy == "diffusion" else 1 #NOTE: this is a hot fix to handle conditioning on previous state via concatenation... if self.cf.forecast_policy is not None: for i in range(self.cf.fe_num_blocks): # Alternate between global and local attention @@ -299,15 +313,17 @@ def create(self) -> torch.nn.ModuleList: dim_aux=1, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=True if self.cf.forecast_policy == "diffusion" else False, ) ) else: + print(f'diff factor is {diff_factor}') self.fe_blocks.append( MultiSelfAttentionHeadLocal( self.cf.ae_global_dim_embed, num_heads=self.cf.fe_num_heads, - qkv_len=self.num_healpix_cells * self.cf.ae_local_num_queries, - block_factor=self.cf.ae_global_block_factor, + qkv_len=self.num_healpix_cells * self.cf.fe_local_num_queries * diff_factor, + block_factor=self.cf.fe_global_block_factor, dropout_rate=self.cf.fe_dropout_rate, with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, @@ -315,6 +331,7 @@ def create(self) -> torch.nn.ModuleList: dim_aux=1, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), + with_noise_conditioning=True if self.cf.forecast_policy == "diffusion" else False, ) ) # Add MLP block @@ -341,6 +358,44 @@ def init_weights_final(m): return self.fe_blocks +def edm_sampler( + net, latents, class_labels=None, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, +): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float32, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = latents.to(torch.float32) * t_steps[0] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = net(x_hat, t_hat, class_labels).to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = net(x_next, t_next, class_labels).to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next + + class EnsPredictionHead(torch.nn.Module): def __init__( diff --git a/src/weathergen/model/layers.py b/src/weathergen/model/layers.py index 1f7b8df5d..bc3de80c9 100644 --- a/src/weathergen/model/layers.py +++ b/src/weathergen/model/layers.py @@ -10,9 +10,10 @@ import torch import torch.nn as nn +import numpy as np from weathergen.model.norms import AdaLayerNorm, RMSNorm - +# from weathergen.model.attention import LinearNormConditioning class NamedLinear(torch.nn.Module): def __init__(self, name: str | None = None, **kwargs): @@ -78,7 +79,6 @@ def __init__( self.layers.append(torch.nn.Dropout(p=dropout_rate)) self.layers.append(torch.nn.Linear(dim_hidden, dim_out)) - def forward(self, *args): x, x_in, aux = args[0], args[0], args[-1] @@ -93,3 +93,105 @@ def forward(self, *args): x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) return x + + +# class FFW(MLP): +# def __init__( +# self, +# dim, +# dim_out=None, +# hidden_factor=2, +# pre_layer_norm=True, +# dropout_rate=0.0, +# nonlin=torch.nn.GELU, +# with_residual=True, +# norm_type="LayerNorm", +# dim_aux=None, +# norm_eps=1e-5, +# name: str | None = None, +# with_noise_conditioning=False + +# ): +# """Constructor""" + +# super(FFW, self).__init__( +# dim_in=dim, +# dim_out=dim_out, +# num_layers=2, +# hidden_factor=hidden_factor, +# pre_layer_norm=pre_layer_norm, +# dropout_rate=dropout_rate, +# nonlin=nonlin, +# with_residual=with_residual, +# norm_type=norm_type, +# dim_aux=dim_aux, +# norm_eps=norm_eps, +# name=name, +# ) + +# if with_noise_conditioning: +# self.noise_conditioning = LinearNormConditioning(dim_in) + +# def forward(self, *args): +# x, x_in, noise_embedding, aux = args[0], args[0], args[-1], args[-2] + +# if self.noise_conditioning: +# assert noise_embedding is not None, "Need noise embedding if using noise conditioning" +# x = self.noise_conditioning(x, noise_embedding) + +# for i, layer in enumerate(self.layers): +# x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x) + +# if self.with_residual: +# if x.shape[-1] == x_in.shape[-1]: +# x = x_in + x +# else: +# assert x.shape[-1] % x_in.shape[-1] == 0 +# x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]]) + +# return x + + +#from EDM +class PositionalEmbedding(torch.nn.Module): + def __init__(self, num_channels, max_positions=25000, endpoint=False): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + # TODO: should I set dtype here like in the attention blocks? + + def forward(self, x): + freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + +#from EDM +class Linear(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) + self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None + + def forward(self, x): + x = x @ self.weight.to(x.dtype).t().to(x.device) + if self.bias is not None: + x = x.add_(self.bias.to(x.dtype).to(x.device)) + return x + +#from EDM +def weight_init(shape, mode, fan_in, fan_out): + if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') + + +#TODO: try gencast positional embeddings for noise encoding... \ No newline at end of file diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 136d96149..54d72752b 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -12,6 +12,7 @@ import math import warnings from pathlib import Path +import logging import astropy_healpix as hp import astropy_healpix.healpy @@ -19,6 +20,7 @@ import torch from astropy_healpix import healpy from torch.utils.checkpoint import checkpoint +from torch.nn.functional import silu from weathergen.common.config import Config from weathergen.model.engines import ( @@ -36,6 +38,7 @@ from weathergen.model.utils import get_num_parameters from weathergen.utils.logger import logger from weathergen.utils.utils import get_dtype +import sys class ModelParams(torch.nn.Module): @@ -192,6 +195,10 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size + self.P_mean = cf.P_mean if "P_mean" in cf else -1.2 + self.P_std = cf.P_std if "P_std" in cf else 1.2 + self.sigma_data = cf.sigma_data if "sigma_data" in cf else 1.0 + ######################################### def create(self) -> "Model": """Create each individual module of the model""" @@ -259,7 +266,8 @@ def create(self) -> "Model": "Empty forecast engine (fe_num_blocks = 0), but forecast_steps[i] > 0 for some i" ) - self.fe_blocks = ForecastingEngine(cf, self.num_healpix_cells).create() + self.fe = ForecastingEngine(cf, self.num_healpix_cells) + self.fe_blocks = self.fe.create() ############### # embed coordinates yielding one query token for each target token @@ -400,6 +408,14 @@ def freeze_weights_forecast(self) -> "Model": # unfreeze forecast part for p in self.fe_blocks.parameters(): p.requires_grad = True + + if self.cf.forecast_policy == "diffusion": + for p in self.fe.map_layer0.parameters(): + p.requires_grad = True + for p in self.fe.map_layer1.parameters(): + p.requires_grad = True + for p in self.fe.map_noise.parameters(): + p.requires_grad = True return self @@ -422,6 +438,7 @@ def print_num_parameters(self) -> None: num_params_embed_tcs = [get_num_parameters(etc) for etc in self.embed_target_coords] num_params_tte = [get_num_parameters(tte) for tte in self.target_token_engines] num_params_preds = [get_num_parameters(head) for head in self.pred_heads] + print("-----------------") print(f"Total number of trainable parameters: {num_params_total:,}") @@ -518,8 +535,30 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = self.assimilate_global(model_params, tokens) + # now encode the targets into the latent space for all fsteps + if self.cf.get("encode_targets_latent", False) or (self.cf.forecast_policy == "diffusion"): + with torch.no_grad(): + tokens_targets = [] + tokens_targets_srclk = self.embed_cells_targets_srclk(model_params, streams_data) + for fstep in range(len(tokens_targets_srclk)): + tokens_target, _ = self.assimilate_local( + model_params, tokens_targets_srclk[fstep], source_cell_lens + ) + tokens_target = self.assimilate_global(model_params, tokens_target) + tokens_target_det = tokens_target.detach() # explicitly detach as well + tokens_targets.append(tokens_target_det) + # roll-out in latent space preds_all = [] + + #weights default to 1s + weights = torch.ones(tokens.shape[0], dtype=self.dtype, device=tokens.device) + + + if self.cf.forecast_policy == "diffusion": + assert forecast_steps == 1, "diffusion forecast only implemented for single step" + + tokens_all = [tokens] for fstep in range(forecast_offset, forecast_offset + forecast_steps): # prediction preds_all += [ @@ -532,9 +571,25 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - tokens = self.forecast(model_params, tokens) + if not self.cf.forecast_policy == "diffusion": + tokens = self.forecast(model_params, tokens) + tokens_all += [tokens] + if self.cf.forecast_policy == "diffusion": + #change the latent targets to residuals + tokens_targets[fstep - forecast_offset] = tokens_targets[fstep - forecast_offset] - tokens + if self.training: + logger.info('denoising step') + #TODO should never denoise multipel delta_t... + #in this case, we predict residual, and add it to previous state before decoding + res_tokens, weights = self.edm_denoise(model_params, tokens, tokens_targets[fstep - forecast_offset]) + else: + logger.info('sampling step') + res_tokens = self.edm_sample(model_params, tokens) #NOTE: weights are set to default 1s during sampling... + tokens = tokens + res_tokens + tokens_all += [res_tokens] # prediction for final step + #TODO: May exclude this step during forecast training when only latent loss is used (when working with diffusion) preds_all += [ self.predict( model_params, @@ -544,8 +599,11 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca target_coords_idxs, ) ] - - return preds_all, posteriors + + if self.cf.get("encode_targets_latent", False) or (self.cf.forecast_policy == "diffusion"): + return preds_all, posteriors, weights, tokens_all, tokens_targets + else: + return preds_all, posteriors, weights, None, None ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: @@ -600,6 +658,70 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: return tokens_all + def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> torch.Tensor: + """Embeds target data similar to source tokens for each fstep and stream separately and rearranges it to cell-wise order + Args: + model_params : Query and embedding parameters + streams_data : Used to initialize first tokens for pre-processing + Returns: + Tokens for local assimilation + """ + with torch.no_grad(): + target_srclk_tokens_lens = torch.stack( + [ + torch.stack( + [ + torch.stack( + [ + s.target_srclk_tokens_lens[fstep] + if len(s.target_srclk_tokens_lens[fstep]) > 0 + else torch.tensor([]) + for fstep in range(len(s.target_srclk_tokens_lens)) + ] + ) + for s in stl_b + ] + ) + for stl_b in streams_data + ] + ) + offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) + num_fsteps = target_srclk_tokens_lens.shape[2] # TODO: KCT, if there are diff no of tokens per fstep, this may fail + tokens_all = [] + for fstep in range(num_fsteps): + tokens_all.append( + torch.empty( + (int(offsets_base[fstep][-1]), self.cf.ae_local_dim_embed), + dtype=self.dtype, + device="cuda", + ) + ) + + tokens_all_scattered = [] + for _, sb in enumerate(streams_data): + for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): + for fstep in range(num_fsteps): + if not (s.target_srclk_tokens_lens[fstep].sum() == 0): + idxs = s.target_srclk_idxs_embed[fstep] + idxs_pe = s.target_srclk_idxs_embed_pe[fstep] + + # create full scatter index + # (there's no broadcasting which is likely highly inefficient) + idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) + + x_embed = embed( + s.target_srclk_tokens_cells[fstep], s.source_centroids # TODO: KCT, get this from the srclk targets + ).flatten(0, 1) + + # scatter write to reorder from per stream to per cell ordering + tokens_all_fstep = tokens_all[fstep] + tokens_all_fstep.scatter_( + 0, idxs, x_embed + model_params.pe_embed[idxs_pe] + ) + tokens_all_scattered.append(tokens_all_fstep) + + return tokens_all_scattered + ######################################### def assimilate_local( self, model_params: ModelParams, tokens: torch.Tensor, cell_lens: torch.Tensor @@ -728,7 +850,7 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> return tokens ######################################### - def forecast(self, model_params: ModelParams, tokens: torch.Tensor) -> torch.Tensor: + def forecast(self, model_params: ModelParams, tokens: torch.Tensor, noise_conditioning: None) -> torch.Tensor: """Advances latent space representation in time Args: @@ -742,10 +864,79 @@ def forecast(self, model_params: ModelParams, tokens: torch.Tensor) -> torch.Ten for it, block in enumerate(self.fe_blocks): aux_info = torch.tensor([it], dtype=torch.float32, device="cuda") - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + tokens = checkpoint(block, tokens, noise_conditioning, aux_info, use_reentrant=False) return tokens + def edm_preconditioning(self, model_params: ModelParams, condition_tokens: torch.Tensor, noised_target: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: + + c_skip = self.sigma_data ** 2 / (noise ** 2 + self.sigma_data ** 2) + c_out = noise * self.sigma_data / (noise ** 2 + self.sigma_data ** 2).sqrt() + c_in = 1 / (self.sigma_data ** 2 + noise ** 2).sqrt() + c_noise = noise.log() / 4 + + target_tokens = c_in * noised_target + + concat_x = torch.concat([condition_tokens, target_tokens], dim=1) + + #embed the noise + emb = self.fe.map_noise(c_noise.flatten()) + emb = silu(self.fe.map_layer0(emb)) + emb = silu(self.fe.map_layer1(emb)).to(concat_x.device) + + F_x = self.forecast(model_params, concat_x, emb) + D_x = c_skip * noised_target + c_out * F_x[:, -target_tokens.shape[1]:, :] + return D_x + + def edm_denoise(self, model_params: ModelParams, condition_tokens: torch.Tensor, target_tokens: torch.Tensor): + rnd_normal = torch.randn([target_tokens.shape[0], 1, 1], dtype=torch.float32, device=target_tokens.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + n = torch.randn_like(target_tokens) * sigma + noised_target = target_tokens + n + weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + + return self.edm_preconditioning(model_params, condition_tokens, noised_target, noise=sigma), weight + + def edm_sample( + #default parameters taken form gencast supplementary, sampler architecture adapted form edm + self, model_params: ModelParams, condition_tokens: torch.Tensor, class_labels=None, randn_like=torch.randn_like, + num_steps=20, sigma_min=0.02, sigma_max=88, rho=7, + S_churn=2.5, S_min=0, S_max=80, S_noise=1.05, + ): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, self.fe.sigma_min) + sigma_max = min(sigma_max, self.fe.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float32) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + n = torch.randn(condition_tokens.shape, device=condition_tokens.device) #shape should be like target_tokens (may need adaptation if conditioning on multiple lags) + x_next = (n * t_steps[0]).to(condition_tokens.device) #.to(torch.float64) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = torch.as_tensor(t_cur + gamma * t_cur).to(condition_tokens.device) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = self.edm_preconditioning(model_params, condition_tokens, x_hat, noise=t_hat) #.to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = self.edm_preconditioning(model_params, condition_tokens, x_hat, noise=t_next) #.to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next + + ######################################### def predict( self, diff --git a/src/weathergen/run_evaluate.py b/src/weathergen/run_evaluate.py new file mode 100644 index 000000000..a0fb28419 --- /dev/null +++ b/src/weathergen/run_evaluate.py @@ -0,0 +1,180 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +The entry point for training and inference weathergen-atmo +""" + +import logging +import pdb +import sys +import time +import traceback +from pathlib import Path + +import weathergen.utils.cli as cli +import weathergen.utils.config as config +from weathergen.train.trainer import Trainer +from weathergen.utils.logger import init_loggers + + +def inference(): + # By default, arguments from the command line are read. + inference_from_args(sys.argv[1:]) + + +def inference_from_args(argl: list[str]): + """ + Inference function for WeatherGenerator model. + Entry point for calling the inference code from the command line. + + When running integration tests, the arguments are directly provided. + """ + parser = cli.get_inference_parser() + args = parser.parse_args(argl) + + inference_overwrite = dict( + shuffle=False, + start_date_val=args.start_date, + end_date_val=args.end_date, + samples_per_validation=args.samples, + log_validation=args.samples if args.save_samples else 0, + analysis_streams_output=args.analysis_streams_output, + ) + + cli_overwrite = config.from_cli_arglist(args.options) + cf = config.load_config( + args.private_config, + args.from_run_id, + args.epoch, + *args.config, + inference_overwrite, + cli_overwrite, + ) + cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + + fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" + init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) + + cf.run_history += [(args.from_run_id, cf.istep)] + + trainer = Trainer() + trainer.inference(cf, args.from_run_id, args.epoch) + + +#################################################################################################### +def train_continue() -> None: + parser = cli.get_continue_parser() + args = parser.parse_args() + + init_loggers() + + if args.finetune_forecast: + finetune_overwrite = dict( + training_mode="forecast", + forecast_delta_hrs=0, # 12 + forecast_steps=1, # [j for j in range(1,9) for i in range(4)] + forecast_policy="fixed", # 'sequential_random' # 'fixed' #'sequential' #_random' + forecast_freeze_model=True, + forecast_att_dense_rate=1.0, # 0.25 + fe_num_blocks=8, + fe_num_heads=16, + fe_dropout_rate=0.1, + fe_with_qk_lnorm=True, + lr_start=0.000001, + lr_max=0.00003, + lr_final_decay=0.00003, + lr_final=0.0, + lr_steps_warmup=1024, + lr_steps_cooldown=4096, + lr_policy_warmup="cosine", + lr_policy_decay="linear", + lr_policy_cooldown="linear", + num_epochs=12, # len(cf.forecast_steps) + 4 + istep=0, + ) + else: + finetune_overwrite = dict() + + cli_overwrite = config.from_cli_arglist(args.options) + cf = config.load_config( + args.private_config, + args.from_run_id, + args.epoch, + finetune_overwrite, + *args.config, + cli_overwrite, + ) + cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + + fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" + init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) + + # track history of run to ensure traceability of results + cf.run_history += [(args.from_run_id, cf.istep)] + + if args.finetune_forecast: + if cf.forecast_freeze_model: + cf.with_fsdp = False + import torch + + torch._dynamo.config.optimize_ddp = False + trainer = Trainer() + trainer.run(cf, args.from_run_id, args.epoch) + + +#################################################################################################### +def train() -> None: + """ + Training function for WeatherGenerator model. + Entry point for calling the training code from the command line. + Configurations are set in the function body. + + Args: + run_id (str, optional): Run/model id of pretrained WeatherGenerator model to + continue training. Defaults to None. + Note: All model configurations are set in the function body. + """ + train_with_args(sys.argv[1:], None) + + +def train_with_args(argl: list[str], stream_dir: str | None): + """ + Training function for WeatherGenerator model.""" + parser = cli.get_train_parser() + args = parser.parse_args(argl) + + cli_overwrite = config.from_cli_arglist(args.options) + + cf = config.load_config(args.private_config, None, None, *args.config, cli_overwrite) + cf = config.set_run_id(cf, args.run_id, False) + + fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" + init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) + + cf.streams = config.load_streams(Path(cf.streams_directory)) + + if cf.with_flash_attention: + assert cf.with_mixed_precision + cf.data_loader_rng_seed = int(time.time()) + + trainer = Trainer(checkpoint_freq=250, print_freq=10) + + try: + trainer.run(cf) + except Exception: + extype, value, tb = sys.exc_info() + traceback.print_exc() + pdb.post_mortem(tb) + + +if __name__ == "__main__": + # Entry point for slurm script. + # Check whether --from_run_id passed as argument. + inference() diff --git a/src/weathergen/train/loss.py b/src/weathergen/train/loss.py index 6f5cf5819..47791987d 100644 --- a/src/weathergen/train/loss.py +++ b/src/weathergen/train/loss.py @@ -64,6 +64,9 @@ def stats_normalized_erf(target, ens, mu, stddev): def mse(target, ens, mu, *kwargs): return torch.nn.functional.mse_loss(target, mu) +def diff_mse(target, ens, mu, *kwargs): + return torch.nn.functional.mse_loss(target, mu, reduction="none") + def mse_ens(target, ens, mu, stddev): mse_loss = torch.nn.functional.mse_loss @@ -89,6 +92,7 @@ def mse_channel_location_weighted( pred: torch.Tensor, weights_channels: torch.Tensor | None, weights_points: torch.Tensor | None, + weights_samples: torch.Tensor | None = None, ): """ Compute weighted MSE loss for one window or step @@ -138,12 +142,14 @@ def mse_channel_location_weighted( diff2 = torch.square(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)) if weights_points is not None: diff2 = (diff2.transpose(1, 0) * weights_points).transpose(1, 0) - loss_chs = diff2.mean(0) - loss = torch.mean(loss_chs * weights_channels if weights_channels else loss_chs) + diff2_weighted = diff2 * weights_samples.squeeze(0) if weights_samples is not None else diff2 + loss_chs = diff2_weighted.mean(0) + loss_per_sample = loss_chs * weights_channels if weights_channels else loss_chs + loss_per_sample_weighted = loss_per_sample * weights_samples if weights_samples is not None else loss_per_sample + loss = loss_per_sample_weighted.mean() return loss, loss_chs - def cosine_latitude(stream_data, forecast_offset, fstep, min_value=1e-3, max_value=1.0): latitudes_radian = stream_data.target_coords_raw[forecast_offset + fstep][:, 0] * np.pi / 180 return (max_value - min_value) * np.cos(latitudes_radian) + min_value diff --git a/src/weathergen/train/loss_calculator.py b/src/weathergen/train/loss_calculator.py index 4ad91fe5a..0a876274c 100644 --- a/src/weathergen/train/loss_calculator.py +++ b/src/weathergen/train/loss_calculator.py @@ -38,6 +38,8 @@ class LossValues: # well as standard deviations when operating with ensembles (e.g., when training with CRPS). losses_all: dict[str, Tensor] stddev_all: dict[str, Tensor] + + losses_all_lat: Tensor class LossCalculator: @@ -81,6 +83,19 @@ def __init__( [getattr(losses, name if name != "mse" else "mse_channel_location_weighted"), w] for name, w in loss_fcts ] + + loss_fcts_lat = cf.get("latent_loss_fcts") if stage == TRAIN else cf.get("latent_loss_fcts_val") + if loss_fcts_lat: + self.loss_fcts_lat = [ + [getattr(losses, name), w] + for name, w in loss_fcts_lat + ] + else: + self.loss_fcts_lat = [] + + if self.cf.forecast_policy == "diffusion" and self.cf.training_mode == "forecast" and self.loss_fcts: + _logger.warning("Loss functions in physical space are specified despite training diffusion-based forecast engine – Carefully check if this is the desired specification.") + def _get_weights(self, stream_info): """ @@ -140,6 +155,7 @@ def _loss_per_loss_function( substep_masks: list[torch.Tensor], weights_channels: torch.Tensor, weights_locations: torch.Tensor, + weights_samples: torch.Tensor = None, ): """ Compute loss for given loss function @@ -153,7 +169,7 @@ def _loss_per_loss_function( assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True loss, loss_chs = loss_fct( - target[mask_t], pred[:, mask_t], weights_channels, weights_locations + target[mask_t], pred[:, mask_t], weights_channels, weights_locations, weights_samples ) # accumulate loss @@ -168,10 +184,24 @@ def _loss_per_loss_function( loss_lfct = loss_lfct / (ctr_substeps if ctr_substeps > 0 else 1.0) return loss_lfct, losses_chs + + def _loss_per_loss_function_lat( + loss_fct, + stream_info, + target: torch.Tensor, + pred: torch.Tensor, + ): + """ + Compute loss for given loss function + """ + + loss_val = loss_fct(target=target, ens=None, mu=pred) + + return loss_val def compute_loss( self, - preds: list[list[Tensor]], + out: list[list[Tensor]], streams_data: list[list[any]], ) -> LossValues: """ @@ -221,10 +251,21 @@ def compute_loss( stddev_all: dict[str, Tensor] = { st.name: torch.zeros(len(stat_loss_fcts), device=self.device) for st in self.cf.streams } + + losses_all_lat: Tensor= torch.zeros( + len(self.loss_fcts_lat), + device=self.device, + ) + + preds, posteriors, weights, tokens_all, tokens_targets = out # TODO: iterate over batch dimension i_batch = 0 + for i_stream_info, stream_info in enumerate(self.cf.streams): + + # 1. First go through the losses in the physical space + # extract target tokens for current stream from the specified forecast offset onwards targets = streams_data[i_batch][i_stream_info].target_tokens[self.cf.forecast_offset :] @@ -234,6 +275,7 @@ def compute_loss( ctr_fsteps = 0 for fstep, target in enumerate(targets): # skip if either target or prediction has no data points + preds = out[0] pred = preds[fstep][i_stream_info] if not (target.shape[0] > 0 and pred.shape[0] > 0): continue @@ -268,6 +310,7 @@ def compute_loss( substep_masks, weights_channels, weights_locations, + weights ) losses_all[stream_info.name][:, i_lfct] += loss_lfct_chs @@ -278,9 +321,12 @@ def compute_loss( loss_fsteps = loss_fsteps + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) ctr_fsteps += 1 if ctr_loss_fcts > 0 else 0 - + + + loss = loss + (loss_fsteps / (ctr_fsteps if ctr_fsteps > 0 else 1.0)) ctr_streams += 1 if ctr_fsteps > 0 else 0 + # normalize by forecast step losses_all[stream_info.name] /= ctr_fsteps if ctr_fsteps > 0 else 1.0 @@ -289,6 +335,41 @@ def compute_loss( # replace channels without information by nan to exclude from further computations losses_all[stream_info.name][losses_all[stream_info.name] == 0.0] = torch.nan stddev_all[stream_info.name][stddev_all[stream_info.name] == 0.0] = torch.nan + + + # 2. Now go through the losses in the latent space, if applicable + if self.loss_fcts_lat: + loss_fsteps_lat = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_fsteps_lat = 0 + # TODO: KCT, do we need the below per fstep? + for fstep in range(1, len(tokens_all)): # the first entry in tokens_all is the source itself, so skip it + loss_fstep = torch.tensor(0.0, device=self.device, requires_grad=True) + ctr_loss_fcts = 0 + # if forecast_offset==0, then the timepoints correspond. Otherwise targets don't encode the source timestep, so we don't need to skip + fstep_targs = fstep if self.cf.forecast_offset == 0 else fstep -1 + for i_lfct, (loss_fct, loss_fct_weight) in enumerate(self.loss_fcts_lat): + loss_lfct = LossCalculator._loss_per_loss_function_lat( + loss_fct, + stream_info=None, + target=tokens_targets[fstep_targs], + pred=tokens_all[fstep] + ) + + losses_all_lat[i_lfct] += loss_lfct # TODO: break into fsteps + + # Add the weighted and normalized loss from this loss function to the total + # batch loss + loss_fstep = loss_fstep + (loss_fct_weight * loss_lfct) + ctr_loss_fcts += 1 if loss_lfct > 0.0 else 0 + + loss_fsteps_lat = loss_fsteps_lat + (loss_fstep / ctr_loss_fcts if ctr_loss_fcts > 0 else 0) + ctr_fsteps_lat += 1 if ctr_loss_fcts > 0 else 0 + + loss = loss + (loss_fsteps_lat / (ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0)) + ctr_streams = ctr_streams if ctr_streams > 0 else 1 # TODO: KCT, check the logic here + + losses_all_lat /= ctr_fsteps_lat if ctr_fsteps_lat > 0 else 1.0 + losses_all_lat[losses_all_lat == 0.0] = torch.nan if loss == 0.0: # streams_data[i] are samples in batch @@ -303,4 +384,4 @@ def compute_loss( loss = loss / ctr_streams # Return all computed loss components encapsulated in a ModelLoss dataclass - return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all) + return LossValues(loss=loss, losses_all=losses_all, stddev_all=stddev_all, losses_all_lat=losses_all_lat) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f3eb8850e..029e7ca30 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -416,6 +416,7 @@ def _prepare_logging( # removing the reshaping, make sure to index the tensors starting at forecast_offset, e.g., # target_times_raw = streams_data[i_batch][i_strm].target_times_raw[forecast_offset+fstep], # when iterating over batch, stream, and fsteps. + targets_rt = [ [ torch.cat([t[i].target_tokens[fstep] for t in streams_data]) @@ -487,7 +488,7 @@ def train(self, epoch): self.optimizer.zero_grad() # Unweighted loss, real weighted loss, std for losses that need it - self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist, self.loss_unweighted_lat_hist = [], [], [], [] # training loop self.t_start = time.time() @@ -501,16 +502,28 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, posteriors = self.ddp_model( + out = self.ddp_model( self.model_params, batch, cf.forecast_offset, forecast_steps ) loss_values = self.loss_calculator.compute_loss( - preds=preds, + out=out, streams_data=batch[0], ) if cf.latent_noise_kl_weight > 0.0: kl = torch.cat([posterior.kl() for posterior in posteriors]) loss_values.loss += cf.latent_noise_kl_weight * kl.mean() + + # if bidx == 0 and is_root(): + # if self.cf.get("encode_targets_latent", False): + # unpack the predictions/tokens from the latent space if the latent space tokens are encoded + # preds, posteriors, weights, tokens_all, tokens_targets = out + # save_dir = "/iopsstor/scratch/cscs/ktezcan/weathergen/tokens/" + # np.save(save_dir + self.cf.run_id + "_tokens_all_epoch" + str(epoch) + ".npy", [t.detach().cpu().numpy() for t in tokens_all]) + # np.save(save_dir + self.cf.run_id + "_tokens_targets_epoch" + str(epoch) + ".npy", [t.detach().cpu().numpy() for t in tokens_targets]) + preds, posteriors, weights, tokens_all, tokens_targets = out + if cf.latent_noise_kl_weight > 0.0: + kl = torch.cat([posterior.kl() for posterior in posteriors]) + loss_values.loss += cf.latent_noise_kl_weight * kl.mean() # backward pass self.grad_scaler.scale(loss_values.loss).backward() @@ -530,6 +543,9 @@ def train(self, epoch): self.loss_unweighted_hist += [loss_values.losses_all] self.loss_model_hist += [loss_values.loss.item()] self.stdev_unweighted_hist += [loss_values.stddev_all] + + if loss_values.losses_all_lat.numel() > 0: + self.loss_unweighted_lat_hist += [loss_values.losses_all_lat.item()] perf_gpu, perf_mem = self.get_perf() self.perf_gpu = ddp_average(torch.tensor([perf_gpu])).item() @@ -552,7 +568,7 @@ def validate(self, epoch): self.ddp_model.eval() dataset_val_iter = iter(self.data_loader_validation) - self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist, self.loss_unweighted_lat_hist = [], [], [], [] with torch.no_grad(): # print progress bar but only in interactive mode, i.e. when without ddp @@ -569,14 +585,14 @@ def validate(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, _ = self.ddp_model( + out = self.ddp_model( self.model_params, batch, cf.forecast_offset, forecast_steps ) # compute loss and log output if bidx < cf.log_validation: loss_values = self.loss_calculator_val.compute_loss( - preds=preds, + out=out, streams_data=batch[0], ) @@ -588,7 +604,7 @@ def validate(self, epoch): targets_times_all, targets_lens, ) = self._prepare_logging( - preds=preds, + preds=out[0], forecast_offset=cf.forecast_offset, forecast_steps=cf.forecast_steps, streams_data=batch[0], @@ -608,13 +624,16 @@ def validate(self, epoch): else: loss_values = self.loss_calculator_val.compute_loss( - preds=preds, + out=out, streams_data=batch[0], ) self.loss_unweighted_hist += [loss_values.losses_all] self.loss_model_hist += [loss_values.loss.item()] self.stdev_unweighted_hist += [loss_values.stddev_all] + + if loss_values.losses_all_lat.numel() > 0: + self.loss_unweighted_lat_hist += [loss_values.losses_all_lat.item()] pbar.update(self.cf.batch_size_validation_per_gpu) @@ -691,21 +710,27 @@ def _prepare_losses_for_logging( """ losses_all: dict[str, Tensor] = {} stddev_all: dict[str, Tensor] = {} + losses_all_lat: Tensor # Make list of losses into a tensor. This is individual tensor per rank real_loss = torch.tensor(self.loss_model_hist, device=self.devices[0]) # Gather all tensors from all ranks into a list and stack them into one tensor again real_loss = torch.cat(all_gather_vlen(real_loss)) - for stream in self.cf.streams: # Loop over all steams + for stream in self.cf.streams: # Loop over all streams stream_hist = [losses_all[stream.name] for losses_all in self.loss_unweighted_hist] stream_all = torch.stack(stream_hist).to(torch.float64) losses_all[stream.name] = torch.cat(all_gather_vlen(stream_all)) + stream_hist = [stddev_all[stream.name] for stddev_all in self.stdev_unweighted_hist] stream_all = torch.stack(stream_hist).to(torch.float64) stddev_all[stream.name] = torch.cat(all_gather_vlen(stream_all)) + + lat_hist = [losses_all_lat for losses_all_lat in self.loss_unweighted_lat_hist] + lat_all = torch.tensor(lat_hist, device=self.devices[0]).to(torch.float64) + losses_all_lat = torch.cat(all_gather_vlen(lat_all)) - return real_loss, losses_all, stddev_all + return real_loss, losses_all, stddev_all, losses_all_lat def _log(self, stage: Stage): """ @@ -719,7 +744,7 @@ def _log(self, stage: Stage): - This method only executes logging on the main process (rank 0). - After logging, historical loss and standard deviation records are cleared. """ - avg_loss, losses_all, stddev_all = self._prepare_losses_for_logging() + avg_loss, losses_all, stddev_all, losses_all_lat = self._prepare_losses_for_logging() samples = self.cf.istep * self.cf.batch_size_per_gpu * self.cf.num_ranks if is_root(): @@ -743,7 +768,7 @@ def _log(self, stage: Stage): def _log_terminal(self, bidx: int, epoch: int, stage: Stage): if bidx % self.print_freq == 0 and bidx > 0 or stage == VAL: # compute from last iteration - avg_loss, losses_all, _ = self._prepare_losses_for_logging() + avg_loss, losses_all, _, losses_all_lat = self._prepare_losses_for_logging() if is_root(): if stage == VAL: @@ -783,6 +808,13 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): + f" : {losses_all[st['name']].nanmean():0.4E} \t", end="", ) + # if the latent loss tensor is not empty: + if losses_all_lat.numel(): + # print latent losses if available + print( + f"latent loss : {losses_all_lat.nanmean():0.4E} \t", + end="", + ) print("\n", flush=True) self.t_start = time.time() diff --git a/src/weathergen/train/trainer_base.py b/src/weathergen/train/trainer_base.py index ad7fad8bc..e8ba7f4c1 100644 --- a/src/weathergen/train/trainer_base.py +++ b/src/weathergen/train/trainer_base.py @@ -22,6 +22,9 @@ from weathergen.train.utils import str_to_tensor, tensor_to_str from weathergen.utils.distributed import is_root +import socket +import time + _logger = logging.getLogger(__name__) @@ -82,9 +85,11 @@ def init_ddp(cf): _logger.info(f"rank: {rank} has run_id: {cf.run_id}") return + master_port = os.environ.get("MASTER_PORT", "29514") + local_rank = int(os.environ.get("SLURM_LOCALID")) ranks_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE", "1")[0]) - rank = int(os.environ.get("SLURM_NODEID")) * ranks_per_node + local_rank + rank = int(os.environ.get("SLURM_PROCID")) num_ranks = int(os.environ.get("SLURM_NTASKS")) _logger.info( f"DDP initialization: local_rank={local_rank}, ranks_per_node={ranks_per_node}, " @@ -96,6 +101,7 @@ def init_ddp(cf): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind((master_node, 1345)) + print("Port 1345 is available for DDP initialization.") except OSError as e: if e.errno == errno.EADDRINUSE: _logger.error( @@ -112,11 +118,30 @@ def init_ddp(cf): _logger.info( f"Initializing DDP with rank {rank} out of {num_ranks} on master_node:{master_node}." ) + + # def check_port_open(host, port, timeout=5): + # s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # s.settimeout(timeout) + # try: + # s.connect((host, port)) + # s.close() + # return True + # except Exception: + # return False + + # if rank != 0: + # # Wait for master to bind the port + # time.sleep(2) + # port_open = check_port_open(master_node, 1345) + # if not port_open: + # raise RuntimeError(f"Rank {rank} cannot connect to {master_node}:1345") + # else: + # _logger.info(f"Rank {rank} port open") dist.init_process_group( backend="nccl", - init_method="tcp://" + master_node + ":1345", - timeout=datetime.timedelta(seconds=240), + init_method=f"tcp://{master_node}:{master_port}", + timeout=datetime.timedelta(seconds=60), world_size=num_ranks, rank=rank, device_id=torch.device("cuda", local_rank),