Skip to content
25 changes: 10 additions & 15 deletions src/weathergen/datasets/data_reader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@
_DT_ZERO = np.datetime64("1850-01-01T00:00")


@dataclass
class TimeIndexRange:
"""
Defines a time window for indexing into datasets.
It is defined as number of hours since the start of the dataset.
"""

start: TIndex
end: TIndex


@dataclass
class DTRange:
"""
Expand Down Expand Up @@ -132,10 +120,14 @@ def __init__(

assert self.t_start < self.t_end, "end datetime has to be in the past of start datetime"
assert self.t_start > _DT_ZERO, "start datetime has to be >= 1850-01-01T00:00."
_logger.info(
f"Time window handler: start={self.t_start}, end={self.t_end},"
f"len_hrs={self.t_window_len}, step_hrs={self.t_window_step}"
)

def get_index_range(self) -> TimeIndexRange:
def get_index_range(self) -> range[TIndex]:
"""
Temporal window corresponding to index
Range of indices identifying time ranges, from start to end.
Parameters
----------
Expand All @@ -151,7 +143,7 @@ def get_index_range(self) -> TimeIndexRange:
idx_end = np.int64((self.t_end - self.t_start) // self.t_window_step)
assert idx_start <= idx_end, f"time window idxs invalid: {idx_start} <= {idx_end}"

return TimeIndexRange(idx_start, idx_end)
return range(idx_start, idx_end)

def window(self, idx: TIndex) -> DTRange:
"""
Expand All @@ -171,6 +163,9 @@ def window(self, idx: TIndex) -> DTRange:
t_end_win = t_start_win + self.t_window_len

return DTRange(t_start_win, t_end_win)

def get_forecast_len(self, forecast_step: int)-> int:
return (int(self.t_window_len) * (forecast_step + 1)) // int(self.t_window_step)


@dataclass
Expand Down
210 changes: 103 additions & 107 deletions src/weathergen/datasets/multi_stream_data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
ReaderData,
TimeWindowHandler,
TIndex,
str_to_datetime64,
)
from weathergen.datasets.data_reader_fesom import DataReaderFesom
from weathergen.datasets.data_reader_obs import DataReaderObs
Expand All @@ -33,6 +32,7 @@
compute_offsets_scatter_embed,
compute_source_cell_lens,
)
from weathergen.utils.distributed import get_rank, get_world_size
from weathergen.utils.logger import logger
from weathergen.utils.train_logger import Stage

Expand All @@ -44,36 +44,23 @@ class MultiStreamDataSampler(torch.utils.data.IterableDataset):
def __init__(
self,
cf,
start_date_,
end_date_,
start_date,
end_date,
batch_size,
samples_per_epoch,
stage: Stage,
shuffle=True,
):
super(MultiStreamDataSampler, self).__init__()

start_date = str_to_datetime64(start_date_)
end_date = str_to_datetime64(end_date_)

assert end_date > start_date, (end_date, start_date)

self.mask_value = 0.0
self._stage = stage

self.len_hrs: int = cf.len_hrs
self.step_hrs: int = cf.step_hrs
self.time_window_handler = TimeWindowHandler(start_date, end_date, cf.len_hrs, cf.step_hrs)
logger.info(
f"Time window handler: start={start_date}, end={end_date},"
f"len_hrs={cf.len_hrs}, step_hrs={cf.step_hrs}"
)

self.forecast_offset = cf.forecast_offset
self.forecast_delta_hrs = (
cf.forecast_delta_hrs if cf.forecast_delta_hrs > 0 else self.len_hrs
)
assert self.forecast_delta_hrs == self.len_hrs, "Only supported option at the moment"
self.forecast_delta_hrs = cf.forecast_delta_hrs if cf.forecast_delta_hrs > 0 else cf.len_hrs
assert self.forecast_delta_hrs == cf.len_hrs, "Only supported option at the moment"
self.forecast_steps = np.array(
[cf.forecast_steps] if isinstance(cf.forecast_steps, int) else cf.forecast_steps
)
Expand All @@ -82,87 +69,24 @@ def __init__(
logger.warning("forecast policy is not None but number of forecast steps is 0.")
self.forecast_policy = cf.forecast_policy

self.len = 100000000

self.streams_datasets: list[list[AnyDataReader]] = []
for _, stream_info in enumerate(cf.streams):
self.streams_datasets.append([])

for fname in stream_info["filenames"]:
kwargs = {
"tw_handler": self.time_window_handler,
"stream_info": stream_info,
}
dataset: type[AnyDataReader] | None = None
match stream_info["type"]:
case "obs":
dataset = DataReaderObs
datapath = cf.data_path_obs
# kwargs["end"] = end_date_padded # TODO: implement the padding
case "anemoi":
dataset = DataReaderAnemoi
datapath = cf.data_path_anemoi
case "fesom":
dataset = DataReaderFesom
datapath = cf.data_path_fesom
case "icon":
dataset = IconDataset
datapath = cf.data_path_icon
case _:
msg = f"Unsupported stream type {stream_info['type']}"
f"for stream name '{stream_info['name']}'."
raise ValueError(msg)

datapath = pathlib.Path(datapath)
fname = pathlib.Path(fname)
# dont check if file exists since zarr stores might be directories
if fname.exists():
# check if fname is a valid path to allow for simple overwriting
filename = fname
else:
filename = pathlib.Path(datapath) / fname

if not filename.exists(): # see above
msg = (
f"Did not find input data for {stream_info['type']} "
f"stream '{stream_info['name']}': {filename}."
)
raise FileNotFoundError(msg)

ds_type = stream_info["type"]
logger.info(
f"Opening dataset with type: {ds_type}"
+ f" from stream config {stream_info['name']}.",
)
ds = dataset(filename=filename, **kwargs)

fsm = self.forecast_steps[0]
if len(ds) > 0:
self.len = min(self.len, len(ds) - (self.len_hrs * (fsm + 1)) // self.step_hrs)

# MODIFIES config !!!
stream_info[str(self._stage) + "_source_channels"] = ds.source_channels
stream_info[str(self._stage) + "_target_channels"] = ds.target_channels
self.streams_datasets: list[list[AnyDataReader]] = [
create_datasets(stream_info, self.time_window_handler, cf)
for stream_info in cf.streams
]

self.streams_datasets[-1] += [ds]
# MODIFIES config !!!
for stream_info, stream_datasets in zip(cf.streams, self.streams_datasets, strict=True):
# assume all datasets within one stream have the same channels
sample_ds = stream_datasets[0]
stream_info[f"{self._stage}_source_channels"] = sample_ds.source_channels
stream_info[f"{self._stage}_target_channels"] = sample_ds.target_channels

index_range = self.time_window_handler.get_index_range()
self.len = int(index_range.end - index_range.start)
self.len = min(self.len, samples_per_epoch if samples_per_epoch else self.len)
# adjust len to split loading across all workers and ensure it is multiple of batch_size
len_chunk = ((self.len // cf.num_ranks) // batch_size) * batch_size
self.len = min(self.len, len_chunk)
logger.info(f"index_range={index_range}, len={self.len}, len_chunk={len_chunk}")

self.rank = cf.rank
self.num_ranks = cf.num_ranks
self._len = self._get_len(self.time_window_handler, samples_per_epoch, batch_size)

self.streams = cf.streams
self.shuffle = shuffle
# TODO: remove options that are no longer supported
self.input_window_steps = cf.input_window_steps
self.embed_local_coords = cf.embed_local_coords
self.embed_centroids_local_coords = cf.embed_centroids_local_coords

self.input_window_steps = cf.input_window_steps # TODO is deprecated?
self.sampling_rate_target = cf.sampling_rate_target

self.batch_size = batch_size
Expand Down Expand Up @@ -233,7 +157,7 @@ def reset(self):
# value in worker_workset()
self.rng = np.random.default_rng(self.data_loader_rng_seed)

fsm = (
fsm: int = (
self.forecast_steps[min(self.epoch, len(self.forecast_steps) - 1)]
if self.forecast_policy != "random"
else self.forecast_steps.max()
Expand All @@ -243,12 +167,15 @@ def reset(self):

# data
index_range = self.time_window_handler.get_index_range()
idx_end = index_range.end
# native length of datasets, independent of epoch length that has potentially been specified
forecast_len = (self.len_hrs * (fsm + 1)) // self.step_hrs
idx_end -= forecast_len + self.forecast_offset
assert idx_end > 0, "dataset size too small for forecast range"
self.perms = np.arange(index_range.start, idx_end)
adjusted_index_range = range(
index_range.start,
index_range.stop
- self.time_window_handler.get_forecast_len(fsm)
+ self.forecast_offset,
)
assert adjusted_index_range.stop > 0, "dataset size too small for forecast range"
self.perms = np.array(adjusted_index_range)
if self.shuffle:
self.perms = self.rng.permutation(self.perms)

Expand Down Expand Up @@ -288,7 +215,7 @@ def __iter__(self):
len[*] : number of streams
"""
iter_start, iter_end = self.worker_workset()
logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={self.len}")
logger.info(f"iter_start={iter_start}, iter_end={iter_end}, len={len(self)}")

# create new shuffeling
self.reset()
Expand Down Expand Up @@ -358,8 +285,8 @@ def __iter__(self):
for fstep in range(
self.forecast_offset, self.forecast_offset + forecast_dt + 1
):
step_forecast_dt = (
idx + (self.forecast_delta_hrs * fstep) // self.step_hrs
step_forecast_dt = idx + (self.forecast_delta_hrs * fstep) // int(
self.time_window_handler.t_window_step
)
time_win2 = self.time_window_handler.window(step_forecast_dt)

Expand Down Expand Up @@ -408,16 +335,30 @@ def __iter__(self):

###################################################
def __len__(self):
return self.len
return self._len

@staticmethod
def _get_len(twh: TimeWindowHandler, samples_per_epoch, batch_size) -> int:
index_range = twh.get_index_range()
_len = len(index_range)
samples_per_epoch = samples_per_epoch if samples_per_epoch else _len
_len = min(_len, samples_per_epoch)

# adjust len to split loading across all workers and ensure it is multiple of batch_size
len_chunk = ((_len // get_world_size()) // batch_size) * batch_size
_len = min(_len, len_chunk)
logger.info(f"index_range={index_range}, len={_len}, len_chunk={len(_len)}")
return _len

###################################################
def worker_workset(self):
local_start, local_end = self.rank * self.len, (self.rank + 1) * self.len
rank = get_rank()
local_start, local_end = rank * len(self), (rank + 1) * len(self)

worker_info = torch.utils.data.get_worker_info()

if worker_info is None:
assert self.num_ranks == 1
assert get_world_size() == 1
iter_start = 0
iter_end = len(self)

Expand All @@ -442,8 +383,63 @@ def worker_workset(self):
if worker_info.id + 1 == worker_info.num_workers:
iter_end = local_end
logger.info(
f"{self.rank}::{worker_info.id}"
f"{rank}::{worker_info.id}"
+ f" : dataset [{local_start},{local_end}) : [{iter_start},{iter_end})"
)

return iter_start, iter_end


def create_datasets(stream_info, time_window_handler, cf) -> list[DataReaderBase]:
datasets: list[DataReaderBase] = []

for fname in stream_info["filenames"]:
kwargs = {
"tw_handler": time_window_handler,
"stream_info": stream_info,
}
dataset: type[AnyDataReader] | None = None
match stream_info["type"]:
case "obs":
dataset = DataReaderObs
datapath = cf.data_path_obs
# kwargs["end"] = end_date_padded # TODO: implement the padding
case "anemoi":
dataset = DataReaderAnemoi
datapath = cf.data_path_anemoi
case "fesom":
dataset = DataReaderFesom
datapath = cf.data_path_fesom
case "icon":
dataset = IconDataset
datapath = cf.data_path_icon
case _:
msg = f"Unsupported stream type {stream_info['type']}"
f"for stream name '{stream_info['name']}'."
raise ValueError(msg)

datapath = pathlib.Path(datapath)
fname = pathlib.Path(fname)
# dont check if file exists since zarr stores might be directories
if fname.exists():
# check if fname is a valid path to allow for simple overwriting
filename = fname
else:
filename = pathlib.Path(datapath) / fname

if not filename.exists(): # see above
msg = (
f"Did not find input data for {stream_info['type']} "
f"stream '{stream_info['name']}': {filename}."
)
raise FileNotFoundError(msg)

ds_type = stream_info["type"]
logger.info(
f"Opening dataset with type: {ds_type}" + f" from stream config {stream_info['name']}.",
)
ds = dataset(filename=filename, **kwargs)

datasets += [ds]

return datasets
Loading