diff --git a/src/vxingest/partial_sums_to_cb/partial_sums_builder.py b/src/vxingest/partial_sums_to_cb/partial_sums_builder.py index 65d53aa3..b65c6e9f 100644 --- a/src/vxingest/partial_sums_to_cb/partial_sums_builder.py +++ b/src/vxingest/partial_sums_to_cb/partial_sums_builder.py @@ -9,6 +9,7 @@ import copy import cProfile import datetime as dt +import json import logging import re import sys @@ -292,16 +293,24 @@ def handle_named_function(self, named_function_def): func = parts[0].replace("&", "") params = [] if len(parts) > 1: - params = parts[1].split(",") - dict_params = {} - for _p in params: - # be sure to slice the * off of the front of the param - # translate_template_item returns an array of tuples - value,interp_value, one for each station - # ordered by domain_stations. - if _p[0] == "&" or _p[0] == "*": - dict_params[_p[1:]] = self.translate_template_item(_p) + if parts[1][0] == "{": + params = json.loads( + parts[1].replace("'", '"') + ) # json loads requires double quotes around key/val strings else: - dict_params[_p] = self.translate_template_item(_p) + params = parts[1].split(",") + if isinstance(params, dict): + dict_params = params + else: + dict_params = {} + for _p in params: + # be sure to slice the * off of the front of the param + # translate_template_item returns an array of tuples - value,interp_value, one for each station + # ordered by domain_stations. + if _p[0] == "&" or _p[0] == "*": + dict_params[_p[1:]] = self.translate_template_item(_p) + else: + dict_params[_p] = self.translate_template_item(_p) # call the named function using getattr replace_with = getattr(self, func)(dict_params) except Exception as _e: @@ -842,20 +851,37 @@ def get_document_map(self): # named functions def handle_sum(self, params_dict): - """calculate sums for a given data set - i.e. model, region, fcstValidEpoch, fcstLen""" + """Calculate partial sums on matching model & obs values + for a given data set - i.e. model, region, fcstValidEpoch, fcstLen + + Args: + params_dict (dict): Expects one of the following formats: + {'var_name': 'var_name'} (has one item) + {'model': 'model_var_name', 'obs': 'obs_var_name'} + Returns: + dict of calculated sum stats + """ try: - keys = list(params_dict.keys()) - variable = keys[0] + if "model" in params_dict: + model_var_name = params_dict["model"] + else: + model_var_name = list(params_dict.keys())[0] + if "obs" in params_dict: + obs_var_name = params_dict["obs"] + else: + obs_var_name = model_var_name + obs_vals = [] model_vals = [] diff_vals = [] diff_vals_squared = [] abs_diff_vals = [] + for name in self.domain_stations: if name in self.obs_data and name in self.model_data["data"]: obs_elem = self.obs_data[name] model_elem = self.model_data["data"][name] - if variable == "RH": + if obs_var_name == "RH" or model_var_name == "RH": if ( "RH" not in obs_elem and obs_elem["DewPoint"] is not None @@ -878,7 +904,9 @@ def handle_sum(self, params_dict): model_elem["DewPoint"] * units.degF, ).magnitude ) * 100 - if variable == "UW" or variable == "VW": + if (obs_var_name == "UW" or model_var_name == "UW") or ( + obs_var_name == "VW" or model_var_name == "VW" + ): # wind direction in the data is from 0 to 360 and we need it from -180 to 180 if ( ("UW" not in obs_elem or "VW" not in obs_elem) @@ -902,8 +930,8 @@ def handle_sum(self, params_dict): ) model_elem["UW"] = wind_components_t[0].magnitude model_elem["VW"] = wind_components_t[1].magnitude - obs_var = obs_elem.get(variable) - model_var = model_elem.get(variable) + obs_var = obs_elem.get(obs_var_name) + model_var = model_elem.get(model_var_name) # If there is no observation or model data for this variable for this station, skip it if obs_var is not None and model_var is not None: obs_vals.append(obs_var) diff --git a/tests/vxingest/partial_sums_to_cb/test_unit_partial_sums_builder.py b/tests/vxingest/partial_sums_to_cb/test_unit_partial_sums_builder.py new file mode 100644 index 00000000..3398f0e2 --- /dev/null +++ b/tests/vxingest/partial_sums_to_cb/test_unit_partial_sums_builder.py @@ -0,0 +1,87 @@ +import pytest + +from vxingest.partial_sums_to_cb.partial_sums_builder import ( + PartialSumsSurfaceModelObsBuilderV01, +) + + +@pytest.fixture +def dummy_builder(): + load_spec = "load_spec" + ingest_document = {"template": ""} + return PartialSumsSurfaceModelObsBuilderV01(load_spec, ingest_document) + + +@pytest.fixture +def model_data(): + data = { + "data": { + "KAAA": { + "temperature": 25, + "temperature_adj": 22, + }, + "KBBB": { + "temperature": 10, + "temperature_adj": 14, + }, + } + } + return data + + +@pytest.fixture +def obs_data(): + data = { + "KAAA": { + "temperature": 23, + }, + "KBBB": { + "temperature": 15, + }, + } + return data + + +def test_handle_sum_simple_param(dummy_builder, model_data, obs_data): + """Test that handle_sum() returns correct values when passed a params_dict + with a single variable""" + + builder = dummy_builder + builder.domain_stations = ["KAAA", "KBBB"] + builder.obs_data = obs_data + builder.model_data = model_data + params_dict = {"temperature": "temperature"} + sums = builder.handle_sum(params_dict) + + assert sums == { + "num_recs": 2, + "sum_obs": 38, + "sum_model": 35, + "sum_diff": -3, + "sum2_diff": 29, + "sum_abs": 7, + } + + +def test_handle_sum_obj_param(dummy_builder, model_data, obs_data): + """Test that handle_sum() returns correct values when passed a params_dict + with a dict of model and obs variables""" + + builder = dummy_builder + builder.domain_stations = ["KAAA", "KBBB"] + builder.obs_data = obs_data + builder.model_data = model_data + params_dict = { + "model": "temperature_adj", + "obs": "temperature", + } + sums = builder.handle_sum(params_dict) + + assert sums == { + "num_recs": 2, + "sum_obs": 38, + "sum_model": 36, + "sum_diff": -2, + "sum2_diff": 2, + "sum_abs": 2, + }