Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ latent_noise_gamma: 2.0
latent_noise_saturate_encodings: 5
latent_noise_use_additive_noise: False
latent_noise_deterministic_latents: True
encode_targets_latent: False

loss_fcts:
-
Expand Down
29 changes: 29 additions & 0 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from weathergen.datasets.utils import (
compute_idxs_predict,
compute_offsets_scatter_embed,
compute_offsets_scatter_embed_target_source_like,
compute_source_cell_lens,
)
from weathergen.utils.logger import logger
Expand Down Expand Up @@ -348,6 +349,7 @@ def __iter__(self):
rdata.datetimes,
(time_win1.start, time_win1.end),
ds,
"source_normalizer",
)

stream_data.add_source(rdata_wrapped, ss_lens, ss_cells, ss_centroids)
Expand All @@ -367,6 +369,7 @@ def __iter__(self):

if rdata.is_empty():
stream_data.add_empty_target(fstep)
stream_data.add_empty_target_source_like(fstep)
else:
(tt_cells, tc, tt_c, tt_t) = self.tokenizer.batchify_target(
stream_info,
Expand All @@ -379,7 +382,32 @@ def __iter__(self):
ds,
)

target_raw_source_like = torch.from_numpy(
np.concatenate((rdata.coords, rdata.geoinfos, rdata.data), 1)
)
(
tt_cells_source_like,
tt_lens_source_like,
tt_centroids_source_like,
) = self.tokenizer.batchify_source(
stream_info,
torch.from_numpy(rdata.coords),
torch.from_numpy(rdata.geoinfos),
torch.from_numpy(rdata.data),
rdata.datetimes,
(time_win2.start, time_win2.end),
ds,
"target_normalizer",
)

stream_data.add_target(fstep, tt_cells, tc, tt_c, tt_t)
stream_data.add_target_source_like(
fstep,
target_raw_source_like,
tt_lens_source_like,
tt_cells_source_like,
tt_centroids_source_like,
)

# merge inputs for sources and targets for current stream
stream_data.merge_inputs()
Expand All @@ -398,6 +426,7 @@ def __iter__(self):

# compute offsets for scatter computation after embedding
batch = compute_offsets_scatter_embed(batch)
batch = compute_offsets_scatter_embed_target_source_like(batch)

