From 49fe44d77f468c9e3212e97b213af80da8f07662 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 29 Aug 2025 05:17:43 +0000 Subject: [PATCH 1/5] Adding data reader for metno synop dataset. Work in progress --- src/weathergen/datasets/data_reader_synop.py | 268 ++++++++++++++++++ .../datasets/multi_stream_data_sampler.py | 4 + 2 files changed, 272 insertions(+) create mode 100644 src/weathergen/datasets/data_reader_synop.py diff --git a/src/weathergen/datasets/data_reader_synop.py b/src/weathergen/datasets/data_reader_synop.py new file mode 100644 index 000000000..a960ac6c5 --- /dev/null +++ b/src/weathergen/datasets/data_reader_synop.py @@ -0,0 +1,268 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from pathlib import Path +from typing import override + +import anemoi.datasets as anemoi_datasets +import numpy as np +import xarray as xr +from anemoi.datasets.data import MissingDateError +from anemoi.datasets.data.dataset import Dataset +from numpy.typing import NDArray + +from weathergen.datasets.data_reader_base import ( + DataReaderTimestep, + ReaderData, + TimeWindowHandler, + TIndex, + check_reader_data, +) + +_logger = logging.getLogger(__name__) + + +class DataReaderSynop(DataReaderTimestep): + "Wrapper for SYNOP datasets from MetNo in NetCDF" + + def __init__( + self, + tw_handler: TimeWindowHandler, + filename: Path, + stream_info: dict, + ) -> None: + """ + Construct data reader for anemoi dataset + + Parameters + ---------- + filename : + filename (and path) of dataset + stream_info : + information about stream + + Returns + ------- + None + """ + + # open dataset to peak that it is compatible with requested parameters + ds = xr.open_dataset(filename, engine="netcdf4") + + import code + + code.interact(local=locals()) + + # If there is no overlap with the time range, the dataset will be empty + if tw_handler.t_start >= ds.time.max() or tw_handler.t_end <= ds.time.min(): + name = stream_info["name"] + _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") + super().__init__(tw_handler, stream_info) + self.init_empty() + return + + kwargs = {} + if "frequency" in stream_info: + # kwargs["frequency"] = str_to_timedelta(stream_info["frequency"]) + assert False, "Frequency sub-sampling currently not supported" + ds: Dataset = anemoi_datasets.open_dataset( + ds0, **kwargs, start=tw_handler.t_start, end=tw_handler.t_end + ) + + period = (ds.time[1] - ds.time[0]).values + data_start_time = ds.dates[0] + data_end_time = ds.dates[-1] + assert data_start_time is not None and data_end_time is not None, ( + data_start_time, + data_end_time, + ) + super().__init__( + tw_handler, + stream_info, + data_start_time, + data_end_time, + period, + ) + # If there is no overlap with the time range, no need to keep the dataset. + if tw_handler.t_start >= data_end_time or tw_handler.t_end <= data_start_time: + self.init_empty() + return + else: + self.ds = ds + self.len = len(ds) + + # caches lats and lons + self.latitudes = _clip_lat(ds.latitude) + self.longitudes = _clip_lon(ds.longitude) + + # select/filter requested source channels + self.source_idx = self.select_channels(ds, "source") + self.source_channels = [ds.variables[i] for i in self.source_idx] + + # select/filter requested target channels + self.target_idx = self.select_channels(ds, "target") + self.target_channels = [ds.variables[i] for i in self.target_idx] + + self.geoinfo_channels = ["altitude"] + self.geoinfo_idx = [2] + + ds_name = stream_info["name"] + _logger.info(f"{ds_name}: source channels: {self.source_channels}") + _logger.info(f"{ds_name}: target channels: {self.target_channels}") + _logger.info(f"{ds_name}: geoinfo channels: {self.geoinfo_channels}") + + self.properties = { + "stream_id": 0, + } + self.mean = ds.statistics["mean"] + self.stdev = ds.statistics["stdev"] + + @override + def init_empty(self) -> None: + super().init_empty() + self.ds = None + self.len = 0 + + @override + def length(self) -> int: + return self.len + + @override + def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: + """ + Get data for window (for either source or target, through public interface) + + Parameters + ---------- + idx : int + Index of temporal window + channels_idx : np.array + Selection of channels + + Returns + ------- + ReaderData providing coords, geoinfos, data, datetimes + """ + + (t_idxs, dtr) = self._get_dataset_idxs(idx) + + if self.ds is None or self.len == 0 or len(t_idxs) == 0: + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + assert t_idxs[0] >= 0, "index must be non-negative" + didx_start = t_idxs[0] + # End is inclusive + didx_end = t_idxs[-1] + 1 + + # extract number of time steps and collapse ensemble dimension + # ds is a wrapper around zarr with get_coordinate_selection not being exposed since + # subsetting is pushed to the ctor via frequency argument; this also ensures that no sub- + # sampling is required here + try: + data = self.ds[didx_start:didx_end][:, :, 0].astype(np.float32) + except MissingDateError as e: + _logger.debug(f"Date not present in anemoi dataset: {str(e)}. Skipping.") + return ReaderData.empty( + num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) + ) + + # extract channels + data = ( + data[:, list(channels_idx)] + .transpose([0, 2, 1]) + .reshape((data.shape[0] * data.shape[2], -1)) + ) + + # construct lat/lon coords + latlon = np.concatenate( + [ + np.expand_dims(self.latitudes, 0), + np.expand_dims(self.longitudes, 0), + ], + axis=0, + ).transpose() + # repeat latlon len(t_idxs) times + coords = np.vstack((latlon,) * len(t_idxs)) + + # empty geoinfos for anemoi + geoinfos = np.zeros((len(data), 0), dtype=data.dtype) + + # date time matching #data points of data + # Assuming a fixed frequency for the dataset + datetimes = np.repeat(self.ds.dates[didx_start:didx_end], len(data) // len(t_idxs)) + + rd = ReaderData( + coords=coords, + geoinfos=geoinfos, + data=data, + datetimes=datetimes, + ) + check_reader_data(rd, dtr) + + return rd + + def select_channels(self, ds, ch_type: str) -> NDArray[np.int64]: + """ + Select source or target channels + + Parameters + ---------- + ds0 : + raw anemoi dataset with available channels + ch_type : + "source" or "target", i.e channel type to select + + Returns + ------- + ReaderData providing coords, geoinfos, data, datetimes + + """ + + channels = self.stream_info.get(ch_type) + channels_exclude = self.stream_info.get(ch_type + "_exclude", []) + # sanity check + is_empty = len(channels) == 0 if channels is not None else False + if is_empty: + stream_name = self.stream_info["name"] + _logger.warning(f"No channel for {stream_name} for {ch_type}.") + + channels = list(ds.keys()) + chs_idx = np.sort( + [ + ds.name_to_index[k] + for (k, v) in ds0.typed_variables.items() + if ( + not v.is_computed_forcing + and not v.is_constant_in_time + and ( + np.array([f in k for f in channels]).any() if channels is not None else True + ) + and not np.array([f in k for f in channels_exclude]).any() + ) + ] + ) + + return chs_idx + + +def _clip_lat(lats: NDArray) -> NDArray[np.float32]: + """ + Clip latitudes to the range [-90, 90] and ensure periodicity. + """ + return (2 * np.clip(lats, -90.0, 90.0) - lats).astype(np.float32) + + +def _clip_lon(lons: NDArray) -> NDArray[np.float32]: + """ + Clip longitudes to the range [-180, 180] and ensure periodicity. + """ + return ((lons + 180.0) % 360.0 - 180.0).astype(np.float32) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index ef2d6be24..053522b50 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -23,6 +23,7 @@ ) from weathergen.datasets.data_reader_fesom import DataReaderFesom from weathergen.datasets.data_reader_obs import DataReaderObs +from weathergen.datasets.data_reader_synop import DataReaderSynop from weathergen.datasets.icon_dataset import IconDataset from weathergen.datasets.masking import Masker from weathergen.datasets.stream_data import StreamData @@ -108,6 +109,9 @@ def __init__( case "icon": dataset = IconDataset datapath = cf.data_path_icon + case "synop": + dataset = DataReaderSynop + datapath = cf.data_path_obs case _: msg = f"Unsupported stream type {stream_info['type']}" f"for stream name '{stream_info['name']}'." From bc796dc6d3b384fac1aa3495f01339babee3ccb2 Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Fri, 29 Aug 2025 07:00:31 +0000 Subject: [PATCH 2/5] Working prototype --- src/weathergen/datasets/data_reader_synop.py | 108 ++++++++++--------- 1 file changed, 55 insertions(+), 53 deletions(-) diff --git a/src/weathergen/datasets/data_reader_synop.py b/src/weathergen/datasets/data_reader_synop.py index a960ac6c5..9460b03d5 100644 --- a/src/weathergen/datasets/data_reader_synop.py +++ b/src/weathergen/datasets/data_reader_synop.py @@ -11,11 +11,8 @@ from pathlib import Path from typing import override -import anemoi.datasets as anemoi_datasets import numpy as np import xarray as xr -from anemoi.datasets.data import MissingDateError -from anemoi.datasets.data.dataset import Dataset from numpy.typing import NDArray from weathergen.datasets.data_reader_base import ( @@ -56,10 +53,6 @@ def __init__( # open dataset to peak that it is compatible with requested parameters ds = xr.open_dataset(filename, engine="netcdf4") - import code - - code.interact(local=locals()) - # If there is no overlap with the time range, the dataset will be empty if tw_handler.t_start >= ds.time.max() or tw_handler.t_end <= ds.time.min(): name = stream_info["name"] @@ -72,13 +65,10 @@ def __init__( if "frequency" in stream_info: # kwargs["frequency"] = str_to_timedelta(stream_info["frequency"]) assert False, "Frequency sub-sampling currently not supported" - ds: Dataset = anemoi_datasets.open_dataset( - ds0, **kwargs, start=tw_handler.t_start, end=tw_handler.t_end - ) period = (ds.time[1] - ds.time[0]).values - data_start_time = ds.dates[0] - data_end_time = ds.dates[-1] + data_start_time = ds.time[0].values + data_end_time = ds.time[-1].values assert data_start_time is not None and data_end_time is not None, ( data_start_time, data_end_time, @@ -98,20 +88,26 @@ def __init__( self.ds = ds self.len = len(ds) + self.offset_data_channels = 4 + self.fillvalue = ds["air_temperature"][0, 0].values.item() + # caches lats and lons - self.latitudes = _clip_lat(ds.latitude) - self.longitudes = _clip_lon(ds.longitude) + self.latitudes = _clip_lat(np.array(ds.latitude, dtype=np.float32)) + self.longitudes = _clip_lon(np.array(ds.longitude, dtype=np.float32)) + + self.geoinfos = np.array(ds.altitude, dtype=np.float32) + self.geoinfo_channels = [] # ["altitude"] + self.geoinfo_idx = [] # [2] + + self.channels_file = [k for k in self.ds.keys()] # select/filter requested source channels self.source_idx = self.select_channels(ds, "source") - self.source_channels = [ds.variables[i] for i in self.source_idx] + self.source_channels = [self.channels_file[i] for i in self.source_idx] # select/filter requested target channels self.target_idx = self.select_channels(ds, "target") - self.target_channels = [ds.variables[i] for i in self.target_idx] - - self.geoinfo_channels = ["altitude"] - self.geoinfo_idx = [2] + self.target_channels = [self.channels_file[i] for i in self.target_idx] ds_name = stream_info["name"] _logger.info(f"{ds_name}: source channels: {self.source_channels}") @@ -121,8 +117,29 @@ def __init__( self.properties = { "stream_id": 0, } - self.mean = ds.statistics["mean"] - self.stdev = ds.statistics["stdev"] + + self.mean, self.stdev = self.compute_mean_stdev() + + def compute_mean_stdev(self) -> (np.array, np.array): + _logger.info("Starting computation of mean and stdev.") + + mean = [0.0 for _ in range(self.offset_data_channels)] + stdev = [1.0 for _ in range(self.offset_data_channels)] + + data_channels_file = [k for k in self.ds.keys()][self.offset_data_channels :] + for ch in data_channels_file: + data = np.array(self.ds[ch], np.float64) + mask = data == self.fillvalue + data[mask] = np.nan + mean += [np.nanmean(data.flatten())] + stdev += [np.nanstd(data.flatten())] + + mean = np.array(mean) + stdev = np.array(stdev) + + _logger.info("Finished computation of mean and stdev.") + + return mean, stdev @override def init_empty(self) -> None: @@ -167,20 +184,11 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: # ds is a wrapper around zarr with get_coordinate_selection not being exposed since # subsetting is pushed to the ctor via frequency argument; this also ensures that no sub- # sampling is required here - try: - data = self.ds[didx_start:didx_end][:, :, 0].astype(np.float32) - except MissingDateError as e: - _logger.debug(f"Date not present in anemoi dataset: {str(e)}. Skipping.") - return ReaderData.empty( - num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx) - ) - - # extract channels - data = ( - data[:, list(channels_idx)] - .transpose([0, 2, 1]) - .reshape((data.shape[0] * data.shape[2], -1)) - ) + sel_channels = [self.channels_file[i] for i in channels_idx] + data = np.stack([self.ds[ch].isel(time=slice(didx_start, didx_end)) for ch in sel_channels]) + data = data.transpose([1, 2, 0]).reshape((data.shape[1] * data.shape[2], data.shape[0])) + mask = data == self.fillvalue + data[mask] = np.nan # construct lat/lon coords latlon = np.concatenate( @@ -193,12 +201,16 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: # repeat latlon len(t_idxs) times coords = np.vstack((latlon,) * len(t_idxs)) + # import code + # code.interact( local=locals()) + # empty geoinfos for anemoi + # TODO: altitudes geoinfos = np.zeros((len(data), 0), dtype=data.dtype) # date time matching #data points of data # Assuming a fixed frequency for the dataset - datetimes = np.repeat(self.ds.dates[didx_start:didx_end], len(data) // len(t_idxs)) + datetimes = np.repeat(self.ds.time[didx_start:didx_end].values, len(data) // len(t_idxs)) rd = ReaderData( coords=coords, @@ -227,7 +239,9 @@ def select_channels(self, ds, ch_type: str) -> NDArray[np.int64]: """ - channels = self.stream_info.get(ch_type) + channels_file = [k for k in ds.keys()][self.offset_data_channels :] + + channels = self.stream_info.get(ch_type, channels_file) channels_exclude = self.stream_info.get(ch_type + "_exclude", []) # sanity check is_empty = len(channels) == 0 if channels is not None else False @@ -235,23 +249,11 @@ def select_channels(self, ds, ch_type: str) -> NDArray[np.int64]: stream_name = self.stream_info["name"] _logger.warning(f"No channel for {stream_name} for {ch_type}.") - channels = list(ds.keys()) - chs_idx = np.sort( - [ - ds.name_to_index[k] - for (k, v) in ds0.typed_variables.items() - if ( - not v.is_computed_forcing - and not v.is_constant_in_time - and ( - np.array([f in k for f in channels]).any() if channels is not None else True - ) - and not np.array([f in k for f in channels_exclude]).any() - ) - ] - ) + chs_idx = np.sort([channels_file.index(ch) for ch in channels]) + chs_idx_exclude = np.sort([channels_file.index(ch) for ch in channels_exclude]) + chs_idx = [idx for idx in chs_idx if idx not in chs_idx_exclude] - return chs_idx + return np.array(chs_idx) + self.offset_data_channels def _clip_lat(lats: NDArray) -> NDArray[np.float32]: From 8889c798eec76b60acd7bbe43b091b1bb4818e1f Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Sep 2025 09:25:51 +0000 Subject: [PATCH 3/5] Renaming dataset/reader type from synop to station --- src/weathergen/datasets/multi_stream_data_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/datasets/multi_stream_data_sampler.py b/src/weathergen/datasets/multi_stream_data_sampler.py index 053522b50..20228b7b2 100644 --- a/src/weathergen/datasets/multi_stream_data_sampler.py +++ b/src/weathergen/datasets/multi_stream_data_sampler.py @@ -109,7 +109,7 @@ def __init__( case "icon": dataset = IconDataset datapath = cf.data_path_icon - case "synop": + case "station": dataset = DataReaderSynop datapath = cf.data_path_obs case _: From e5f3b1a4e1725a5a3e8b8d6611c84a4225e1beed Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Wed, 3 Sep 2025 09:26:40 +0000 Subject: [PATCH 4/5] Cleaned up code and made it more general --- src/weathergen/datasets/data_reader_synop.py | 68 +++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/src/weathergen/datasets/data_reader_synop.py b/src/weathergen/datasets/data_reader_synop.py index 9460b03d5..a37d79162 100644 --- a/src/weathergen/datasets/data_reader_synop.py +++ b/src/weathergen/datasets/data_reader_synop.py @@ -50,6 +50,8 @@ def __init__( None """ + np32 = np.float32 + # open dataset to peak that it is compatible with requested parameters ds = xr.open_dataset(filename, engine="netcdf4") @@ -58,12 +60,10 @@ def __init__( name = stream_info["name"] _logger.warning(f"{name} is not supported over data loader window. Stream is skipped.") super().__init__(tw_handler, stream_info) - self.init_empty() + self._init_empty() return - kwargs = {} if "frequency" in stream_info: - # kwargs["frequency"] = str_to_timedelta(stream_info["frequency"]) assert False, "Frequency sub-sampling currently not supported" period = (ds.time[1] - ds.time[0]).values @@ -82,7 +82,7 @@ def __init__( ) # If there is no overlap with the time range, no need to keep the dataset. if tw_handler.t_start >= data_end_time or tw_handler.t_end <= data_start_time: - self.init_empty() + self._init_empty() return else: self.ds = ds @@ -90,16 +90,19 @@ def __init__( self.offset_data_channels = 4 self.fillvalue = ds["air_temperature"][0, 0].values.item() + self.channels_file = [k for k in self.ds.keys()] # caches lats and lons - self.latitudes = _clip_lat(np.array(ds.latitude, dtype=np.float32)) - self.longitudes = _clip_lon(np.array(ds.longitude, dtype=np.float32)) - - self.geoinfos = np.array(ds.altitude, dtype=np.float32) - self.geoinfo_channels = [] # ["altitude"] - self.geoinfo_idx = [] # [2] + lat_name = stream_info.get("latitude_name", "latitude") + self.latitudes = _clip_lat(np.array(ds[lat_name], dtype=np32)) + lon_name = stream_info.get("longitude_name", "longitude") + self.longitudes = _clip_lon(np.array(ds[lon_name], dtype=np32)) - self.channels_file = [k for k in self.ds.keys()] + self.geoinfo_channels = stream_info.get("geoinfos", []) + self.geoinfo_idx = [self.channels_file.index(ch) for ch in self.geoinfo_channels] + # cache geoinfos + self.geoinfo_data = np.stack([np.array(ds[ch], dtype=np32) for ch in self.geoinfo_channels]) + self.geoinfo_data = self.geoinfo_data.transpose() # select/filter requested source channels self.source_idx = self.select_channels(ds, "source") @@ -118,9 +121,10 @@ def __init__( "stream_id": 0, } - self.mean, self.stdev = self.compute_mean_stdev() + # TODO: this should be stored/cached + self.mean, self.stdev = self._compute_mean_stdev() - def compute_mean_stdev(self) -> (np.array, np.array): + def _compute_mean_stdev(self) -> (np.array, np.array): _logger.info("Starting computation of mean and stdev.") mean = [0.0 for _ in range(self.offset_data_channels)] @@ -142,13 +146,19 @@ def compute_mean_stdev(self) -> (np.array, np.array): return mean, stdev @override - def init_empty(self) -> None: - super().init_empty() + def _init_empty(self) -> None: + super()._init_empty() self.ds = None self.len = 0 @override def length(self) -> int: + """ + Length of dataset + + Return : + Length + """ return self.len @override @@ -185,8 +195,10 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: # subsetting is pushed to the ctor via frequency argument; this also ensures that no sub- # sampling is required here sel_channels = [self.channels_file[i] for i in channels_idx] - data = np.stack([self.ds[ch].isel(time=slice(didx_start, didx_end)) for ch in sel_channels]) + data = self.ds[sel_channels].isel(time=slice(didx_start, didx_end)).to_array().values + # flatten along time dimension data = data.transpose([1, 2, 0]).reshape((data.shape[1] * data.shape[2], data.shape[0])) + # set invalid values to NaN mask = data == self.fillvalue data[mask] = np.nan @@ -198,18 +210,12 @@ def _get(self, idx: TIndex, channels_idx: list[int]) -> ReaderData: ], axis=0, ).transpose() - # repeat latlon len(t_idxs) times - coords = np.vstack((latlon,) * len(t_idxs)) - - # import code - # code.interact( local=locals()) - # empty geoinfos for anemoi - # TODO: altitudes - geoinfos = np.zeros((len(data), 0), dtype=data.dtype) + # repeat len(t_idxs) times + coords = np.vstack((latlon,) * len(t_idxs)) + geoinfos = np.vstack((self.geoinfo_data,) * len(t_idxs)) # date time matching #data points of data - # Assuming a fixed frequency for the dataset datetimes = np.repeat(self.ds.time[didx_start:didx_end].values, len(data) // len(t_idxs)) rd = ReaderData( @@ -239,23 +245,20 @@ def select_channels(self, ds, ch_type: str) -> NDArray[np.int64]: """ - channels_file = [k for k in ds.keys()][self.offset_data_channels :] - - channels = self.stream_info.get(ch_type, channels_file) - channels_exclude = self.stream_info.get(ch_type + "_exclude", []) + channels = self.stream_info.get(ch_type) + assert not channels, f"{ch_type} channels need to be specified" # sanity check is_empty = len(channels) == 0 if channels is not None else False if is_empty: stream_name = self.stream_info["name"] _logger.warning(f"No channel for {stream_name} for {ch_type}.") - chs_idx = np.sort([channels_file.index(ch) for ch in channels]) - chs_idx_exclude = np.sort([channels_file.index(ch) for ch in channels_exclude]) - chs_idx = [idx for idx in chs_idx if idx not in chs_idx_exclude] + chs_idx = np.sort([self.channels_file.index(ch) for ch in channels]) return np.array(chs_idx) + self.offset_data_channels +# TODO: move to base class def _clip_lat(lats: NDArray) -> NDArray[np.float32]: """ Clip latitudes to the range [-90, 90] and ensure periodicity. @@ -263,6 +266,7 @@ def _clip_lat(lats: NDArray) -> NDArray[np.float32]: return (2 * np.clip(lats, -90.0, 90.0) - lats).astype(np.float32) +# TODO: move to base class def _clip_lon(lons: NDArray) -> NDArray[np.float32]: """ Clip longitudes to the range [-180, 180] and ensure periodicity. From 6bf185cf6d6228c9114407b7997b7278ab69f3df Mon Sep 17 00:00:00 2001 From: Christian Lessig Date: Thu, 4 Sep 2025 20:24:02 +0000 Subject: [PATCH 5/5] Fixing in handling of normalization --- src/weathergen/datasets/data_reader_synop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/weathergen/datasets/data_reader_synop.py b/src/weathergen/datasets/data_reader_synop.py index a37d79162..913efb624 100644 --- a/src/weathergen/datasets/data_reader_synop.py +++ b/src/weathergen/datasets/data_reader_synop.py @@ -123,15 +123,15 @@ def __init__( # TODO: this should be stored/cached self.mean, self.stdev = self._compute_mean_stdev() + self.mean_geoinfo = self.mean[self.geoinfo_idx] + self.stdev_geoinfo = self.stdev[self.geoinfo_idx] def _compute_mean_stdev(self) -> (np.array, np.array): _logger.info("Starting computation of mean and stdev.") - mean = [0.0 for _ in range(self.offset_data_channels)] - stdev = [1.0 for _ in range(self.offset_data_channels)] + mean, stdev = [], [] - data_channels_file = [k for k in self.ds.keys()][self.offset_data_channels :] - for ch in data_channels_file: + for ch in self.channels_file: data = np.array(self.ds[ch], np.float64) mask = data == self.fillvalue data[mask] = np.nan @@ -246,7 +246,7 @@ def select_channels(self, ds, ch_type: str) -> NDArray[np.int64]: """ channels = self.stream_info.get(ch_type) - assert not channels, f"{ch_type} channels need to be specified" + assert channels is not None, f"{ch_type} channels need to be specified" # sanity check is_empty = len(channels) == 0 if channels is not None else False if is_empty: @@ -255,7 +255,7 @@ def select_channels(self, ds, ch_type: str) -> NDArray[np.int64]: chs_idx = np.sort([self.channels_file.index(ch) for ch in channels]) - return np.array(chs_idx) + self.offset_data_channels + return np.array(chs_idx) # TODO: move to base class