Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1ab20ce
initial changes
kctezcan Sep 25, 2025
7a74aaa
clean up debug statements
kctezcan Sep 25, 2025
c4129a3
reading model output as dictionary
kctezcan Sep 25, 2025
4c2d12f
added the config parameter with false as def
kctezcan Sep 25, 2025
029d0ad
ruff
kctezcan Sep 25, 2025
6f61c16
.._normlaizer and removed one KCT comment
kctezcan Sep 29, 2025
2f4197e
removed a KCT and corrected some ..._normalizer comments
kctezcan Sep 29, 2025
f73adda
some comments and ruff changes
kctezcan Sep 29, 2025
cd30eb1
ruff
kctezcan Sep 29, 2025
5f42d43
removed run_evaluate
kctezcan Sep 29, 2025
5a38dc7
using !=
kctezcan Sep 29, 2025
d3976fb
merged changes from the upper branch
kctezcan Sep 29, 2025
721de02
removed some comments >>>>
kctezcan Sep 29, 2025
44f4a11
renamed everything srclk -> source_like
kctezcan Sep 29, 2025
b16d6ae
addressed ruff errors
kctezcan Sep 29, 2025
4304352
ruff
kctezcan Sep 29, 2025
97ffb58
removed some KCt comments
kctezcan Sep 29, 2025
c8fc70d
Merge branch 'ktezcan/dev/iss941_encode_targets' into ktezcan/dev/iss…
kctezcan Sep 29, 2025
f432e30
implemented a common embed_cells() for both source and targets
kctezcan Sep 30, 2025
88ffff5
Merge branch 'develop' into ktezcan/dev/iss941_encode_targets_sepfstep
kctezcan Sep 30, 2025
a367dc6
appending an empty tensor for target even if no tokens to embed
kctezcan Sep 30, 2025
c78c3c5
appending empty tensor for offsetted forecast steps
kctezcan Sep 30, 2025
32b5970
ruff
kctezcan Sep 30, 2025
d76ea22
Merge branch 'develop' into ktezcan/dev/iss941_encode_targets_sepfstep
kctezcan Oct 1, 2025
036ab9e
Merge branch 'develop' into ktezcan/dev/iss941_encode_targets_sepfstep
kctezcan Oct 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True
encode_targets_latent: False

loss_fcts:
-
Expand Down
29 changes: 29 additions & 0 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from weathergen.datasets.utils import (
compute_idxs_predict,
compute_offsets_scatter_embed,
compute_offsets_scatter_embed_target_source_like,
compute_source_cell_lens,
)
from weathergen.utils.logger import logger
Expand Down Expand Up @@ -348,6 +349,7 @@ def __iter__(self):
rdata.datetimes,
(time_win1.start, time_win1.end),
ds,
"source_normalizer",
)

stream_data.add_source(rdata_wrapped, ss_lens, ss_cells, ss_centroids)
Expand All @@ -367,6 +369,7 @@ def __iter__(self):

if rdata.is_empty():
stream_data.add_empty_target(fstep)
stream_data.add_empty_target_source_like(fstep)
else:
(tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target(
stream_info,
Expand All @@ -379,7 +382,32 @@ def __iter__(self):
ds,
)

target_raw_source_like = torch.from_numpy(
np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1)
)
(
tt_cells_source_like,
tt_lens_source_like,
tt_centroids_source_like,
) = self.tokenizer.batchify_source(
stream_info,
torch.from_numpy(rdata.coords),
torch.from_numpy(rdata.geoinfos),
torch.from_numpy(rdata.data),
rdata.datetimes,
(time_win2.start, time_win2.end),
ds,
"target_normalizer",
)

stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t)
stream_data.add_target_source_like(
fstep,
target_raw_source_like,
tt_lens_source_like,
tt_cells_source_like,
tt_centroids_source_like,
)

# merge inputs for sources and targets for current stream
stream_data.merge_inputs()
Expand All @@ -398,6 +426,7 @@ def __iter__(self):

# compute offsets for scatter computation after embedding
batch = compute_offsets_scatter_embed(batch)
batch = compute_offsets_scatter_embed_target_source_like(batch)

# compute offsets and auxiliary data needed for prediction computation
# (info is not per stream so separate data structure)
Expand Down
110 changes: 110 additions & 0 deletions src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i
self.source_idxs_embed = torch.tensor([])
self.source_idxs_embed_pe = torch.tensor([])

# below are for targets which are tokenized like sources
self.target_source_like_raw = [[] for _ in range(forecast_steps + 1)]
self.target_source_like_tokens_lens = [[] for _ in range(forecast_steps + 1)]
self.target_source_like_tokens_cells = [[] for _ in range(forecast_steps + 1)]
self.target_source_like_centroids = [[] for _ in range(forecast_steps + 1)]

self.target_source_like_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)]
self.target_source_like_idxs_embed_pe = [
torch.tensor([]) for _ in range(forecast_steps + 1)
]

