Skip to content

Commit f1c6953

Browse files
feat(extractor): add type hints to some functions
1 parent 8a6973c commit f1c6953

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

hat/extract_timeseries/extractor.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pandas as pd
33
import xarray as xr
44
import numpy as np
5+
from typing import Any
56
from hat.core import load_da
67

78
from hat import _LOGGER as logger
@@ -30,7 +31,8 @@ def construct_mask(x_indices, y_indices, shape):
3031
def create_mask_from_index(df, shape):
3132
logger.info(f"Creating mask {shape} from index")
3233
logger.debug(f"DataFrame columns: {df.columns.tolist()}")
33-
x_indices, y_indices = df["x_index"].values, df["y_index"].values
34+
x_indices = df["x_index"].values
35+
y_indices = df["y_index"].values
3436
if np.any(x_indices < 0) or np.any(x_indices >= shape[0]) or np.any(y_indices < 0) or np.any(y_indices >= shape[1]):
3537
raise ValueError(
3638
f"Station indices out of grid bounds. Grid shape={shape}, "
@@ -56,7 +58,7 @@ def create_mask_from_coords(df, gridx, gridy, shape):
5658
return mask, duplication_indexes
5759

5860

59-
def parse_stations(station_config):
61+
def parse_stations(station_config: dict[str, Any]) -> pd.DataFrame:
6062
"""Read, filter, and normalize station DataFrame to canonical column names."""
6163
logger.debug(f"Reading station file, {station_config}")
6264
if "name" not in station_config:
@@ -116,12 +118,12 @@ def parse_stations(station_config):
116118
return df_renamed
117119

118120

119-
def _process_gribjump(grid_config, df):
121+
def _process_gribjump(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset:
120122
if "index_1d" not in df.columns:
121123
raise ValueError("Gribjump source requires 'index_1d' in station config.")
122124

123125
station_names = df["station_name"].values
124-
unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True)
126+
unique_indices, duplication_indexes = np.unique(df["index_1d"].values, return_inverse=True) # type: ignore[call-overload]
125127
# TODO: Double-check this. Converting indices to ranges is currently
126128
# faster than using indices directly, should be fixed in the gribjump
127129
# source.
@@ -138,7 +140,7 @@ def _process_gribjump(grid_config, df):
138140
return ds
139141

140142

141-
def _process_regular(grid_config, df):
143+
def _process_regular(grid_config: dict[str, Any], df: pd.DataFrame) -> xr.Dataset:
142144
station_names = df["station_name"].values
143145
da, var_name, x_dim, y_dim, shape = process_grid_inputs(grid_config)
144146

@@ -159,7 +161,7 @@ def _process_regular(grid_config, df):
159161
return ds
160162

161163

162-
def process_inputs(station_config, grid_config):
164+
def process_inputs(station_config: dict[str, Any], grid_config: dict[str, Any]) -> xr.Dataset:
163165
df = parse_stations(station_config)
164166
if "gribjump" in grid_config.get("source", {}):
165167
return _process_gribjump(grid_config, df)
@@ -189,7 +191,7 @@ def apply_mask(da, mask, coordx, coordy):
189191
return task.compute()
190192

191193

192-
def extractor(config):
194+
def extractor(config: dict[str, Any]) -> xr.Dataset:
193195
ds = process_inputs(config["station"], config["grid"])
194196
if config.get("output", None) is not None:
195197
logger.info(f"Saving output to {config['output']['file']}")

0 commit comments

Comments
 (0)