From e3198599908ad2a6c13bbf0dc3d0657386162430 Mon Sep 17 00:00:00 2001 From: Michael Anstett Date: Wed, 4 Mar 2026 13:46:16 -0500 Subject: [PATCH 1/5] Restructure to do parallel fetch of obs files --- src/swell/tasks/get_observations.py | 200 ++++++++++++++++------------ 1 file changed, 118 insertions(+), 82 deletions(-) diff --git a/src/swell/tasks/get_observations.py b/src/swell/tasks/get_observations.py index f8c561450..96d1134c9 100644 --- a/src/swell/tasks/get_observations.py +++ b/src/swell/tasks/get_observations.py @@ -14,6 +14,7 @@ import r2d2 import shutil from typing import Union +from multiprocessing import Pool from datetime import timedelta, datetime as dt from swell.tasks.base.task_base import taskBase @@ -29,8 +30,8 @@ 'geos_marine': 'mom6', } -# -------------------------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------------- class GetObservations(taskBase): @@ -147,26 +148,28 @@ def execute(self) -> None: # Read observation ioda names ioda_names_list = get_ioda_names_list() + # Create a dictionary of all fetch criteria + # ----------------------------------------- + r2d2_fetch_dicts = [] + + observation_dicts = {} + # Loop over observation operators # ------------------------------- for observation in observations: # Open the observation operator dictionary # ---------------------------------------- - observation_dict = self.jedi_rendering.render_interface_observations(observation) + observation_dicts[observation] = observation_dict = self.jedi_rendering.render_interface_observations(observation) # Get the set obs providers for each observation # ---------------------------------------------- obs_provider = get_provider_for_observation(observation, ioda_names_list, self.logger) - # Fetch observation files - # ----------------------- - combine_input_files = [] # Here, we are fetching for obs_num, obs_time in enumerate(obs_list_dto): obs_window_begin = dt.strftime(obs_time, datetime_formats['iso_format']) target_file = os.path.join(self.cycle_dir(), f'{observation}.{obs_num}.nc4') - combine_input_files.append(target_file) fetch_criteria = { 'item': 'observation', # Required for r2d2 v3 @@ -176,46 +179,10 @@ def execute(self) -> None: 'window_start': obs_window_begin, # From filename timestamp 'window_length': obs_window_length, # From filename 'target_file': target_file, # Where to save + 'fetch_empty': True } - try: - r2d2.fetch(**fetch_criteria) - self.logger.info(f"Successfully fetched {target_file}") - except Exception: - self.logger.info( - f"Failed to fetch {target_file}. " - "Fetch empty observation instead." - ) - - # fetch empty obs - r2d2.fetch( - item='observation', - provider='empty_provider', - observation_type='empty_type', - file_extension='nc4', - window_start='19700101T030000Z', - window_length='PT6H', - target_file=target_file, - ) - - # Check how many of the combine_input_files exist in the cycle directory. - # If all of them are missing proceed without creating an observation input - # file since bias correction files still need to be propagated to the next cycle - # for cycling VarBC. - # ----------------------------------------------------------------------- - if not any([os.path.exists(f) for f in combine_input_files]): - self.logger.info(f'None of the {observation} files exist for this cycle!') - else: - jedi_obs_file = observation_dict['obs space']['obsdatain']['engine']['obsfile'] - self.logger.info(f'Processing observation file {jedi_obs_file}') - # If obs_list_dto has one member, then just rename the file - # --------------------------------------------------------- - if len(obs_list_dto) == 1: - os.rename(combine_input_files[0], jedi_obs_file) - else: - self.read_and_combine(combine_input_files, jedi_obs_file) - # Change permission - os.chmod(jedi_obs_file, 0o644) + r2d2_fetch_dicts.append(fetch_criteria) # Otherwise there is only work to do if the observation operator has bias correction # ---------------------------------------------------------------------------------- @@ -269,32 +236,29 @@ def execute(self) -> None: if fetch_required: # Fetch coefficients file (.acftbias or .satbias) self.logger.info(f'Processing bias file {target_bccoef}') - r2d2.fetch( - item='bias_correction', - target_file=target_bccoef, - model=r2d2_model, - experiment=obs_experiment, - provider='gsi', - observation_type=observation, - file_extension=bias_file_ext, - file_type=bias_file_type, - date=background_time_iso - ) - - r2d2.fetch( - item='bias_correction', - target_file=target_bccovr, - model=r2d2_model, - experiment=obs_experiment, - provider='gsi', - observation_type=observation, - file_extension=bias_file_ext + '_cov', - file_type=bias_err_type, # obsbias_coeff_errors Official JCSDA enum - date=background_time_iso - ) - # Change permission - os.chmod(target_bccoef, 0o644) - os.chmod(target_bccovr, 0o644) + r2d2_fetch_dicts.append({ + 'item': 'bias_correction', + 'target_file': target_bccoef, + 'model': r2d2_model, + 'experiment': obs_experiment, + 'provider': 'gsi', + 'observation_type': observation, + 'file_extension': bias_file_ext, + 'file_type': bias_file_type, + 'date': background_time_iso + }) + + r2d2_fetch_dicts.append({ + 'item': 'bias_correction', + 'target_file': target_bccovr, + 'model': r2d2_model, + 'experiment': obs_experiment, + 'provider': 'gsi', + 'observation_type': observation, + 'file_extension': bias_file_ext + '_cov', + 'file_type': bias_err_type, # obsbias_coeff_errors Official JCSDA enum + 'date': background_time_iso + }) # Skip time lapse part for aircraft observations # ---------------------------------------------- @@ -307,20 +271,92 @@ def execute(self) -> None: self.logger.info(f'Processing satellite time lapse file {target_file}') - r2d2.fetch( - item='bias_correction', - target_file=target_file, - model=r2d2_model, - experiment=obs_experiment, - provider='gsi', - observation_type=observation, - file_extension='tlapse', - file_type='obsbias_tlapse', # Official JCSDA enum - date=background_time_iso - ) + r2d2_fetch_dicts.append({ + 'item': 'bias_correction', + 'target_file': target_file, + 'model': r2d2_model, + 'experiment': obs_experiment, + 'provider': 'gsi', + 'observation_type': observation, + 'file_extension': 'tlapse', + 'file_type': 'obsbias_tlapse', # Official JCSDA enum + 'date': background_time_iso + }) + + # Run through all files to fetch + # ------------------------------ + number_of_workers = 4 + self.logger.info(f'Running parallel plot generation with {number_of_workers} workers') + with Pool(processes=number_of_workers) as pool: + pool.map(self.run_r2d2_fetch, r2d2_fetch_dicts) + + # Iterate through observation files to read and combine + # ----------------------------------------------------- + for observation in observations: + observation_dict = observation_dicts[observation] + + # Fetch observation files + # ----------------------- + combine_input_files = [] + # Here, we are fetching + for obs_num, obs_time in enumerate(obs_list_dto): + obs_window_begin = dt.strftime(obs_time, datetime_formats['iso_format']) + target_file = os.path.join(self.cycle_dir(), f'{observation}.{obs_num}.nc4') + combine_input_files.append(target_file) + + # Check how many of the combine_input_files exist in the cycle directory. + # If all of them are missing proceed without creating an observation input + # file since bias correction files still need to be propagated to the next cycle + # for cycling VarBC. + # ----------------------------------------------------------------------- + if not any([os.path.exists(f) for f in combine_input_files]): + self.logger.info(f'None of the {observation} files exist for this cycle!') + else: + jedi_obs_file = observation_dict['obs space']['obsdatain']['engine']['obsfile'] + self.logger.info(f'Processing observation file {jedi_obs_file}') + # If obs_list_dto has one member, then just rename the file + # --------------------------------------------------------- + if len(obs_list_dto) == 1: + os.rename(combine_input_files[0], jedi_obs_file) + else: + self.read_and_combine(combine_input_files, jedi_obs_file) # Change permission - os.chmod(target_file, 0o644) + os.chmod(jedi_obs_file, 0o644) + + # ---------------------------------------------------------------------------------------------- + + def run_r2d2_fetch(self, r2d2_dict: dict) -> None: + fetch_empty_obs = r2d2_dict.pop('fetch_empty', False) + + target_file = r2d2_dict['target_file'] + + try: + r2d2.fetch(**r2d2_dict) + self.logger.info(f"Successfully fetched {target_file}") + except Exception as e: + # If this is + if fetch_empty_obs: + self.logger.info(f"Failed to fetch {target_file}. Fetch empty observation instead.") + empty_obs_file = os.path.join(self.cycle_dir(), 'empty_obs.nc4') + if not os.path.exists(empty_obs_file): + # fetch empty obs + r2d2.fetch( + item='observation', + provider='empty_provider', + observation_type='empty_type', + file_extension='nc4', + window_start='19700101T030000Z', + window_length='PT6H', + target_file=empty_obs_file, + ) + + shutil.copy(empty_obs_file, target_file) + + else: + raise Exception(e) + + os.chmod(target_file, 0o644) # ---------------------------------------------------------------------------------------------- From c1ea95da819dd2bb043192f54476f51ddb08a2bf Mon Sep 17 00:00:00 2001 From: Michael Anstett Date: Wed, 4 Mar 2026 14:06:14 -0500 Subject: [PATCH 2/5] Fix multiprocessing --- src/swell/tasks/get_observations.py | 77 ++++++++++++++++------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/src/swell/tasks/get_observations.py b/src/swell/tasks/get_observations.py index 96d1134c9..90ca3b435 100644 --- a/src/swell/tasks/get_observations.py +++ b/src/swell/tasks/get_observations.py @@ -30,6 +30,43 @@ 'geos_marine': 'mom6', } + # ---------------------------------------------------------------------------------------------- + +def run_r2d2_fetch(r2d2_dict: dict) -> None: + + fetch_empty_obs = r2d2_dict.pop('fetch_empty', False) + cycle_dir = r2d2_dict.pop('cycle_dir') + logger = r2d2_dict.pop('logger') + + target_file = r2d2_dict['target_file'] + + try: + r2d2.fetch(**r2d2_dict) + logger.info(f"Successfully fetched {target_file}") + except Exception as e: + # If this is + if fetch_empty_obs: + logger.info(f"Failed to fetch {target_file}. Fetch empty observation instead.") + empty_obs_file = os.path.join(cycle_dir, 'empty_obs.nc4') + if not os.path.exists(empty_obs_file): + # fetch empty obs + r2d2.fetch( + item='observation', + provider='empty_provider', + observation_type='empty_type', + file_extension='nc4', + window_start='19700101T030000Z', + window_length='PT6H', + target_file=empty_obs_file, + ) + + shutil.copy(empty_obs_file, target_file) + + else: + raise Exception(e) + + os.chmod(target_file, 0o644) + # -------------------------------------------------------------------------------------------------- @@ -283,12 +320,16 @@ def execute(self) -> None: 'date': background_time_iso }) + for fetch_dict in r2d2_fetch_dicts: + fetch_dict['logger'] = self.logger + fetch_dict['cycle_dir'] = self.cycle_dir() + # Run through all files to fetch # ------------------------------ number_of_workers = 4 self.logger.info(f'Running parallel plot generation with {number_of_workers} workers') with Pool(processes=number_of_workers) as pool: - pool.map(self.run_r2d2_fetch, r2d2_fetch_dicts) + pool.map(run_r2d2_fetch, r2d2_fetch_dicts) # Iterate through observation files to read and combine # ----------------------------------------------------- @@ -326,40 +367,6 @@ def execute(self) -> None: # ---------------------------------------------------------------------------------------------- - def run_r2d2_fetch(self, r2d2_dict: dict) -> None: - fetch_empty_obs = r2d2_dict.pop('fetch_empty', False) - - target_file = r2d2_dict['target_file'] - - try: - r2d2.fetch(**r2d2_dict) - self.logger.info(f"Successfully fetched {target_file}") - except Exception as e: - # If this is - if fetch_empty_obs: - self.logger.info(f"Failed to fetch {target_file}. Fetch empty observation instead.") - empty_obs_file = os.path.join(self.cycle_dir(), 'empty_obs.nc4') - if not os.path.exists(empty_obs_file): - # fetch empty obs - r2d2.fetch( - item='observation', - provider='empty_provider', - observation_type='empty_type', - file_extension='nc4', - window_start='19700101T030000Z', - window_length='PT6H', - target_file=empty_obs_file, - ) - - shutil.copy(empty_obs_file, target_file) - - else: - raise Exception(e) - - os.chmod(target_file, 0o644) - - # ---------------------------------------------------------------------------------------------- - def get_tlapse_files(self, observation_dict: dict) -> Union[None, int]: # Function to locate instances of tlapse in the obs operator config From a2157a1914da1afd8348b730c438041af491295c Mon Sep 17 00:00:00 2001 From: Michael Anstett Date: Wed, 4 Mar 2026 14:24:33 -0500 Subject: [PATCH 3/5] Add comments --- src/swell/tasks/get_observations.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/swell/tasks/get_observations.py b/src/swell/tasks/get_observations.py index 90ca3b435..6ca36ab38 100644 --- a/src/swell/tasks/get_observations.py +++ b/src/swell/tasks/get_observations.py @@ -30,10 +30,21 @@ 'geos_marine': 'mom6', } - # ---------------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------------------- def run_r2d2_fetch(r2d2_dict: dict) -> None: + """Runs fetch command for all types of obs files + + Arguments: + r2d2_dict: Dictionary of r2d2 fetch parameters, ALSO including additional information including: + **r2d2_dict['fetch_empty']: bool whether fetching empty obs file is appropriate + **r2d2_dict['cycle_dir']: Experiment cycle directory (has to be specified this way for multiprocessing) + **r2d2_dict['logger']: Swell logger (has to be specified this way for multiprocessing) + + These values will be popped from the dictionary before running the fetch command + """ + fetch_empty_obs = r2d2_dict.pop('fetch_empty', False) cycle_dir = r2d2_dict.pop('cycle_dir') logger = r2d2_dict.pop('logger') @@ -44,12 +55,12 @@ def run_r2d2_fetch(r2d2_dict: dict) -> None: r2d2.fetch(**r2d2_dict) logger.info(f"Successfully fetched {target_file}") except Exception as e: - # If this is + # If this can be an empty obs file, fetch or copy empty file to the target file if fetch_empty_obs: logger.info(f"Failed to fetch {target_file}. Fetch empty observation instead.") empty_obs_file = os.path.join(cycle_dir, 'empty_obs.nc4') if not os.path.exists(empty_obs_file): - # fetch empty obs + # fetch empty obs, if it doesn't exist r2d2.fetch( item='observation', provider='empty_provider', @@ -59,12 +70,14 @@ def run_r2d2_fetch(r2d2_dict: dict) -> None: window_length='PT6H', target_file=empty_obs_file, ) - + + # Copy the empty file to the target file directory shutil.copy(empty_obs_file, target_file) else: raise Exception(e) - + + # Change the permissions os.chmod(target_file, 0o644) @@ -189,6 +202,8 @@ def execute(self) -> None: # ----------------------------------------- r2d2_fetch_dicts = [] + # Dictionary tracking all observation files + # ----------------------------------------- observation_dicts = {} # Loop over observation operators @@ -340,7 +355,7 @@ def execute(self) -> None: # Fetch observation files # ----------------------- combine_input_files = [] - # Here, we are fetching + for obs_num, obs_time in enumerate(obs_list_dto): obs_window_begin = dt.strftime(obs_time, datetime_formats['iso_format']) target_file = os.path.join(self.cycle_dir(), f'{observation}.{obs_num}.nc4') From 59dc2a3ae9f94db9e1b32eaacfca8c6df10a1bf0 Mon Sep 17 00:00:00 2001 From: Michael Anstett Date: Wed, 4 Mar 2026 14:31:13 -0500 Subject: [PATCH 4/5] Code test fixes --- src/swell/tasks/get_observations.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/swell/tasks/get_observations.py b/src/swell/tasks/get_observations.py index 6ca36ab38..617157f08 100644 --- a/src/swell/tasks/get_observations.py +++ b/src/swell/tasks/get_observations.py @@ -32,6 +32,7 @@ # ---------------------------------------------------------------------------------------------- + def run_r2d2_fetch(r2d2_dict: dict) -> None: """Runs fetch command for all types of obs files @@ -39,8 +40,9 @@ def run_r2d2_fetch(r2d2_dict: dict) -> None: Arguments: r2d2_dict: Dictionary of r2d2 fetch parameters, ALSO including additional information including: **r2d2_dict['fetch_empty']: bool whether fetching empty obs file is appropriate - **r2d2_dict['cycle_dir']: Experiment cycle directory (has to be specified this way for multiprocessing) - **r2d2_dict['logger']: Swell logger (has to be specified this way for multiprocessing) + **r2d2_dict['cycle_dir']: Experiment cycle directory + **r2d2_dict['logger']: Swell logger + (specified this way for multiprocessing) These values will be popped from the dictionary before running the fetch command """ @@ -70,13 +72,13 @@ def run_r2d2_fetch(r2d2_dict: dict) -> None: window_length='PT6H', target_file=empty_obs_file, ) - + # Copy the empty file to the target file directory shutil.copy(empty_obs_file, target_file) - + else: raise Exception(e) - + # Change the permissions os.chmod(target_file, 0o644) @@ -201,7 +203,7 @@ def execute(self) -> None: # Create a dictionary of all fetch criteria # ----------------------------------------- r2d2_fetch_dicts = [] - + # Dictionary tracking all observation files # ----------------------------------------- observation_dicts = {} @@ -212,7 +214,8 @@ def execute(self) -> None: # Open the observation operator dictionary # ---------------------------------------- - observation_dicts[observation] = observation_dict = self.jedi_rendering.render_interface_observations(observation) + observation_dicts[observation] = observation_dict = \ + self.jedi_rendering.render_interface_observations(observation) # Get the set obs providers for each observation # ---------------------------------------------- @@ -308,7 +311,7 @@ def execute(self) -> None: 'provider': 'gsi', 'observation_type': observation, 'file_extension': bias_file_ext + '_cov', - 'file_type': bias_err_type, # obsbias_coeff_errors Official JCSDA enum + 'file_type': bias_err_type, # obsbias_coeff_errors Official JCSDA enum 'date': background_time_iso }) From 7fd9760c1b668dca92a616082b40fa1123e982a1 Mon Sep 17 00:00:00 2001 From: Michael Anstett Date: Thu, 5 Mar 2026 10:29:27 -0500 Subject: [PATCH 5/5] Fix message --- src/swell/tasks/get_observations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/swell/tasks/get_observations.py b/src/swell/tasks/get_observations.py index 617157f08..1e1e54f45 100644 --- a/src/swell/tasks/get_observations.py +++ b/src/swell/tasks/get_observations.py @@ -345,7 +345,7 @@ def execute(self) -> None: # Run through all files to fetch # ------------------------------ number_of_workers = 4 - self.logger.info(f'Running parallel plot generation with {number_of_workers} workers') + self.logger.info(f'Fetching observations in parallel with {number_of_workers} workers') with Pool(processes=number_of_workers) as pool: pool.map(run_r2d2_fetch, r2d2_fetch_dicts)