# compute offsets and auxiliary data needed for prediction computation
# (info is not per stream so separate data structure)
Expand Down
110 changes: 110 additions & 0 deletions src/weathergen/datasets/stream_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ def __init__(self, idx: int, forecast_steps: int, nhc_source: int, nhc_target: i
self.source_idxs_embed = torch.tensor([])
self.source_idxs_embed_pe = torch.tensor([])

# below are for targets which are tokenized like sources
self.target_source_like_raw = [[] for _ in range(forecast_steps + 1)]
self.target_source_like_tokens_lens = [[] for _ in range(forecast_steps + 1)]
self.target_source_like_tokens_cells = [[] for _ in range(forecast_steps + 1)]
self.target_source_like_centroids = [[] for _ in range(forecast_steps + 1)]

self.target_source_like_idxs_embed = [torch.tensor([]) for _ in range(forecast_steps + 1)]
self.target_source_like_idxs_embed_pe = [
torch.tensor([]) for _ in range(forecast_steps + 1)
]

def to_device(self, device="cuda") -> None:
"""
Move data to GPU
Expand All @@ -91,6 +102,26 @@ def to_device(self, device="cuda") -> None:
self.source_idxs_embed = self.source_idxs_embed.to(device, non_blocking=True)
self.source_idxs_embed_pe = self.source_idxs_embed_pe.to(device, non_blocking=True)

self.target_source_like_raw = [
t.to(device, non_blocking=True) for t in self.target_source_like_raw
]
self.target_source_like_tokens_lens = [
t.to(device, non_blocking=True) for t in self.target_source_like_tokens_lens
]
self.target_source_like_tokens_cells = [
t.to(device, non_blocking=True) for t in self.target_source_like_tokens_cells
]
self.target_source_like_centroids = [
t.to(device, non_blocking=True) for t in self.target_source_like_centroids
]

self.target_source_like_idxs_embed = [
t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed
]
self.target_source_like_idxs_embed_pe = [
t.to(device, non_blocking=True) for t in self.target_source_like_idxs_embed_pe
]

return self

def add_empty_source(self, source: IOReaderData) -> None:
Expand All @@ -111,6 +142,24 @@ def add_empty_source(self, source: IOReaderData) -> None:
self.source_tokens_cells += [torch.tensor([])]
self.source_centroids += [torch.tensor([])]

def add_empty_target_source_like(self, fstep: int) -> None:
"""
Add an empty target for an input encoded like source.
Parameters
----------
None
Returns
-------
None
"""

self.target_source_like_raw[fstep] += [torch.tensor([])]
self.target_source_like_tokens_lens[fstep] += [
torch.zeros([self.nhc_source], dtype=torch.int32)
]
self.target_source_like_tokens_cells[fstep] += [torch.tensor([])]
self.target_source_like_centroids[fstep] += [torch.tensor([])]

def add_empty_target(self, fstep: int) -> None:
"""
Add an empty target for an input.
Expand Down Expand Up @@ -159,6 +208,34 @@ def add_source(
self.source_tokens_cells += [ss_cells]
self.source_centroids += [ss_centroids]

def add_target_source_like(
self,
fstep: int,
tt_raw: torch.tensor,
tt_lens: torch.tensor,
tt_cells: list,
tt_centroids: list,
) -> None:
"""
Add data for source for one input.
Parameters
----------
ss_raw : torch.tensor( number of data points in time window , number of channels )
ss_lens : torch.tensor( number of healpix cells )
ss_cells : list( number of healpix cells )
[ torch.tensor( tokens per cell, token size, number of channels) ]
ss_centroids : list(number of healpix cells )
[ torch.tensor( for source , 5) ]
Returns
-------
None
"""

self.target_source_like_raw[fstep] += [tt_raw]
self.target_source_like_tokens_lens[fstep] += [tt_lens]
self.target_source_like_tokens_cells[fstep] += [tt_cells]
self.target_source_like_centroids[fstep] += [tt_centroids]

def add_target(
self,
fstep: int,
Expand Down Expand Up @@ -318,6 +395,39 @@ def merge_inputs(self) -> None:
self.source_tokens_cells = torch.tensor([])
self.source_centroids = torch.tensor([])

# collect all source like tokens in current stream and add to
# batch sample list when non-empty
for fstep in range(len(self.target_source_like_tokens_cells)):
if (
torch.tensor([len(s) for s in self.target_source_like_tokens_cells[fstep]]).sum()
> 0
):
self.target_source_like_raw[fstep] = torch.cat(self.target_source_like_raw[fstep])

# collect by merging entries per cells, preserving cell structure
self.target_source_like_tokens_cells[fstep] = self._merge_cells(
self.target_source_like_tokens_cells[fstep], self.nhc_source
)
self.target_source_like_centroids[fstep] = self._merge_cells(
self.target_source_like_centroids[fstep], self.nhc_source
)
# lens can be stacked and summed
self.target_source_like_tokens_lens[fstep] = torch.stack(
self.target_source_like_tokens_lens[fstep]
).sum(0)

# remove NaNs
idx = torch.isnan(self.target_source_like_tokens_cells[fstep])
self.target_source_like_tokens_cells[fstep][idx] = self.mask_value
idx = torch.isnan(self.target_source_like_centroids[fstep])
self.target_source_like_centroids[fstep][idx] = self.mask_value

else:
self.target_source_like_raw[fstep] = torch.tensor([])
self.target_source_like_tokens_lens[fstep] = torch.zeros([self.nhc_source])
self.target_source_like_tokens_cells[fstep] = torch.tensor([])
self.target_source_like_centroids[fstep] = torch.tensor([])

# targets
for fstep in range(len(self.target_coords)):
# collect all targets in current stream and add to batch sample list when non-empty
Expand Down
11 changes: 9 additions & 2 deletions src/weathergen/datasets/tokenizer_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,20 @@ def batchify_source(
source: np.array,
times: np.array,
time_win: tuple,
normalizer, # dataset
normalizer, # dataset,
use_normalizer: str, # "source_normalizer" or "target_normalizer"
):
init_loggers()
token_size = stream_info["token_size"]
is_diagnostic = stream_info.get("diagnostic", False)
tokenize_spacetime = stream_info.get("tokenize_spacetime", False)

channel_normalizer = (
normalizer.normalize_source_channels
if use_normalizer == "source_normalizer"
else normalizer.normalize_target_channels
)

tokenize_window = partial(
tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space,
time_win=time_win,
Expand All @@ -56,7 +63,7 @@ def batchify_source(
hpy_verts_rots=self.hpy_verts_rots_source[-1],
n_coords=normalizer.normalize_coords,
n_geoinfos=normalizer.normalize_geoinfos,
n_data=normalizer.normalize_source_channels,
n_data=channel_normalizer,
enc_time=encode_times_source,
)

Expand Down
9 changes: 8 additions & 1 deletion src/weathergen/datasets/tokenizer_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,19 @@ def batchify_source(
times: np.array,
time_win: tuple,
normalizer, # dataset
use_normalizer: str, # "source_normalizer" or "target_normalizer"
):
init_loggers()
token_size = stream_info["token_size"]
is_diagnostic = stream_info.get("diagnostic", False)
tokenize_spacetime = stream_info.get("tokenize_spacetime", False)

channel_normalizer = (
normalizer.normalize_source_channels
if use_normalizer == "source_normalizer"
else normalizer.normalize_target_channels
)

tokenize_window = partial(
tokenize_window_spacetime if tokenize_spacetime else tokenize_window_space,
time_win=time_win,
Expand All @@ -62,7 +69,7 @@ def batchify_source(
hpy_verts_rots=self.hpy_verts_rots_source[-1],
n_coords=normalizer.normalize_coords,
n_geoinfos=normalizer.normalize_geoinfos,
n_data=normalizer.normalize_source_channels,
n_data=channel_normalizer,
enc_time=encode_times_source,
)

Expand Down
78 changes: 78 additions & 0 deletions src/weathergen/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,84 @@ def compute_offsets_scatter_embed(batch: StreamData) -> StreamData:
return batch


def compute_offsets_scatter_embed_target_source_like(batch: StreamData) -> StreamData:
"""
Compute auxiliary information for scatter operation that changes from stream-centric to
cell-centric computations

Parameters
----------
batch : str
batch of stream data information for which offsets have to be computed

Returns
-------
StreamData
stream data with offsets added as members
"""

# collect source_tokens_lens for all stream datas
target_source_like_tokens_lens = torch.stack(
[
torch.stack(
[
torch.stack(
[
s.target_source_like_tokens_lens[fstep]
if len(s.target_source_like_tokens_lens[fstep]) > 0
else torch.tensor([])
for fstep in range(len(s.target_source_like_tokens_lens))
]
)
for s in stl_b
]
)
for stl_b in batch
]
)

# precompute index sets for scatter operation after embed
offsets_base = target_source_like_tokens_lens.sum(1).sum(0).cumsum(1)
# shift the offsets for each fstep by one to the right, add a zero to the
# beginning as the first token starts at 0
zeros_col = torch.zeros(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand on the comment? It's not clear to me why this is necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rephrased it to:

    # shift the offsets for each fstep by one to the right, add a zero to the beginning the first token starts at 0

(offsets_base.shape[0], 1), dtype=offsets_base.dtype, device=offsets_base.device
)
offsets = torch.cat([zeros_col, offsets_base[:, :-1]], dim=1)
offsets_pe = torch.zeros_like(offsets)

for ib, sb in enumerate(batch):
for itype, s in enumerate(sb):
for fstep in range(offsets.shape[0]):
if target_source_like_tokens_lens[ib, itype, fstep].sum() != 0: # if not empty
s.target_source_like_idxs_embed[fstep] = torch.cat(
[
torch.arange(offset, offset + token_len, dtype=torch.int64)
for offset, token_len in zip(
offsets[fstep],
target_source_like_tokens_lens[ib, itype, fstep],
strict=False,
)
]
)
s.target_source_like_idxs_embed_pe[fstep] = torch.cat(
[
torch.arange(offset, offset + token_len, dtype=torch.int32)
for offset, token_len in zip(
offsets_pe[fstep],
target_source_like_tokens_lens[ib][itype][fstep],
strict=False,
)
]
)

# advance offsets
offsets[fstep] += target_source_like_tokens_lens[ib][itype][fstep]
offsets_pe[fstep] += target_source_like_tokens_lens[ib][itype][fstep]

return batch


def compute_idxs_predict(forecast_dt: int, batch: StreamData) -> list:
"""
Compute auxiliary information for prediction
Expand Down
2 changes: 1 addition & 1 deletion src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, cf: Config, sources_size) -> None:
:param sources_size: List of source sizes for each stream.
"""
self.cf = cf
self.sources_size = sources_size # KCT:iss130, what is this?
self.sources_size = sources_size
self.embeds = torch.nn.ModuleList()

def create(self) -> torch.nn.ModuleList:
Expand Down
Loading
Loading