Skip to content

Commit 35ba589

Browse files
committed
Improve dSSIM rd curves
1 parent b031d13 commit 35ba589

1 file changed

Lines changed: 27 additions & 4 deletions

File tree

src/climatebenchpress/compressor/plotting/plot_metrics.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def plot_metrics(
146146
# )
147147

148148
df = _rename_compressors(df)
149-
normalized_df = _normalize(df)
149+
normalized_df, normalized_mean_std = _normalize(df)
150150
_plot_bound_violations(
151151
normalized_df, bound_names, plots_path / "bound_violations.pdf"
152152
)
@@ -164,6 +164,7 @@ def plot_metrics(
164164
normalized_df,
165165
compression_metric="Relative CR",
166166
distortion_metric=metric,
167+
mean_std=normalized_mean_std[metric],
167168
outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf",
168169
agg="mean",
169170
bound_names=bound_names,
@@ -173,6 +174,7 @@ def plot_metrics(
173174
normalized_df,
174175
compression_metric="Relative CR",
175176
distortion_metric=metric,
177+
mean_std=normalized_mean_std[metric],
176178
outfile=plots_path
177179
/ f"full_rd_curve_{metric.lower().replace(' ', '_')}.pdf",
178180
agg="mean",
@@ -224,6 +226,7 @@ def _normalize(data):
224226
dssim_unreliable = normalized["Variable"].isin(["ta", "tos"])
225227
normalized.loc[dssim_unreliable, "DSSIM"] = np.nan
226228

229+
normalize_mean_std = dict()
227230
for col, new_col in normalize_vars:
228231
mean_std = dict()
229232
for var in variables:
@@ -239,7 +242,9 @@ def _normalize(data):
239242
axis=1,
240243
)
241244

242-
return normalized
245+
normalize_mean_std[new_col] = mean_std
246+
247+
return normalized, normalize_mean_std
243248

244249

245250
def _plot_per_variable_metrics(
@@ -434,6 +439,7 @@ def _plot_aggregated_rd_curve(
434439
normalized_df,
435440
compression_metric,
436441
distortion_metric,
442+
mean_std,
437443
outfile: None | Path = None,
438444
agg="median",
439445
bound_names=["low", "mid", "high"],
@@ -546,10 +552,27 @@ def _plot_aggregated_rd_curve(
546552

547553
arrow_color = "black"
548554
if "dSSIM" in distortion_metric:
555+
# Annotate dSSIM = 1, accounting for the normalization
556+
dssim_one = getattr(np, f"nan{agg}")(
557+
[(1 - ms[0]) / ms[1] for ms in mean_std.values()]
558+
)
559+
plt.axhline(dssim_one, c="k", ls="--")
560+
plt.text(
561+
np.percentile(plt.xlim(), 63),
562+
dssim_one,
563+
"dSSIM = 1",
564+
fontsize=16,
565+
fontweight="bold",
566+
color="black",
567+
ha="center",
568+
va="center",
569+
bbox=dict(edgecolor="none", facecolor="w", alpha=0.85),
570+
)
571+
549572
# Add an arrow pointing into the top right corner
550573
plt.annotate(
551574
"",
552-
xy=(0.95, 0.95),
575+
xy=(0.95, 0.875 if remove_outliers else 0.9),
553576
xycoords="axes fraction",
554577
xytext=(-60, -50),
555578
textcoords="offset points",
@@ -562,7 +585,7 @@ def _plot_aggregated_rd_curve(
562585
# Attach the text to the lower left of the arrow
563586
plt.text(
564587
0.83,
565-
0.92,
588+
0.845 if remove_outliers else 0.87,
566589
"Better",
567590
transform=plt.gca().transAxes,
568591
fontsize=16,

0 commit comments

Comments
 (0)