def to_device(self, device="cuda") -> None:
"""
Move data to GPU
Expand All @@ -91,6 +102,26 @@ def to_device(self, device="cuda") -> None:
self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True)
self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True)

self.target_source_like_raw = [
t.to(device, non_blocking=True) for t in self.target_source_like_raw
]
self.target_source_like_tokens_lens = [
t.to(device, non_blocking=True) for t in self.target_source_like_tokens_lens
]
self.target_source_like_tokens_cells = [
t.to(device, non_blocking=True) for t in self.target_source_like_tokens_cells
]
self.target_source_like_centroids = [
t.to(device, non_blocking=True) for t in self.target_source_like_centroids
]

self.target_source_like_idxs_embed = [
t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed
]
self.target_source_like_idxs_embed_pe = [
t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed_pe
]

return self

def add_empty_source(self, source: IOReaderData) -> None:
Expand All @@ -111,6 +142,24 @@ def add_empty_source(self, source: IOReaderData) -> None:
self.source_tokens_cells += [torch.tensor([])]
self.source_centroids += [torch.tensor([])]

def add_empty_target_source_like(self, fstep: int) -> None:
"""
Add an empty target for an input encoded like source.
Parameters
----------
None
Returns
-------
None
"""

self.target_source_like_raw[fstep] += [torch.tensor([])]
self.target_source_like_tokens_lens[fstep] += [
torch.zeros([self.nhc_source], dtype=torch.int32)
]
self.target_source_like_tokens_cells[fstep] += [torch.tensor([])]
self.target_source_like_centroids[fstep] += [torch.tensor([])]

def add_empty_target(self, fstep: int) -> None:
"""
Add an empty target for an input.
Expand Down Expand Up @@ -159,6 +208,34 @@ def add_source(
self.source_tokens_cells += [ss_cells]
self.source_centroids += [ss_centroids]

def add_target_source_like(
self,
fstep: int,
tt_raw: torch.tensor,
tt_lens: torch.tensor,
tt_cells: list,
tt_centroids: list,
) -> None:
"""
Add data for source for one input.
Parameters
----------
ss_raw : torch.tensor( number of data points in time window , number of channels )
ss_lens : torch.tensor( number of healpix cells )
ss_cells : list( number of healpix cells )
[ torch.tensor( tokens per cell, token size, number of channels) ]
ss_centroids : list(number of healpix cells )
[ torch.tensor( for source , 5) ]
Returns
-------
None
"""

self.target_source_like_raw[fstep] += [tt_raw]
self.target_source_like_tokens_lens[fstep] += [tt_lens]
self.target_source_like_tokens_cells[fstep] += [tt_cells]
self.target_source_like_centroids[fstep] += [tt_centroids]

def add_target(
self,
fstep: int,
Expand Down Expand Up @@ -318,6 +395,39 @@ def merge_inputs(self) -> None:
self.source_tokens_cells = torch.tensor([])
self.source_centroids = torch.tensor([])

# collect all source like tokens in current stream and add to
# batch sample list when non-empty
for fstep in range(len(self.target_source_like_tokens_cells)):
if (
torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace

if (
   torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum()
   > 0
):

with

if any(len(s) > 0 for s in self.target_source_like_tokens_cells[fstep]):

for slightly better efficiency.

Maybe you can find a way to replace len(s) with a way to do the check in constant time without having to write multiple lines of code.

> 0
):
self.target_source_like_raw[fstep] = torch.cat(self.target_source_like_raw[fstep])

# collect by merging entries per cells, preserving cell structure
self.target_source_like_tokens_cells[fstep] = self._merge_cells(
self.target_source_like_tokens_cells[fstep], self.nhc_source
)
self.target_source_like_centroids[fstep] = self._merge_cells(
self.target_source_like_centroids[fstep], self.nhc_source
)
# lens can be stacked and summed
self.target_source_like_tokens_lens[fstep] = torch.stack(
self.target_source_like_tokens_lens[fstep]
).sum(0)

# remove NaNs
idx = torch.isnan(self.target_source_like_tokens_cells[fstep])
self.target_source_like_tokens_cells[fstep][idx] = self.mask_value
idx = torch.isnan(self.target_source_like_centroids[fstep])
self.target_source_like_centroids[fstep][idx] = self.mask_value

else:
self.target_source_like_raw[fstep] = torch.tensor([])
self.target_source_like_tokens_lens[fstep] = torch.zeros([self.nhc_source])
self.target_source_like_tokens_cells[fstep] = torch.tensor([])
self.target_source_like_centroids[fstep] = torch.tensor([])

# targets
for fstep in range(len(self.target_coords)):
# collect all targets in current stream and add to batch sample list when non-empty
Expand Down
11 changes: 9 additions & 2 deletions src/weathergen/datasets/tokenizer_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,20 @@ def batchify_source(
source: np.array,
times: np.array,
time_win: tuple,
normalizer, # dataset
normalizer, # dataset,
use_normalizer: str, # "source_normalizer" or "target_normalizer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename use_normalizer to channel_to_normalize. Even though the type and possible values are clearly documented use_normalizer indicates for a boolean value.
Another option is to rename normalizer to normaliser_datasetor normaliser_dsso you can use normalizer instead of use_normalizer

):
init_loggers()
token_size = stream_info["token_size"]
is_diagnostic = stream_info.get("diagnostic", False)
tokenize_spacetime = stream_info.get("tokenize_spacetime", False)

channel_normalizer = (
normalizer.normalize_source_channels
if use_normalizer == "source_normalizer"
else normalizer.normalize_target_channels
)

tokenize_window = partial(
tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space,
time_win=time_win,
Expand All @@ -56,7 +63,7 @@ def batchify_source(
hpy_verts_rots=self.hpy_verts_rots_source[-1],
n_coords=normalizer.normalize_coords,
n_geoinfos=normalizer.normalize_geoinfos,
n_data=normalizer.normalize_source_channels,
n_data=channel_normalizer,
enc_time=encode_times_source,
)

Expand Down
9 changes: 8 additions & 1 deletion src/weathergen/datasets/tokenizer_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@ def batchify_source(
times: np.array,
time_win: tuple,
normalizer, # dataset
use_normalizer: str, # "source_normalizer" or "target_normalizer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename use_tokenizer as you did in tokeniser_forecast.py(see first comment)

):
init_loggers()
token_size = stream_info["token_size"]
is_diagnostic = stream_info.get("diagnostic", False)
tokenize_spacetime = stream_info.get("tokenize_spacetime", False)

channel_normalizer = (
normalizer.normalize_source_channels
if use_normalizer == "source_normalizer"
else normalizer.normalize_target_channels
)

tokenize_window = partial(
tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space,
time_win=time_win,
Expand All @@ -62,7 +69,7 @@ def batchify_source(
hpy_verts_rots=self.hpy_verts_rots_source[-1],
n_coords=normalizer.normalize_coords,
n_geoinfos=normalizer.normalize_geoinfos,
n_data=normalizer.normalize_source_channels,
n_data=channel_normalizer,
enc_time=encode_times_source,
)

Expand Down
78 changes: 78 additions & 0 deletions src/weathergen/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,84 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData:
return batch


def compute_offsets_scatter_embed_target_source_like(batch: StreamData) -> StreamData:
"""
Compute auxiliary information for scatter operation that changes from stream-centric to
cell-centric computations

Parameters
----------
batch : str
batch of stream data information for which offsets have to be computed

Returns
-------
StreamData
stream data with offsets added as members
"""

# collect source_tokens_lens for all stream datas
target_source_like_tokens_lens = torch.stack(
[
torch.stack(
[
torch.stack(
[
s.target_source_like_tokens_lens[fstep]
if len(s.target_source_like_tokens_lens[fstep]) > 0
else torch.tensor([])
for fstep in range(len(s.target_source_like_tokens_lens))
]
)
for s in stl_b
]
)
for stl_b in batch
]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use less lines, because it looks more complex than it actually is.

target_source_like_tokens_lens = torch.stack([
     torch.stack([
          torch.stack([
             s.target_source_like_tokens_lens[fstep]
             if len(s.target_source_like_tokens_lens[fstep]) > 0
             else torch.tensor([])
             for fstep in range(len(s.target_source_like_tokens_lens))
          ]) for s in stl_b
       ]) for stl_b in batch
    ])

If this was caused by ruff then just forget about this comment...


# precompute index sets for scatter operation after embed
offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1)
# shift the offsets for each fstep by one to the right, add a zero to the
# beginning as the first token starts at 0
zeros_col = torch.zeros(
(offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device
)
offsets = torch.cat([zeros_col, offsets_base[:, :-1]], dim=1)
offsets_pe = torch.zeros_like(offsets)

for ib, sb in enumerate(batch):
for itype, s in enumerate(sb):
for fstep in range(offsets.shape[0]):
if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace with if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: with if any(target_source_like_tokens_lens[ib, type, fstep]): for better efficiency.

s.target_source_like_idxs_embed[fstep] = torch.cat(
[
torch.arange(offset, offset + token_len, dtype=torch.int64)
for offset, token_len in zip(
offsets[fstep],
target_source_like_tokens_lens[ib, itype, fstep],
strict=False,
)
]
)
s.target_source_like_idxs_embed_pe[fstep] = torch.cat(
[
torch.arange(offset, offset + token_len, dtype=torch.int32)
for offset, token_len in zip(
offsets_pe[fstep],
target_source_like_tokens_lens[ib][itype][fstep],
strict=False,
)
]
)

# advance offsets
offsets[fstep] += target_source_like_tokens_lens[ib][itype][fstep]
offsets_pe[fstep] += target_source_like_tokens_lens[ib][itype][fstep]

return batch


def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list:
"""
Compute auxiliary information for prediction
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, cf: Config, sources_size) -> None:
:param sources_size: List of source sizes for each stream.
"""
self.cf = cf
self.sources_size = sources_size # KCT:iss130, what is this?
self.sources_size = sources_size
self.embeds = torch.nn.ModuleList()

def create(self) -> torch.nn.ModuleList:
Expand Down
Loading
Loading