@@ -146,7 +146,7 @@ def plot_metrics(
146146 # )
147147
148148 df = _rename_compressors (df )
149- normalized_df = _normalize (df )
149+ normalized_df , normalized_mean_std = _normalize (df )
150150 _plot_bound_violations (
151151 normalized_df , bound_names , plots_path / "bound_violations.pdf"
152152 )
@@ -164,6 +164,7 @@ def plot_metrics(
164164 normalized_df ,
165165 compression_metric = "Relative CR" ,
166166 distortion_metric = metric ,
167+ mean_std = normalized_mean_std [metric ],
167168 outfile = plots_path / f"rd_curve_{ metric .lower ().replace (' ' , '_' )} .pdf" ,
168169 agg = "mean" ,
169170 bound_names = bound_names ,
@@ -173,6 +174,7 @@ def plot_metrics(
173174 normalized_df ,
174175 compression_metric = "Relative CR" ,
175176 distortion_metric = metric ,
177+ mean_std = normalized_mean_std [metric ],
176178 outfile = plots_path
177179 / f"full_rd_curve_{ metric .lower ().replace (' ' , '_' )} .pdf" ,
178180 agg = "mean" ,
@@ -224,6 +226,7 @@ def _normalize(data):
224226 dssim_unreliable = normalized ["Variable" ].isin (["ta" , "tos" ])
225227 normalized .loc [dssim_unreliable , "DSSIM" ] = np .nan
226228
229+ normalize_mean_std = dict ()
227230 for col , new_col in normalize_vars :
228231 mean_std = dict ()
229232 for var in variables :
@@ -239,7 +242,9 @@ def _normalize(data):
239242 axis = 1 ,
240243 )
241244
242- return normalized
245+ normalize_mean_std [new_col ] = mean_std
246+
247+ return normalized , normalize_mean_std
243248
244249
245250def _plot_per_variable_metrics (
@@ -434,6 +439,7 @@ def _plot_aggregated_rd_curve(
434439 normalized_df ,
435440 compression_metric ,
436441 distortion_metric ,
442+ mean_std ,
437443 outfile : None | Path = None ,
438444 agg = "median" ,
439445 bound_names = ["low" , "mid" , "high" ],
@@ -546,10 +552,27 @@ def _plot_aggregated_rd_curve(
546552
547553 arrow_color = "black"
548554 if "dSSIM" in distortion_metric :
555+ # Annotate dSSIM = 1, accounting for the normalization
556+ dssim_one = getattr (np , f"nan{ agg } " )(
557+ [(1 - ms [0 ]) / ms [1 ] for ms in mean_std .values ()]
558+ )
559+ plt .axhline (dssim_one , c = "k" , ls = "--" )
560+ plt .text (
561+ np .percentile (plt .xlim (), 63 ),
562+ dssim_one ,
563+ "dSSIM = 1" ,
564+ fontsize = 16 ,
565+ fontweight = "bold" ,
566+ color = "black" ,
567+ ha = "center" ,
568+ va = "center" ,
569+ bbox = dict (edgecolor = "none" , facecolor = "w" , alpha = 0.85 ),
570+ )
571+
549572 # Add an arrow pointing into the top right corner
550573 plt .annotate (
551574 "" ,
552- xy = (0.95 , 0.95 ),
575+ xy = (0.95 , 0.875 if remove_outliers else 0.9 ),
553576 xycoords = "axes fraction" ,
554577 xytext = (- 60 , - 50 ),
555578 textcoords = "offset points" ,
@@ -562,7 +585,7 @@ def _plot_aggregated_rd_curve(
562585 # Attach the text to the lower left of the arrow
563586 plt .text (
564587 0.83 ,
565- 0.92 ,
588+ 0.845 if remove_outliers else 0.87 ,
566589 "Better" ,
567590 transform = plt .gca ().transAxes ,
568591 fontsize = 16 ,
0 commit comments