From 52686d90e445654abfdff83f54ae06222e39185c Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Thu, 3 Oct 2024 23:49:42 +0200 Subject: [PATCH] fix(datasets) Fix the scale of value axis when plotting in absolute sizes (#4255) --- .../comparison_label_distribution.py | 65 +++++++++++++++---- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/datasets/flwr_datasets/visualization/comparison_label_distribution.py b/datasets/flwr_datasets/visualization/comparison_label_distribution.py index 8a15452fb86d..17b9a9aec251 100644 --- a/datasets/flwr_datasets/visualization/comparison_label_distribution.py +++ b/datasets/flwr_datasets/visualization/comparison_label_distribution.py @@ -15,7 +15,7 @@ """Comparison of label distribution plotting.""" -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union import matplotlib.colors as mcolors import matplotlib.pyplot as plt @@ -33,10 +33,10 @@ def plot_comparison_label_distribution( partitioner_list: list[Partitioner], label_name: Union[str, list[str]], - plot_type: str = "bar", - size_unit: str = "percent", - max_num_partitions: Optional[Union[int]] = 30, - partition_id_axis: str = "y", + plot_type: Literal["bar", "heatmap"] = "bar", + size_unit: Literal["percent", "absolute"] = "percent", + max_num_partitions: Optional[int] = 30, + partition_id_axis: Literal["x", "y"] = "y", figsize: Optional[tuple[float, float]] = None, subtitle: str = "Comparison of Per Partition Label Distribution", titles: Optional[list[str]] = None, @@ -55,14 +55,14 @@ def plot_comparison_label_distribution( List of partitioners to be compared. label_name : Union[str, List[str]] Column name or list of column names identifying labels for each partitioner. - plot_type : str + plot_type : Literal["bar", "heatmap"] Type of plot, either "bar" or "heatmap". - size_unit : str + size_unit : Literal["percent", "absolute"] "absolute" for raw counts, or "percent" to normalize values to 100%. max_num_partitions : Optional[int] Maximum number of partitions to include in the plot. If None, all partitions are included. - partition_id_axis : str + partition_id_axis : Literal["x", "y"] Axis on which the partition IDs will be marked, either "x" or "y". figsize : Optional[Tuple[float, float]] Size of the figure. If None, a default size is calculated. @@ -151,7 +151,10 @@ def plot_comparison_label_distribution( f"{type(label_name)}" ) figsize = _initialize_comparison_figsize(figsize, num_partitioners) - fig, axes = plt.subplots(1, num_partitioners, layout="constrained", figsize=figsize) + axes_sharing = _initialize_axis_sharing(size_unit, plot_type, partition_id_axis) + fig, axes = plt.subplots( + 1, num_partitioners, layout="constrained", figsize=figsize, **axes_sharing + ) if titles is None: titles = ["" for _ in range(num_partitioners)] @@ -201,11 +204,12 @@ def plot_comparison_label_distribution( axis.set_xlabel("") axis.set_ylabel("") axis.set_title(titles[idx]) - for axis in axes[1:]: - axis.set_yticks([]) + _set_tick_on_value_axes(axes, partition_id_axis, size_unit) # Set up figure xlabel and ylabel - xlabel, ylabel = _initialize_comparison_xy_labels(plot_type, partition_id_axis) + xlabel, ylabel = _initialize_comparison_xy_labels( + plot_type, size_unit, partition_id_axis + ) fig.supxlabel(xlabel) fig.supylabel(ylabel) fig.suptitle(subtitle) @@ -226,11 +230,13 @@ def _initialize_comparison_figsize( def _initialize_comparison_xy_labels( - plot_type: str, partition_id_axis: str + plot_type: Literal["bar", "heatmap"], + size_unit: Literal["percent", "absolute"], + partition_id_axis: Literal["x", "y"], ) -> tuple[str, str]: if plot_type == "bar": xlabel = "Partition ID" - ylabel = "Class distribution" + ylabel = "Class distribution" if size_unit == "percent" else "Class Count" elif plot_type == "heatmap": xlabel = "Partition ID" ylabel = "Label" @@ -243,3 +249,34 @@ def _initialize_comparison_xy_labels( xlabel, ylabel = ylabel, xlabel return xlabel, ylabel + + +def _initialize_axis_sharing( + size_unit: Literal["percent", "absolute"], + plot_type: Literal["bar", "heatmap"], + partition_id_axis: Literal["x", "y"], +) -> dict[str, bool]: + # Do not intervene when the size_unit is percent and plot_type is heatmap + if size_unit == "percent": + return {} + if plot_type == "heatmap": + return {} + if partition_id_axis == "x": + return {"sharey": True} + if partition_id_axis == "y": + return {"sharex": True} + return {"sharex": False, "sharey": False} + + +def _set_tick_on_value_axes( + axes: list[Axes], + partition_id_axis: Literal["x", "y"], + size_unit: Literal["percent", "absolute"], +) -> None: + if partition_id_axis == "x" and size_unit == "absolute": + # Exclude this case due to sharing of y-axis (and thus y-ticks) + # They must remain set and the number are displayed only on the first plot + pass + else: + for axis in axes[1:]: + axis.set_yticks([])