diff --git a/fluxy/io.py b/fluxy/io.py index 031db286..0967bec5 100644 --- a/fluxy/io.py +++ b/fluxy/io.py @@ -485,15 +485,6 @@ def read_model_output( if add_sites_to_flux and file_type == DataTypes.FLUX: ds_all[m] = add_sites_var(ds_all[m], filepath, m, period[i], config_data) - # Overwrite species attributes - current_species = ds_all[m].attrs.get("species", "not set") - ds_all[m].attrs["species"] = current_species - if species is not None and current_species != species: - logger.info( - f"'species' attribute in dataset {m} ({current_species}) differs from species {species}. It is overwritten." - ) - ds_all[m].attrs["species"] = species - return ds_all @@ -788,9 +779,10 @@ def edit_vars_and_attributes( xarray dataset with updated variables and attributes. """ - # Add inversion frequency to global attributes + # Add inversion frequency and exp name to global attributes if "frequency" not in ds.attrs: ds.attrs["frequency"] = frequency + ds.attrs["exp_name"] = model # Rename legacy variables name_dict = { @@ -807,14 +799,14 @@ def edit_vars_and_attributes( filename_tags = os.path.basename(model) m0 = filename_tags.split("_")[0].lower() - # check the species + # check (and overwrite) species attribute if species is not None: - if "species" not in ds.attrs: - ds.attrs["species"] = species - elif ds.attrs["species"] != species: + current_species = ds.attrs.get("species", "not set") + if current_species != species: logger.info( - f"Species {ds.attrs['species']} in dataset does not match species {species} in model {model}." + f"'species' attribute in dataset {model} ({current_species}) differs from species {species}. It is overwritten." ) + ds.attrs["species"] = species file_type = DataTypes(file_type) diff --git a/fluxy/operators/mf.py b/fluxy/operators/mf.py index 3bef2d91..7b208151 100644 --- a/fluxy/operators/mf.py +++ b/fluxy/operators/mf.py @@ -2,7 +2,6 @@ import xarray as xr import pandas as pd import logging -from fluxy.operators.select import get_unique_sites, get_site_index from fluxy.operators.convert import get_variables from typing import Literal @@ -74,6 +73,18 @@ def compute_mf_difference( unique_platforms, platform_indices = np.unique( common_platforms, return_inverse=True ) + + # check species + species_set = {ds_left.attrs["species"], ds_right.attrs["species"]} + if len(species_set) != 1: + logger.warning( + f"Different species found {species_set} between the 2 compared datasets." + ) + species = "mix" + else: + species = list(species_set)[0] + + # create dataset with coord and dim ds_diff[key_name] = xr.Dataset( coords={ "time": ("index", common_index.get_level_values("time").values), @@ -82,6 +93,8 @@ def compute_mf_difference( }, attrs={ "description": f"Difference between {model_left} and {model_right}", + "exp_name": f"{ds_left.attrs['exp_name']} - {ds_right.attrs['exp_name']}", + "species": species, }, ) diff --git a/fluxy/operators/select.py b/fluxy/operators/select.py index 96c9d9db..a5ee7434 100644 --- a/fluxy/operators/select.py +++ b/fluxy/operators/select.py @@ -55,7 +55,7 @@ def slice_flux( """ ds_all_sliced = dict() - species_info = config_data.get("species_info",{}).get(species, None) + species_info = config_data.get("species_info", {}).get(species, None) if type(start_date) is str: start_date = [start_date] * len(ds_all.keys()) @@ -94,7 +94,7 @@ def slice_mf( ds_all: dict[str, xr.Dataset], start_date: str = None, end_date: str = None, - site: str = None, + site: str | list[str] | None = None, baseline_site: str = None, baseline_filename: str = "InTEM_baseline_timestamps", data_dir: os.PathLike | None = None, @@ -114,7 +114,7 @@ def slice_mf( end_date (str): Date to slice data to, e.g. '2022-01-01' would include all data up to 2021-12-31. - site (str): + site (str | list[str] | None): Obs site to select data from, e.g. 'MHD'. baseline_site (str): Site used to define baseline at, options for 'MHD', 'JFJ', or 'CMN'. @@ -225,29 +225,69 @@ def slice_mf( return ds_all -def slice_site(ds: xr.Dataset, site: str) -> xr.Dataset: +def slice_site( + ds: xr.Dataset | dict[str, xr.Dataset], + site: str | list[str], + raise_error: bool = True, +) -> xr.Dataset | dict[str, xr.Dataset] | None: """ Slices the dataset to only include data for a given site. Args: ds (xarray dataset): Dataset with mf data of a given model. - site (str): - Site of interest. + Can also be a dictionary of datasets, in which case the function + is applied to each dataset and return a dictionary of sliced datasets. + site (str | list[str]): + Site(s) of interest. + raise_error: + if True, raise an error if the site is not found in the dataset. Returns: ds (xarray dataset): - Dataset with mf data of a given model, sliced to only include data for the given site. + Dataset with mf data of a given model, sliced to only include data for the given site(s). """ - site_index = get_site_index(ds, site) + if isinstance(ds, dict): + ds_all_site = dict() + for m, ds_this in ds.items(): + logger.info(f"Slicing site {site} from {m}.") - if site_index is None: - raise ValueError(f"Site {site} not found in dataset.") + ds_sliced = slice_site(ds_this, site, raise_error=raise_error) + if ds_sliced is None: + logger.warning( + f"Site {site} not found in dataset for {m}. " + f"Continuing without {m} - {site}." + ) + continue - mask = ds["number_of_identifier"] == site_index - ds = ds.where(mask, drop=True) + ds_all_site[m] = ds_sliced - return ds + return ds_all_site + + if isinstance(site, str): + sites = [site] + else: + sites = site + + site_indices = [] + for site in sites: + site_index = get_site_index(ds, site) + if site_index is not None: + site_indices.append(site_index) + else: + logger.warning(f"Site {site} not found for model {ds.attrs['exp_name']}.") + + mask = ds["number_of_identifier"].isin(site_indices) + if mask.any(): + ds = ds.where(mask, drop=True) + return ds + + msg = f"No data for any sites {sites} with indices {site_indices} in model {ds.attrs['exp_name']}." + if raise_error: + raise ValueError(msg) + else: + logger.warning(msg) + return None def slice_height(ds: xr.Dataset, intake_height: float) -> xr.Dataset: @@ -496,3 +536,25 @@ def clean_timeseries_missing_data( ds = ds.sortby("time") return ds + + +def check_site_list( + site_list: list[str] | None, ds_all: dict[str, xr.Dataset] +) -> list[str]: + """ + Check that every site in the list exists. If None, set it to all the sites available. + Args: + site_list: list of sites to check + ds_all: datasets into which check for the sites + Returns: + site_list: list of sites + """ + if site_list is None: + return get_unique_sites(ds_all) + available_sites = get_unique_sites(ds_all) + for site in site_list: + if site not in available_sites: + raise ValueError( + f"Site {site} not found in the datasets provided. Available sites are {available_sites}." + ) + return site_list diff --git a/fluxy/operators/stats.py b/fluxy/operators/stats.py index b1214f11..8e347395 100644 --- a/fluxy/operators/stats.py +++ b/fluxy/operators/stats.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd import xarray as xr -from fluxy.operators.select import get_unique_sites, get_site_index +from fluxy.operators.select import get_unique_sites, slice_site def stats_observed_vs_simulated( @@ -77,17 +77,9 @@ def stats_observed_vs_simulated( # Compute stats for all sites and all models for site in sites_all: for model, ds in ds_all.items(): - site_index = get_site_index(ds, site) - if site_index is None: - logger.warning(f"Site {site} not found in model {model}.") + ds_site = slice_site(ds, site, raise_error=False) + if ds_site is None: continue - mask_site = ds["number_of_identifier"] == site_index - if not mask_site.any(): - logger.warning( - f"No data for site {site} with index {site_index} in model {model}." - ) - continue - ds_site = ds.where(mask_site, drop=True) # select what to compare obs = ds_site[obs_var] diff --git a/fluxy/plots/flux_timeseries.py b/fluxy/plots/flux_timeseries.py index 04d5af7e..52fd3f0b 100644 --- a/fluxy/plots/flux_timeseries.py +++ b/fluxy/plots/flux_timeseries.py @@ -5,7 +5,7 @@ import pandas as pd import xarray as xr import matplotlib.pyplot as plt -from typing import Tuple +from typing import Literal, Tuple from pathlib import Path from datetime import date, datetime, timedelta from calendar import isleap, month_abbr, monthrange @@ -34,9 +34,6 @@ } - - - def get_unit(ds_all: dict[str, xr.Dataset]) -> str: """ Determine unit of posterior estimations from datasets. If incoherencies between datasets, an error is raised. @@ -46,7 +43,12 @@ def get_unit(ds_all: dict[str, xr.Dataset]) -> str: unit: unit of posterior variables in dataset. """ - variables_to_check = ["flux_total_posterior_country", "posterior", "flux_total_prior_country", "prior"] + variables_to_check = [ + "flux_total_posterior_country", + "posterior", + "flux_total_prior_country", + "prior", + ] for var in variables_to_check: if all([var in ds for ds in ds_all.values()]): @@ -63,6 +65,7 @@ def get_unit(ds_all: dict[str, xr.Dataset]) -> str: f"Did not find any of the expected variables {variables_to_check} in every dataset. Thus couldn't determine unit." ) + def determine_subplots_arrangement(subplot_number: int) -> tuple[int, int]: """ Determine number of columns and rows for the figure given the number of subplots to make. @@ -157,7 +160,7 @@ def prepare_data_to_plot( plot_resample_and_original: If True, plots both the resampled data and the data as its original frequency. If False, only plots the resampled data. aggreg_month: if True, plot the data aggregated by month. Used to study seasonnal cycle. - only_overlapping: if True, only includes data from years when all models/species available. If False, + only_overlapping: if True, only includes data from years when all models/species available. If False, includes all available data. Returns: ds_to_plot : dictionnary of datasets to plot @@ -171,7 +174,7 @@ def prepare_data_to_plot( # Convert some inputs to list and check their size plot_separate, plot_combined, resample, rolling_mean = update_list_params( [plot_separate, plot_combined, resample, rolling_mean], - ['plot_separate','plot_combined','resample','rolling_mean'], + ["plot_separate", "plot_combined", "resample", "rolling_mean"], expected_size=len(ds_all_region.keys()), ) @@ -233,7 +236,9 @@ def prepare_data_to_plot( combined_models_dict = {"Mean": list(ds_all_region.keys())} else: combined_model_list = sum(combined_models_dict.values(), []) - check_missing_models = set(combined_model_list) - set(ds_all_region.keys()) + check_missing_models = set(combined_model_list) - set( + ds_all_region.keys() + ) if check_missing_models: raise ValueError( f"Models in `combined_model_list` are not available: {check_missing_models}. " @@ -251,46 +256,48 @@ def prepare_data_to_plot( ] } - combined_resample = [resamp for comb, resamp in zip(plot_combined, resample) if comb] - use_resampled = ( - len(unique_resample := set(combined_resample)) == 1 - and unique_resample not in ({None}, {False}) - ) + combined_resample = [ + resamp for comb, resamp in zip(plot_combined, resample) if comb + ] + use_resampled = len( + unique_resample := set(combined_resample) + ) == 1 and unique_resample not in ({None}, {False}) if use_resampled: - combined_models_dict = { - group_label: [f"{model}_resample" for model in model_list] - for group_label, model_list in combined_models_dict.items() - } - ds_to_combine = { - m: calc_rolling_mean(ds) if rm else ds - for rm, (m, ds) in zip(rolling_mean, ds_resampled.items()) - } - else: - ds_to_combine = { - m: calc_rolling_mean(ds) if rm else ds - for rm, (m, ds) in zip(rolling_mean, ds_all_region.items()) + combined_models_dict = { + group_label: [f"{model}_resample" for model in model_list] + for group_label, model_list in combined_models_dict.items() } + ds_to_combine = { + m: calc_rolling_mean(ds) if rm else ds + for rm, (m, ds) in zip( + rolling_mean, (ds_resampled if use_resampled else ds_all_region).items() + ) + } + for group_label, model_list in combined_models_dict.items(): - combine_mask = [model in model_list for model in ds_to_combine.keys()] - ds_combined = combine_dataset(ds_to_combine, combine_mask,only_overlapping) - ds_combined["combined"].attrs["model_label"] = group_label - if any('_resample' in s for s in model_list) and plot_resample_and_original: - ds_combined["combined"].attrs["model_label"] += " (resampled)" - new_key = group_label.replace(" ", "_") - ds_combined = {f"combined_{new_key}": ds_combined["combined"]} # rename key to include group label - ds_to_plot.update(ds_combined) + combine_mask = [model in model_list for model in ds_to_combine.keys()] + ds_combined = combine_dataset(ds_to_combine, combine_mask, only_overlapping) + ds_combined["combined"].attrs["model_label"] = group_label + if any("_resample" in s for s in model_list) and plot_resample_and_original: + ds_combined["combined"].attrs["model_label"] += " (resampled)" + new_key = group_label.replace(" ", "_") + ds_combined = { + f"combined_{new_key}": ds_combined["combined"] + } # rename key to include group label + ds_to_plot.update(ds_combined) # Determine plot color and label of each dataset color_usage = {k: 0 for k in map_model_colors.keys()} i_comb = 0 for m in ds_to_plot.keys(): + include_label = ds_to_plot[m].attrs.get("model_label", None) if "combined" in m: - include_label = ds_to_plot[m].attrs.get("model_label", None) - model_color = config.mean_color_palette[i_comb % len(config.mean_color_palette)] + model_color = config.mean_color_palette[ + i_comb % len(config.mean_color_palette) + ] i_comb += 1 else: - include_label = ds_to_plot[m].attrs.get("model_label", None) key_mc = [ k for k in map_model_colors.keys() @@ -308,110 +315,77 @@ def prepare_data_to_plot( return ds_to_plot -def add_posterior_plot( - ax: Axes, ds_region: xr.Dataset, highlighted_line: bool, add_post_unc: bool +def add_line_plot( + ax: Axes, + ds: xr.Dataset, + variable: Literal["posterior", "prior"], + highlighted_line: bool = False, + add_unc: bool = False, ) -> dict[str, dict]: """ - Plot the posterior data on the axis. The variable posterior of the dataset ds_region is plotted as a line (color and label found in the dataset - attributes) and the uncertainty (variables lower_posterior, upper_posterior in the dataset) is plotted as a semi-transparent filled space. - Axes: + Plot the posterior/prior data on the axis. The variable posterior/prior of the dataset ds is plotted as a line (color and label found in the dataset + attributes) and the uncertainty (variables lower_posterior/prior, upper_posterior/prior in the dataset) is plotted as a semi-transparent filled space. + Args: ax: axes on which to plot - ds_region: dataset containing posterior data - highlighted_line: if True, the linewidth is made bigger (3.0) than when False (1.5). Typicall used for the annexes to highlight the PARIS mean. - add_post_unc: if True, plots model uncertainty. - Returns: + ds: dataset containing posterior/prior data + variable: whether the posterior or prior should be plotted + highlighted_line: for posterior/prior, if True, the linewidth is made bigger (3.0/1.5) than when False (1.5/1.0). Typicaly used for the annexes to highlight the PARIS mean. + add_unc: if True, plots model uncertainty. Returns: - res: dataframe with one line per timestamp and 9 columns ("type", "model", "sector", "country", "species", + res: dataframe with one line per timestamp and 9 columns ("type", "model", "sector", "country", "species", "time", "mean_val", "min_unc", "max_unc") """ - - linew = 3 if highlighted_line else 1.5 - - time_as_datetime = ds_region.time.values.astype("datetime64[D]").tolist() - - ax.plot( - time_as_datetime, - ds_region.posterior, - label=ds_region.attrs["model_label"], - color=ds_region.attrs["model_color"], - linewidth=linew, - ) - - if add_post_unc: - ax.fill_between( - time_as_datetime, - ds_region.posterior_lower, - ds_region.posterior_upper, - alpha=0.2, - color=ds_region.attrs["model_color"], + if variable == "posterior": + kwargs_plot = dict( + ls="-", + lw=3 if highlighted_line else 1.5, + alpha=1.0, + label=ds.attrs["model_label"], + ) + alpha_unc = 0.2 + elif variable == "prior": + kwargs_plot = dict( + ls="--", + lw=1.5 if highlighted_line else 1.0, + alpha=1.0 if highlighted_line else 0.7, + label=ds.attrs["model_label"] + " prior", + ) + alpha_unc = 0.1 + else: + raise ValueError( + "Available options for 'variable' parameter are 'prior' and 'posterior'." ) - res = pd.DataFrame( - { - "type": ["posterior",] * ds_region.time.size, - "model": [ds_region.attrs["model_label"],] * ds_region.time.size, - "sector": [ds_region.attrs["sector"],] * ds_region.time.size, - "country": [ds_region.attrs["country"],] * ds_region.time.size, - "species": [ds_region.attrs["species"],] * ds_region.time.size, - "time": time_as_datetime, - "mean_val": ds_region.posterior.values, - "min_unc": ds_region.posterior_lower.values, - "max_unc": ds_region.posterior_upper.values, - } - ) - return res - - -def add_prior_plot( - ax: Axes, ds_region: xr.Dataset, annex_mode: bool, add_prior_unc: bool -) -> dict[str, dict]: - """ - Plot the variable prior of the dataset ds_region on the axis. - Axes: - ax: axes on which to plot - ds_region: dataset containing prior data - annex_bool: if True, the linewidth is made slighty smaller (1.0) and a transparency of 0.7 is applied to the prior uncertainty. - If False, linewidth is set to standard value (1.5) and no transparency s applied to the prior uncertainty (alpha=1.0). - add_prior_unc: if True add prior uncertainty on the plot as a semi-transparent filled space. - Returns: - res: dataframe with one line per timestamp and 7-9 columns ("type", "model", "sector", "country", "species", - "time", "mean_val" and "min_unc", "max_unc" if add_prior_unc) - """ - linewidth, alpha = (1.0, 0.7) if annex_mode else (1.5, 1.0) - - time_as_datetime = ds_region.time.values.astype("datetime64[D]").tolist() + time_as_datetime = ds.time.values.astype("datetime64[D]").tolist() ax.plot( time_as_datetime, - ds_region.prior, - label=ds_region.attrs["model_label"] + " prior", - color=ds_region.attrs["model_color"], - linestyle="dashed", - linewidth=linewidth, - alpha=alpha, + ds[variable], + color=ds.attrs["model_color"], + **kwargs_plot, ) res = pd.DataFrame( { - "type": ["prior",] * ds_region.time.size, - "model": [ds_region.attrs["model_label"],] * ds_region.time.size, - "sector": [ds_region.attrs["sector"],] * ds_region.time.size, - "country": [ds_region.attrs["country"],] * ds_region.time.size, - "species": [ds_region.attrs["species"],] * ds_region.time.size, "time": time_as_datetime, - "mean_val": ds_region.prior.values, + "mean_val": ds[variable].values, } ) - if add_prior_unc: + res["type"] = variable + res["model"] = ds.attrs["model_label"] + for attr in ["sector", "country", "species"]: + res[attr] = ds.attrs[attr] + + if add_unc: ax.fill_between( time_as_datetime, - ds_region.prior_lower, - ds_region.prior_upper, - alpha=0.1, - color=ds_region.attrs["model_color"], + ds[f"{variable}_lower"], + ds[f"{variable}_upper"], + alpha=alpha_unc, + color=ds.attrs["model_color"], ) - res["min_unc"] = ds_region.prior_lower.values - res["max_unc"] = ds_region.prior_upper.values + res["min_unc"] = ds[f"{variable}_lower"].values + res["max_unc"] = ds[f"{variable}_upper"].values return res @@ -452,7 +426,7 @@ def add_inventory_barplot( annex_mode: If True, replace Inventory label with a more concise version for National Inventory Report Annexes. plot_inventory_uncertainty: If True and uncertainty available, plots inventory error bars. If a list is provided, should be of same size as inventory_years. Returns: - res: dataframe with one line per timestamp and 7 columns ("type", "model", "sector", "country", "species", + res: dataframe with one line per timestamp and 7 columns ("type", "model", "sector", "country", "species", "time", "mean_val") """ @@ -489,11 +463,16 @@ def add_inventory_barplot( res = pd.DataFrame() for i_inv, inventory in enumerate(inventories_to_plot): time_as_datetime = inventory.time.values.astype("datetime64[D]").tolist() - - if plot_inventory_uncertainty[i_inv] is True and inventories_uncert_to_plot[i_inv] is not None and np.any(inventories_uncert_to_plot[i_inv] > 0): - yerr = inventories_uncert_to_plot[i_inv].values - else: - yerr = None + + this_uncert_to_plot = inventories_uncert_to_plot[i_inv] + + yerr = None + if ( + plot_inventory_uncertainty[i_inv] is True + and this_uncert_to_plot is not None + and np.any(this_uncert_to_plot > 0) + ): + yerr = this_uncert_to_plot.values ax.bar( time_as_datetime, @@ -502,23 +481,23 @@ def add_inventory_barplot( edgecolor=inventory.plot_color, align="edge", fill=False, - label=(f"NID {inventory.year}" if annex_mode else f"Inventory {inventory.year}"), + label=f"{'NID' if annex_mode else 'Inventory'} {inventory.year}", yerr=yerr, - error_kw={"ecolor":inventory.plot_color,"capsize":2}, + error_kw={"ecolor": inventory.plot_color, "capsize": 2}, zorder=0, ) tmp = pd.DataFrame( { - "type": ["inventory",] * inventory.time.size, - "model": [f"inventory_{inventory.year}",] * inventory.time.size, - "sector": [sector,] * inventory.time.size, - "country": [country,] * inventory.time.size, - "species": [species,] * inventory.time.size, "time": time_as_datetime, "mean_val": inventory.values, } ) + tmp["type"] = "inventory" + tmp["model"] = f"inventory_{inventory.year}" + tmp["sector"] = sector + tmp["country"] = country + tmp["species"] = species res = pd.concat([res, tmp], ignore_index=True) return res @@ -535,7 +514,7 @@ def add_sector_barplot( variable: variable to plot (either "posterior" or "prior") bottom_values: bottom values passed as argument toax.bar. Correspond to the previous heights of the stacks. Returns: - res: dataframe with one line per timestamp and 7 columns ("type", "model", "sector", "country", "species", + res: dataframe with one line per timestamp and 7 columns ("type", "model", "sector", "country", "species", "time", "mean_val") """ @@ -580,16 +559,17 @@ def add_sector_barplot( res = pd.DataFrame( { - "type": [variable,] * ds_sector.time.size, - "model": [ds_sector.attrs["model_label"],] * ds_sector.time.size, - "sector": [sector,] * ds_sector.time.size, - "country": [ds_sector.attrs["country"],] * ds_sector.time.size, - "species": [ds_sector.attrs["species"],] * ds_sector.time.size, "time": np.array(time_as_datetime) + offset, "mean_val": ds_sector[variable].values + bottom_values, } ) + res["type"] = variable + res["sector"] = sector + res["model"] = ds_sector.attrs["model_label"] + res["country"] = ds_sector.attrs["country"] + res["species"] = ds_sector.attrs["species"] + return res @@ -642,7 +622,10 @@ def prepare_inventory_sector_barplot( ) inventories = [inv.to_dataset(name="inv_data") for inv in inventories] - inventories_stdev = [inv_std.to_dataset(name="inv_data_stdev") if inv_std else None for inv_std in inventories_stdev] + inventories_stdev = [ + inv_std.to_dataset(name="inv_data_stdev") if inv_std else None + for inv_std in inventories_stdev + ] return inventories, inventories_stdev @@ -663,7 +646,7 @@ def add_ylim( dim: dimension following which the list of axes is made. values: list of values taken by the dimension and corresponding to the axes (should be the same length and order). res_dict: dictionnary containing the data plotted. Should have one key per regions plotted, the values being dictionnaries with 3 keys: "inventory", "posterior" and "prior"; - whose values are the output of add_inventory_barplot, add_posterior_plot, add_prior_plot). The data stored in them is used to infer the ylims. + whose values are the output of add_inventory_barplot, add_line_plot). The data stored in them is used to infer the ylims. fix_y_axes: if list, use it as params to ax.set_ylim; if bool and True, all subplots have the same y lim (the max value that can be found in res_dict); else the max of the data plotted in each subplots is used. set_global_leg: if True (and thus one common legend is plotted for all subplots in add_legend), a zoom of only 1.1 is made on the ymax, else it is 1.2 to make space for the legend. @@ -750,7 +733,11 @@ def add_ylabel( def add_xlims_and_ticks( - ax: Axes, yearly_freq: bool, plotted_data_df: dict[str, dict], aggreg_month: bool, xticks_at_centre: bool = False + ax: Axes, + yearly_freq: bool, + plotted_data_df: dict[str, dict], + aggreg_month: bool, + xticks_at_centre: bool = False, ): """ Add x limits, ticks and ticks labels to matplotlib axes. Optimize them by looking at if they are monthly, yearly, or monthly aggregated, and covered time range. @@ -758,7 +745,7 @@ def add_xlims_and_ticks( ax: axis to add xlim and xticks to. yearly_freq: set to True if the data plotted have a yearly frequency. res_dict: dictionnary containing the data plotted. Should have one key per regions plotted, the values being dictionnaries with 3 keys: "inventory", "posterior" and "prior"; - whose values are the output of add_inventory_barplot, add_posterior_plot, add_prior_plot). The time data stored in them is used to infer the xlims. + whose values are the output of add_inventory_barplot, add_line_plot). The time data stored in them is used to infer the xlims. aggreg_month: if True, the data plotted are supposed to be a monthly aggregated so 12 stciks are created, whose labels are the 3 first letters of each month. xticks_at_centre: if True, set the x ticks at the centre of each time period rather than at the beginning. """ @@ -785,7 +772,7 @@ def add_xlims_and_ticks( [date(year, 1, 1) for year in range(min_x.year, max_x.year, step)] ) if xticks_at_centre: - xticks = xticks.astype('datetime64[D]') + np.timedelta64(182, 'D') + xticks = xticks.astype("datetime64[D]") + np.timedelta64(182, "D") ax.set_xticks(xticks) ax.set_xticklabels(xticks.astype("datetime64[Y]")) else: @@ -802,7 +789,11 @@ def add_xlims_and_ticks( def add_legend( - fig: Figure, set_global_leg: bool, annex_mode: bool, plot_inventory: bool, inventory_years: int | list[int] | None = None, + fig: Figure, + set_global_leg: bool, + annex_mode: bool, + plot_inventory: bool, + inventory_years: int | list[int] | None = None, ): """ Add legend. @@ -876,6 +867,7 @@ def add_title(ax: Axes, country: str, r_data: dict, country_codes_as_titles: boo elif country_codes_as_titles == False: ax.set_title(f"{print_country}") + def add_vlines(ax: Axes, vline_dates: list[str]): """ Add vertical lines to matplotlib axes at specified dates. @@ -891,8 +883,8 @@ def add_vlines(ax: Axes, vline_dates: list[str]): linestyle="dotted", linewidth=2.5, ) - - + + def add_secondary_yaxis( ax: Axes, s_data: dict[str, dict], @@ -902,7 +894,7 @@ def add_secondary_yaxis( secondary_unit: str, ): """ - Add a secondary y-axis to the plot, converting from `unit` + Add a secondary y-axis to the plot, converting from `unit` to `secondary_unit`, including CO2-eq aware conversions. Args: @@ -916,10 +908,8 @@ def add_secondary_yaxis( # Get conversion factor between the two units conversion_factor = convert_units_co2eq( - from_unit = unit, - to_unit = secondary_unit, - species_info = s_data.get(species, {}) - ) + from_unit=unit, to_unit=secondary_unit, species_info=s_data.get(species, {}) + ) # Define foward and backward conversion functions def unit_to_secondary_unit(y): @@ -930,25 +920,19 @@ def secondary_unit_to_unit(y): # Create secondary y-axis secax = ax.secondary_yaxis( - 'right', - functions=(unit_to_secondary_unit, secondary_unit_to_unit) + "right", functions=(unit_to_secondary_unit, secondary_unit_to_unit) ) # Styling - sec_color = 'darkred' - secax.tick_params(axis='y', colors=sec_color) - secax.spines['right'].set_color(sec_color) - secax.spines['right'].set_linewidth(1.5) + sec_color = "darkred" + secax.tick_params(axis="y", colors=sec_color) + secax.spines["right"].set_color(sec_color) + secax.spines["right"].set_linewidth(1.5) # Labeling add_ylabel( - secax, - s_data, - species, - secondary_unit, - plot_type="country_plot", - sector=sector - ) + secax, s_data, species, secondary_unit, plot_type="country_plot", sector=sector + ) secax.yaxis.label.set_color("darkred") secax.yaxis.label.set_rotation(270) secax.yaxis.labelpad = 20 @@ -990,7 +974,7 @@ def plot_country_flux( plot_grid: bool = True, add_vline: list[str] | None = None, secondary_units: str | None = None, - only_overlapping: bool = True + only_overlapping: bool = True, ) -> Figure | tuple[Figure, dict[str, dict]]: """ Timeseries plot of prior and posterior country fluxes, from list of @@ -1058,7 +1042,6 @@ def plot_country_flux( s_data = config_data.get("species_info", {}) r_data = config_data.get("regions_info", {}) - plot_regions = format_plot_regions(plot_regions, ds_all) unit = get_unit(ds_all) @@ -1095,23 +1078,35 @@ def plot_country_flux( plot_resample_and_original=plot_resample_and_original, resample_uncert_correlation=resample_uncert_correlation, aggreg_month=aggreg_month, - only_overlapping=only_overlapping + only_overlapping=only_overlapping, ) # plot posterior and prior (if requested) for m, ds_region in ds_to_plot.items(): highlighted_post = ("combined" in m) & annex_mode - add_post_unc = (("combined" in m) & plot_combined_unc) | (("combined" not in m) & plot_separate_unc) + add_post_unc = (("combined" in m) & plot_combined_unc) | ( + ("combined" not in m) & plot_separate_unc + ) if "posterior" in ds_region.data_vars: - posterior_df = add_posterior_plot( - ax, ds_region, highlighted_post, add_post_unc + posterior_df = add_line_plot( + ax, + ds_region, + variable="posterior", + highlighted_line=highlighted_post, + add_unc=add_post_unc, ) plotted_data_df = pd.concat( [plotted_data_df, posterior_df], ignore_index=True ) - if add_prior and 'prior' in ds_region: - prior_df = add_prior_plot(ax, ds_region, annex_mode, add_prior_unc) + if add_prior and "prior" in ds_region: + prior_df = add_line_plot( + ax, + ds_region, + variable="prior", + highlighted_line=not annex_mode, + add_unc=add_prior_unc, + ) plotted_data_df = pd.concat( [plotted_data_df, prior_df], ignore_index=True ) @@ -1168,7 +1163,9 @@ def plot_country_flux( "yearly" in [ds.attrs["frequency"] for ds in ds_to_plot.values()] or resample == "year" ) - add_xlims_and_ticks(axes[-1], yearly_freq, plotted_data_df, aggreg_month, xticks_at_centre) + add_xlims_and_ticks( + axes[-1], yearly_freq, plotted_data_df, aggreg_month, xticks_at_centre + ) add_legend(fig, set_global_leg, annex_mode, plot_inventory, inventory_years) @@ -1197,7 +1194,7 @@ def plot_country_sector_flux_bar( resample: str | list[str] | None = None, resample_uncert_correlation: bool = False, rolling_mean: bool = False, - sectors: list[str] = ["agriculture", "waste", "energy", "industry"] + sectors: list[str] = ["agriculture", "waste", "energy", "industry"], ) -> Figure | list: """ Stacked bar plot of posterior fluxes, split by sector, for a single region, for a range of models. @@ -1391,21 +1388,23 @@ def plot_country_sector_flux_bar( return fig, plotted_data_df -def plot_all_species_stacked_bar(all_species: list[str], - ds_all_flux_scaled: dict[str, xr.Dataset], - regions: list[str], - config_data: dict[str, dict] = {}, - model_colors: dict[str, str] = {}, - model_labels: dict[str, str] = {}, - start_date: str | None = None, - end_date: str | None = None, - inventory_years: list[str] | None = None, - inventory_filename: str = "UNFCCC_inventory", - plot_inventory_uncertainty: bool = True, - data_dir: str | None = None, - sector: str = "total", - y_lim: list[float] | None = None - ) -> Figure: + +def plot_all_species_stacked_bar( + all_species: list[str], + ds_all_flux_scaled: dict[str, xr.Dataset], + regions: list[str], + config_data: dict[str, dict] = {}, + model_colors: dict[str, str] = {}, + model_labels: dict[str, str] = {}, + start_date: str | None = None, + end_date: str | None = None, + inventory_years: list[str] | None = None, + inventory_filename: str = "UNFCCC_inventory", + plot_inventory_uncertainty: bool = True, + data_dir: str | None = None, + sector: str = "total", + y_lim: list[float] | None = None, +) -> Figure: """ Stacked bar plot of posterior fluxes, summed over all species, for a single region, for a range of models. Args: @@ -1430,40 +1429,41 @@ def plot_all_species_stacked_bar(all_species: list[str], s_data = config_data.get("species_info", {}) r_data = config_data.get("regions_info", {}) species_colors = config.species_color_palette - + unit = [] - + for s in all_species: unit.append(get_unit(ds_all_flux_scaled[s])) if all(x == unit[0] for x in unit): country_flux_units_print = unit[0] else: - raise ValueError('Units for all species\' datasets are not equal.') + raise ValueError("Units for all species' datasets are not equal.") ds_to_plot = {} inventories_to_plot = {} - inventories_uncert_to_plot = {} - + models = [] - - fig,ax = plt.subplots(1,1,figsize=(10,6)) + + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) for s, species in enumerate(all_species): - ds_all_region = extract_region_flux(ds_all_flux_scaled[species], regions, r_data, sectors=sector) + ds_all_region = extract_region_flux( + ds_all_flux_scaled[species], regions, r_data, sectors=sector + ) ds_to_plot[species] = prepare_data_to_plot( ds_all_region=ds_all_region, model_labels=model_labels, model_colors=model_colors, plot_separate=True, plot_combined=False, - resample='year', + resample="year", rolling_mean=False, plot_resample_and_original=False, resample_uncert_correlation=False, aggreg_month=False, ) - - inventories_to_plot[species],inventories_uncert_to_plot[species] = retrieve_inventories( + + inventories_to_plot[species], inventories_uncert_to_plot = retrieve_inventories( data_dir, regions, species, @@ -1476,88 +1476,118 @@ def plot_all_species_stacked_bar(all_species: list[str], inventory_filename, sectors=sector, ) - + models.append(list(ds_to_plot[species].keys())[0]) - + if len(list(ds_to_plot[species].keys())) > 1: - logger.warning(f"Only the first model {list(ds_to_plot[species].keys())} will be plotted because this function is "+ - "currently only set up to plot inventory data and model output from one model.") + logger.warning( + f"Only the first model {list(ds_to_plot[species].keys())} will be plotted because this function is " + + "currently only set up to plot inventory data and model output from one model." + ) + + this_uncert = inventories_uncert_to_plot[0] + + posterior_diff = ( + ds_to_plot[species][models[s]]["posterior_upper"].values + - ds_to_plot[species][models[s]]["posterior_lower"].values + ) - if inventories_uncert_to_plot[species][0] is None: - inventories_uncert_to_plot[species][0] = np.zeros_like(inventories_to_plot[species][0].values) + if this_uncert is None: + this_uncert = np.zeros_like(inventories_to_plot[species][0].values) if s == 0: inv_plot_times = inventories_to_plot[species][0].time.values - plot_times = ds_to_plot[species][models[s]].time.values.astype('datetime64[Y]') - uncert_combined = ds_to_plot[species][models[s]]['posterior_upper'].values-ds_to_plot[species][models[s]]['posterior_lower'].values + plot_times = ds_to_plot[species][models[s]].time.values.astype( + "datetime64[Y]" + ) + uncert_combined = posterior_diff if plot_inventory_uncertainty: - inventories_uncert_combined = inventories_uncert_to_plot[species][0] + inventories_uncert_combined = this_uncert else: - uncert_combined = np.sqrt(uncert_combined**2 + (ds_to_plot[species][models[s]]['posterior_upper'].values - ds_to_plot[species][models[s]]['posterior_lower'].values)**2) + uncert_combined = np.sqrt(uncert_combined**2 + posterior_diff**2) if plot_inventory_uncertainty: - inventories_uncert_combined = np.sqrt(inventories_uncert_combined**2 + inventories_uncert_to_plot[species][0]**2) + inventories_uncert_combined = np.sqrt( + inventories_uncert_combined**2 + this_uncert**2 + ) - width = np.timedelta64(150,'D') + width = np.timedelta64(150, "D") for s, species in enumerate(all_species): - + if s == 0: bottom = None inv_bottom = None - inventory_label = f'{inventory_years[0]} Inventory' + inventory_label = f"{inventory_years[0]} Inventory" else: bottom = flux_sum inv_bottom = inventory_sum inventory_label = None - if s == (len(all_species)-1): - uncert = uncert_combined/2. + if s == (len(all_species) - 1): + uncert = uncert_combined / 2.0 if plot_inventory_uncertainty: - inventories_uncert = inventories_uncert_combined/2. + inventories_uncert = inventories_uncert_combined / 2.0 else: inventories_uncert = None else: uncert = None inventories_uncert = None - ax.bar(inv_plot_times+width, + ax.bar( + inv_plot_times + width, inventories_to_plot[species][0].values, width=width, bottom=inv_bottom, - color='lightgrey', - edgecolor='grey', + color="lightgrey", + edgecolor="grey", label=inventory_label, yerr=inventories_uncert, - error_kw={'capsize':2}) + error_kw={"capsize": 2}, + ) - ax.bar(plot_times, - ds_to_plot[species][models[s]]['posterior'].values, + ax.bar( + plot_times, + ds_to_plot[species][models[s]]["posterior"].values, width=width, bottom=bottom, color=species_colors[species], - label=s_data.get(species, {}).get('species_print', species), + label=s_data.get(species, {}).get("species_print", species), yerr=uncert, - error_kw={'capsize':2},alpha=0.8) + error_kw={"capsize": 2}, + alpha=0.8, + ) if s == 0: - flux_sum = ds_to_plot[species][models[s]]['posterior'].values + flux_sum = ds_to_plot[species][models[s]]["posterior"].values inventory_sum = inventories_to_plot[species][0].values else: - flux_sum += ds_to_plot[species][models[s]]['posterior'].values + flux_sum += ds_to_plot[species][models[s]]["posterior"].values inventory_sum += inventories_to_plot[species][0].values - ax.set_xticks(plot_times+(width/2)) - ax.set_xticklabels((plot_times+(width/2)).astype('datetime64[Y]')) + ax.set_xticks(plot_times + (width / 2)) + ax.set_xticklabels((plot_times + (width / 2)).astype("datetime64[Y]")) - #ax.set_ylabel(country_flux_units_print) - add_ylabel(ax,s_data,species=f'{regions} total',unit=country_flux_units_print, - plot_type='country_plot',sector=sector) + # ax.set_ylabel(country_flux_units_print) + add_ylabel( + ax, + s_data, + species=f"{regions} total", + unit=country_flux_units_print, + plot_type="country_plot", + sector=sector, + ) if y_lim: ax.set_ylim(y_lim) handles, labels = ax.get_legend_handles_labels() - ax.legend(handles[::-1], labels[::-1], - ncol=4,loc='upper right',borderpad=0.4,columnspacing=1.0, - fontsize=12) + ax.legend( + handles[::-1], + labels[::-1], + ncol=4, + loc="upper right", + borderpad=0.4, + columnspacing=1.0, + fontsize=12, + ) - return fig \ No newline at end of file + return fig diff --git a/fluxy/plots/mf_timeseries.py b/fluxy/plots/mf_timeseries.py index 23289fe4..0487dc71 100644 --- a/fluxy/plots/mf_timeseries.py +++ b/fluxy/plots/mf_timeseries.py @@ -1,7 +1,9 @@ import logging from typing import Literal - +from enum import Enum import numpy as np + +import pandas as pd import xarray as xr from datetime import date, timedelta from calendar import month_abbr @@ -20,12 +22,20 @@ clean_timeseries_missing_data, get_site_index, get_unique_site_height_pairs, + slice_site, + check_site_list, ) -from fluxy.plots.utils import set_min_decimal_points logger = logging.getLogger(__name__) +class PlotTypes(Enum): + SEPARATE = "separate" + TOGETHER = "together" + DIFF = "diff" + MULTIPLE_SITES = "multiple_sites" + + def plot_mf_timeseries(*args, **kwargs) -> plt.Figure: # Solve the legacy position of the include argument LEGACY_INCLUDE_POSITION = 9 @@ -138,7 +148,7 @@ def _prepare_var( elif unc_var.split("_")[-1] in ["prior", "posterior"]: unc_lower = (ds[var] - ds[unc_var]).expand_dims({"percentile": ["lower"]}) unc_upper = (ds[var] + ds[unc_var]).expand_dims({"percentile": ["upper"]}) - unc = xr.concat([unc_lower, unc_upper],dim="percentile") + unc = xr.concat([unc_lower, unc_upper], dim="percentile") unc.name = unc_var else: unc = ds[unc_var].expand_dims({"percentile": ["std"]}) @@ -206,12 +216,12 @@ def _retrieve_variable(ds, var, unc_var): def _prepare_data_to_plot( - ds_all: dict[str, xr.Dataset], + ds_all: dict[str, xr.Dataset | None], include: str | dict[str, str | None] | list | tuple, diff_include: None | list, aggreg_month: bool, time_freq_min: FrequencyType, - plot_type: Literal["separate", "together", "diff"], + plot_type: PlotTypes, ) -> dict[str, xr.Dataset]: """ Create dictionnary of datasets containing all the data that will be plotted. @@ -230,6 +240,9 @@ def _prepare_data_to_plot( and "percentile". "percentile" can take 4 values: "mean" (always present) being the main value, "lower" and "upper" (optionnals) which are the "upper" and "lower" boundaries of the associated uncertainty, and "std" (optionnal) which is the std / one side associated uncertainty. """ + + plot_type = PlotTypes(plot_type) + if not include: raise ValueError( "The include dictionary is empty. Please provide variables to include in the plot." @@ -249,8 +262,14 @@ def _prepare_data_to_plot( for m, ds in ds_all.items(): + if ds is None: + continue + # Check there is only one site in the dataset - if len(np.unique(ds.get("number_of_identifier", 0))) > 1: + if ( + len(np.unique(ds.get("number_of_identifier", 0))) > 1 + and plot_type != PlotTypes.MULTIPLE_SITES + ): raise ValueError( f"Dataset {m} contains more than one site. " "Use slice_site to select a single site." @@ -279,14 +298,14 @@ def _prepare_data_to_plot( ds_var = _prepare_aggreg_month_var(ds_clean[var]) if all_var[var] is not None: logger.warning( - f"`{all_var[var]}` present as value of include dict for {var} is overwritten as you put `aggreg_month=True`." + f"`{all_var[var]}` present as uncertainty for {var} is overwritten as you put `aggreg_month=True`." + " The uncertainty plotted is the 0.159 and 0.851 percentile of the variable for the corresponding month." ) else: ds_var = _prepare_var(ds_clean, var, unc_var, model=m) if unc_var: - if plot_type == "diff": + if plot_type == PlotTypes.DIFF: raise ValueError( f"Option plot_type='diff' does not accept uncertainties. Replace '{unc_var}' by None." ) @@ -294,6 +313,11 @@ def _prepare_data_to_plot( [data_to_plot[m], ds_var], compat="no_conflicts", join="outer" ) + # Add some global attributes to each var + attrs_global = {attr: ds_all[m].attrs[attr] for attr in ["exp_name", "species"]} + for var in data_to_plot[m].data_vars: + data_to_plot[m][var].attrs.update(attrs_global) + return data_to_plot @@ -301,10 +325,10 @@ def _set_labels_and_colors( ds_dict: dict[str, xr.Dataset], model_labels: dict[str, str], model_colors: dict[str, list], - plot_type: Literal["separate", "together", "diff"], + plot_type: PlotTypes, ) -> dict[str, xr.Dataset]: """ - Set labels and colors that will be used by plot_timeseries and plot_histogram as attributes of the variables dataset. + Set labels and colors, that will be used by add_line_scatter_plot and plot_histogram, as attributes of the variables dataset. For variables "mf_observed" and "observed_above_BC", the color will be black (and not one of model_colors) if more than one variable is plotted. Args: ds_dict: dictionnary containing the dataset with the variables to be plotted (and only them). @@ -315,12 +339,14 @@ def _set_labels_and_colors( Return: ds_dict: dictionnary of datasets where the label and color have been added as attributes of the dataset / dataset variables. """ + plot_type = PlotTypes(plot_type) + if model_colors is None: model_colors = config.set_model_colors(ds_dict.keys()) for m in ds_dict.keys(): vars_to_plot = ds_dict[m].data_vars.keys() - if plot_type == "diff": + if plot_type == PlotTypes.DIFF: mdiff0, mdiff1 = m.split("--") model_label = f"{model_labels[mdiff0]} - {model_labels[mdiff1]}" model_color = model_colors[mdiff0] @@ -347,7 +373,7 @@ def _set_labels_and_colors( def _create_figure( models: list[str], - plot_type: Literal["separate", "together", "diff"], + plot_type: PlotTypes, histogram_type: str | None, aggreg_month: bool, ) -> tuple[Figure, Axes]: @@ -362,9 +388,12 @@ def _create_figure( Return: fig, ax: matplotlib figure and axes object with appropriate sizes. """ - if plot_type == "separate": + + plot_type = PlotTypes(plot_type) + + if plot_type == PlotTypes.SEPARATE: nrows = len(models) - elif plot_type in ["together", "diff"]: + elif plot_type in [PlotTypes.TOGETHER, PlotTypes.DIFF]: nrows = 1 else: raise ValueError( @@ -406,19 +435,22 @@ def _get_unit(ds_dict: dict[str, xr.Dataset]) -> str: plot_units.append(ds_dict[m][var].attrs["units"]) plot_units = list(set(plot_units)) - if len(plot_units) != 1: + if len(plot_units) > 1: + variables_set = {ds_dict[m].data_vars.keys() for m in ds_dict.keys()} raise ValueError( - f"{ds_dict[m].data_vars.keys()} in {ds_dict.keys()} do not have the same units. So far, the following were found: {plot_units}." + f"{variables_set} do not have the same units. So far, the following were found: {plot_units}." + "Select only one model to plot, or run 'slice_mf' with 'mf_units_print' equal to a valid mole fraction unit before running 'plot_timeseries'." ) + elif len(plot_units) == 0: + return "" return plot_units[0] -def add_xlims_and_ticks( +def _add_xlims_and_ticks( ax: Axes, yearly_freq: bool, - res_dict: dict[str, dict[str, xr.Dataset] | xr.Dataset], + plotted_data_df: pd.DataFrame, aggreg_month: bool, rotate_xticks: bool = False, ): @@ -436,18 +468,8 @@ def add_xlims_and_ticks( ax.set_xticklabels(list(month_abbr)[1:]) return - min_x, max_x = date(2100, 1, 1), date(1900, 1, 1) - for key_1 in res_dict.keys(): - if isinstance(res_dict[key_1], dict): - for key_2 in res_dict[key_1].keys(): - for key_3 in res_dict[key_1][key_2].keys(): - time = res_dict[key_1][key_2][key_3]["time"] - min_x = np.nanmin([*time, min_x]) - max_x = np.nanmax([*time, max_x]) - else: - time = res_dict[key_1].time.values.astype("datetime64[D]").tolist() - min_x = np.nanmin([*time, min_x]) - max_x = np.nanmax([*time, max_x]) + min_x = np.nanmin(plotted_data_df["time"]) + max_x = np.nanmax(plotted_data_df["time"]) # set xticks year_range = date(max_x.year, 1, 1) - date(min_x.year, 1, 1) @@ -460,7 +482,9 @@ def add_xlims_and_ticks( min_x = date(min_x.year, 1, 1) max_x = date(max_x.year + 1, 1, 1) step = (max_x.year - min_x.year) // 8 + 1 - xticks = np.array([date(year, 1, 1) for year in range(min_x.year, max_x.year, step)]) + xticks = np.array( + [date(year, 1, 1) for year in range(min_x.year, max_x.year, step)] + ) if (max_x.year - min_x.year) % step == 0: xticks = np.append(xticks, max_x) ax.set_xticks(xticks) @@ -475,6 +499,141 @@ def add_xlims_and_ticks( ax.set_xlim(xlim) +def get_minmax_unc(ds: xr.Dataset) -> tuple[list, list] | tuple[None, None]: + """Get uncertainty min and max values in the dataset. + + Args: + ds: dataset containing posterior/prior data + Return: + min_unc: list of the same size as ds.time containing + the minimum values of the uncertainty band to plot. + max_unc: same as min_unc but for the maximum values. + """ + + if ds.percentile.size == 3: + min_unc = ds.sel(percentile="lower").values + max_unc = ds.sel(percentile="upper").values + + elif ds.percentile.size == 2: + std_values = ds.sel(percentile="std").values + min_unc, max_unc = std_values, std_values + else: + min_unc, max_unc = None, None + + return min_unc, max_unc + + +def add_unc_plot( + ax: Axes, + min_unc: list[np.float64], + max_unc: list[np.float64], + da: xr.DataArray, + plot_type: Literal["Errorbar", "FillBetween"] | None = None, +): + """ + Get uncertainty min and max values to plot the uncertainty band in add_line_scatter_plot. + Args: + ax: axes on which to plot + min_unc: list of the same size as da.time containing the minimum values of the uncertainty band to plot. + max_unc: same as min_unc but for the maximum values. + da: dataarray to plot. + plot_type: type of plot to use for uncertainty ("Errorbar" or "FillBetween") + """ + time_as_datetime = da.time.values.astype("datetime64[D]").tolist() + + if plot_type is None: + return + + min_unc = np.array(min_unc, dtype=float) + max_unc = np.array(max_unc, dtype=float) + + mean = da.sel(percentile="mean").values + + if plot_type == "FillBetween": + ax.fill_between( + time_as_datetime, + y1=min_unc, + y2=max_unc, + alpha=0.2, + color=da.attrs["plot_color"], + ) + + elif plot_type == "Errorbar": + ax.errorbar( + time_as_datetime, + y=da.sel(percentile="mean"), + yerr=[mean - min_unc, max_unc - mean], + alpha=0.4, + fmt="none", + color=da.attrs["plot_color"], + ) + else: + raise ValueError( + f"Uncertainty plot type {plot_type} not implemented. Set plot_type to 'Errorbar' or 'FillBetween'." + ) + + +def add_line_scatter_plot( + ax: Axes, + da: xr.DataArray, + plot_type: PlotTypes, + add_unc: bool = True, + marker: str = "s", + unc_type: Literal["Errorbar", "FillBetween"] | None = None, +) -> pd.DataFrame: + """ + Add line/scatter plot with uncertainty if requested. + Args: + ax: axes on which to plot + da: dataarray containing mole fraction data to plot + plot_type: used to determine labels to display + add_unc: if True, plot uncertainty + unc_type: type of plot to use for uncertainty ("Errorbar" or "FillBetween") + Return: + res: dataframe with data plotted + """ + + plot_type = PlotTypes(plot_type) + + time_as_datetime = da.time.values.astype("datetime64[D]").tolist() + + kwargs = { + "alpha": 0.8, + "color": da.attrs["plot_color"], + "label": da.attrs["plot_label"], + } + + if plot_type == PlotTypes.MULTIPLE_SITES: + type_plot = da.attrs["plot_label"] + else: + type_plot = da.name + + if da.name in ["mf_observed", "observed_above_BC"] or plot_type == PlotTypes.DIFF: + plot_func = ax.scatter + kwargs.update({"s": 8, "marker": marker}) + else: + plot_func = ax.plot + kwargs.update({"linewidth": 2.0, "marker": "o", "markersize": 1.5}) + + plot_func(time_as_datetime, da.sel(percentile="mean"), **kwargs) + + res = pd.DataFrame( + { + "time": time_as_datetime, + "mean_val": da.sel(percentile="mean").values, + } + ) + res["type"] = type_plot + res["model"] = da.attrs["exp_name"] + res["species"] = da.attrs["species"] + + if add_unc: + res["min_unc"], res["max_unc"] = get_minmax_unc(da) + add_unc_plot(ax, res["min_unc"], res["max_unc"], da, unc_type) + + return res + + def plot_timeseries( ds_all: dict[str, xr.Dataset], include: VariableType, @@ -485,7 +644,7 @@ def plot_timeseries( config_data: dict[str, dict] = {}, annotate_coords: dict[int, list] = {}, presentation_mode: bool = False, - plot_type: Literal["separate", "together", "diff"] = "separate", + plot_type: PlotTypes = PlotTypes.SEPARATE, diff_include: list[str] | None = None, y_lim: None | tuple[float | None, float | None] = None, n_bins: int = 30, @@ -494,6 +653,7 @@ def plot_timeseries( histogram_type: Literal["hist", "violin", "none"] | None = "hist", hist_kwargs: dict[str, any] = {}, aggreg_month: bool = False, + unc_type: Literal["Errorbar", "FillBetween"] = "FillBetween", ): """ Timeseries plots of observations, modelled mole fractions, baseline mf and/or @@ -514,6 +674,8 @@ def plot_timeseries( Obs site, e.g. 'MHD'. model_colors (dict of str): Models and corresponding colours used to plot the model. + model_labels (dict of str): + Labels to use for each model config_data (dict of dict): Dictionary with settings read from json file. Use json filenames as keys. @@ -533,15 +695,21 @@ def plot_timeseries( line. If the frequency is lower than this, the line will be discontinous. see :py:func:`fluxy.operators.select.clean_timeseries_missing_data` for more information. + unc_type (str, optional): + Type of plot to use for uncertainty ("Errorbar" or "FillBetween"). Default is "FillBetween". Returns: fig (figure): A timeseries and histogram plot for each model included. """ + plot_type = PlotTypes(plot_type) + models = ds_all.keys() species_info = config_data.get("species_info", {}).get(species, {}) + plotted_data_df = pd.DataFrame() + # Check the include dictionary data_to_plot = _prepare_data_to_plot( ds_all, include, diff_include, aggreg_month, time_freq_min, plot_type @@ -565,59 +733,14 @@ def plot_timeseries( for i, m in enumerate(models): # Define plot_type specific settings - iax = i if plot_type == "separate" else 0 + iax = i if plot_type == PlotTypes.SEPARATE else 0 # Loop over all variables to plot for var in include.keys(): - - ds_plot = data_to_plot[m][var] - - # Define plotting color - x, y = ds_plot.time.values, ds_plot.sel(percentile="mean").values - kwargs = { - "alpha": 0.8, - "color": ds_plot.attrs["plot_color"], - "label": ds_plot.attrs["plot_label"], - } - - if var in ["mf_observed", "observed_above_BC"] or plot_type == "diff": - # Make scatter plot - ax[iax, 0].scatter( - x, - y, - s=8, - marker="s", - **kwargs, - ) - - else: - # Make line plot - ax[iax, 0].plot( - x, - y, - linewidth=2.0, - marker="o", - markersize=1.5, - **kwargs, - ) - - if ds_plot.percentile.size == 3: - ax[iax, 0].fill_between( - x, - y1=ds_plot.sel(percentile="lower"), - y2=ds_plot.sel(percentile="upper"), - alpha=0.2, - color=ds_plot.attrs["plot_color"], - ) - elif ds_plot.percentile.size == 2: - ax[iax, 0].errorbar( - x, - y=ds_plot.sel(percentile="mean"), - yerr=ds_plot.sel(percentile="std"), - alpha=0.4, - fmt="none", - color=ds_plot.attrs["plot_color"], - ) + res = add_line_scatter_plot( + ax[iax, 0], data_to_plot[m][var], plot_type, unc_type=unc_type + ) + plotted_data_df = pd.concat([plotted_data_df, res], ignore_index=True) # Plot histogram if ax.shape[1] == 2: @@ -641,10 +764,14 @@ def plot_timeseries( max_mf = max(max_mf, ax[iax, 0].get_ylim()[1]) # Set timeseries title - if plot_type in ["separate", "diff"]: + if plot_type in [PlotTypes.SEPARATE, PlotTypes.MULTIPLE_SITES]: plot_title = data_to_plot[m].attrs["label"] - elif plot_type == "together": + elif plot_type == PlotTypes.TOGETHER: plot_title = "All models" + elif plot_type == PlotTypes.DIFF: + plot_title = f"Diff" + else: + raise ValueError(f"Option {plot_type} not implemented.") ax[iax, 0].set_title(plot_title) # Set timeseries y-axis label and legend @@ -676,10 +803,10 @@ def plot_timeseries( if len(data_to_plot[m].time) <= 1: continue - add_xlims_and_ticks( + _add_xlims_and_ticks( ax[iax, 0], yearly_freq=False, - res_dict=data_to_plot, + plotted_data_df=plotted_data_df, aggreg_month=aggreg_month, rotate_xticks=presentation_mode, ) @@ -845,7 +972,7 @@ def plot_histogram( presentation_mode: bool, annotate_coords: dict[int, list], annotate_index: int, - plot_type: Literal["separate", "together", "diff"], + plot_type: PlotTypes | str, n_bins: int = 30, violin: bool = False, **kwargs, @@ -873,11 +1000,13 @@ def plot_histogram( Coordinates to annotate histogram. annotate_index (int): Model index. Used to specify annotation location if plot_type == "together". - plot_type (str): + plot_type (PlotTypes | str): Type of timeseries plot in which the histogram will be plotted. Options for "separate", "together" and "diff". """ + plot_type = PlotTypes(plot_type) + if not annotate_coords: annotate_coords = config.set_print_settings(presentation_mode) @@ -918,9 +1047,9 @@ def plot_histogram( if diff_include: ax.vlines(0, 0, np.max(a), color="dimgrey", linewidth=3.0) - if plot_type in ["separate", "diff"]: + if plot_type in [PlotTypes.SEPARATE, PlotTypes.DIFF]: index = v - elif plot_type == "together": + elif plot_type == PlotTypes.TOGETHER: index = annotate_index # Compute and format mean and std of the histogram @@ -931,7 +1060,7 @@ def plot_histogram( # Write mean/std to histogram # If plot_type = togehter, print only mean/std of the first variable - if not (plot_type == "together" and v != 0): + if not (plot_type == PlotTypes.TOGETHER and v != 0): xcoord = annotate_coords["x"] ycoord = annotate_coords["ytop"] - index * annotate_coords["dy"] ax.annotate( @@ -942,7 +1071,7 @@ def plot_histogram( ) # Write number of obs - if plot_type == "separate" and len(hist_to_plot) == 1: + if plot_type == PlotTypes.SEPARATE and len(hist_to_plot) == 1: var = list(vars_to_plot)[0] values = ds[var].values mask_not_nan = ~np.isnan(values) @@ -960,3 +1089,179 @@ def plot_histogram( ax.set_xlabel(legend_hist) return None + + +def plot_sites_list_mf( + ds_all: dict[str, xr.Dataset], + sites: list[str] | None = None, + species: str | None = None, + include: VariableType = {"mf_posterior": None}, + model_labels: dict[str, dict] = {}, + model_colors: dict[str, list] = {}, + aggreg_month: bool = False, + config_data: dict[str, dict] = {}, + data_on_single_graph: Literal["models", "sites"] = "sites", + unc_type: Literal["Errorbar", "FillBetween"] = "FillBetween", + obs_markers: list[str] = ["s", "v", "^", "<", ">", "p", "P", "*", "+"], +) -> tuple[Figure, pd.DataFrame]: + """ + Plot timeseries of multiple site, with subplots separated by site or model. + Args: + ds_all: xarray datasets, scaled and sliced between chosen dates and for + chosen site. + sites: Obs sites list, e.g. ['MHD', 'CBW']. + species: Gas species, e.g. 'ch4'. + include (dict of str): + Dictionary keys are variables to include in the plot. + The respective values are the uncertainty variables to plot as error bar/uncertainty band. + model_labels: labels to use for each model + aggreg_month: if True, plot the data aggregated by month. + Used to study seasonal cycle. + config_data: Dictionary with settings read from json file. + Use json filenames as keys. + data_on_single_graph: str, "sites" or "models", determines whether each axis has multiple sites or models. + unc_type: type of plot to use for uncertainty ("Errorbar" or "FillBetween"). Default is "FillBetween". + obs_markers: list of markers to use for observed data. + Returns: + fig: figure created + plotted_data_df: data plotted on the figure + + + """ + ds_all_p = ds_all.copy() + + plot_type = PlotTypes.MULTIPLE_SITES + models = ds_all_p.keys() + species_info = config_data.get("species_info", {}).get(species, {}) + plotted_data_df = pd.DataFrame() + + # Look for sites + sites = check_site_list(sites, ds_all_p) + + if data_on_single_graph == "models": + axes_looper = models + elif data_on_single_graph == "sites": + axes_looper = sites + else: + raise ValueError("data_on_single_graph should be 'models' or 'sites'") + + ncols = 1 + nrows = len(axes_looper) + length = 8 if aggreg_month else 15 + fig, ax = plt.subplots( + nrows, + ncols, + figsize=(length, nrows * 3), + constrained_layout=True, + sharey="row", + sharex="col", + squeeze=False, + ) + + ax = ax.flatten() + + for isite, site in enumerate(sites): + + ds_all_site = slice_site(ds_all_p, site, raise_error=False) + # Prepare data to plot + data_to_plot = _prepare_data_to_plot( + ds_all_site, + include, + diff_include=None, + time_freq_min=None, + aggreg_month=aggreg_month, + plot_type=plot_type, + ) + unit = _get_unit(data_to_plot) + for im, m in enumerate(models): + if m not in data_to_plot or len(data_to_plot[m].dims) == 0: + logger.warning(f"Model {m} not found for site {site}, skipping.") + continue + # Select site + if m not in model_colors: + model_colors[m] = config.get_default_colors() + attrs = { + "sites": { + "plot_label": model_labels.get(m, m), + "plot_color": model_colors[m][0], + }, + "models": {"plot_label": site, "plot_color": f"C{isite:02d}"}, + } + + if data_on_single_graph == "models": + ax_index = im + marker_index = isite + elif data_on_single_graph == "sites": + ax_index = isite + marker_index = im + + if marker_index >= len(obs_markers): + logger.warning( + f"More than {len(obs_markers)} sites or models to plot, " + "some markers will be reused." + ) + marker_index = marker_index % len(obs_markers) + + for variable in include.keys(): + attrs_update = attrs[data_on_single_graph].copy() + if ( + variable in ["mf_observed", "observed_above_BC"] + and len(include.keys()) > 1 + ): + attrs_update["plot_color"] = "black" + label_add = f" {config.mf_labels.get(variable, variable)}" + attrs_update["plot_label"] += label_add + data_to_plot[m][variable].attrs.update(attrs_update) + res = add_line_scatter_plot( + ax[ax_index], + data_to_plot[m][variable], + plot_type=plot_type, + marker=obs_markers[marker_index], + add_unc=include[variable], + unc_type=unc_type, + ) + plotted_data_df = pd.concat([plotted_data_df, res], ignore_index=True) + + for iaxes, label in enumerate(axes_looper): + if data_on_single_graph == "models": + label = model_labels[label] + ax[iaxes].set_title(label) + ax[iaxes].set_ylabel( + " ".join( + [ + species_info.get("species_print", ""), + f"({unit})", + ] + ) + ) + + handles, labels = fig.axes[-1].get_legend_handles_labels() + if len(labels) == 0: + continue + fig.legend( + handles, + labels, + loc="upper center", + ncol=( + len(labels) if len(labels) <= 6 else len(labels) // 2 + len(labels) % 2 + ), + borderpad=0.4, + columnspacing=1.0, + bbox_to_anchor=[0.5, 0.96], + ) + + _add_xlims_and_ticks( + ax[iaxes], + yearly_freq=False, + plotted_data_df=plotted_data_df, + aggreg_month=aggreg_month, + ) + + ax[iaxes].grid(color="lightgrey", linestyle="-", linewidth=0.7) + ax[iaxes].set_axisbelow(True) + + logger.info( + "If annotations in the histograms are not displaying correctly, adjust annotate_coords." + ) + + return fig, plotted_data_df diff --git a/scripts/PARIS_inversion_results.ipynb b/scripts/PARIS_inversion_results.ipynb index a6d3768d..cd090512 100644 --- a/scripts/PARIS_inversion_results.ipynb +++ b/scripts/PARIS_inversion_results.ipynb @@ -538,6 +538,61 @@ " y_lim=None,intake_height=intake_height)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Timeseries plot, separated by model, all sites together:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fluxy.operators.select import slice_mf\n", + "from fluxy.plots.mf_timeseries import plot_sites_list_mf\n", + "\n", + "###################################\n", + "### edit slicing options in this block\n", + "sites = ['MHD','TAC']\n", + "intake_height = None\n", + "mf_units_print = 'ppt'\n", + "start_date = '2018-01-01' #inclusive\n", + "end_date = '2019-01-01' #not inclusive\n", + "baseline_site = None #'MHD', 'JFJ' or 'CMN'. If None, does not mask by baseline time\n", + "baseline_filename = 'InTEM_baseline_timestamps'\n", + "###################################\n", + "\n", + "### edit plotting options in this block\n", + "include = {'mf_observed' : None,\n", + " 'mf_posterior' : 'percentile_mf_posterior'\n", + " }\n", + "\n", + "aggreg_month = False\n", + "data_on_single_graph=\"sites\" # if \"models\", plots all models on the same graph. If \"sites\", plots all sites on the same graph.\n", + "uncertainty_type=\"FillBetween\"\n", + "###################################\n", + "\n", + "ds_all_mf_sliced_multisites = slice_mf(ds_all_mf.copy(),start_date,end_date,sites,baseline_site=baseline_site,\n", + " baseline_filename=baseline_filename,data_dir=data_dir,\n", + " mf_units_print=mf_units_print,intake_height=intake_height)\n", + "\n", + "fig = plot_sites_list_mf(ds_all_mf_sliced_multisites,\n", + " sites,\n", + " species,\n", + " include,\n", + " model_labels,\n", + " model_colors,\n", + " aggreg_month,\n", + " config_data,\n", + " data_on_single_graph,\n", + " unc_type=uncertainty_type)\n", + "\n", + "\n" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/scripts/example_basics.ipynb b/scripts/example_basics.ipynb index 63edcb25..3f4407d9 100644 --- a/scripts/example_basics.ipynb +++ b/scripts/example_basics.ipynb @@ -542,6 +542,27 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### MF across sites\n", + "\n", + "This function can plot the timeseries of multiple sites together." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from fluxy.plots.mf_timeseries import plot_sites_list_mf\n", + "\n", + "\n", + "fig, plotted_data_df = plot_sites_list_mf(ds_all_mf)\n" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/tests/test_operators.py b/tests/test_operators.py index 96e392c1..388f7132 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -20,10 +20,10 @@ test_models_with_inlet, ) -ds_all_mf = get_loaded_models(test_models,"concentration") -ds_all_flux = get_loaded_models(test_models,"flux") +ds_all_mf = get_loaded_models(test_models, "concentration") +ds_all_flux = get_loaded_models(test_models, "flux") -ds_all_mf_with_inlet = get_loaded_models(test_models_with_inlet,"concentration") +ds_all_mf_with_inlet = get_loaded_models(test_models_with_inlet, "concentration") # Test the difference between all available models @@ -80,11 +80,12 @@ def test_slice_height(model): ds_all_mf__with_inlet_sliced = slice_site(ds_all_mf_with_inlet[model], site="TAC") ds_sliced = slice_height(ds_all_mf__with_inlet_sliced, intake_height=185) - + print(ds_sliced) assert np.unique(ds_sliced["intake_height"].values == 185.0) + ds_test = xr.Dataset( { "var1": (["index"], [1, 2, None, 4, 5]), diff --git a/tests/test_plots.py b/tests/test_plots.py index 0e1c8844..83cd1fc6 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -14,10 +14,11 @@ plot_flux_map_combined_models_comparison, plot_flux_map_period_comparison, ) -from fluxy.plots.flux_timeseries import plot_country_flux,plot_country_sector_flux_bar +from fluxy.plots.flux_timeseries import plot_country_flux, plot_country_sector_flux_bar from fluxy.plots.mf_timeseries import ( plot_mf_timeseries, - plot_sites_timeseries + plot_sites_timeseries, + plot_sites_list_mf, ) from fluxy.operators.mf import compute_mf_difference from fluxy.plots.mf_stats import plot_stats_mf, plot_taylor_diagram @@ -59,7 +60,13 @@ data_dir, "flux", species, models, config_data, period=period ) ds_all_flux_with_sites = read_model_output( - data_dir, "flux", species, models, config_data, period=period, add_sites_to_flux=True, + data_dir, + "flux", + species, + models, + config_data, + period=period, + add_sites_to_flux=True, ) for m in models: @@ -79,10 +86,11 @@ end_date, species=species, country_flux_units_print=country_flux_units_print, - )[m] + )[m] site = "MHD" +multiple_sites = ["MHD", "TAC"] baseline_site = None mf_units_print = "ppt" ds_all_mf = read_model_output( @@ -104,6 +112,16 @@ mf_units_print=mf_units_print, ) +ds_all_mf_sliced_multiple_sites = slice_mf( + ds_all_mf.copy(), + start_date, + end_date, + multiple_sites, + baseline_site=baseline_site, + data_dir=data_dir, + mf_units_print=mf_units_print, +) + model_colors = set_model_colors(models) model_labels = set_model_labels(models, config_data, get_labels_from_file) @@ -128,37 +146,38 @@ def test_country_flux_default(): """Test country flux with default settings.""" - plot_country_flux( - ds_all_flux_scaled, - species - ) + plot_country_flux(ds_all_flux_scaled, species) + def test_country_flux_with_inventory_raises_no_datadir(): """Test that ValueError is raised if plot_inventory=True and data_dir is not provided.""" - with pytest.raises(ValueError, match="data_dir must be provided to plot inventory data."): + with pytest.raises( + ValueError, match="data_dir must be provided to plot inventory data." + ): plot_country_flux( ds_all_flux_scaled, species, plot_inventory=True, ) + def test_flux_timeseries(): kwargs = dict( - data_dir = data_dir, - plot_inventory = False, - inventory_years = None, - fix_y_axes = False, - add_prior = True, - add_prior_unc = False, - set_global_leg = True, - country_codes_as_titles = False, - plot_separate = True, - plot_combined = False, - resample = None, - resample_uncert_correlation = False, - plot_resample_and_original = False, - annex_mode = False, - rolling_mean = False, + data_dir=data_dir, + plot_inventory=False, + inventory_years=None, + fix_y_axes=False, + add_prior=True, + add_prior_unc=False, + set_global_leg=True, + country_codes_as_titles=False, + plot_separate=True, + plot_combined=False, + resample=None, + resample_uncert_correlation=False, + plot_resample_and_original=False, + annex_mode=False, + rolling_mean=False, ) plot_country_flux( @@ -170,29 +189,29 @@ def test_flux_timeseries(): model_labels, start_date, end_date, - **kwargs + **kwargs, ) def test_flux_timeseries_combined_unc(): kwargs = dict( - data_dir = data_dir, - plot_inventory = False, - inventory_years = None, - fix_y_axes = False, - add_prior = True, - add_prior_unc = False, - set_global_leg = True, - country_codes_as_titles = False, - plot_separate = True, - plot_separate_unc = False, - plot_combined = True, - plot_combined_unc = True, - resample = None, - resample_uncert_correlation = False, - plot_resample_and_original = False, - annex_mode = False, - rolling_mean = False, + data_dir=data_dir, + plot_inventory=False, + inventory_years=None, + fix_y_axes=False, + add_prior=True, + add_prior_unc=False, + set_global_leg=True, + country_codes_as_titles=False, + plot_separate=True, + plot_separate_unc=False, + plot_combined=True, + plot_combined_unc=True, + resample=None, + resample_uncert_correlation=False, + plot_resample_and_original=False, + annex_mode=False, + rolling_mean=False, ) plot_country_flux( @@ -204,34 +223,35 @@ def test_flux_timeseries_combined_unc(): model_labels, start_date, end_date, - **kwargs + **kwargs, ) + def test_flux_timeseries_multi_combined(): combined_models_dict = { - 'All Mean': models, - 'Partial Mean': models[0:2], + "All Mean": models, + "Partial Mean": models[0:2], } kwargs = dict( - data_dir = data_dir, - plot_inventory = False, - inventory_years = None, - fix_y_axes = False, - add_prior = True, - add_prior_unc = False, - set_global_leg = True, - country_codes_as_titles = False, - plot_separate = True, - plot_separate_unc = False, - plot_combined = True, - plot_combined_unc = True, - resample = None, - resample_uncert_correlation = False, - plot_resample_and_original = False, - annex_mode = False, - rolling_mean = False, - combined_models_dict = combined_models_dict, + data_dir=data_dir, + plot_inventory=False, + inventory_years=None, + fix_y_axes=False, + add_prior=True, + add_prior_unc=False, + set_global_leg=True, + country_codes_as_titles=False, + plot_separate=True, + plot_separate_unc=False, + plot_combined=True, + plot_combined_unc=True, + resample=None, + resample_uncert_correlation=False, + plot_resample_and_original=False, + annex_mode=False, + rolling_mean=False, + combined_models_dict=combined_models_dict, ) plot_country_flux( @@ -243,7 +263,7 @@ def test_flux_timeseries_multi_combined(): model_labels, start_date, end_date, - **kwargs + **kwargs, ) @@ -257,7 +277,7 @@ def test_mf_timeseries(): end_date, model_colors, model_labels, - config_data + config_data, ) @@ -276,6 +296,7 @@ def test_obs_modelled_separate(): y_lim=None, ) + def test_obs_modelled_aggreg_month(): fig = plot_mf_timeseries( ds_all_mf_sliced, @@ -286,11 +307,17 @@ def test_obs_modelled_aggreg_month(): config_data, annotate_coords, plot_type="separate", - include={"mf_observed": None, "mf_posterior": None, "posterior_above_BC":None, "observed_posterior_diff":None}, + include={ + "mf_observed": None, + "mf_posterior": None, + "posterior_above_BC": None, + "observed_posterior_diff": None, + }, y_lim=None, - aggreg_month=True + aggreg_month=True, ) + def test_mf_timeseries_no_hist(): fig = plot_mf_timeseries( ds_all_mf_sliced, @@ -324,7 +351,6 @@ def test_mf_timeseries_bad_uncertainty_var(): ) - def test_obs_modelled_together(): fig = plot_mf_timeseries( @@ -361,6 +387,43 @@ def test_mole_fraction_diff(): ) +def test_plot_sites_list_mf(): + + fig = plot_sites_list_mf(ds_all_mf_sliced_multiple_sites) + + +def test_mf_timeseries_multiple_site_axes(): + aggreg_month = False + fig = plot_sites_list_mf( + ds_all_mf_sliced_multiple_sites, + multiple_sites, + species, + {"mf_observed": None, "mf_posterior": "percentile_mf_posterior"}, + model_labels, + model_colors, + aggreg_month, + config_data, + "sites", + unc_type="FillBetween", + ) + + +def test_mf_timeseries_multiple_model_axes(): + aggreg_month = False + fig = plot_sites_list_mf( + ds_all_mf_sliced_multiple_sites, + multiple_sites, + species, + {"mf_observed": None, "mf_posterior": "percentile_mf_posterior"}, + model_labels, + model_colors, + aggreg_month, + config_data, + "models", + unc_type="FillBetween", + ) + + def test_plot_stats(): ds_all_allsites = slice_mf( ds_all_mf.copy(), @@ -525,7 +588,7 @@ def test_plot_country_sector_flux_bar(): regions=regions, sector_file=sector_file, create_region_sector_totals=create_region_sector_totals, - cell_area_test_file=True + cell_area_test_file=True, ) fig = plot_country_sector_flux_bar( @@ -541,6 +604,7 @@ def test_plot_country_sector_flux_bar(): sectors=["agriculture", "waste", "energy", "industry"], ) + def test_plot_flux_map_combined_models_comparison(): var = "flux_total_posterior" @@ -565,12 +629,13 @@ def test_plot_flux_map_combined_models_comparison(): set_fluxlim_percentile=set_fluxlim_percentile, ) + def test_plot_flux_map_period_comparison(): var = "flux_total_posterior" - start_dates = ['2018-01-01', '2020-01-01'] - end_dates = ['2021-01-01', '2024-01-01'] + start_dates = ["2018-01-01", "2020-01-01"] + end_dates = ["2021-01-01", "2024-01-01"] fig = plot_flux_map_period_comparison( ds_all=ds_all_flux_with_sites_scaled, @@ -588,6 +653,4 @@ def test_plot_flux_map_period_comparison(): add_markers=add_markers, set_fluxlim=set_fluxlim, set_fluxlim_percentile=set_fluxlim_percentile, - ) -