diff --git a/src/swell/tasks/get_observations.py b/src/swell/tasks/get_observations.py index f8c561450..1e1e54f45 100644 --- a/src/swell/tasks/get_observations.py +++ b/src/swell/tasks/get_observations.py @@ -14,6 +14,7 @@ import r2d2 import shutil from typing import Union +from multiprocessing import Pool from datetime import timedelta, datetime as dt from swell.tasks.base.task_base import taskBase @@ -29,8 +30,60 @@ 'geos_marine': 'mom6', } -# -------------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------- + + +def run_r2d2_fetch(r2d2_dict: dict) -> None: + + """Runs fetch command for all types of obs files + Arguments: + r2d2_dict: Dictionary of r2d2 fetch parameters, ALSO including additional information including: + **r2d2_dict['fetch_empty']: bool whether fetching empty obs file is appropriate + **r2d2_dict['cycle_dir']: Experiment cycle directory + **r2d2_dict['logger']: Swell logger + (specified this way for multiprocessing) + + These values will be popped from the dictionary before running the fetch command + """ + + fetch_empty_obs = r2d2_dict.pop('fetch_empty', False) + cycle_dir = r2d2_dict.pop('cycle_dir') + logger = r2d2_dict.pop('logger') + + target_file = r2d2_dict['target_file'] + + try: + r2d2.fetch(**r2d2_dict) + logger.info(f"Successfully fetched {target_file}") + except Exception as e: + # If this can be an empty obs file, fetch or copy empty file to the target file + if fetch_empty_obs: + logger.info(f"Failed to fetch {target_file}. Fetch empty observation instead.") + empty_obs_file = os.path.join(cycle_dir, 'empty_obs.nc4') + if not os.path.exists(empty_obs_file): + # fetch empty obs, if it doesn't exist + r2d2.fetch( + item='observation', + provider='empty_provider', + observation_type='empty_type', + file_extension='nc4', + window_start='19700101T030000Z', + window_length='PT6H', + target_file=empty_obs_file, + ) + + # Copy the empty file to the target file directory + shutil.copy(empty_obs_file, target_file) + + else: + raise Exception(e) + + # Change the permissions + os.chmod(target_file, 0o644) + + +# -------------------------------------------------------------------------------------------------- class GetObservations(taskBase): @@ -147,26 +200,31 @@ def execute(self) -> None: # Read observation ioda names ioda_names_list = get_ioda_names_list() + # Create a dictionary of all fetch criteria + # ----------------------------------------- + r2d2_fetch_dicts = [] + + # Dictionary tracking all observation files + # ----------------------------------------- + observation_dicts = {} + # Loop over observation operators # ------------------------------- for observation in observations: # Open the observation operator dictionary # ---------------------------------------- - observation_dict = self.jedi_rendering.render_interface_observations(observation) + observation_dicts[observation] = observation_dict = \ + self.jedi_rendering.render_interface_observations(observation) # Get the set obs providers for each observation # ---------------------------------------------- obs_provider = get_provider_for_observation(observation, ioda_names_list, self.logger) - # Fetch observation files - # ----------------------- - combine_input_files = [] # Here, we are fetching for obs_num, obs_time in enumerate(obs_list_dto): obs_window_begin = dt.strftime(obs_time, datetime_formats['iso_format']) target_file = os.path.join(self.cycle_dir(), f'{observation}.{obs_num}.nc4') - combine_input_files.append(target_file) fetch_criteria = { 'item': 'observation', # Required for r2d2 v3 @@ -176,46 +234,10 @@ def execute(self) -> None: 'window_start': obs_window_begin, # From filename timestamp 'window_length': obs_window_length, # From filename 'target_file': target_file, # Where to save + 'fetch_empty': True } - try: - r2d2.fetch(**fetch_criteria) - self.logger.info(f"Successfully fetched {target_file}") - except Exception: - self.logger.info( - f"Failed to fetch {target_file}. " - "Fetch empty observation instead." - ) - - # fetch empty obs - r2d2.fetch( - item='observation', - provider='empty_provider', - observation_type='empty_type', - file_extension='nc4', - window_start='19700101T030000Z', - window_length='PT6H', - target_file=target_file, - ) - - # Check how many of the combine_input_files exist in the cycle directory. - # If all of them are missing proceed without creating an observation input - # file since bias correction files still need to be propagated to the next cycle - # for cycling VarBC. - # ----------------------------------------------------------------------- - if not any([os.path.exists(f) for f in combine_input_files]): - self.logger.info(f'None of the {observation} files exist for this cycle!') - else: - jedi_obs_file = observation_dict['obs space']['obsdatain']['engine']['obsfile'] - self.logger.info(f'Processing observation file {jedi_obs_file}') - # If obs_list_dto has one member, then just rename the file - # --------------------------------------------------------- - if len(obs_list_dto) == 1: - os.rename(combine_input_files[0], jedi_obs_file) - else: - self.read_and_combine(combine_input_files, jedi_obs_file) - # Change permission - os.chmod(jedi_obs_file, 0o644) + r2d2_fetch_dicts.append(fetch_criteria) # Otherwise there is only work to do if the observation operator has bias correction # ---------------------------------------------------------------------------------- @@ -269,32 +291,29 @@ def execute(self) -> None: if fetch_required: # Fetch coefficients file (.acftbias or .satbias) self.logger.info(f'Processing bias file {target_bccoef}') - r2d2.fetch( - item='bias_correction', - target_file=target_bccoef, - model=r2d2_model, - experiment=obs_experiment, - provider='gsi', - observation_type=observation, - file_extension=bias_file_ext, - file_type=bias_file_type, - date=background_time_iso - ) - - r2d2.fetch( - item='bias_correction', - target_file=target_bccovr, - model=r2d2_model, - experiment=obs_experiment, - provider='gsi', - observation_type=observation, - file_extension=bias_file_ext + '_cov', - file_type=bias_err_type, # obsbias_coeff_errors Official JCSDA enum - date=background_time_iso - ) - # Change permission - os.chmod(target_bccoef, 0o644) - os.chmod(target_bccovr, 0o644) + r2d2_fetch_dicts.append({ + 'item': 'bias_correction', + 'target_file': target_bccoef, + 'model': r2d2_model, + 'experiment': obs_experiment, + 'provider': 'gsi', + 'observation_type': observation, + 'file_extension': bias_file_ext, + 'file_type': bias_file_type, + 'date': background_time_iso + }) + + r2d2_fetch_dicts.append({ + 'item': 'bias_correction', + 'target_file': target_bccovr, + 'model': r2d2_model, + 'experiment': obs_experiment, + 'provider': 'gsi', + 'observation_type': observation, + 'file_extension': bias_file_ext + '_cov', + 'file_type': bias_err_type, # obsbias_coeff_errors Official JCSDA enum + 'date': background_time_iso + }) # Skip time lapse part for aircraft observations # ---------------------------------------------- @@ -307,20 +326,62 @@ def execute(self) -> None: self.logger.info(f'Processing satellite time lapse file {target_file}') - r2d2.fetch( - item='bias_correction', - target_file=target_file, - model=r2d2_model, - experiment=obs_experiment, - provider='gsi', - observation_type=observation, - file_extension='tlapse', - file_type='obsbias_tlapse', # Official JCSDA enum - date=background_time_iso - ) + r2d2_fetch_dicts.append({ + 'item': 'bias_correction', + 'target_file': target_file, + 'model': r2d2_model, + 'experiment': obs_experiment, + 'provider': 'gsi', + 'observation_type': observation, + 'file_extension': 'tlapse', + 'file_type': 'obsbias_tlapse', # Official JCSDA enum + 'date': background_time_iso + }) + + for fetch_dict in r2d2_fetch_dicts: + fetch_dict['logger'] = self.logger + fetch_dict['cycle_dir'] = self.cycle_dir() + + # Run through all files to fetch + # ------------------------------ + number_of_workers = 4 + self.logger.info(f'Fetching observations in parallel with {number_of_workers} workers') + with Pool(processes=number_of_workers) as pool: + pool.map(run_r2d2_fetch, r2d2_fetch_dicts) + + # Iterate through observation files to read and combine + # ----------------------------------------------------- + for observation in observations: + observation_dict = observation_dicts[observation] + + # Fetch observation files + # ----------------------- + combine_input_files = [] + + for obs_num, obs_time in enumerate(obs_list_dto): + obs_window_begin = dt.strftime(obs_time, datetime_formats['iso_format']) + target_file = os.path.join(self.cycle_dir(), f'{observation}.{obs_num}.nc4') + combine_input_files.append(target_file) + + # Check how many of the combine_input_files exist in the cycle directory. + # If all of them are missing proceed without creating an observation input + # file since bias correction files still need to be propagated to the next cycle + # for cycling VarBC. + # ----------------------------------------------------------------------- + if not any([os.path.exists(f) for f in combine_input_files]): + self.logger.info(f'None of the {observation} files exist for this cycle!') + else: + jedi_obs_file = observation_dict['obs space']['obsdatain']['engine']['obsfile'] + self.logger.info(f'Processing observation file {jedi_obs_file}') + # If obs_list_dto has one member, then just rename the file + # --------------------------------------------------------- + if len(obs_list_dto) == 1: + os.rename(combine_input_files[0], jedi_obs_file) + else: + self.read_and_combine(combine_input_files, jedi_obs_file) # Change permission - os.chmod(target_file, 0o644) + os.chmod(jedi_obs_file, 0o644) # ----------------------------------------------------------------------------------------------