22import pandas as pd
33import xarray as xr
44import numpy as np
5+ from typing import Any
56from hat .core import load_da
67
78from hat import _LOGGER as logger
@@ -30,7 +31,8 @@ def construct_mask(x_indices, y_indices, shape):
3031def create_mask_from_index (df , shape ):
3132 logger .info (f"Creating mask { shape } from index" )
3233 logger .debug (f"DataFrame columns: { df .columns .tolist ()} " )
33- x_indices , y_indices = df ["x_index" ].values , df ["y_index" ].values
34+ x_indices = df ["x_index" ].values
35+ y_indices = df ["y_index" ].values
3436 if np .any (x_indices < 0 ) or np .any (x_indices >= shape [0 ]) or np .any (y_indices < 0 ) or np .any (y_indices >= shape [1 ]):
3537 raise ValueError (
3638 f"Station indices out of grid bounds. Grid shape={ shape } , "
@@ -56,7 +58,7 @@ def create_mask_from_coords(df, gridx, gridy, shape):
5658 return mask , duplication_indexes
5759
5860
59- def parse_stations (station_config ) :
61+ def parse_stations (station_config : dict [ str , Any ]) -> pd . DataFrame :
6062 """Read, filter, and normalize station DataFrame to canonical column names."""
6163 logger .debug (f"Reading station file, { station_config } " )
6264 if "name" not in station_config :
@@ -116,12 +118,12 @@ def parse_stations(station_config):
116118 return df_renamed
117119
118120
119- def _process_gribjump (grid_config , df ) :
121+ def _process_gribjump (grid_config : dict [ str , Any ], df : pd . DataFrame ) -> xr . Dataset :
120122 if "index_1d" not in df .columns :
121123 raise ValueError ("Gribjump source requires 'index_1d' in station config." )
122124
123125 station_names = df ["station_name" ].values
124- unique_indices , duplication_indexes = np .unique (df ["index_1d" ].values , return_inverse = True )
126+ unique_indices , duplication_indexes = np .unique (df ["index_1d" ].values , return_inverse = True ) # type: ignore[call-overload]
125127 # TODO: Double-check this. Converting indices to ranges is currently
126128 # faster than using indices directly, should be fixed in the gribjump
127129 # source.
@@ -138,7 +140,7 @@ def _process_gribjump(grid_config, df):
138140 return ds
139141
140142
141- def _process_regular (grid_config , df ) :
143+ def _process_regular (grid_config : dict [ str , Any ], df : pd . DataFrame ) -> xr . Dataset :
142144 station_names = df ["station_name" ].values
143145 da , var_name , x_dim , y_dim , shape = process_grid_inputs (grid_config )
144146
@@ -159,7 +161,7 @@ def _process_regular(grid_config, df):
159161 return ds
160162
161163
162- def process_inputs (station_config , grid_config ) :
164+ def process_inputs (station_config : dict [ str , Any ], grid_config : dict [ str , Any ]) -> xr . Dataset :
163165 df = parse_stations (station_config )
164166 if "gribjump" in grid_config .get ("source" , {}):
165167 return _process_gribjump (grid_config , df )
@@ -189,7 +191,7 @@ def apply_mask(da, mask, coordx, coordy):
189191 return task .compute ()
190192
191193
192- def extractor (config ) :
194+ def extractor (config : dict [ str , Any ]) -> xr . Dataset :
193195 ds = process_inputs (config ["station" ], config ["grid" ])
194196 if config .get ("output" , None ) is not None :
195197 logger .info (f"Saving output to { config ['output' ]['file' ]} " )
0 commit comments