diff --git a/compressed-datasets/.gitignore b/compressed-datasets/.gitignore index 0336a5a..0f3db69 100644 --- a/compressed-datasets/.gitignore +++ b/compressed-datasets/.gitignore @@ -1,2 +1,2 @@ -/*/*/decompressed.zarr -/*/*/measurements.json +/*/*/*/decompressed.zarr +/*/*/*/measurements.json diff --git a/metrics/.gitignore b/metrics/.gitignore index 581f497..dacf85c 100644 --- a/metrics/.gitignore +++ b/metrics/.gitignore @@ -1,3 +1,3 @@ -/*/*/metrics.csv -/*/*/tests.csv +/*/*/*/metrics.csv +/*/*/*/tests.csv /all_results.csv diff --git a/plots/.gitignore b/plots/.gitignore new file mode 100644 index 0000000..844f945 --- /dev/null +++ b/plots/.gitignore @@ -0,0 +1,3 @@ +/*.png +/*/*.png +/*/*/*.png diff --git a/pyproject.toml b/pyproject.toml index 8a92df2..80954e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,10 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "astropy~=7.0.1", + "cartopy~=0.24.1", "cf-xarray~=0.10", "dask>=2024.12.0,<2025.4", + "matplotlib~=3.8", "numcodecs>=0.13.0,<0.17", "numcodecs-combinators[xarray]~=0.2.4", "numcodecs-observers~=0.1.1", @@ -24,6 +26,7 @@ dependencies = [ "numcodecs-wasm-zlib~=0.3.0", "pandas~=2.2", "scipy~=1.14", + "seaborn~=0.13.2", "tabulate~=0.9", "typed-classproperties~=1.1.0", "xarray>=2024.11.0,<2025.4", @@ -38,6 +41,7 @@ dev = [ "pre-commit~=4.0", "ruff~=0.9", "scipy-stubs~=1.15", + "types-seaborn~=0.13.2.20250111", ] [tool.setuptools.packages.find] @@ -47,5 +51,5 @@ where = ["src"] "climatebenchpress.compressor" = ["py.typed"] [[tool.mypy.overrides]] -module = ["numcodecs.*", "astropy.convolution.*"] +module = ["numcodecs.*", "astropy.convolution.*", "cartopy.*"] follow_untyped_imports = true diff --git a/src/climatebenchpress/compressor/plotting/__init__.py b/src/climatebenchpress/compressor/plotting/__init__.py new file mode 100644 index 0000000..c9c2ef6 --- /dev/null +++ b/src/climatebenchpress/compressor/plotting/__init__.py @@ -0,0 +1 @@ +__all__: list[str] = [] diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py new file mode 100644 index 0000000..62eeabb --- /dev/null +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -0,0 +1,405 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import xarray as xr + +from .variable_plotters import PLOTTERS + +COMPRESSOR2LINEINFO = { + "jpeg2000": ("#EE7733", "-"), + "zfp": ("#EE3377", "--"), + "sz3": ("#CC3311", "-."), + "bitround-pco-conservative-rel": ("#0077BB", ":"), + "bitround-conservative-rel": ("#33BBEE", "-"), + "stochround": ("#009988", "--"), + "tthresh": ("#BBBBBB", "-."), +} + +COMPRESSOR2LEGEND_NAME = { + "jpeg2000": "JPEG2000", + "zfp": "ZFP", + "sz3": "SZ3", + "bitround-pco-conservative-rel": "BitRound + PCO", + "bitround-conservative-rel": "BitRound + Zlib", + "stochround": "StochRound", + "tthresh": "TTHRESH", +} + + +def plot_metrics( + basepath: Path = Path(), bound_names: list[str] = ["low", "mid", "high"] +): + metrics_path = basepath / "metrics" + plots_path = basepath / "plots" + + df = pd.read_csv(metrics_path / "all_results.csv") + plot_per_variable_metrics( + basepath=basepath, + plots_path=plots_path, + all_results=df, + ) + + df = rename_error_bounds(df, bound_names) + normalized_df = normalize(df, bound_normalize="mid") + + plot_bound_violations( + normalized_df, bound_names, plots_path / "bound_violations.pdf" + ) + + for metric in ["Relative MAE", "Relative DSSIM", "Relative MaxAbsError"]: + plot_aggregated_rd_curve( + normalized_df, + plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", + compression_metric="Relative CR", + distortion_metric=metric, + agg="median", + bound_names=bound_names, + ) + + +def rename_error_bounds(df, bound_names): + """Give error bound consistent names between variables. By default the error bounds + have the pattern {variable_name}-{bound_type}={bound_value}.""" + # Get unique variables + variables = df["Variable"].unique() + + # Process each variable + for variable in variables: + var_selector = df["Variable"] == variable + var_data = df[var_selector] + + error_bounds = sorted( + var_data["Error Bound"].unique(), + key=lambda x: float(x.split("=")[1].split("_")[0]), + ) + + assert len(error_bounds) == len(bound_names) + for i in range(len(error_bounds)): + bound_selector = var_data["Error Bound"] == error_bounds[i] + df.loc[bound_selector & var_selector, "Error Bound"] = bound_names[i] + + return df + + +def normalize(data, bound_normalize="mid"): + """Generate normalized metrics for each compressor and variable. The normalization + first computes the 'best compressor' with the highest average rank over all variables (ranked by + compression ratio). + + For each metric, the normalization is done by dividing the metric by the value of the + 'best compressor' for the same variable and error bound, i.e.: + normalized_metric = metric[compressor, variable] / metric[best_compressor, variable]. + """ + # Group by Variable and rank compressors within each variable + ranked = data.copy() + ranked = ranked[ranked["Error Bound"] == bound_normalize] + ranked["CompRatio_Rank"] = ranked.groupby("Variable")[ + "Compression Ratio [raw B / enc B]" + ].rank(ascending=False) + + # Calculate average rank for each compressor across all variables + avg_ranks = ranked.groupby("Compressor")["CompRatio_Rank"].mean().reset_index() + avg_ranks.columns = ["Compressor", "Average_Rank"] + avg_ranks = avg_ranks.sort_values("Average_Rank") + + best_compressor = avg_ranks.iloc[0]["Compressor"] + + normalized = data.copy() + normalize_vars = [ + ("Compression Ratio [raw B / enc B]", "Relative CR"), + ("MAE", "Relative MAE"), + ("DSSIM", "Relative DSSIM"), + ("Max Absolute Error", "Relative MaxAbsError"), + ] + # Avoid negative values. By default, DSSIM is in the range [-1, 1]. + normalized["DSSIM"] = normalized["DSSIM"] + 1.0 + + def get_normalizer(row): + return normalized[ + (data["Compressor"] == best_compressor) + & (data["Variable"] == row["Variable"]) + & (data["Error Bound"] == bound_normalize) + ][col].item() + + for col, new_col in normalize_vars: + normalized[new_col] = normalized.apply( + lambda x: x[col] / get_normalizer(x), + axis=1, + ) + + return normalized + + +def plot_per_variable_metrics( + basepath: Path, plots_path: Path, all_results: pd.DataFrame +): + """Creates all the plots which only depend on a single variable.""" + for dataset in all_results["Dataset"].unique(): + df = all_results[all_results["Dataset"] == dataset] + dataset_plots_path = plots_path / dataset + dataset_plots_path.mkdir(parents=True, exist_ok=True) + + # For each variable and compressor, plot the input, output, and error fields. + variables = df["Variable"].unique() + for var in variables: + for dist_metric in ["Max Absolute Error", "MAE"]: + metric_name = dist_metric.lower().replace(" ", "_") + if df[df["Variable"] == var][dist_metric].isnull().all(): + continue + plot_variable_rd_curve( + df[df["Variable"] == var], + dataset_plots_path / f"{var}_compression_ratio_{metric_name}.pdf", + distortion_metric=dist_metric, + ) + + error_bounds = df[df["Variable"] == var]["Error Bound"].unique() + for err_bound in error_bounds: + compressors = df[ + (df["Variable"] == var) & (df["Error Bound"] == err_bound) + ]["Compressor"].unique() + + err_bound_path = dataset_plots_path / err_bound + err_bound_path.mkdir(parents=True, exist_ok=True) + for comp in compressors: + print(f"Plotting {var} error for {comp}...") + plot_variable_error( + basepath, + dataset, + err_bound, + comp, + var, + err_bound_path / f"{var}_{comp}.png", + ) + + +def plot_variable_error(basepath, dataset_name, error_bound, compressor, var, outfile): + if outfile.exists(): + # These plots can be quite expensive to generate, so we skip if they already exist. + return + + compressed = ( + basepath + / ".." + / "compressor" + / "compressed-datasets" + / dataset_name + / error_bound + / compressor + / "decompressed.zarr" + ) + input = ( + basepath + / ".." + / "data-loader" + / "datasets" + / dataset_name + / "standardized.zarr" + ) + + ds = xr.open_dataset(input, chunks=dict(), engine="zarr").compute() + ds_new = xr.open_dataset(compressed, chunks=dict(), engine="zarr").compute() + ds, ds_new = ds[var], ds_new[var] + + plotter = PLOTTERS.get(dataset_name, None) + if plotter: + plotter().plot(ds, ds_new, dataset_name, compressor, var, outfile) + else: + print(f"No plotter found for dataset {dataset_name}") + + +def plot_variable_rd_curve(df, outfile, distortion_metric): + plt.figure(figsize=(8, 6)) + compressors = df["Compressor"].unique() + for comp in compressors: + compressor_data = df[df["Compressor"] == comp] + sorting_ixs = np.argsort(compressor_data["Compression Ratio [raw B / enc B]"]) + compr_ratio = [ + compressor_data["Compression Ratio [raw B / enc B]"].iloc[i] + for i in sorting_ixs + ] + distortion = [compressor_data[distortion_metric].iloc[i] for i in sorting_ixs] + color, linestyle = COMPRESSOR2LINEINFO[comp] + plt.plot( + compr_ratio, + distortion, + label=COMPRESSOR2LEGEND_NAME[comp], + marker="s", + color=color, + linestyle=linestyle, + linewidth=4, + markersize=8, + ) + + plt.xlabel("Compression Ratio [raw B / enc B]", fontsize=14) + plt.xscale("log") + if distortion_metric != "PSNR": + # PSNR is already on log scale. + plt.yscale("log") + plt.ylabel(distortion_metric, fontsize=14) + + plt.legend( + title="Compressor", + fontsize=10, + title_fontsize=12, + ) + plt.tick_params( + axis="both", + which="major", + labelsize=14, + length=12, + direction="in", + top=True, + right=True, + ) + plt.tick_params( + axis="both", + which="minor", + length=6, + direction="in", + top=True, + right=True, + ) + + plt.tight_layout() + plt.savefig(outfile, dpi=300) + plt.close() + + +def plot_aggregated_rd_curve( + normalized_df, + outfile, + compression_metric, + distortion_metric, + agg="median", + bound_names=["low", "mid", "high"], +): + plt.figure(figsize=(8, 6)) + compressors = normalized_df["Compressor"].unique() + agg_distortion = normalized_df.groupby(["Error Bound", "Compressor"])[ + [compression_metric, distortion_metric] + ].agg(agg) + for comp in compressors: + compr_ratio = [ + agg_distortion.loc[(bound, comp), compression_metric] + for bound in bound_names + ] + distortion = [ + agg_distortion.loc[(bound, comp), distortion_metric] + for bound in bound_names + ] + color, linestyle = COMPRESSOR2LINEINFO[comp] + plt.plot( + compr_ratio, + distortion, + label=COMPRESSOR2LEGEND_NAME[comp], + marker="s", + color=color, + linestyle=linestyle, + linewidth=4, + markersize=8, + ) + + plt.xlabel(f"{agg.title()} {compression_metric}", fontsize=14) + plt.xscale("log") + if "PSNR" not in distortion_metric: + # PSNR is already on log scale. + plt.yscale("log") + plt.ylabel(f"{agg.title()} {distortion_metric}", fontsize=14) + + plt.legend( + title="Compressor", + fontsize=10, + title_fontsize=12, + ) + plt.tick_params( + axis="both", + which="major", + labelsize=14, + length=12, + direction="in", + top=True, + right=True, + ) + plt.tick_params( + axis="both", + which="minor", + length=6, + direction="in", + top=True, + right=True, + ) + + if "MAE" in distortion_metric: + plt.legend( + title="Compressor", + loc="upper right", + bbox_to_anchor=(0.95, 0.6), + fontsize=10, + title_fontsize=12, + ) + plt.xlabel("Median Compression Ratio Relative to SZ3", fontsize=14) + plt.ylabel("Median Mean Absolute Error Relative to SZ3", fontsize=14) + # Add an arrow pointing into the lower right corner + plt.annotate( + "", + xy=(0.97, 0.05), + xycoords="axes fraction", + xytext=(-60, 50), + textcoords="offset points", + arrowprops=dict(arrowstyle="-|>", color="grey", lw=2), + ) + plt.text( + 0.85, + 0.08, + "Better", + transform=plt.gca().transAxes, + fontsize=14, + fontweight="bold", + color="grey", + ha="center", + ) + + plt.tight_layout() + plt.savefig(outfile, dpi=300) + plt.close() + + +def plot_bound_violations(df, bound_names, outfile): + fig, axs = plt.subplots(1, 3, figsize=(len(bound_names) * 6, 6), sharey=True) + + for i, bound_name in enumerate(bound_names): + df_bound = df[df["Error Bound"] == bound_name] + pass_fail = df_bound.pivot( + index="Compressor", columns="Variable", values="Satisfies Bound (Passed)" + ) + pass_fail = pass_fail.astype(np.float32) + fraction_fail = df_bound.pivot( + index="Compressor", columns="Variable", values="Satisfies Bound (Value)" + ) + annotations = fraction_fail.map( + lambda x: "{:.2f}".format(x * 100) if x * 100 >= 0.01 else "<0.01" + ) + annotations[fraction_fail == 0.0] = "" + sns.heatmap( + pass_fail, + cbar=False, + cmap="vlag_r", + annot=annotations, + fmt="s", + linewidths=0.5, + ax=axs[i], + ) + axs[i].set_title(f"Bound: {bound_name}") + if i != 0: + axs[i].set_ylabel("") + + fig.tight_layout() + fig.savefig(outfile, dpi=300) + plt.close() + + +if __name__ == "__main__": + plot_metrics(basepath=Path()) diff --git a/src/climatebenchpress/compressor/plotting/variable_plotters.py b/src/climatebenchpress/compressor/plotting/variable_plotters.py new file mode 100644 index 0000000..ee202ff --- /dev/null +++ b/src/climatebenchpress/compressor/plotting/variable_plotters.py @@ -0,0 +1,227 @@ +from abc import ABC, abstractmethod + +import cartopy.crs as ccrs +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np + + +class Plotter(ABC): + datasets: list[str] + + def __init__(self): + self.projection = ccrs.Robinson() + + @abstractmethod + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + pass + + def plot(self, ds, ds_new, dataset_name, compressor, var, outfile): + fig, ax = plt.subplots( + nrows=1, + ncols=3, + figsize=(20, 7), + subplot_kw={"projection": self.projection}, + ) + self.plot_fields(fig, ax, ds, ds_new, dataset_name, var) + ax[0].coastlines() + ax[1].coastlines() + ax[2].coastlines() + ax[0].set_title("Original Dataset") + ax[1].set_title("Compressed Dataset") + ax[2].set_title("Error") + fig.suptitle(f"{var} Error for {dataset_name} ({compressor})") + fig.tight_layout() + fig.savefig(outfile, dpi=300) + plt.close() + + +class CmipAtmosPlotter(Plotter): + datasets = ["cmip6-access-ta-tiny", "cmip6-access-ta"] + + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0, plev=3) + ds.isel(**selector).plot(ax=ax[0], transform=ccrs.PlateCarree()) + ds_new.isel(**selector).plot( + ax=ax[1], transform=ccrs.PlateCarree(), robust=True + ) + error = ds.isel(**selector) - ds_new.isel(**selector) + error.plot(ax=ax[2], transform=ccrs.PlateCarree(), rasterized=True) + + +class CmipOceanPlotter(Plotter): + datasets = ["cmip6-access-tos-tiny", "cmip6-access-tos"] + + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + pcm0 = ax[0].pcolormesh( + ds.longitude.values, + ds.latitude.values, + ds.isel(time=0).values.squeeze(), + transform=ccrs.PlateCarree(), + shading="auto", + cmap="coolwarm", + rasterized=True, + ) + fig.colorbar( + pcm0, ax=ax[0], orientation="vertical", fraction=0.046, pad=0.04 + ).set_label("degC") + + pcm1 = ax[1].pcolormesh( + ds_new.longitude.values, + ds_new.latitude.values, + ds_new.isel(time=0).values.squeeze(), + transform=ccrs.PlateCarree(), + shading="auto", + cmap="coolwarm", + rasterized=True, + ) + fig.colorbar( + pcm1, ax=ax[1], orientation="vertical", fraction=0.046, pad=0.04 + ).set_label("degC") + + error = ds.isel(time=0) - ds_new.isel(time=0) + pcm2 = ax[2].pcolormesh( + ds.longitude.values, + ds.latitude.values, + error.values.squeeze(), + transform=ccrs.PlateCarree(), + shading="auto", + cmap="coolwarm", + rasterized=True, + ) + fig.colorbar( + pcm2, ax=ax[2], orientation="vertical", fraction=0.046, pad=0.04 + ).set_label("degC") + + +class Era5Plotter(Plotter): + datasets = ["era5-tiny", "era5"] + + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0) + error = ds.isel(**selector) - ds_new.isel(**selector) + + # Instead of using the inbuilt xarray plot method, we are manually doing + # the projection and calling pcolormesh. By doing so we can avoid having + # to do the projection three times and only have to do it once and re-use + # it between plots. + lons = ds.isel(**selector).longitude.values + lats = ds.isel(**selector).latitude.values + lon_grid, lat_grid = np.meshgrid(lons, lats) + xys = self.projection.transform_points(ccrs.PlateCarree(), lon_grid, lat_grid) + x, y = xys[..., 0], xys[..., 1] + # Wind variable plots coolwarm because they lie around 0 and change in sign + # signifies change in wind direction. + cmap = "coolwarm" if var.startswith("10m") else "viridis" + c1 = ax[0].pcolormesh(x, y, ds.isel(**selector).values.squeeze(), cmap=cmap) + c2 = ax[1].pcolormesh( + x, + y, + ds_new.isel(**selector).values.squeeze(), + cmap=cmap, + rasterized=True, + ) + c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") + for i, c in enumerate([c1, c2, c3]): + fig.colorbar(c, ax=ax[i], shrink=0.6) + + +class NextGEMSPlotter(Plotter): + datasets = ["nextgems-icon-tiny", "nextgems-icon"] + + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0) + error = ds.isel(**selector) - ds_new.isel(**selector) + + lons = ds.isel(**selector).lon.values + lats = ds.isel(**selector).lat.values + lon_grid, lat_grid = np.meshgrid(lons, lats) + xys = self.projection.transform_points(ccrs.PlateCarree(), lon_grid, lat_grid) + x, y = xys[..., 0], xys[..., 1] + + cmap = "Blues" + max_val = max( + ds.isel(**selector).max().values.item(), + ds_new.isel(**selector).max().values.item(), + ) + color_norm = mcolors.LogNorm(vmin=1e-12, vmax=max_val) if var == "pr" else None + # Avoid zero values for log transformation for precipitation + offset = 1e-12 if var == "pr" else 0 + c1 = ax[0].pcolormesh( + x, + y, + ds.isel(**selector).values.squeeze() + offset, + norm=color_norm, + cmap=cmap, + rasterized=True, + ) + c2 = ax[1].pcolormesh( + x, + y, + ds_new.isel(**selector).values.squeeze() + offset, + norm=color_norm, + cmap=cmap, + rasterized=True, + ) + c3 = ax[2].pcolormesh(x, y, error.values.squeeze(), cmap="coolwarm") + for i, c in enumerate([c1, c2, c3]): + fig.colorbar(c, ax=ax[i], shrink=0.6) + + +class CamsPlotter(Plotter): + datasets = ["cams-nitrogen-dioxide-tiny", "cams-nitrogen-dioxide"] + + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(valid_time=0, pressure_level=3) + in_min = ds.isel(**selector).min().values.item() + in_max = ds.isel(**selector).max().values.item() + out_min = ds_new.isel(**selector).min().values.item() + out_max = ds_new.isel(**selector).max().values.item() + vmin, vmax = min(in_min, out_min), max(in_max, out_max) + vmin = max(vmin, 1e-14) # Avoid zero values for log transformation + color_norm = mcolors.LogNorm(vmin=vmin, vmax=vmax) + ds.isel(**selector).plot( + ax=ax[0], + transform=ccrs.PlateCarree(), + norm=color_norm, + cmap="gist_earth", + rasterized=True, + ) + ds_new.isel(**selector).plot( + ax=ax[1], + transform=ccrs.PlateCarree(), + norm=color_norm, + cmap="gist_earth", + rasterized=True, + ) + error = ds.isel(**selector) - ds_new.isel(**selector) + error.plot(ax=ax[2], transform=ccrs.PlateCarree()) + + +class EsaBiomassPlotter(Plotter): + datasets = ["esa-biomass-cci-tiny", "esa-biomass-cci"] + + def plot_fields(self, fig, ax, ds, ds_new, dataset_name, var): + selector = dict(time=0) + ds.isel(**selector).plot(ax=ax[0]) + ds_new.isel(**selector).plot(ax=ax[1]) + error = ds.isel(**selector) - ds_new.isel(**selector) + error.plot(ax=ax[2], rasterized=True) + ax[0].set_title("Original Dataset") + ax[1].set_title("Compressed Dataset") + ax[2].set_title("Error") + + +plotter_clss: list[type[Plotter]] = [ + CamsPlotter, + CmipAtmosPlotter, + CmipOceanPlotter, + Era5Plotter, + EsaBiomassPlotter, + NextGEMSPlotter, +] +PLOTTERS: dict[str, type[Plotter]] = dict() +for plotter_cls in plotter_clss: + for dataset in plotter_cls.datasets: + assert dataset not in PLOTTERS, f"Duplicate dataset found: {dataset}" + PLOTTERS[dataset] = plotter_cls