diff --git a/config/streams/fesom/fesom.yml b/config/streams/fesom/fesom.yml index ef315e22f..789011e2d 100644 --- a/config/streams/fesom/fesom.yml +++ b/config/streams/fesom/fesom.yml @@ -10,6 +10,7 @@ FESOM_NODE : type : fesom filenames : ['ocean_node'] + target_file: "/work/ab0995/a270088/Kacper/weathergenertor/AWICM3/ocean_elem" loss_weight : 1. source : null target : null diff --git a/config/streams/fesom/fesom_elem.yml b/config/streams/fesom/fesom_elem.yml index 8afe69435..f9c07e847 100644 --- a/config/streams/fesom/fesom_elem.yml +++ b/config/streams/fesom/fesom_elem.yml @@ -10,6 +10,7 @@ FESOM_ELEM : type : fesom filenames : ['ocean_elem'] + target_file: "/work/ab0995/a270088/Kacper/weathergenertor/AWICM3/ocean_node" loss_weight : 1. source : null target : null diff --git a/src/weathergen/datasets/data_reader_fesom.py b/src/weathergen/datasets/data_reader_fesom.py index bcffc632f..3021f15f9 100644 --- a/src/weathergen/datasets/data_reader_fesom.py +++ b/src/weathergen/datasets/data_reader_fesom.py @@ -48,6 +48,15 @@ def __init__( self.filenames = sorted(glob.glob(str(filename) + "/*")) self._tw_handler = tw_handler self._stream_info = stream_info + self.target_files = self.filenames + + self._src_lat_conv = False + self._src_lon_conv = False + self._trg_lat_conv = False + self._trg_lon_conv = False + + if "target_file" in stream_info: + self.target_files = sorted(glob.glob(str(stream_info["target_file"]) + "/*")) if len(self.filenames) == 0: self.init_empty() @@ -55,8 +64,10 @@ def __init__( return # Initialize data-dependent attributes to None. They will be set by _lazy_init. - self.time: da.Array | None = None - self.data: da.Array | None = None + self.source_time: da.Array | None = None + self.source_data: da.Array | None = None + self.target_time: da.Array | None = None + self.target_data: da.Array | None = None self.len = 0 # Default length is 0 until initialized self.source_channels = [] self.source_idx = [] @@ -65,10 +76,10 @@ def __init__( self.geoinfo_channels = [] self.geoinfo_idx = [] self.properties = {} - self._lat_needs_conversion = False - self._lon_needs_conversion = False + self.fake_specs = {} + self.fake_target = False - if len(self.filenames) == 0: + if len(self.filenames) == 0 or len(self.target_files) == 0: name = stream_info["name"] _logger.warning( f"{name} couldn't find any files matching {filename}. Stream is skipped." @@ -83,6 +94,39 @@ def __init__( # This flag ensures initialization happens only once per worker self._initialized = False + # print(f"checking stream info {list(stream_info.keys())}") + + def _get_mesh_size(self, group: zarr.Group) -> int: + if "nod2" in group.data.attrs: + return group.data.attrs["nod2"] + else: + return group.data.attrs["n_points"] + + def _reorder_groups(self, colnames: list[str], groups: list[zarr.Group]) -> list[da.Array]: + reordered_data_arrays: list[da.Array] = [] + + for group in groups: + local_colnames = group["data"].attrs["colnames"] + + # If the order is already correct, no need to do anything. + if local_colnames == colnames: + reordered_data_arrays.append(da.from_zarr(group["data"])) + else: + # Create the list of indices to re-shuffle the columns. + reorder_indices = [local_colnames.index(name) for name in colnames] + + # Lazily re-index the dask array. This operation is not executed immediately. + dask_array = da.from_zarr(group["data"]) + reordered_array = dask_array[:, reorder_indices] + reordered_data_arrays.append(reordered_array) + + return reordered_data_arrays + + def _remove_lonlat(self, colnames: list[str]) -> list[str]: + temp_colnames = list(colnames) + temp_colnames.remove("lat") + temp_colnames.remove("lon") + return temp_colnames def _lazy_init(self) -> None: """ @@ -92,45 +136,94 @@ def _lazy_init(self) -> None: if self._initialized: return + _logger.info(f"Initialising {self._stream_info['name']}") + # Each worker now opens its own file handles safely - groups: list[zarr.Group] = [zarr.open_group(name, mode="r") for name in self.filenames] - times: list[zarr.Array] = [group["dates"] for group in groups] - self.time = da.concatenate(times, axis=0) + s_groups: list[zarr.Group] = [zarr.open_group(name, mode="r") for name in self.filenames] + t_groups: list[zarr.Group] = [zarr.open_group(name, mode="r") for name in self.target_files] + + s_times: list[zarr.Array] = [group["dates"] for group in s_groups] + t_times: list[zarr.Array] = [group["dates"] for group in t_groups] + + self.source_time = da.concatenate(s_times, axis=0) + self.target_time = da.concatenate(t_times, axis=0) # Use the first group for metadata - first_group = groups[0] - if "nod2" in first_group.data.attrs: - self.mesh_size = first_group.data.attrs["nod2"] - else: - self.mesh_size = first_group.data.attrs["n_points"] + self.source_mesh_size = self._get_mesh_size(s_groups[0]) + self.target_mesh_size = self._get_mesh_size(t_groups[0]) # Metadata reading is cheap, but let's do it with the rest of the init - start_ds = self.time[0][0].compute() - end_ds = self.time[-1][0].compute() + self.start_source = self.source_time[0][0].compute() + self.end_source = self.source_time[-1][0].compute() + + if self.start_source > self._tw_handler.t_end or self.end_source < self._tw_handler.t_start: + name = self._stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + self.init_empty() + self._initialized = True + return + + self.start_target = self.target_time[0][0].compute() + self.end_target = self.target_time[-1][0].compute() - if start_ds > self._tw_handler.t_end or end_ds < self._tw_handler.t_start: + if self.start_target > self._tw_handler.t_end or self.end_target < self._tw_handler.t_start: name = self._stream_info["name"] _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") self.init_empty() self._initialized = True return - period = (self.time[self.mesh_size][0] - self.time[0][0]).compute() + self.source_period = ( + self.source_time[self.source_mesh_size][0] - self.source_time[0][0] + ).compute() + self.target_period = ( + self.target_time[self.target_mesh_size][0] - self.target_time[0][0] + ).compute() # Re-initialize the parent class with correct time info - super().__init__(self._tw_handler, self._stream_info, start_ds, end_ds, period) + super().__init__( # Initialise only for source as source-target split is not supported + self._tw_handler, + self._stream_info, + self.start_source, + self.end_source, + self.source_period, + ) + + if ( + self._tw_handler.t_start > self.start_source + and self._tw_handler.t_start > self.end_source + ): + self.source_start_idx = ( + (self._tw_handler.t_start - self.start_source) // self.source_period + 1 + ) * self.source_mesh_size + else: + self.source_start_idx = 0 - if self._tw_handler.t_start > start_ds: - self.start_idx = ((self._tw_handler.t_start - start_ds) // period + 1) * self.mesh_size + if ( + self._tw_handler.t_start > self.start_target + and self._tw_handler.t_start > self.end_target + ): + self.target_start_idx = ( + (self._tw_handler.t_start - self.start_target) // self.target_period + 1 + ) * self.target_mesh_size else: - self.start_idx = 0 + self.target_start_idx = 0 - self.end_idx = ((self._tw_handler.t_end - start_ds) // period + 1) * self.mesh_size + self.source_end_idx = ( + (self._tw_handler.t_end - self.start_source) // self.source_period + 1 + ) * self.source_mesh_size + self.target_end_idx = ( + (self._tw_handler.t_end - self.start_target) // self.target_period + 1 + ) * self.target_mesh_size - if self.end_idx > len(self.time): - self.end_idx = len(self.time) + if self.source_end_idx > len(self.source_time): + self.source_end_idx = len(self.source_time) + if self.target_end_idx > len(self.target_time): + self.target_end_idx = len(self.target_time) - self.len = (self.end_idx - self.start_idx) // self.mesh_size + self.source_len = (self.source_end_idx - self.source_start_idx) // self.source_mesh_size + self.target_len = (self.target_end_idx - self.target_start_idx) // self.target_mesh_size + self.len = min(self.source_len, self.target_len) # Check for a valid length after calculations if self.len <= 0: @@ -138,99 +231,133 @@ def _lazy_init(self) -> None: self._initialized = True return - self.colnames: list[str] = list(first_group.data.attrs["colnames"]) - self.cols_idx = list(np.arange(len(self.colnames))) - self.lat_index = self.colnames.index("lat") - self.lon_index = self.colnames.index("lon") - - reordered_data_arrays: list[zarr.Group] = [] + self.source_colnames: list[str] = list(s_groups[0].data.attrs["colnames"]) + self.target_colnames: list[str] = list(t_groups[0].data.attrs["colnames"]) - for group in groups: - local_colnames = group["data"].attrs["colnames"] + self.source_cols_idx = list(np.arange(len(self.source_colnames), dtype=int)) + self.target_cols_idx = list(np.arange(len(self.target_colnames), dtype=int)) - # If the order is already correct, no need to do anything. - if local_colnames == self.colnames: - reordered_data_arrays.append(da.from_zarr(group["data"])) - else: - # Create the list of indices to re-shuffle the columns. - reorder_indices = [local_colnames.index(name) for name in self.colnames] + self.src_lat_index: int = self.source_colnames.index("lat") + self.src_lon_index: int = self.source_colnames.index("lon") + self.trg_lat_index: int = self.target_colnames.index("lat") + self.trg_lon_index: int = self.target_colnames.index("lon") - # Lazily re-index the dask array. This operation is not executed immediately. - dask_array = da.from_zarr(group["data"]) - reordered_array = dask_array[:, reorder_indices] - reordered_data_arrays.append(reordered_array) + source_reorderd = self._reorder_groups(self.source_colnames, s_groups) + target_reorderd = self._reorder_groups(self.target_colnames, t_groups) # Modify a copy, not the original list while iterating - temp_colnames = list(self.colnames) - temp_colnames.remove("lat") - temp_colnames.remove("lon") - self.colnames = temp_colnames + self.source_colnames = self._remove_lonlat(self.source_colnames) + self.target_colnames = self._remove_lonlat(self.target_colnames) + + self.source_cols_idx.remove(self.src_lat_index) + self.source_cols_idx.remove(self.src_lon_index) + self.source_cols_idx = np.array(self.source_cols_idx) - self.cols_idx.remove(self.lat_index) - self.cols_idx.remove(self.lon_index) - self.cols_idx = np.array(self.cols_idx) + self.target_cols_idx.remove(self.trg_lat_index) + self.target_cols_idx.remove(self.trg_lon_index) + self.target_cols_idx = np.array(self.target_cols_idx) - self.properties = {"stream_id": first_group.data.attrs["obs_id"]} + self.properties = {"stream_id": s_groups[0].data.attrs["obs_id"]} - self.mean = np.concatenate((np.array([0, 0]), np.array(first_group.data.attrs["means"]))) - self.stdev = np.sqrt( - np.concatenate((np.array([1, 1]), np.array(first_group.data.attrs["std"]))) + self.source_mean = np.concatenate( + (np.array([0, 0]), np.array(s_groups[0].data.attrs["means"])) + ) + self.source_stdev = np.sqrt( + np.concatenate((np.array([1, 1]), np.array(s_groups[0].data.attrs["std"]))) ) - self.stdev[self.stdev <= 1e-5] = 1.0 + self.source_stdev[self.source_stdev <= 1e-5] = 1.0 - self.data = da.concatenate(reordered_data_arrays, axis=0) + self.target_mean = np.concatenate( + (np.array([0, 0]), np.array(t_groups[0].data.attrs["means"])) + ) + self.target_stdev = np.sqrt( + np.concatenate((np.array([1, 1]), np.array(t_groups[0].data.attrs["std"]))) + ) + self.target_stdev[self.target_stdev <= 1e-5] = 1.0 + + self.source = da.concatenate(source_reorderd, axis=0) + self.target = da.concatenate(target_reorderd, axis=0) - first_timestep_lats = self.data[: self.mesh_size, self.lat_index].compute() - first_timestep_lons = self.data[: self.mesh_size, self.lon_index].compute() + source_channels = self._stream_info.get("source") + source_excl = self._stream_info.get("source_exclude") + self.source_channels, self.source_idx = ( + self.select(self.source_colnames, self.source_cols_idx, source_channels, source_excl) + if source_channels or source_excl + else (self.source_colnames, self.source_cols_idx) + ) - if np.any(first_timestep_lats > 90.0): + target_channels = self._stream_info.get("target") + target_excl = self._stream_info.get("target_exclude") + self.target_channels, self.target_idx = ( + self.select(self.target_colnames, self.target_cols_idx, target_channels, target_excl) + if target_channels or target_excl + else (self.target_colnames, self.target_cols_idx) + ) + + src_timestep_lats = self.source[: self.source_mesh_size, self.src_lat_index].compute() + trg_timestep_lats = self.target[: self.target_mesh_size, self.trg_lat_index].compute() + + if np.any(src_timestep_lats > 90.0): _logger.warning( - f"Latitude for stream '{self._stream_info['name']}' appears to be in a [0, 180] " - f"format. It will be automatically converted to the required [-90, 90] format." + f"Latitude for stream '{self._stream_info['name']}' " + f"source appears to be in a [0, 180] format. " + f"It will be automatically converted to the required [-90, 90] format." ) - self._lat_needs_conversion = True + self._src_lat_conv = True - if np.any(first_timestep_lons > 180.0): + if np.any(trg_timestep_lats > 90.0): _logger.warning( - f"Longitude for stream '{self._stream_info['name']}' appears to be in a [0, 360] " - f"format. It will be automatically converted to the required [-180, 180] format." + f"Latitude for stream '{self._stream_info['name']}' " + f"target appears to be in a [0, 180] format. " + f"It will be automatically converted to the required [-90, 90] format." ) - self._lon_needs_conversion = True + self._trg_lat_conv = True - source_channels = self._stream_info.get("source") - source_excl = self._stream_info.get("source_exclude") - self.source_channels, self.source_idx = self.select_channels(source_channels, source_excl) + src_timestep_lons = self.source[: self.source_mesh_size, self.src_lon_index].compute() + trg_timestep_lons = self.target[: self.target_mesh_size, self.trg_lon_index].compute() - target_channels = self._stream_info.get("target") - target_excl = self._stream_info.get("target_exclude") - self.target_channels, self.target_idx = self.select_channels(target_channels, target_excl) + if np.any(src_timestep_lons > 180.0): + _logger.warning( + f"Longitude for stream '{self._stream_info['name']}' " + f"source appears to be in a [0, 360] format. " + f"It will be automatically converted to the required [-180, 180] format." + ) + self._src_lon_conv = True + + if np.any(trg_timestep_lons > 180.0): + _logger.warning( + f"Longitude for stream '{self._stream_info['name']}' " + f"target appears to be in a [0, 360] format." + f"It will be automatically converted to the required [-180, 180] format." + ) + self._trg_lat_conv = True self.geoinfo_channels = [] self.geoinfo_idx = [] self._initialized = True - def select_channels( - self, ch_filters: list[str] | None, excl: list[str] | None = None + def select( + self, + colnames: list[str], + cols_idx: NDArray, + ch_filters: list[str] | None, + excl: list[str] | None = None, ) -> tuple[list[str], NDArray]: - """ - Allow user to specify which columns they want to access. - Get functions only returned for these specified columns. - """ if excl and ch_filters: mask = [ any(f == c for f in ch_filters) and all(ex not in c for ex in excl) - for c in self.colnames + for c in colnames ] elif ch_filters: - mask = [any(f == c for f in ch_filters) for c in self.colnames] + mask = [any(f == c for f in ch_filters) for c in colnames] elif excl: - mask = [all(ex not in c for ex in excl) for c in self.colnames] + mask = [all(ex not in c for ex in excl) for c in colnames] else: - return self.colnames, self.cols_idx + assert False, "Cannot use select with both ch_filters and excl as None" - selected_cols_idx = self.cols_idx[np.where(mask)[0]] - selected_colnames = [self.colnames[i] for i in np.where(mask)[0]] + selected_cols_idx = cols_idx[np.where(mask)[0]] + selected_colnames = [colnames[i] for i in np.where(mask)[0]] return selected_colnames, selected_cols_idx @override @@ -244,10 +371,9 @@ def length(self) -> int: self._lazy_init() return self.len - @override - def _get_dataset_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: + def _get_source_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: """ - Get dataset indexes for a given time window index, when the dataset is periodic. + Get source dataset indexes for a given time window index, when the dataset is periodic. This function assumes state of a variable is persistent, thus if no data is found in the time window, last measurement is used before the beggining of the windows is used. @@ -268,66 +394,160 @@ def _get_dataset_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: dtr = tw_handler.window(idx) # If there is no or only marginal overlap with the dataset, return empty index ranges if ( - not self.data_start_time - or not self.data_end_time - or dtr.end < self.data_start_time - or dtr.start > self.data_end_time - or dtr.start < self.data_start_time - or dtr.end > self.data_end_time - or (self.data_end_time is not None and dtr.start > self.data_end_time) + not self.start_source + or not self.end_source + or dtr.end < self.start_source + or dtr.start > self.end_source + or dtr.start < self.start_source + or dtr.end > self.end_source + or (self.end_source is not None and dtr.start > self.end_source) ): return (np.array([], dtype=np.int64), dtr) # relative time in dataset - delta_t_start = dtr.start - self.data_start_time - delta_t_end = dtr.end - self.data_start_time - t_epsilon + delta_t_start = dtr.start - self.start_source + delta_t_end = dtr.end - self.start_source - t_epsilon assert isinstance(delta_t_start, np.timedelta64), "delta_t_start must be timedelta64" - start_didx = delta_t_start // self.period - end_didx = delta_t_end // self.period + start_didx = delta_t_start // self.source_period + end_didx = delta_t_end // self.source_period # adjust start_idx if not exactly on start time - if (delta_t_start % self.period) > np.timedelta64(0, "s"): + if (delta_t_start % self.source_period) > np.timedelta64(0, "s"): # empty window in between two timesteps if start_didx == end_didx: return (np.array([start_didx], dtype=np.int64), dtr) start_didx += 1 - end_didx = start_didx + int((dtr.end - dtr.start - t_epsilon) / self.period) + end_didx = start_didx + int((dtr.end - dtr.start - t_epsilon) / self.source_period) + return (np.arange(start_didx, end_didx + 1, dtype=np.int64), dtr) + + def _get_target_idxs(self, idx: TIndex) -> tuple[NDArray, DTRange]: + """ + Get target dataset indexes for a given time window index, when the dataset is periodic. + + This function assumes state of a variable is persistent, thus if no data is found + in the time window, last measurement is used before the beggining of the windows is used. + + Parameters + ---------- + idx : TIndex + Index of the time window. + + Returns + ------- + NDArray[np.int64] + Array of dataset indexes corresponding to the time window. + """ + tw_handler = self.time_window_handler + + # Function is separated from the class to allow testing without instantiating the class. + dtr = tw_handler.window(idx) + # If there is no or only marginal overlap with the dataset, return empty index ranges + if ( + not self.start_target + or not self.end_target + or dtr.end < self.start_target + or dtr.start > self.end_target + or dtr.start < self.start_target + or dtr.end > self.end_target + or (self.end_target is not None and dtr.start > self.end_target) + ): + return (np.array([], dtype=np.int64), dtr) + # relative time in dataset + delta_t_start = dtr.start - self.start_target + delta_t_end = dtr.end - self.start_target - t_epsilon + assert isinstance(delta_t_start, np.timedelta64), "delta_t_start must be timedelta64" + start_didx = delta_t_start // self.target_period + end_didx = delta_t_end // self.target_period + + # adjust start_idx if not exactly on start time + if (delta_t_start % self.target_period) > np.timedelta64(0, "s"): + # empty window in between two timesteps + if start_didx == end_didx: + return (np.array([start_didx], dtype=np.int64), dtr) + start_didx += 1 + + end_didx = start_didx + int((dtr.end - dtr.start - t_epsilon) / self.target_period) return (np.arange(start_didx, end_didx + 1, dtype=np.int64), dtr) @override - def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + def get_source(self, idx: TIndex) -> ReaderData: + self._lazy_init() + (t_idxs, dtr) = self._get_source_idxs(idx) + if self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(self.source_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + start_row = t_idxs[0] * self.source_mesh_size + end_row = (t_idxs[-1] + 1) * self.source_mesh_size + + # Note: we read all columns from start_row to end_row once, + # then select the ones we need. This is more efficient for Zarr. + full_data_slice = self.source[start_row:end_row] + datetimes_lazy = self.source_time[start_row:end_row] + + # Define the specific slices we need from the larger block + data_lazy = full_data_slice[:, self.source_idx] + lat_lazy = full_data_slice[:, self.src_lat_index] + lon_lazy = full_data_slice[:, self.src_lon_index] + + # Dask optimizes this to a single (or few) efficient read operation(s). + data, lat, lon, datetimes = dask.compute( + data_lazy, lat_lazy, lon_lazy, datetimes_lazy, scheduler="single-threaded" + ) + + if self._src_lat_conv: + lat = 90.0 - lat + + if self._src_lon_conv: + lon = ((lon + 180.0) % 360.0) - 180.0 + + coords = np.stack([lat, lon], axis=1) + geoinfos = np.zeros((data.shape[0], 0), dtype=data.dtype) + datetimes = np.squeeze(datetimes) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + + return rd + + @override + def get_target(self, idx: TIndex) -> ReaderData: self._lazy_init() - (t_idxs, dtr) = self._get_dataset_idxs(idx) + (t_idxs, dtr) = self._get_target_idxs(idx) if self.len == 0 or len(t_idxs) == 0: return ReaderData.empty( - num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + num_data_fields=len(self.source_idx), num_geo_fields=len(self.geoinfo_idx) ) - start_row = t_idxs[0] * self.mesh_size - end_row = (t_idxs[-1] + 1) * self.mesh_size + start_row = t_idxs[0] * self.target_mesh_size + end_row = (t_idxs[-1] + 1) * self.target_mesh_size # Note: we read all columns from start_row to end_row once, # then select the ones we need. This is more efficient for Zarr. - full_data_slice = self.data[start_row:end_row] - time_slice = self.time[start_row:end_row] + full_data_slice = self.target[start_row:end_row] + datetimes_lazy = self.target_time[start_row:end_row] # Define the specific slices we need from the larger block - data_lazy = full_data_slice[:, channels_idx] - lat_lazy = full_data_slice[:, self.lat_index] - lon_lazy = full_data_slice[:, self.lon_index] - datetimes_lazy = time_slice + data_lazy = full_data_slice[:, self.target_idx] + lat_lazy = full_data_slice[:, self.trg_lat_index] + lon_lazy = full_data_slice[:, self.trg_lon_index] # Dask optimizes this to a single (or few) efficient read operation(s). data, lat, lon, datetimes = dask.compute( data_lazy, lat_lazy, lon_lazy, datetimes_lazy, scheduler="single-threaded" ) - if self._lat_needs_conversion: + if self._trg_lat_conv: lat = 90.0 - lat - if self._lon_needs_conversion: + if self._trg_lon_conv: lon = ((lon + 180.0) % 360.0) - 180.0 coords = np.stack([lat, lon], axis=1) @@ -342,3 +562,87 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: ) return rd + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + return self.get_source(idx) + + @override + def normalize_source_channels(self, source: NDArray) -> NDArray: + """ + Normalize source channels + + Parameters + ---------- + data : + data to be normalized + + Returns + ------- + Normalized data + """ + assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" + for i, ch in enumerate(self.source_idx): + source[..., i] = (source[..., i] - self.source_mean[ch]) / self.source_stdev[ch] + + return source + + @override + def normalize_target_channels(self, target: NDArray) -> NDArray: + """ + Normalize target channels + + Parameters + ---------- + data : + data to be normalized + + Returns + ------- + Normalized data + """ + assert target.shape[-1] == len(self.target_idx), "incorrect number of target channels" + for i, ch in enumerate(self.target_idx): + target[..., i] = (target[..., i] - self.target_mean[ch]) / self.target_stdev[ch] + + return target + + @override + def denormalize_source_channels(self, source: NDArray) -> NDArray: + """ + Denormalize source channels + + Parameters + ---------- + data : + data to be denormalized + + Returns + ------- + Denormalized data + """ + assert source.shape[-1] == len(self.source_idx), "incorrect number of source channels" + for i, ch in enumerate(self.source_idx): + source[..., i] = (source[..., i] * self.source_stdev[ch]) + self.source_mean[ch] + + return source + + @override + def denormalize_target_channels(self, data: NDArray) -> NDArray: + """ + Denormalize target channels + + Parameters + ---------- + data : + data to be denormalized (target or pred) + + Returns + ------- + Denormalized data + """ + assert data.shape[-1] == len(self.target_idx), "incorrect number of target channels" + for i, ch in enumerate(self.target_idx): + data[..., i] = (data[..., i] * self.target_stdev[ch]) + self.target_mean[ch] + + return data diff --git a/src/weathergen/datasets/masking.py b/src/weathergen/datasets/masking.py index 58d5d5731..f8769de7f 100644 --- a/src/weathergen/datasets/masking.py +++ b/src/weathergen/datasets/masking.py @@ -75,9 +75,10 @@ def __init__(self, cf: Config): if self.current_strategy == "channel": # Ensure that masking_strategy_config contains either 'global' or 'per_cell' - assert self.masking_strategy_config.get("mode") in ["global", "per_cell"], ( - "masking_strategy_config must contain 'mode' key with value 'global' or 'per_cell'." - ) + assert self.masking_strategy_config.get("mode") in [ + "global", + "per_cell", + ], "masking_strategy_config must contain 'mode' key with value 'global' or 'per_cell'." # check all streams that source and target channels are identical for stream in cf.streams: @@ -277,6 +278,9 @@ def mask_target( # process all tokens used for embedding for cc, pp in zip(target_tokenized_data, self.perm_sel, strict=True): + if len(cc) == 0: # Skip if there's no target data + pass + if self.current_strategy == "channel": # If masking strategy is channel, handle target tokens differently. # We don't have Booleans per cell, instead per channel per cell, @@ -293,11 +297,28 @@ def mask_target( elif self.current_strategy == "causal": # select only the target times where mask is True - selected_tensors = [c for i, c in enumerate(cc) if pp[i]] - + if len(cc) == len(pp): + selected_tensors = [c for i, c in enumerate(cc) if pp[i]] + elif len(pp) == 0: + selected_tensors = cc + else: # If length of target and mask doesn't match, create new mask + ratio = np.sum(cc) / len(pp) # Ratio of masked tokens in source + indx = max(1, int(ratio * len(cc))) # Get the same for target + selected_tensors = cc[:indx] + + elif self.current_strategy == "healpix": + selected_tensors = ( + cc if len(pp) > 0 and pp[0] else [] + ) # All tokens inside healpix cell have the same mask + + elif self.current_strategy == "random": + # For random masking, we simply select the tensors where the mask is True. + # When there's no mask it's assumed to be False. This is done via strict=False + selected_tensors = [c for c, p in zip(cc, pp, strict=False) if p] else: - # For other masking strategies, we simply select the tensors where the mask is True. - selected_tensors = [c for c, p in zip(cc, pp, strict=True) if p] + raise NotImplementedError( + f"Masking strategy {self.current_strategy} is not supported." + ) # Append the selected tensors to the processed_target_tokens list. if selected_tensors: @@ -487,14 +508,16 @@ def _generate_causal_mask( # Create masks with list comprehension # Needed to handle variable lengths full_mask = [ - np.concatenate( - [ - np.zeros(start_idx, dtype=bool), - np.ones(max(0, token_len - start_idx), dtype=bool), - ] + ( + np.concatenate( + [ + np.zeros(start_idx, dtype=bool), + np.ones(max(0, token_len - start_idx), dtype=bool), + ] + ) + if token_len > 1 + else (np.zeros(1, dtype=bool) if token_len == 1 else np.array([], dtype=bool)) ) - if token_len > 1 - else (np.zeros(1, dtype=bool) if token_len == 1 else np.array([], dtype=bool)) for token_len, start_idx in zip(token_lens, start_mask_indices, strict=False) ]