diff --git a/src/swell/configuration/jedi/interfaces/geos_cf/model/background_error.py b/src/swell/configuration/jedi/interfaces/geos_cf/model/background_error.py index 2327c3b14..b9d61760a 100644 --- a/src/swell/configuration/jedi/interfaces/geos_cf/model/background_error.py +++ b/src/swell/configuration/jedi/interfaces/geos_cf/model/background_error.py @@ -7,19 +7,152 @@ # -------------------------------------------------------------------------------------------------- from collections.abc import Mapping +from swell.configuration.jedi.interfaces.geos_cf.model.shared import \ + field_io_names + +# Central block builders +# -------------------------------------------------------------------------------------------------- + + +def _build_identity_central_block(template_dict: Mapping) -> Mapping: + """Identity covariance central block""" + return {'saber block name': 'ID'} + + +def _build_bump_nicas_central_block(template_dict: Mapping) -> Mapping: + """BUMP/NICAS covariance central block""" + field_aliases = [] + for var in template_dict.get('analysis_variables', []): + field_aliases.append({ + 'in code': var, + 'in file': 'fixedlevel_H_500_300_V_00_03' + }) + + return { + 'saber block name': 'BUMP_NICAS', + 'read': { + 'io': { + 'data directory': f"{template_dict['cycle_dir']}/", + 'files prefix': 'geos_cf', + 'alias': field_aliases + }, + 'drivers': { + 'multivariate strategy': 'univariate', + 'read local nicas': True + } + } + } + +# Outer block builders +# -------------------------------------------------------------------------------------------------- + + +def _build_stddev_bkg_scaled(template_dict: Mapping) -> Mapping: + """Standard deviation outer block""" + return { + 'saber block name': 'StdDev', + 'stddev scale factor': '0.25', + 'read': { + 'model file': { + 'datetime': template_dict['local_background_time_iso'], + 'set datetime on read': True, + 'filetype': 'cube sphere history', + 'max allowable geometry difference': 0.1, + 'datapath': template_dict['cycle_dir'], + 'filename': 'bkg.%yyyy%mm%ddT%hh%MM%ssZ.nc4', + 'field io names': field_io_names, + } + } + } + + +def _build_stddev_fixed_values(template_dict: Mapping) -> Mapping: + """Standard deviation outer block with fixed values from analysis variables""" + # Get the stddev values for each analysis variable + # Example: {'volume_mixing_ratio_of_no2': 5e-9, 'volume_mixing_ratio_of_o3': 1e-8} + stddev_dict = { + 'volume_mixing_ratio_of_no2': 5e-9, + 'volume_mixing_ratio_of_o3': 1e-8, + 'volume_mixing_ratio_of_co': 5e-8, + } + + # Build the standard deviations list + standard_deviations = [ + {'variable': var, 'stddev': value} + for var, value in stddev_dict.items() + ] + + return { + 'saber block name': 'StdDev', + 'standard deviations': standard_deviations + } + +# Main assembler # -------------------------------------------------------------------------------------------------- def background_error(template_dict: Mapping) -> Mapping: + """ + Assemble background error covariance configuration from modular components. + + Parameters + ---------- + template_dict : Mapping + Dictionary containing configuration parameters including: + - central_block: Type of central block ('identity', 'bump_nicas', 'diffusion') + - outer_block: add StdDev outer block + + + Returns + ------- + Mapping + JEDI-compatible background error covariance configuration + """ + + # Select central block builder + central_block_select = template_dict.get('saber_central_block', 'identity') + + central_block_types = { + 'identity': _build_identity_central_block, + 'bump_nicas': _build_bump_nicas_central_block, + } - background_error = { + # Build the central block + if central_block_select not in central_block_types: + raise ValueError( + f"Unknown background_error_model '{central_block_select}'. " + f"Choose from: {list(central_block_types.keys())}" + ) + + central_block = central_block_types[central_block_select](template_dict) + + # Assemble base configuration + background_error_config = { 'covariance model': 'SABER', - 'saber central block': { - 'saber block name': 'ID' - }, + 'saber central block': central_block, } - return background_error + # Optionally add outer blocks (skip for identity central block) + if central_block_select != 'identity' and 'saber_outer_block' in template_dict: + + outer_block_select = template_dict.get('saber_outer_block') + + outer_block_types = { + 'stddev_bkg_scaled': _build_stddev_bkg_scaled, + 'stddev_fixed_values': _build_stddev_fixed_values, + } + + if outer_block_select not in outer_block_types: + raise ValueError( + f"Unknown outer_block '{outer_block_select}'. " + f"Choose from: {list(outer_block_types.keys())}" + ) + + outer_block = outer_block_types[outer_block_select](template_dict) + + background_error_config['saber outer blocks'] = [outer_block] + + return background_error_config # -------------------------------------------------------------------------------------------------- diff --git a/src/swell/configuration/jedi/interfaces/geos_cf/model/stage_cycle.py b/src/swell/configuration/jedi/interfaces/geos_cf/model/stage_cycle.py index c6d50408c..c2a490247 100644 --- a/src/swell/configuration/jedi/interfaces/geos_cf/model/stage_cycle.py +++ b/src/swell/configuration/jedi/interfaces/geos_cf/model/stage_cycle.py @@ -15,15 +15,35 @@ def stage_cycle(template_dict: Mapping) -> Mapping: cycle_dir = template_dict['cycle_dir'] swell_static_files = template_dict['swell_static_files'] + npx_proc = template_dict['npx_proc'] + npy_proc = template_dict['npy_proc'] stage_cycle = [ {'copy_files': { 'directories': [ - [f'{swell_static_files}/jedi/interfaces/geos_cf/fv3_files/*', f'{cycle_dir}/'] + [ + f'{swell_static_files}/jedi/interfaces/geos_cf/fv3_files/*', + f'{cycle_dir}/' + ] ] }} ] + # Only link NICAS files when the central block corresponds to BUMP_NICAS + central_block_select = template_dict.get('saber_central_block', 'none') + if central_block_select == 'bump_nicas': + stage_cycle.append({ + 'link_files': { + 'directories': [ + [ + f'{swell_static_files}/jedi/interfaces/geos_cf/nicas/' + f'layout_{npx_proc}x{npy_proc}x6/*', + f'{cycle_dir}/' + ] + ] + } + }) + return stage_cycle # -------------------------------------------------------------------------------------------------- diff --git a/src/swell/configuration/jedi/interfaces/geos_cf/task_questions.yaml b/src/swell/configuration/jedi/interfaces/geos_cf/task_questions.yaml index 8db8f8996..ed1230bef 100644 --- a/src/swell/configuration/jedi/interfaces/geos_cf/task_questions.yaml +++ b/src/swell/configuration/jedi/interfaces/geos_cf/task_questions.yaml @@ -4,10 +4,17 @@ analysis_variables: options: - volume_mixing_ratio_of_no2 -background_error_model: - default_value: identity +saber_central_block: + default_value: bump_nicas options: - identity + - bump_nicas + +saber_outer_block: + default_value: stddev_bkg_scaled + options: + - stddev_bkg_scaled + - stddev_fixed_values background_experiment: default_value: swell_test diff --git a/src/swell/suites/3dvar_cf/suite_config.py b/src/swell/suites/3dvar_cf/suite_config.py index 0518a87ac..607c3cfa7 100644 --- a/src/swell/suites/3dvar_cf/suite_config.py +++ b/src/swell/suites/3dvar_cf/suite_config.py @@ -41,6 +41,8 @@ class SuiteConfig(QuestionContainer, Enum): qd.npx_proc(2), qd.npy_proc(2), qd.vertical_resolution(72), + qd.saber_central_block('bump_nicas'), + qd.saber_outer_block('stddev_bkg_scaled'), qd.analysis_variables(["volume_mixing_ratio_of_no2"]), qd.background_experiment("swell_test"), qd.observations([ diff --git a/src/swell/tasks/run_jedi_variational_executable.py b/src/swell/tasks/run_jedi_variational_executable.py index 3c9e0b5a8..138ebd3cc 100644 --- a/src/swell/tasks/run_jedi_variational_executable.py +++ b/src/swell/tasks/run_jedi_variational_executable.py @@ -69,6 +69,8 @@ def execute(self) -> None: self.jedi_rendering.add_key('minimizer', self.config.minimizer()) self.jedi_rendering.add_key('number_of_iterations', number_of_iterations[0]) self.jedi_rendering.add_key('analysis_variables', self.config.analysis_variables()) + self.jedi_rendering.add_key('saber_central_block', self.config.saber_central_block()) + self.jedi_rendering.add_key('saber_outer_block', self.config.saber_outer_block()) self.jedi_rendering.add_key('gradient_norm_reduction', self.config.gradient_norm_reduction()) self.jedi_rendering.add_key('marine_models', self.config.marine_models(None)) diff --git a/src/swell/tasks/stage_jedi.py b/src/swell/tasks/stage_jedi.py index a9baa19ab..06847a72e 100644 --- a/src/swell/tasks/stage_jedi.py +++ b/src/swell/tasks/stage_jedi.py @@ -46,17 +46,23 @@ def execute(self) -> None: swell_static_files = swell_static_files_user vertical_resolution = self.config.vertical_resolution() + npx_proc = self.config.npx_proc(None) + npy_proc = self.config.npy_proc(None) gsibec_configuration = self.config.gsibec_configuration(None) gsibec_nlats = self.config.gsibec_nlats(None) gsibec_nlons = self.config.gsibec_nlons(None) + saber_central_block = self.config.saber_central_block(None) # Add jedi interface template keys self.jedi_rendering.add_key('horizontal_resolution', horizontal_resolution) self.jedi_rendering.add_key('swell_static_files', swell_static_files) self.jedi_rendering.add_key('vertical_resolution', vertical_resolution) + self.jedi_rendering.add_key('npx_proc', npx_proc) + self.jedi_rendering.add_key('npy_proc', npy_proc) self.jedi_rendering.add_key('gsibec_configuration', gsibec_configuration) self.jedi_rendering.add_key('gsibec_nlats', gsibec_nlats) self.jedi_rendering.add_key('gsibec_nlons', gsibec_nlons) + self.jedi_rendering.add_key('saber_central_block', saber_central_block) # Open the stage configuration file # --------------------------------- diff --git a/src/swell/tasks/task_questions.py b/src/swell/tasks/task_questions.py index 8347850a7..9081e93ea 100644 --- a/src/swell/tasks/task_questions.py +++ b/src/swell/tasks/task_questions.py @@ -86,6 +86,8 @@ class TaskQuestions(QuestionContainer, Enum): qd.gsibec_nlons(), qd.number_of_iterations(), qd.total_processors(), + qd.saber_central_block(), + qd.saber_outer_block(), ] ) @@ -769,10 +771,13 @@ class TaskQuestions(QuestionContainer, Enum): list_name="StageJedi", questions=[ swell_static_file_questions, + qd.npx_proc(), + qd.npy_proc(), qd.gsibec_configuration(), qd.gsibec_nlats(), qd.gsibec_nlons(), qd.horizontal_resolution(), + qd.saber_central_block(), qd.vertical_resolution() ] ) diff --git a/src/swell/utilities/question_defaults.py b/src/swell/utilities/question_defaults.py index 987f64322..f2474aa68 100644 --- a/src/swell/utilities/question_defaults.py +++ b/src/swell/utilities/question_defaults.py @@ -239,6 +239,34 @@ class background_error_model(TaskQuestion): # -------------------------------------------------------------------------------------------------- + @dataclass + class saber_central_block(TaskQuestion): + default_value: str = "defer_to_model" + question_name: str = "saber_central_block" + ask_question: bool = True + options: str = "defer_to_model" + models: List[str] = mutable_field([ + "all_models" + ]) + prompt: str = "Which saber central block do you want to use?" + widget_type: WType = WType.STRING + + # -------------------------------------------------------------------------------------------------- + + @dataclass + class saber_outer_block(TaskQuestion): + default_value: str = "defer_to_model" + question_name: str = "saber_outer_block" + ask_question: bool = True + options: str = "defer_to_model" + models: List[str] = mutable_field([ + "all_models" + ]) + prompt: str = "Which saber outer blocks do you want to use?" + widget_type: WType = WType.STRING + + # -------------------------------------------------------------------------------------------------- + @dataclass class background_experiment(TaskQuestion): default_value: str = "defer_to_model" diff --git a/src/swell/utilities/render_jedi_interface_files.py b/src/swell/utilities/render_jedi_interface_files.py index 779a73fb1..c0c1c122b 100644 --- a/src/swell/utilities/render_jedi_interface_files.py +++ b/src/swell/utilities/render_jedi_interface_files.py @@ -116,6 +116,8 @@ def __init__( 'obs_filenames', 'packet_ensemble_members', 'perhost', + 'saber_central_block', + 'saber_outer_block', 'skip_ensemble_hofx', 'swell_static_files', 'start_cycle_point',