From 1ab20ceff8b96c64c9d034c724ba7c4e1c8a8e82 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 25 Sep 2025 16:09:10 +0200 Subject: [PATCH 01/17] initial changes --- .../datasets/multi_stream_data_sampler.py | 27 +++++ src/weathergen/datasets/stream_data.py | 102 +++++++++++++++++ src/weathergen/datasets/tokenizer_forecast.py | 7 +- src/weathergen/datasets/tokenizer_masking.py | 5 +- src/weathergen/datasets/utils.py | 70 ++++++++++++ src/weathergen/model/model.py | 108 +++++++++++++++++- 6 files changed, 314 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index c41916f9d..9e2d13e6f 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -31,6 +31,7 @@ from weathergen.datasets.utils import ( compute_idxs_predict, compute_offsets_scatter_embed, + compute_offsets_scatter_embed_target_srclk, compute_source_cell_lens, ) from weathergen.utils.logger import logger @@ -348,6 +349,7 @@ def __iter__(self): rdata.datetimes, (time_win1.start, time_win1.end), ds, + "source_normalizer" ) stream_data.add_source(rdata_wrapped, ss_lens, ss_cells, ss_centroids) @@ -367,6 +369,7 @@ def __iter__(self): if rdata.is_empty(): stream_data.add_empty_target(fstep) + stream_data.add_empty_target_srclk(fstep) else: (tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target( stream_info, @@ -379,7 +382,30 @@ def __iter__(self): ds, ) + target_raw_srclk = torch.from_numpy( + np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) + ) + (tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = ( + self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function + stream_info, + torch.from_numpy(rdata.coords), + torch.from_numpy(rdata.geoinfos), + torch.from_numpy(rdata.data), + rdata.datetimes, + (time_win2.start, time_win2.end), + ds, + "target_normalizer" + ) + ) + stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t) + stream_data.add_target_srclk( + fstep, + target_raw_srclk, + tt_lens_srclk, + tt_cells_srclk, + tt_centroids_srclk, + ) # merge inputs for sources and targets for current stream stream_data.merge_inputs() @@ -398,6 +424,7 @@ def __iter__(self): # compute offsets for scatter computation after embedding batch = compute_offsets_scatter_embed(batch) + batch = compute_offsets_scatter_embed_target_srclk(batch) # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index a5f12327e..4dc7bfd5d 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -66,6 +66,15 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i self.source_idxs_embed = torch.tensor([]) self.source_idxs_embed_pe = torch.tensor([]) + # below are for targets which are tokenized like sources + self.target_srclk_raw = [[] for _ in range(forecast_steps + 1)] + self.target_srclk_tokens_lens = [[] for _ in range(forecast_steps + 1)] + self.target_srclk_tokens_cells = [[] for _ in range(forecast_steps + 1)] + self.target_srclk_centroids = [[] for _ in range(forecast_steps + 1)] + + self.target_srclk_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)] + self.target_srclk_idxs_embed_pe = [torch.tensor([]) for _ in range(forecast_steps + 1)] + def to_device(self, device="cuda") -> None: """ Move data to GPU @@ -91,6 +100,24 @@ def to_device(self, device="cuda") -> None: self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True) self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True) + self.target_srclk_raw = [t.to(device, non_blocking=True) for t in self.target_srclk_raw] + self.target_srclk_tokens_lens = [ + t.to(device, non_blocking=True) for t in self.target_srclk_tokens_lens + ] + self.target_srclk_tokens_cells = [ + t.to(device, non_blocking=True) for t in self.target_srclk_tokens_cells + ] + self.target_srclk_centroids = [ + t.to(device, non_blocking=True) for t in self.target_srclk_centroids + ] + + self.target_srclk_idxs_embed = [ + t.to(device, non_blocking=True) for t in self.target_srclk_idxs_embed + ] + self.target_srclk_idxs_embed_pe = [ + t.to(device, non_blocking=True) for t in self.target_srclk_idxs_embed_pe + ] + return self def add_empty_source(self, source: IOReaderData) -> None: @@ -111,6 +138,22 @@ def add_empty_source(self, source: IOReaderData) -> None: self.source_tokens_cells += [torch.tensor([])] self.source_centroids += [torch.tensor([])] + def add_empty_target_srclk(self, fstep: int) -> None: + """ + Add an empty target for an input encoded like source. + Parameters + ---------- + None + Returns + ------- + None + """ + + self.target_srclk_raw[fstep] += [torch.tensor([])] + self.target_srclk_tokens_lens[fstep] += [torch.zeros([self.nhc_source], dtype=torch.int32)] + self.target_srclk_tokens_cells[fstep] += [torch.tensor([])] + self.target_srclk_centroids[fstep] += [torch.tensor([])] + def add_empty_target(self, fstep: int) -> None: """ Add an empty target for an input. @@ -159,6 +202,34 @@ def add_source( self.source_tokens_cells += [ss_cells] self.source_centroids += [ss_centroids] + def add_target_srclk( + self, + fstep: int, + tt_raw: torch.tensor, + tt_lens: torch.tensor, + tt_cells: list, + tt_centroids: list, + ) -> None: + """ + Add data for source for one input. + Parameters + ---------- + ss_raw : torch.tensor( number of data points in time window , number of channels ) + ss_lens : torch.tensor( number of healpix cells ) + ss_cells : list( number of healpix cells ) + [ torch.tensor( tokens per cell, token size, number of channels) ] + ss_centroids : list(number of healpix cells ) + [ torch.tensor( for source , 5) ] + Returns + ------- + None + """ + + self.target_srclk_raw[fstep] += [tt_raw] + self.target_srclk_tokens_lens[fstep] += [tt_lens] + self.target_srclk_tokens_cells[fstep] += [tt_cells] + self.target_srclk_centroids[fstep] += [tt_centroids] + def add_target( self, fstep: int, @@ -318,6 +389,37 @@ def merge_inputs(self) -> None: self.source_tokens_cells = torch.tensor([]) self.source_centroids = torch.tensor([]) + # >>>>>>> + # collect all source like tokens in current stream and add to batch sample list when non-empty + for fstep in range(len(self.target_srclk_tokens_cells)): + if torch.tensor([len(s) for s in self.target_srclk_tokens_cells[fstep]]).sum() > 0: + self.target_srclk_raw[fstep] = torch.cat(self.target_srclk_raw[fstep]) + + # collect by merging entries per cells, preserving cell structure + self.target_srclk_tokens_cells[fstep] = self._merge_cells( + self.target_srclk_tokens_cells[fstep], self.nhc_source + ) + self.target_srclk_centroids[fstep] = self._merge_cells( + self.target_srclk_centroids[fstep], self.nhc_source + ) + # lens can be stacked and summed + self.target_srclk_tokens_lens[fstep] = torch.stack( + self.target_srclk_tokens_lens[fstep] + ).sum(0) + + # remove NaNs + idx = torch.isnan(self.target_srclk_tokens_cells[fstep]) + self.target_srclk_tokens_cells[fstep][idx] = self.mask_value + idx = torch.isnan(self.target_srclk_centroids[fstep]) + self.target_srclk_centroids[fstep][idx] = self.mask_value + + else: + self.target_srclk_raw[fstep] = torch.tensor([]) + self.target_srclk_tokens_lens[fstep] = torch.zeros([self.nhc_source]) + self.target_srclk_tokens_cells[fstep] = torch.tensor([]) + self.target_srclk_centroids[fstep] = torch.tensor([]) + # <<<<< + # targets for fstep in range(len(self.target_coords)): # collect all targets in current stream and add to batch sample list when non-empty diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py index 3b17fddb2..a815f4600 100644 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ b/src/weathergen/datasets/tokenizer_forecast.py @@ -41,12 +41,15 @@ def batchify_source( source: np.array, times: np.array, time_win: tuple, - normalizer, # dataset + normalizer, # dataset, + use_normalizer: str, # "source" or "target" ): init_loggers() token_size = stream_info["token_size"] is_diagnostic = stream_info.get("diagnostic", False) tokenize_spacetime = stream_info.get("tokenize_spacetime", False) + + channel_normalizer = normalizer.normalize_source_channels if use_normalizer == "source_normalizer" else normalizer.normalize_target_channels tokenize_window = partial( tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, @@ -56,7 +59,7 @@ def batchify_source( hpy_verts_rots=self.hpy_verts_rots_source[-1], n_coords=normalizer.normalize_coords, n_geoinfos=normalizer.normalize_geoinfos, - n_data=normalizer.normalize_source_channels, + n_data=channel_normalizer, enc_time=encode_times_source, ) diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 3da8508f6..b2bab5ac8 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -48,11 +48,14 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset + use_normalizer: str, # "source" or "target" ): init_loggers() token_size = stream_info["token_size"] is_diagnostic = stream_info.get("diagnostic", False) tokenize_spacetime = stream_info.get("tokenize_spacetime", False) + + channel_normalizer = normalizer.normalize_source_channels if use_normalizer == "source_normalizer" else normalizer.normalize_target_channels tokenize_window = partial( tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, @@ -62,7 +65,7 @@ def batchify_source( hpy_verts_rots=self.hpy_verts_rots_source[-1], n_coords=normalizer.normalize_coords, n_geoinfos=normalizer.normalize_geoinfos, - n_data=normalizer.normalize_source_channels, + n_data=channel_normalizer, enc_time=encode_times_source, ) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index b5d2279b8..803c7fb2f 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -676,6 +676,76 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: return batch +def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: + """ + Compute auxiliary information for scatter operation that changes from stream-centric to + cell-centric computations + + Parameters + ---------- + batch : str + batch of stream data information for which offsets have to be computed + + Returns + ------- + StreamData + stream data with offsets added as members + """ + + # collect source_tokens_lens for all stream datas + target_srclk_tokens_lens = torch.stack( + [ + torch.stack( + [ + torch.stack( + [ + s.target_srclk_tokens_lens[fstep] + if len(s.target_srclk_tokens_lens[fstep]) > 0 + else torch.tensor([]) + for fstep in range(len(s.target_srclk_tokens_lens)) + ] + ) + for s in stl_b + ] + ) + for stl_b in batch + ] + ) + + # precompute index sets for scatter operation after embed + offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) + #take offset_base up to last col and append a 0 in the beginning per fstep + zeros_col = torch.zeros((offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device) + offsets = torch.cat([zeros_col, offsets_base[:,:-1]], dim=1) + offsets_pe = torch.zeros_like(offsets) + + for ib, sb in enumerate(batch): + for itype, s in enumerate(sb): + for fstep in range(offsets.shape[0]): + if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty + s.target_srclk_idxs_embed[fstep] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int64) + for offset, token_len in zip( + offsets[fstep], target_srclk_tokens_lens[ib, itype, fstep], strict=False + ) + ] + ) + s.target_srclk_idxs_embed_pe[fstep] = torch.cat( + [ + torch.arange(offset, offset + token_len, dtype=torch.int32) + for offset, token_len in zip( + offsets_pe[fstep], target_srclk_tokens_lens[ib][itype][fstep], strict=False + ) + ] + ) + + # advance offsets + offsets[fstep] += target_srclk_tokens_lens[ib][itype][fstep] + offsets_pe[fstep] += target_srclk_tokens_lens[ib][itype][fstep] + + return batch + def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list: """ diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 136d96149..7a14583c7 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -517,9 +517,10 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens) tokens = self.assimilate_global(model_params, tokens) - + # roll-out in latent space preds_all = [] + tokens_all = [tokens] for fstep in range(forecast_offset, forecast_offset + forecast_steps): # prediction preds_all += [ @@ -533,6 +534,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ] tokens = self.forecast(model_params, tokens) + tokens_all += [tokens] # prediction for final step preds_all += [ @@ -545,7 +547,45 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca ) ] - return preds_all, posteriors + # now encode the targets into the latent space for all fsteps + if self.cf.get("encode_targets_latent", False): + with torch.no_grad(): + tokens_targets = [] + tokens_targets_srclk = self.embed_cells_targets_srclk(model_params, streams_data) + for fstep in range(len(tokens_targets_srclk)): + tokens_target, _ = self.assimilate_local( + model_params, tokens_targets_srclk[fstep], source_cell_lens + ) + tokens_target = self.assimilate_global(model_params, tokens_target) + tokens_target_det = tokens_target.detach() # explicitly detach as well + tokens_targets.append(tokens_target_det) + + print(torch.linalg.norm(tokens_all[1] - tokens_targets[0])) + + if False: + preds_from_targets = [] + for fstep in range(forecast_offset, forecast_offset + forecast_steps): + # prediction + preds_from_targets += [ + self.predict( + model_params, + fstep, + tokens_targets[fstep], + streams_data, + target_coords_idxs, + ) + ] + save_dir = Path("/users/ktezcan/projects/Meteoswiss/WeatherGenerator/personal/clariden/experiments_latent_loss_rnd2/tokens") + np.save(save_dir / (self.cf.run_id+"_preds_from_targets"), [tt[0][0].cpu().detach().numpy() for tt in preds_from_targets]) + np.save(save_dir / (self.cf.run_id+"_preds_all"), [tt[0][0].cpu().detach().numpy() for tt in preds_all]) + np.save(save_dir / (self.cf.run_id+"_tokens_targets"), [tt.cpu().detach().numpy() for tt in tokens_targets]) + np.save(save_dir / (self.cf.run_id+"_tokens_all"), [tt.cpu().detach().numpy() for tt in tokens_all]) + + + if self.cf.get("encode_targets_latent", False): #TODO: KCT, put a safeguard: if there is a latent loss, encode_targets_latent has to be True + return preds_all, tokens_all, tokens_targets + else: + return preds_all, posteriors ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: @@ -600,6 +640,70 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: return tokens_all + def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> torch.Tensor: + """Embeds target data similar to source tokens for each fstep and stream separately and rearranges it to cell-wise order + Args: + model_params : Query and embedding parameters + streams_data : Used to initialize first tokens for pre-processing + Returns: + Tokens for local assimilation + """ + with torch.no_grad(): + target_srclk_tokens_lens = torch.stack( + [ + torch.stack( + [ + torch.stack( + [ + s.target_srclk_tokens_lens[fstep] + if len(s.target_srclk_tokens_lens[fstep]) > 0 + else torch.tensor([]) + for fstep in range(len(s.target_srclk_tokens_lens)) + ] + ) + for s in stl_b + ] + ) + for stl_b in streams_data + ] + ) + offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) + num_fsteps = target_srclk_tokens_lens.shape[2] # TODO: KCT, if there are diff no of tokens per fstep, this may fail + tokens_all = [] + for fstep in range(num_fsteps): + tokens_all.append( + torch.empty( + (int(offsets_base[fstep][-1]), self.cf.ae_local_dim_embed), + dtype=self.dtype, + device="cuda", + ) + ) + + tokens_all_scattered = [] + for _, sb in enumerate(streams_data): + for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): + for fstep in range(num_fsteps): + if not (s.target_srclk_tokens_lens[fstep].sum() == 0): + idxs = s.target_srclk_idxs_embed[fstep] + idxs_pe = s.target_srclk_idxs_embed_pe[fstep] + + # create full scatter index + # (there's no broadcasting which is likely highly inefficient) + idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) + + x_embed = embed( + s.target_srclk_tokens_cells[fstep], s.target_srclk_centroids[fstep] + ).flatten(0, 1) + + # scatter write to reorder from per stream to per cell ordering + tokens_all_fstep = tokens_all[fstep] + tokens_all_fstep.scatter_( + 0, idxs, x_embed + model_params.pe_embed[idxs_pe] + ) + tokens_all_scattered.append(tokens_all_fstep) + + return tokens_all_scattered + ######################################### def assimilate_local( self, model_params: ModelParams, tokens: torch.Tensor, cell_lens: torch.Tensor From 7a74aaafd981b316e8384ebacb985c2da361d65f Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 25 Sep 2025 16:58:48 +0200 Subject: [PATCH 02/17] clean up debug statements --- src/weathergen/model/model.py | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 7a14583c7..67440346c 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -560,32 +560,13 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens_target_det = tokens_target.detach() # explicitly detach as well tokens_targets.append(tokens_target_det) - print(torch.linalg.norm(tokens_all[1] - tokens_targets[0])) - - if False: - preds_from_targets = [] - for fstep in range(forecast_offset, forecast_offset + forecast_steps): - # prediction - preds_from_targets += [ - self.predict( - model_params, - fstep, - tokens_targets[fstep], - streams_data, - target_coords_idxs, - ) - ] - save_dir = Path("/users/ktezcan/projects/Meteoswiss/WeatherGenerator/personal/clariden/experiments_latent_loss_rnd2/tokens") - np.save(save_dir / (self.cf.run_id+"_preds_from_targets"), [tt[0][0].cpu().detach().numpy() for tt in preds_from_targets]) - np.save(save_dir / (self.cf.run_id+"_preds_all"), [tt[0][0].cpu().detach().numpy() for tt in preds_all]) - np.save(save_dir / (self.cf.run_id+"_tokens_targets"), [tt.cpu().detach().numpy() for tt in tokens_targets]) - np.save(save_dir / (self.cf.run_id+"_tokens_all"), [tt.cpu().detach().numpy() for tt in tokens_all]) - + + return_dict = {"preds_all": preds_all, "posteriors": posteriors} + if self.cf.get("encode_targets_latent", False): + return_dict["tokens_all"] = tokens_all + return_dict["tokens_targets"] = tokens_targets - if self.cf.get("encode_targets_latent", False): #TODO: KCT, put a safeguard: if there is a latent loss, encode_targets_latent has to be True - return preds_all, tokens_all, tokens_targets - else: - return preds_all, posteriors + return return_dict ######################################### def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: From c4129a33f7c265e3d6de3cea82e4e8b163d69449 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 25 Sep 2025 16:59:36 +0200 Subject: [PATCH 03/17] reading model output as dictionary --- src/weathergen/train/trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index f3eb8850e..000ecb944 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -501,9 +501,12 @@ def train(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, posteriors = self.ddp_model( + model_output = self.ddp_model( self.model_params, batch, cf.forecast_offset, forecast_steps ) + preds = model_output["preds_all"] + posteriors = model_output["posteriors"] + loss_values = self.loss_calculator.compute_loss( preds=preds, streams_data=batch[0], @@ -569,9 +572,10 @@ def validate(self, epoch): dtype=self.mixed_precision_dtype, enabled=cf.with_mixed_precision, ): - preds, _ = self.ddp_model( + model_output = self.ddp_model( self.model_params, batch, cf.forecast_offset, forecast_steps ) + preds = model_output["preds_all"] # compute loss and log output if bidx < cf.log_validation: From 4c2d12f147dc8c045078b48da8971756ade8e664 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 25 Sep 2025 17:12:40 +0200 Subject: [PATCH 04/17] added the config parameter with false as def --- config/default_config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/default_config.yml b/config/default_config.yml index 56a7c3e25..c7459c5aa 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -66,6 +66,7 @@ latent_noise_gamma: 2.0 latent_noise_saturate_encodings: 5 latent_noise_use_additive_noise: False latent_noise_deterministic_latents: True +encode_targets_latent: False loss_fcts: - From 029d0ada278747fa66720827d24caf0d71fcf6a4 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Thu, 25 Sep 2025 17:16:54 +0200 Subject: [PATCH 05/17] ruff --- .../datasets/multi_stream_data_sampler.py | 6 +- src/weathergen/datasets/stream_data.py | 4 +- src/weathergen/datasets/tokenizer_forecast.py | 10 +- src/weathergen/datasets/tokenizer_masking.py | 10 +- src/weathergen/datasets/utils.py | 39 ++-- src/weathergen/model/model.py | 15 +- src/weathergen/run_evaluate.py | 192 ++++++++++++++++++ src/weathergen/train/trainer.py | 2 +- 8 files changed, 243 insertions(+), 35 deletions(-) create mode 100644 src/weathergen/run_evaluate.py diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 9e2d13e6f..585ed3872 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -349,7 +349,7 @@ def __iter__(self): rdata.datetimes, (time_win1.start, time_win1.end), ds, - "source_normalizer" + "source_normalizer", ) stream_data.add_source(rdata_wrapped, ss_lens, ss_cells, ss_centroids) @@ -386,7 +386,7 @@ def __iter__(self): np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) ) (tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = ( - self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function + self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function stream_info, torch.from_numpy(rdata.coords), torch.from_numpy(rdata.geoinfos), @@ -394,7 +394,7 @@ def __iter__(self): rdata.datetimes, (time_win2.start, time_win2.end), ds, - "target_normalizer" + "target_normalizer", ) ) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 4dc7bfd5d..d50df3910 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -71,7 +71,7 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i self.target_srclk_tokens_lens = [[] for _ in range(forecast_steps + 1)] self.target_srclk_tokens_cells = [[] for _ in range(forecast_steps + 1)] self.target_srclk_centroids = [[] for _ in range(forecast_steps + 1)] - + self.target_srclk_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)] self.target_srclk_idxs_embed_pe = [torch.tensor([]) for _ in range(forecast_steps + 1)] @@ -110,7 +110,7 @@ def to_device(self, device="cuda") -> None: self.target_srclk_centroids = [ t.to(device, non_blocking=True) for t in self.target_srclk_centroids ] - + self.target_srclk_idxs_embed = [ t.to(device, non_blocking=True) for t in self.target_srclk_idxs_embed ] diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py index a815f4600..68a74cfe7 100644 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ b/src/weathergen/datasets/tokenizer_forecast.py @@ -42,14 +42,18 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset, - use_normalizer: str, # "source" or "target" + use_normalizer: str, # "source" or "target" ): init_loggers() token_size = stream_info["token_size"] is_diagnostic = stream_info.get("diagnostic", False) tokenize_spacetime = stream_info.get("tokenize_spacetime", False) - - channel_normalizer = normalizer.normalize_source_channels if use_normalizer == "source_normalizer" else normalizer.normalize_target_channels + + channel_normalizer = ( + normalizer.normalize_source_channels + if use_normalizer == "source_normalizer" + else normalizer.normalize_target_channels + ) tokenize_window = partial( tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index b2bab5ac8..57783f2eb 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -48,14 +48,18 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset - use_normalizer: str, # "source" or "target" + use_normalizer: str, # "source" or "target" ): init_loggers() token_size = stream_info["token_size"] is_diagnostic = stream_info.get("diagnostic", False) tokenize_spacetime = stream_info.get("tokenize_spacetime", False) - - channel_normalizer = normalizer.normalize_source_channels if use_normalizer == "source_normalizer" else normalizer.normalize_target_channels + + channel_normalizer = ( + normalizer.normalize_source_channels + if use_normalizer == "source_normalizer" + else normalizer.normalize_target_channels + ) tokenize_window = partial( tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space, diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 803c7fb2f..9f7d46583 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -676,6 +676,7 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: return batch + def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to @@ -694,40 +695,44 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: # collect source_tokens_lens for all stream datas target_srclk_tokens_lens = torch.stack( + [ + torch.stack( [ torch.stack( [ - torch.stack( - [ - s.target_srclk_tokens_lens[fstep] - if len(s.target_srclk_tokens_lens[fstep]) > 0 - else torch.tensor([]) - for fstep in range(len(s.target_srclk_tokens_lens)) - ] - ) - for s in stl_b + s.target_srclk_tokens_lens[fstep] + if len(s.target_srclk_tokens_lens[fstep]) > 0 + else torch.tensor([]) + for fstep in range(len(s.target_srclk_tokens_lens)) ] ) - for stl_b in batch + for s in stl_b ] ) + for stl_b in batch + ] + ) # precompute index sets for scatter operation after embed offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) - #take offset_base up to last col and append a 0 in the beginning per fstep - zeros_col = torch.zeros((offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device) - offsets = torch.cat([zeros_col, offsets_base[:,:-1]], dim=1) + # take offset_base up to last col and append a 0 in the beginning per fstep + zeros_col = torch.zeros( + (offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device + ) + offsets = torch.cat([zeros_col, offsets_base[:, :-1]], dim=1) offsets_pe = torch.zeros_like(offsets) for ib, sb in enumerate(batch): for itype, s in enumerate(sb): for fstep in range(offsets.shape[0]): - if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty + if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty s.target_srclk_idxs_embed[fstep] = torch.cat( [ torch.arange(offset, offset + token_len, dtype=torch.int64) for offset, token_len in zip( - offsets[fstep], target_srclk_tokens_lens[ib, itype, fstep], strict=False + offsets[fstep], + target_srclk_tokens_lens[ib, itype, fstep], + strict=False, ) ] ) @@ -735,7 +740,9 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: [ torch.arange(offset, offset + token_len, dtype=torch.int32) for offset, token_len in zip( - offsets_pe[fstep], target_srclk_tokens_lens[ib][itype][fstep], strict=False + offsets_pe[fstep], + target_srclk_tokens_lens[ib][itype][fstep], + strict=False, ) ] ) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 67440346c..c87ee0b2d 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -517,7 +517,7 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens, posteriors = self.assimilate_local(model_params, tokens, source_cell_lens) tokens = self.assimilate_global(model_params, tokens) - + # roll-out in latent space preds_all = [] tokens_all = [tokens] @@ -557,15 +557,14 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca model_params, tokens_targets_srclk[fstep], source_cell_lens ) tokens_target = self.assimilate_global(model_params, tokens_target) - tokens_target_det = tokens_target.detach() # explicitly detach as well + tokens_target_det = tokens_target.detach() # explicitly detach as well tokens_targets.append(tokens_target_det) - - + return_dict = {"preds_all": preds_all, "posteriors": posteriors} if self.cf.get("encode_targets_latent", False): return_dict["tokens_all"] = tokens_all return_dict["tokens_targets"] = tokens_targets - + return return_dict ######################################### @@ -649,7 +648,9 @@ def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> ] ) offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) - num_fsteps = target_srclk_tokens_lens.shape[2] # TODO: KCT, if there are diff no of tokens per fstep, this may fail + num_fsteps = target_srclk_tokens_lens.shape[ + 2 + ] # TODO: KCT, if there are diff no of tokens per fstep, this may fail tokens_all = [] for fstep in range(num_fsteps): tokens_all.append( @@ -665,7 +666,7 @@ def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): for fstep in range(num_fsteps): if not (s.target_srclk_tokens_lens[fstep].sum() == 0): - idxs = s.target_srclk_idxs_embed[fstep] + idxs = s.target_srclk_idxs_embed[fstep] idxs_pe = s.target_srclk_idxs_embed_pe[fstep] # create full scatter index diff --git a/src/weathergen/run_evaluate.py b/src/weathergen/run_evaluate.py new file mode 100644 index 000000000..9d210f110 --- /dev/null +++ b/src/weathergen/run_evaluate.py @@ -0,0 +1,192 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +The entry point for training and inference weathergen-atmo +""" + +import logging +import pdb +import sys +import time +import traceback +from pathlib import Path + +import weathergen.common.config as config +import weathergen.utils.cli as cli +from weathergen.train.trainer import Trainer +from weathergen.utils.logger import init_loggers + + +def inference(): + # By default, arguments from the command line are read. + inference_from_args(sys.argv[1:]) + + +def inference_from_args(argl: list[str]): + """ + Inference function for WeatherGenerator model. + Entry point for calling the inference code from the command line. + + When running integration tests, the arguments are directly provided. + """ + parser = cli.get_inference_parser() + args = parser.parse_args(argl) + + inference_overwrite = dict( + shuffle=False, + start_date_val=args.start_date, + end_date_val=args.end_date, + samples_per_validation=args.samples, + log_validation=args.samples if args.save_samples else 0, + analysis_streams_output=args.analysis_streams_output, + ) + + cli_overwrite = config.from_cli_arglist(args.options) + cf = config.load_config( + args.private_config, + args.from_run_id, + args.epoch, + *args.config, + inference_overwrite, + cli_overwrite, + ) + cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + + fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" + init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) + + cf.run_history += [(args.from_run_id, cf.istep)] + + trainer = Trainer() + trainer.inference(cf, args.from_run_id, args.epoch) + + +#################################################################################################### +def train_continue() -> None: + """ + Function to continue training for WeatherGenerator model. + Entry point for calling train_continue from the command line. + Configurations are set in the function body. + + Args: + from_run_id (str): Run/model id of pretrained WeatherGenerator model to + continue training. Defaults to None. + Note: All model configurations are set in the function body. + """ + train_continue_from_args(sys.argv[1:]) + + +def train_continue_from_args(argl: list[str]): + parser = cli.get_continue_parser() + args = parser.parse_args(argl) + + init_loggers() + + if args.finetune_forecast: + finetune_overwrite = dict( + training_mode="forecast", + forecast_delta_hrs=0, # 12 + forecast_steps=1, # [j for j in range(1,9) for i in range(4)] + forecast_policy="fixed", # 'sequential_random' # 'fixed' #'sequential' #_random' + forecast_freeze_model=True, + forecast_att_dense_rate=1.0, # 0.25 + fe_num_blocks=8, + fe_num_heads=16, + fe_dropout_rate=0.1, + fe_with_qk_lnorm=True, + lr_start=0.000001, + lr_max=0.00003, + lr_final_decay=0.00003, + lr_final=0.0, + lr_steps_warmup=1024, + lr_steps_cooldown=4096, + lr_policy_warmup="cosine", + lr_policy_decay="linear", + lr_policy_cooldown="linear", + num_epochs=12, # len(cf.forecast_steps) + 4 + istep=0, + ) + else: + finetune_overwrite = dict() + + cli_overwrite = config.from_cli_arglist(args.options) + cf = config.load_config( + args.private_config, + args.from_run_id, + args.epoch, + finetune_overwrite, + *args.config, + cli_overwrite, + ) + cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) + + fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" + init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) + + # track history of run to ensure traceability of results + cf.run_history += [(args.from_run_id, cf.istep)] + + if args.finetune_forecast: + if cf.forecast_freeze_model: + cf.with_fsdp = False + import torch + + torch._dynamo.config.optimize_ddp = False + trainer = Trainer() + trainer.run(cf, args.from_run_id, args.epoch) + + +#################################################################################################### +def train() -> None: + """ + Training function for WeatherGenerator model. + Entry point for calling the training code from the command line. + Configurations are set in the function body. + + Args: + run_id (str, optional): Run/model id of pretrained WeatherGenerator model to + continue training. Defaults to None. + Note: All model configurations are set in the function body. + """ + train_with_args(sys.argv[1:], None) + + +def train_with_args(argl: list[str], stream_dir: str | None): + """ + Training function for WeatherGenerator model.""" + parser = cli.get_train_parser() + args = parser.parse_args(argl) + + cli_overwrite = config.from_cli_arglist(args.options) + + cf = config.load_config(args.private_config, None, None, *args.config, cli_overwrite) + cf = config.set_run_id(cf, args.run_id, False) + + fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" + init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) + + cf.streams = config.load_streams(Path(cf.streams_directory)) + + if cf.with_flash_attention: + assert cf.with_mixed_precision + cf.data_loader_rng_seed = int(time.time()) + + trainer = Trainer(checkpoint_freq=250, print_freq=10) + + try: + trainer.run(cf) + except Exception: + extype, value, tb = sys.exc_info() + traceback.print_exc() + pdb.post_mortem(tb) + + +if __name__ == "__main__": + inference() diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 000ecb944..6047d5162 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -506,7 +506,7 @@ def train(self, epoch): ) preds = model_output["preds_all"] posteriors = model_output["posteriors"] - + loss_values = self.loss_calculator.compute_loss( preds=preds, streams_data=batch[0], From 2f4197e0dfb72dbe3a41244573bd3d6bf8f041b4 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 09:42:20 +0200 Subject: [PATCH 06/17] removed a KCT and corrected some ..._normalizer comments --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- src/weathergen/datasets/tokenizer_forecast.py | 2 +- src/weathergen/datasets/tokenizer_masking.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 585ed3872..bb044eee3 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -386,7 +386,7 @@ def __iter__(self): np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) ) (tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = ( - self.tokenizer.batchify_source( # TODO: KCT, check if anything source related is happening in the function + self.tokenizer.batchify_source( stream_info, torch.from_numpy(rdata.coords), torch.from_numpy(rdata.geoinfos), diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py index 68a74cfe7..4dd787d76 100644 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ b/src/weathergen/datasets/tokenizer_forecast.py @@ -42,7 +42,7 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset, - use_normalizer: str, # "source" or "target" + use_normalizer: str, # "source_normalizer" or "target_normalizer" ): init_loggers() token_size = stream_info["token_size"] diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 57783f2eb..36f4ad6f5 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -48,7 +48,7 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset - use_normalizer: str, # "source" or "target" + use_normalizer: str, # "source_normalizer" or "target_normalizer" ): init_loggers() token_size = stream_info["token_size"] From f73adda7e9db032df4dfd787063688442320e475 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 11:35:38 +0200 Subject: [PATCH 07/17] some comments and ruff changes --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- src/weathergen/datasets/tokenizer_forecast.py | 2 +- src/weathergen/datasets/tokenizer_masking.py | 2 +- src/weathergen/datasets/utils.py | 4 ++-- src/weathergen/model/model.py | 1 + 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index bb044eee3..7f597c8b9 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -386,7 +386,7 @@ def __iter__(self): np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) ) (tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = ( - self.tokenizer.batchify_source( + self.tokenizer.batchify_source( stream_info, torch.from_numpy(rdata.coords), torch.from_numpy(rdata.geoinfos), diff --git a/src/weathergen/datasets/tokenizer_forecast.py b/src/weathergen/datasets/tokenizer_forecast.py index 4dd787d76..be1d4103f 100644 --- a/src/weathergen/datasets/tokenizer_forecast.py +++ b/src/weathergen/datasets/tokenizer_forecast.py @@ -42,7 +42,7 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset, - use_normalizer: str, # "source_normalizer" or "target_normalizer" + use_normalizer: str, # "source_normalizer" or "target_normalizer" ): init_loggers() token_size = stream_info["token_size"] diff --git a/src/weathergen/datasets/tokenizer_masking.py b/src/weathergen/datasets/tokenizer_masking.py index 36f4ad6f5..cab667d9e 100644 --- a/src/weathergen/datasets/tokenizer_masking.py +++ b/src/weathergen/datasets/tokenizer_masking.py @@ -48,7 +48,7 @@ def batchify_source( times: np.array, time_win: tuple, normalizer, # dataset - use_normalizer: str, # "source_normalizer" or "target_normalizer" + use_normalizer: str, # "source_normalizer" or "target_normalizer" ): init_loggers() token_size = stream_info["token_size"] diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 9f7d46583..2b367e765 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -715,7 +715,7 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: # precompute index sets for scatter operation after embed offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) - # take offset_base up to last col and append a 0 in the beginning per fstep + # shift the offsets for each fstep by one to the right, add a zero to the beginning the first token starts at 0 zeros_col = torch.zeros( (offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device ) @@ -725,7 +725,7 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: for ib, sb in enumerate(batch): for itype, s in enumerate(sb): for fstep in range(offsets.shape[0]): - if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty + if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty s.target_srclk_idxs_embed[fstep] = torch.cat( [ torch.arange(offset, offset + token_len, dtype=torch.int64) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index c87ee0b2d..a98f98b72 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -665,6 +665,7 @@ def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> for _, sb in enumerate(streams_data): for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): for fstep in range(num_fsteps): + # TODO: KCT: should we actually remove the below check and just return an empty tensor? if not (s.target_srclk_tokens_lens[fstep].sum() == 0): idxs = s.target_srclk_idxs_embed[fstep] idxs_pe = s.target_srclk_idxs_embed_pe[fstep] From cd30eb1250a85f8875d1b94b713361deb850f78f Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 11:36:18 +0200 Subject: [PATCH 08/17] ruff --- src/weathergen/datasets/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 2b367e765..6b13c9aad 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -725,7 +725,7 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: for ib, sb in enumerate(batch): for itype, s in enumerate(sb): for fstep in range(offsets.shape[0]): - if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty + if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty s.target_srclk_idxs_embed[fstep] = torch.cat( [ torch.arange(offset, offset + token_len, dtype=torch.int64) From 5f42d43d923de1734162ccbbfe67fede112c27b9 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 11:38:54 +0200 Subject: [PATCH 09/17] removed run_evaluate --- src/weathergen/run_evaluate.py | 192 --------------------------------- 1 file changed, 192 deletions(-) delete mode 100644 src/weathergen/run_evaluate.py diff --git a/src/weathergen/run_evaluate.py b/src/weathergen/run_evaluate.py deleted file mode 100644 index 9d210f110..000000000 --- a/src/weathergen/run_evaluate.py +++ /dev/null @@ -1,192 +0,0 @@ -# (C) Copyright 2025 WeatherGenerator contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -""" -The entry point for training and inference weathergen-atmo -""" - -import logging -import pdb -import sys -import time -import traceback -from pathlib import Path - -import weathergen.common.config as config -import weathergen.utils.cli as cli -from weathergen.train.trainer import Trainer -from weathergen.utils.logger import init_loggers - - -def inference(): - # By default, arguments from the command line are read. - inference_from_args(sys.argv[1:]) - - -def inference_from_args(argl: list[str]): - """ - Inference function for WeatherGenerator model. - Entry point for calling the inference code from the command line. - - When running integration tests, the arguments are directly provided. - """ - parser = cli.get_inference_parser() - args = parser.parse_args(argl) - - inference_overwrite = dict( - shuffle=False, - start_date_val=args.start_date, - end_date_val=args.end_date, - samples_per_validation=args.samples, - log_validation=args.samples if args.save_samples else 0, - analysis_streams_output=args.analysis_streams_output, - ) - - cli_overwrite = config.from_cli_arglist(args.options) - cf = config.load_config( - args.private_config, - args.from_run_id, - args.epoch, - *args.config, - inference_overwrite, - cli_overwrite, - ) - cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) - - fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" - init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) - - cf.run_history += [(args.from_run_id, cf.istep)] - - trainer = Trainer() - trainer.inference(cf, args.from_run_id, args.epoch) - - -#################################################################################################### -def train_continue() -> None: - """ - Function to continue training for WeatherGenerator model. - Entry point for calling train_continue from the command line. - Configurations are set in the function body. - - Args: - from_run_id (str): Run/model id of pretrained WeatherGenerator model to - continue training. Defaults to None. - Note: All model configurations are set in the function body. - """ - train_continue_from_args(sys.argv[1:]) - - -def train_continue_from_args(argl: list[str]): - parser = cli.get_continue_parser() - args = parser.parse_args(argl) - - init_loggers() - - if args.finetune_forecast: - finetune_overwrite = dict( - training_mode="forecast", - forecast_delta_hrs=0, # 12 - forecast_steps=1, # [j for j in range(1,9) for i in range(4)] - forecast_policy="fixed", # 'sequential_random' # 'fixed' #'sequential' #_random' - forecast_freeze_model=True, - forecast_att_dense_rate=1.0, # 0.25 - fe_num_blocks=8, - fe_num_heads=16, - fe_dropout_rate=0.1, - fe_with_qk_lnorm=True, - lr_start=0.000001, - lr_max=0.00003, - lr_final_decay=0.00003, - lr_final=0.0, - lr_steps_warmup=1024, - lr_steps_cooldown=4096, - lr_policy_warmup="cosine", - lr_policy_decay="linear", - lr_policy_cooldown="linear", - num_epochs=12, # len(cf.forecast_steps) + 4 - istep=0, - ) - else: - finetune_overwrite = dict() - - cli_overwrite = config.from_cli_arglist(args.options) - cf = config.load_config( - args.private_config, - args.from_run_id, - args.epoch, - finetune_overwrite, - *args.config, - cli_overwrite, - ) - cf = config.set_run_id(cf, args.run_id, args.reuse_run_id) - - fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" - init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) - - # track history of run to ensure traceability of results - cf.run_history += [(args.from_run_id, cf.istep)] - - if args.finetune_forecast: - if cf.forecast_freeze_model: - cf.with_fsdp = False - import torch - - torch._dynamo.config.optimize_ddp = False - trainer = Trainer() - trainer.run(cf, args.from_run_id, args.epoch) - - -#################################################################################################### -def train() -> None: - """ - Training function for WeatherGenerator model. - Entry point for calling the training code from the command line. - Configurations are set in the function body. - - Args: - run_id (str, optional): Run/model id of pretrained WeatherGenerator model to - continue training. Defaults to None. - Note: All model configurations are set in the function body. - """ - train_with_args(sys.argv[1:], None) - - -def train_with_args(argl: list[str], stream_dir: str | None): - """ - Training function for WeatherGenerator model.""" - parser = cli.get_train_parser() - args = parser.parse_args(argl) - - cli_overwrite = config.from_cli_arglist(args.options) - - cf = config.load_config(args.private_config, None, None, *args.config, cli_overwrite) - cf = config.set_run_id(cf, args.run_id, False) - - fname_debug_logging = f"./logs/debug_log_{cf.run_id}.txt" - init_loggers(logging_level=logging.DEBUG, debug_output_streams=fname_debug_logging) - - cf.streams = config.load_streams(Path(cf.streams_directory)) - - if cf.with_flash_attention: - assert cf.with_mixed_precision - cf.data_loader_rng_seed = int(time.time()) - - trainer = Trainer(checkpoint_freq=250, print_freq=10) - - try: - trainer.run(cf) - except Exception: - extype, value, tb = sys.exc_info() - traceback.print_exc() - pdb.post_mortem(tb) - - -if __name__ == "__main__": - inference() From 5a38dc70c9a81dfc0dcbe4cce7e1f4f73030b965 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 11:45:18 +0200 Subject: [PATCH 10/17] using != --- src/weathergen/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a98f98b72..67cd744b0 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -666,7 +666,7 @@ def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): for fstep in range(num_fsteps): # TODO: KCT: should we actually remove the below check and just return an empty tensor? - if not (s.target_srclk_tokens_lens[fstep].sum() == 0): + if s.target_srclk_tokens_lens[fstep].sum() != 0: idxs = s.target_srclk_idxs_embed[fstep] idxs_pe = s.target_srclk_idxs_embed_pe[fstep] From 721de0228623a51c5a6f7ac4ad5dd06dc64700bf Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 12:19:04 +0200 Subject: [PATCH 11/17] removed some comments >>>> --- src/weathergen/datasets/stream_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index d50df3910..e71192de8 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -389,7 +389,6 @@ def merge_inputs(self) -> None: self.source_tokens_cells = torch.tensor([]) self.source_centroids = torch.tensor([]) - # >>>>>>> # collect all source like tokens in current stream and add to batch sample list when non-empty for fstep in range(len(self.target_srclk_tokens_cells)): if torch.tensor([len(s) for s in self.target_srclk_tokens_cells[fstep]]).sum() > 0: @@ -418,7 +417,6 @@ def merge_inputs(self) -> None: self.target_srclk_tokens_lens[fstep] = torch.zeros([self.nhc_source]) self.target_srclk_tokens_cells[fstep] = torch.tensor([]) self.target_srclk_centroids[fstep] = torch.tensor([]) - # <<<<< # targets for fstep in range(len(self.target_coords)): From 44f4a11f8d5ba44b0563c25021fab4f04ac97a15 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 13:57:51 +0200 Subject: [PATCH 12/17] renamed everything srclk -> source_like --- .../datasets/multi_stream_data_sampler.py | 20 ++--- src/weathergen/datasets/stream_data.py | 88 +++++++++---------- src/weathergen/datasets/utils.py | 26 +++--- src/weathergen/model/model.py | 28 +++--- 4 files changed, 81 insertions(+), 81 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 7f597c8b9..1a157705a 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -31,7 +31,7 @@ from weathergen.datasets.utils import ( compute_idxs_predict, compute_offsets_scatter_embed, - compute_offsets_scatter_embed_target_srclk, + compute_offsets_scatter_embed_target_source_like, compute_source_cell_lens, ) from weathergen.utils.logger import logger @@ -369,7 +369,7 @@ def __iter__(self): if rdata.is_empty(): stream_data.add_empty_target(fstep) - stream_data.add_empty_target_srclk(fstep) + stream_data.add_empty_target_source_like(fstep) else: (tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target( stream_info, @@ -382,10 +382,10 @@ def __iter__(self): ds, ) - target_raw_srclk = torch.from_numpy( + target_raw_source_like = torch.from_numpy( np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) ) - (tt_cells_srclk, tt_lens_srclk, tt_centroids_srclk) = ( + (tt_cells_source_like, tt_lens_source_like, tt_centroids_source_like) = ( self.tokenizer.batchify_source( stream_info, torch.from_numpy(rdata.coords), @@ -399,12 +399,12 @@ def __iter__(self): ) stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t) - stream_data.add_target_srclk( + stream_data.add_target_source_like( fstep, - target_raw_srclk, - tt_lens_srclk, - tt_cells_srclk, - tt_centroids_srclk, + target_raw_source_like, + tt_lens_source_like, + tt_cells_source_like, + tt_centroids_source_like, ) # merge inputs for sources and targets for current stream @@ -424,7 +424,7 @@ def __iter__(self): # compute offsets for scatter computation after embedding batch = compute_offsets_scatter_embed(batch) - batch = compute_offsets_scatter_embed_target_srclk(batch) + batch = compute_offsets_scatter_embed_target_source_like(batch) # compute offsets and auxiliary data needed for prediction computation # (info is not per stream so separate data structure) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index e71192de8..f1bc687b6 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -67,13 +67,13 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i self.source_idxs_embed_pe = torch.tensor([]) # below are for targets which are tokenized like sources - self.target_srclk_raw = [[] for _ in range(forecast_steps + 1)] - self.target_srclk_tokens_lens = [[] for _ in range(forecast_steps + 1)] - self.target_srclk_tokens_cells = [[] for _ in range(forecast_steps + 1)] - self.target_srclk_centroids = [[] for _ in range(forecast_steps + 1)] + self.target_source_like_raw = [[] for _ in range(forecast_steps + 1)] + self.target_source_like_tokens_lens = [[] for _ in range(forecast_steps + 1)] + self.target_source_like_tokens_cells = [[] for _ in range(forecast_steps + 1)] + self.target_source_like_centroids = [[] for _ in range(forecast_steps + 1)] - self.target_srclk_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)] - self.target_srclk_idxs_embed_pe = [torch.tensor([]) for _ in range(forecast_steps + 1)] + self.target_source_like_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)] + self.target_source_like_idxs_embed_pe = [torch.tensor([]) for _ in range(forecast_steps + 1)] def to_device(self, device="cuda") -> None: """ @@ -100,22 +100,22 @@ def to_device(self, device="cuda") -> None: self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True) self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True) - self.target_srclk_raw = [t.to(device, non_blocking=True) for t in self.target_srclk_raw] - self.target_srclk_tokens_lens = [ - t.to(device, non_blocking=True) for t in self.target_srclk_tokens_lens + self.target_source_like_raw = [t.to(device, non_blocking=True) for t in self.target_source_like_raw] + self.target_source_like_tokens_lens = [ + t.to(device, non_blocking=True) for t in self.target_source_like_tokens_lens ] - self.target_srclk_tokens_cells = [ - t.to(device, non_blocking=True) for t in self.target_srclk_tokens_cells + self.target_source_like_tokens_cells = [ + t.to(device, non_blocking=True) for t in self.target_source_like_tokens_cells ] - self.target_srclk_centroids = [ - t.to(device, non_blocking=True) for t in self.target_srclk_centroids + self.target_source_like_centroids = [ + t.to(device, non_blocking=True) for t in self.target_source_like_centroids ] - self.target_srclk_idxs_embed = [ - t.to(device, non_blocking=True) for t in self.target_srclk_idxs_embed + self.target_source_like_idxs_embed = [ + t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed ] - self.target_srclk_idxs_embed_pe = [ - t.to(device, non_blocking=True) for t in self.target_srclk_idxs_embed_pe + self.target_source_like_idxs_embed_pe = [ + t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed_pe ] return self @@ -138,7 +138,7 @@ def add_empty_source(self, source: IOReaderData) -> None: self.source_tokens_cells += [torch.tensor([])] self.source_centroids += [torch.tensor([])] - def add_empty_target_srclk(self, fstep: int) -> None: + def add_empty_target_source_like(self, fstep: int) -> None: """ Add an empty target for an input encoded like source. Parameters @@ -149,10 +149,10 @@ def add_empty_target_srclk(self, fstep: int) -> None: None """ - self.target_srclk_raw[fstep] += [torch.tensor([])] - self.target_srclk_tokens_lens[fstep] += [torch.zeros([self.nhc_source], dtype=torch.int32)] - self.target_srclk_tokens_cells[fstep] += [torch.tensor([])] - self.target_srclk_centroids[fstep] += [torch.tensor([])] + self.target_source_like_raw[fstep] += [torch.tensor([])] + self.target_source_like_tokens_lens[fstep] += [torch.zeros([self.nhc_source], dtype=torch.int32)] + self.target_source_like_tokens_cells[fstep] += [torch.tensor([])] + self.target_source_like_centroids[fstep] += [torch.tensor([])] def add_empty_target(self, fstep: int) -> None: """ @@ -202,7 +202,7 @@ def add_source( self.source_tokens_cells += [ss_cells] self.source_centroids += [ss_centroids] - def add_target_srclk( + def add_target_source_like( self, fstep: int, tt_raw: torch.tensor, @@ -225,10 +225,10 @@ def add_target_srclk( None """ - self.target_srclk_raw[fstep] += [tt_raw] - self.target_srclk_tokens_lens[fstep] += [tt_lens] - self.target_srclk_tokens_cells[fstep] += [tt_cells] - self.target_srclk_centroids[fstep] += [tt_centroids] + self.target_source_like_raw[fstep] += [tt_raw] + self.target_source_like_tokens_lens[fstep] += [tt_lens] + self.target_source_like_tokens_cells[fstep] += [tt_cells] + self.target_source_like_centroids[fstep] += [tt_centroids] def add_target( self, @@ -390,33 +390,33 @@ def merge_inputs(self) -> None: self.source_centroids = torch.tensor([]) # collect all source like tokens in current stream and add to batch sample list when non-empty - for fstep in range(len(self.target_srclk_tokens_cells)): - if torch.tensor([len(s) for s in self.target_srclk_tokens_cells[fstep]]).sum() > 0: - self.target_srclk_raw[fstep] = torch.cat(self.target_srclk_raw[fstep]) + for fstep in range(len(self.target_source_like_tokens_cells)): + if torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum() > 0: + self.target_source_like_raw[fstep] = torch.cat(self.target_source_like_raw[fstep]) # collect by merging entries per cells, preserving cell structure - self.target_srclk_tokens_cells[fstep] = self._merge_cells( - self.target_srclk_tokens_cells[fstep], self.nhc_source + self.target_source_like_tokens_cells[fstep] = self._merge_cells( + self.target_source_like_tokens_cells[fstep], self.nhc_source ) - self.target_srclk_centroids[fstep] = self._merge_cells( - self.target_srclk_centroids[fstep], self.nhc_source + self.target_source_like_centroids[fstep] = self._merge_cells( + self.target_source_like_centroids[fstep], self.nhc_source ) # lens can be stacked and summed - self.target_srclk_tokens_lens[fstep] = torch.stack( - self.target_srclk_tokens_lens[fstep] + self.target_source_like_tokens_lens[fstep] = torch.stack( + self.target_source_like_tokens_lens[fstep] ).sum(0) # remove NaNs - idx = torch.isnan(self.target_srclk_tokens_cells[fstep]) - self.target_srclk_tokens_cells[fstep][idx] = self.mask_value - idx = torch.isnan(self.target_srclk_centroids[fstep]) - self.target_srclk_centroids[fstep][idx] = self.mask_value + idx = torch.isnan(self.target_source_like_tokens_cells[fstep]) + self.target_source_like_tokens_cells[fstep][idx] = self.mask_value + idx = torch.isnan(self.target_source_like_centroids[fstep]) + self.target_source_like_centroids[fstep][idx] = self.mask_value else: - self.target_srclk_raw[fstep] = torch.tensor([]) - self.target_srclk_tokens_lens[fstep] = torch.zeros([self.nhc_source]) - self.target_srclk_tokens_cells[fstep] = torch.tensor([]) - self.target_srclk_centroids[fstep] = torch.tensor([]) + self.target_source_like_raw[fstep] = torch.tensor([]) + self.target_source_like_tokens_lens[fstep] = torch.zeros([self.nhc_source]) + self.target_source_like_tokens_cells[fstep] = torch.tensor([]) + self.target_source_like_centroids[fstep] = torch.tensor([]) # targets for fstep in range(len(self.target_coords)): diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index 6b13c9aad..cf542eb44 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -677,7 +677,7 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData: return batch -def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: +def compute_offsets_scatter_embed_target_source_like(batch: StreamData) -> StreamData: """ Compute auxiliary information for scatter operation that changes from stream-centric to cell-centric computations @@ -694,16 +694,16 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: """ # collect source_tokens_lens for all stream datas - target_srclk_tokens_lens = torch.stack( + target_source_like_tokens_lens = torch.stack( [ torch.stack( [ torch.stack( [ - s.target_srclk_tokens_lens[fstep] - if len(s.target_srclk_tokens_lens[fstep]) > 0 + s.target_source_like_tokens_lens[fstep] + if len(s.target_source_like_tokens_lens[fstep]) > 0 else torch.tensor([]) - for fstep in range(len(s.target_srclk_tokens_lens)) + for fstep in range(len(s.target_source_like_tokens_lens)) ] ) for s in stl_b @@ -714,7 +714,7 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: ) # precompute index sets for scatter operation after embed - offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) + offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) # shift the offsets for each fstep by one to the right, add a zero to the beginning the first token starts at 0 zeros_col = torch.zeros( (offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device @@ -725,31 +725,31 @@ def compute_offsets_scatter_embed_target_srclk(batch: StreamData) -> StreamData: for ib, sb in enumerate(batch): for itype, s in enumerate(sb): for fstep in range(offsets.shape[0]): - if not (target_srclk_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty - s.target_srclk_idxs_embed[fstep] = torch.cat( + if not (target_source_like_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty + s.target_source_like_idxs_embed[fstep] = torch.cat( [ torch.arange(offset, offset + token_len, dtype=torch.int64) for offset, token_len in zip( offsets[fstep], - target_srclk_tokens_lens[ib, itype, fstep], + target_source_like_tokens_lens[ib, itype, fstep], strict=False, ) ] ) - s.target_srclk_idxs_embed_pe[fstep] = torch.cat( + s.target_source_like_idxs_embed_pe[fstep] = torch.cat( [ torch.arange(offset, offset + token_len, dtype=torch.int32) for offset, token_len in zip( offsets_pe[fstep], - target_srclk_tokens_lens[ib][itype][fstep], + target_source_like_tokens_lens[ib][itype][fstep], strict=False, ) ] ) # advance offsets - offsets[fstep] += target_srclk_tokens_lens[ib][itype][fstep] - offsets_pe[fstep] += target_srclk_tokens_lens[ib][itype][fstep] + offsets[fstep] += target_source_like_tokens_lens[ib][itype][fstep] + offsets_pe[fstep] += target_source_like_tokens_lens[ib][itype][fstep] return batch diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 67cd744b0..67180ea7d 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -551,10 +551,10 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca if self.cf.get("encode_targets_latent", False): with torch.no_grad(): tokens_targets = [] - tokens_targets_srclk = self.embed_cells_targets_srclk(model_params, streams_data) - for fstep in range(len(tokens_targets_srclk)): + tokens_targets_source_like = self.embed_cells_targets_source_like(model_params, streams_data) + for fstep in range(len(tokens_targets_source_like)): tokens_target, _ = self.assimilate_local( - model_params, tokens_targets_srclk[fstep], source_cell_lens + model_params, tokens_targets_source_like[fstep], source_cell_lens ) tokens_target = self.assimilate_global(model_params, tokens_target) tokens_target_det = tokens_target.detach() # explicitly detach as well @@ -620,7 +620,7 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: return tokens_all - def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> torch.Tensor: + def embed_cells_targets_source_like(self, model_params: ModelParams, streams_data) -> torch.Tensor: """Embeds target data similar to source tokens for each fstep and stream separately and rearranges it to cell-wise order Args: model_params : Query and embedding parameters @@ -629,16 +629,16 @@ def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> Tokens for local assimilation """ with torch.no_grad(): - target_srclk_tokens_lens = torch.stack( + target_source_like_tokens_lens = torch.stack( [ torch.stack( [ torch.stack( [ - s.target_srclk_tokens_lens[fstep] - if len(s.target_srclk_tokens_lens[fstep]) > 0 + s.target_source_like_tokens_lens[fstep] + if len(s.target_source_like_tokens_lens[fstep]) > 0 else torch.tensor([]) - for fstep in range(len(s.target_srclk_tokens_lens)) + for fstep in range(len(s.target_source_like_tokens_lens)) ] ) for s in stl_b @@ -647,8 +647,8 @@ def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> for stl_b in streams_data ] ) - offsets_base = target_srclk_tokens_lens.sum(1).sum(0).cumsum(1) - num_fsteps = target_srclk_tokens_lens.shape[ + offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) + num_fsteps = target_source_like_tokens_lens.shape[ 2 ] # TODO: KCT, if there are diff no of tokens per fstep, this may fail tokens_all = [] @@ -666,16 +666,16 @@ def embed_cells_targets_srclk(self, model_params: ModelParams, streams_data) -> for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): for fstep in range(num_fsteps): # TODO: KCT: should we actually remove the below check and just return an empty tensor? - if s.target_srclk_tokens_lens[fstep].sum() != 0: - idxs = s.target_srclk_idxs_embed[fstep] - idxs_pe = s.target_srclk_idxs_embed_pe[fstep] + if s.target_source_like_tokens_lens[fstep].sum() != 0: + idxs = s.target_source_like_idxs_embed[fstep] + idxs_pe = s.target_source_like_idxs_embed_pe[fstep] # create full scatter index # (there's no broadcasting which is likely highly inefficient) idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) x_embed = embed( - s.target_srclk_tokens_cells[fstep], s.target_srclk_centroids[fstep] + s.target_source_like_tokens_cells[fstep], s.target_source_like_centroids[fstep] ).flatten(0, 1) # scatter write to reorder from per stream to per cell ordering From b16d6ae4a895865fd6c2f95fb5038ba462c55ccc Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 14:04:22 +0200 Subject: [PATCH 13/17] addressed ruff errors --- src/weathergen/datasets/stream_data.py | 3 ++- src/weathergen/datasets/utils.py | 5 +++-- src/weathergen/model/model.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index f1bc687b6..0dc44fb00 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -389,7 +389,8 @@ def merge_inputs(self) -> None: self.source_tokens_cells = torch.tensor([]) self.source_centroids = torch.tensor([]) - # collect all source like tokens in current stream and add to batch sample list when non-empty + # collect all source like tokens in current stream and add to + # batch sample list when non-empty for fstep in range(len(self.target_source_like_tokens_cells)): if torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum() > 0: self.target_source_like_raw[fstep] = torch.cat(self.target_source_like_raw[fstep]) diff --git a/src/weathergen/datasets/utils.py b/src/weathergen/datasets/utils.py index cf542eb44..1f9760c49 100644 --- a/src/weathergen/datasets/utils.py +++ b/src/weathergen/datasets/utils.py @@ -715,7 +715,8 @@ def compute_offsets_scatter_embed_target_source_like(batch: StreamData) -> Strea # precompute index sets for scatter operation after embed offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) - # shift the offsets for each fstep by one to the right, add a zero to the beginning the first token starts at 0 + # shift the offsets for each fstep by one to the right, add a zero to the + # beginning as the first token starts at 0 zeros_col = torch.zeros( (offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device ) @@ -725,7 +726,7 @@ def compute_offsets_scatter_embed_target_source_like(batch: StreamData) -> Strea for ib, sb in enumerate(batch): for itype, s in enumerate(sb): for fstep in range(offsets.shape[0]): - if not (target_source_like_tokens_lens[ib, itype, fstep].sum() == 0): # if not empty + if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty s.target_source_like_idxs_embed[fstep] = torch.cat( [ torch.arange(offset, offset + token_len, dtype=torch.int64) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 67180ea7d..549d1ce69 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -621,7 +621,8 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: return tokens_all def embed_cells_targets_source_like(self, model_params: ModelParams, streams_data) -> torch.Tensor: - """Embeds target data similar to source tokens for each fstep and stream separately and rearranges it to cell-wise order + """Embeds target data similar to source tokens for each fstep and stream separately and + rearranges it to cell-wise order Args: model_params : Query and embedding parameters streams_data : Used to initialize first tokens for pre-processing @@ -665,7 +666,8 @@ def embed_cells_targets_source_like(self, model_params: ModelParams, streams_dat for _, sb in enumerate(streams_data): for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): for fstep in range(num_fsteps): - # TODO: KCT: should we actually remove the below check and just return an empty tensor? + # TODO: KCT: should we actually remove the below check and just + # return an empty tensor? if s.target_source_like_tokens_lens[fstep].sum() != 0: idxs = s.target_source_like_idxs_embed[fstep] idxs_pe = s.target_source_like_idxs_embed_pe[fstep] From 4304352b83112ecc6190468bd29b85f1c47d4b54 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 14:04:53 +0200 Subject: [PATCH 14/17] ruff --- .../datasets/multi_stream_data_sampler.py | 24 ++++++++++--------- src/weathergen/datasets/stream_data.py | 17 +++++++++---- src/weathergen/model/model.py | 11 ++++++--- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 1a157705a..a996806e7 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -385,17 +385,19 @@ def __iter__(self): target_raw_source_like = torch.from_numpy( np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1) ) - (tt_cells_source_like, tt_lens_source_like, tt_centroids_source_like) = ( - self.tokenizer.batchify_source( - stream_info, - torch.from_numpy(rdata.coords), - torch.from_numpy(rdata.geoinfos), - torch.from_numpy(rdata.data), - rdata.datetimes, - (time_win2.start, time_win2.end), - ds, - "target_normalizer", - ) + ( + tt_cells_source_like, + tt_lens_source_like, + tt_centroids_source_like, + ) = self.tokenizer.batchify_source( + stream_info, + torch.from_numpy(rdata.coords), + torch.from_numpy(rdata.geoinfos), + torch.from_numpy(rdata.data), + rdata.datetimes, + (time_win2.start, time_win2.end), + ds, + "target_normalizer", ) stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t) diff --git a/src/weathergen/datasets/stream_data.py b/src/weathergen/datasets/stream_data.py index 0dc44fb00..e096a230b 100644 --- a/src/weathergen/datasets/stream_data.py +++ b/src/weathergen/datasets/stream_data.py @@ -73,7 +73,9 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i self.target_source_like_centroids = [[] for _ in range(forecast_steps + 1)] self.target_source_like_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)] - self.target_source_like_idxs_embed_pe = [torch.tensor([]) for _ in range(forecast_steps + 1)] + self.target_source_like_idxs_embed_pe = [ + torch.tensor([]) for _ in range(forecast_steps + 1) + ] def to_device(self, device="cuda") -> None: """ @@ -100,7 +102,9 @@ def to_device(self, device="cuda") -> None: self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True) self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True) - self.target_source_like_raw = [t.to(device, non_blocking=True) for t in self.target_source_like_raw] + self.target_source_like_raw = [ + t.to(device, non_blocking=True) for t in self.target_source_like_raw + ] self.target_source_like_tokens_lens = [ t.to(device, non_blocking=True) for t in self.target_source_like_tokens_lens ] @@ -150,7 +154,9 @@ def add_empty_target_source_like(self, fstep: int) -> None: """ self.target_source_like_raw[fstep] += [torch.tensor([])] - self.target_source_like_tokens_lens[fstep] += [torch.zeros([self.nhc_source], dtype=torch.int32)] + self.target_source_like_tokens_lens[fstep] += [ + torch.zeros([self.nhc_source], dtype=torch.int32) + ] self.target_source_like_tokens_cells[fstep] += [torch.tensor([])] self.target_source_like_centroids[fstep] += [torch.tensor([])] @@ -392,7 +398,10 @@ def merge_inputs(self) -> None: # collect all source like tokens in current stream and add to # batch sample list when non-empty for fstep in range(len(self.target_source_like_tokens_cells)): - if torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum() > 0: + if ( + torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum() + > 0 + ): self.target_source_like_raw[fstep] = torch.cat(self.target_source_like_raw[fstep]) # collect by merging entries per cells, preserving cell structure diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 549d1ce69..ea423fe4e 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -551,7 +551,9 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca if self.cf.get("encode_targets_latent", False): with torch.no_grad(): tokens_targets = [] - tokens_targets_source_like = self.embed_cells_targets_source_like(model_params, streams_data) + tokens_targets_source_like = self.embed_cells_targets_source_like( + model_params, streams_data + ) for fstep in range(len(tokens_targets_source_like)): tokens_target, _ = self.assimilate_local( model_params, tokens_targets_source_like[fstep], source_cell_lens @@ -620,7 +622,9 @@ def embed_cells(self, model_params: ModelParams, streams_data) -> torch.Tensor: return tokens_all - def embed_cells_targets_source_like(self, model_params: ModelParams, streams_data) -> torch.Tensor: + def embed_cells_targets_source_like( + self, model_params: ModelParams, streams_data + ) -> torch.Tensor: """Embeds target data similar to source tokens for each fstep and stream separately and rearranges it to cell-wise order Args: @@ -677,7 +681,8 @@ def embed_cells_targets_source_like(self, model_params: ModelParams, streams_dat idxs = idxs.unsqueeze(1).repeat((1, self.cf.ae_local_dim_embed)) x_embed = embed( - s.target_source_like_tokens_cells[fstep], s.target_source_like_centroids[fstep] + s.target_source_like_tokens_cells[fstep], + s.target_source_like_centroids[fstep], ).flatten(0, 1) # scatter write to reorder from per stream to per cell ordering From 97ffb58995823b8af4a8d531a95b9294975eb985 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Mon, 29 Sep 2025 14:12:51 +0200 Subject: [PATCH 15/17] removed some KCt comments --- src/weathergen/model/engines.py | 2 +- src/weathergen/model/model.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 29bae5806..794468db1 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -40,7 +40,7 @@ def __init__(self, cf: Config, sources_size) -> None: :param sources_size: List of source sizes for each stream. """ self.cf = cf - self.sources_size = sources_size # KCT:iss130, what is this? + self.sources_size = sources_size self.embeds = torch.nn.ModuleList() def create(self) -> torch.nn.ModuleList: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ea423fe4e..d46c528cd 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -655,7 +655,7 @@ def embed_cells_targets_source_like( offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) num_fsteps = target_source_like_tokens_lens.shape[ 2 - ] # TODO: KCT, if there are diff no of tokens per fstep, this may fail + ] tokens_all = [] for fstep in range(num_fsteps): tokens_all.append( @@ -670,8 +670,6 @@ def embed_cells_targets_source_like( for _, sb in enumerate(streams_data): for _, (s, embed) in enumerate(zip(sb, self.embeds, strict=False)): for fstep in range(num_fsteps): - # TODO: KCT: should we actually remove the below check and just - # return an empty tensor? if s.target_source_like_tokens_lens[fstep].sum() != 0: idxs = s.target_source_like_idxs_embed[fstep] idxs_pe = s.target_source_like_idxs_embed_pe[fstep] From 0946c0019daa6dba604521958d374d1e6efdf9d9 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 30 Sep 2025 15:45:39 +0200 Subject: [PATCH 16/17] added empty tensor for mising timestep --- src/weathergen/model/model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index d46c528cd..c221f5d8e 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -555,12 +555,16 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca model_params, streams_data ) for fstep in range(len(tokens_targets_source_like)): - tokens_target, _ = self.assimilate_local( - model_params, tokens_targets_source_like[fstep], source_cell_lens - ) - tokens_target = self.assimilate_global(model_params, tokens_target) - tokens_target_det = tokens_target.detach() # explicitly detach as well - tokens_targets.append(tokens_target_det) + if tokens_targets_source_like[fstep].sum()==0: + # if the input is empty, return an empty tensor + tokens_targets.append(torch.tensor([]).detach()) + else: + tokens_target, _ = self.assimilate_local( + model_params, tokens_targets_source_like[fstep], source_cell_lens + ) + tokens_target = self.assimilate_global(model_params, tokens_target) + tokens_target_det = tokens_target.detach() # explicitly detach as well + tokens_targets.append(tokens_target_det) return_dict = {"preds_all": preds_all, "posteriors": posteriors} if self.cf.get("encode_targets_latent", False): @@ -689,6 +693,8 @@ def embed_cells_targets_source_like( 0, idxs, x_embed + model_params.pe_embed[idxs_pe] ) tokens_all_scattered.append(tokens_all_fstep) + else: + tokens_all_scattered.append(torch.tensor([])) return tokens_all_scattered From b5f8ca3dffce84638d8c9791f9a4935b70646cf4 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 30 Sep 2025 16:16:09 +0200 Subject: [PATCH 17/17] ruff --- src/weathergen/model/engines.py | 2 +- src/weathergen/model/model.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 794468db1..1bb2d6830 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -40,7 +40,7 @@ def __init__(self, cf: Config, sources_size) -> None: :param sources_size: List of source sizes for each stream. """ self.cf = cf - self.sources_size = sources_size + self.sources_size = sources_size self.embeds = torch.nn.ModuleList() def create(self) -> torch.nn.ModuleList: diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index c221f5d8e..5f068e9d9 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -555,10 +555,10 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca model_params, streams_data ) for fstep in range(len(tokens_targets_source_like)): - if tokens_targets_source_like[fstep].sum()==0: + if tokens_targets_source_like[fstep].sum() == 0: # if the input is empty, return an empty tensor tokens_targets.append(torch.tensor([]).detach()) - else: + else: tokens_target, _ = self.assimilate_local( model_params, tokens_targets_source_like[fstep], source_cell_lens ) @@ -657,9 +657,7 @@ def embed_cells_targets_source_like( ] ) offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1) - num_fsteps = target_source_like_tokens_lens.shape[ - 2 - ] + num_fsteps = target_source_like_tokens_lens.shape[2] tokens_all = [] for fstep in range(num_fsteps): tokens_all.append(