diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 7bd81c1b2c..9be89ed8d5 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -280,8 +280,9 @@ def plot_unit_counts( ncol = len(columns) width = 1 / (ncol + 2) - colors = get_some_colors(columns, color_engine="auto", map_name="hot") - colors["num_well_detected"] = "green" + if colors is None: + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" case_colors = study.get_colors(levels_to_group_by=levels_to_group_by) @@ -320,6 +321,8 @@ def plot_unit_counts( ymax = max(ymax, y + yerr[0]) if with_rectangle: + if revert_bad: + ymin = 0 spacing = width * 0.3 for i, key in enumerate(keys_mapping): rect = plt.Rectangle( @@ -413,6 +416,7 @@ def _plot_performances_vs_metric( with_sigmoid_fit=False, show_average_by_bin=True, scatter_size=4, + scatter_alpha=1.0, num_bin_average=20, axs=None, ): @@ -517,7 +521,7 @@ def _plot_performances_vs_metric( all_xs = np.concatenate(all_xs) all_ys = np.concatenate(all_ys) - ax.scatter(all_xs, all_ys, marker=".", label=label, color=color, s=scatter_size) + ax.scatter(all_xs, all_ys, marker=".", label=label, color=color, s=scatter_size, alpha=scatter_alpha) ax.set_ylabel(performance_name) ax.set_ylim(-0.05, 1.05) @@ -542,6 +546,7 @@ def plot_performances_vs_snr( with_sigmoid_fit=False, show_average_by_bin=True, scatter_size=4, + scatter_alpha=1.0, num_bin_average=20, axs=None, ): @@ -572,6 +577,8 @@ def plot_performances_vs_snr( Instead of the sigmoid an average by bins can be plotted. scatter_size : int, default 4 scatter size + scatter_alpha : float, default 1.0 + scatter alpha num_bin_average : int, default 2 Num bin for average axs : matplotlib.axes.Axes | None, default: None @@ -596,6 +603,7 @@ def plot_performances_vs_snr( with_sigmoid_fit=with_sigmoid_fit, show_average_by_bin=show_average_by_bin, scatter_size=scatter_size, + scatter_alpha=scatter_alpha, num_bin_average=num_bin_average, axs=axs, ) @@ -613,6 +621,7 @@ def plot_performances_vs_firing_rate( with_sigmoid_fit=False, show_average_by_bin=True, scatter_size=4, + scatter_alpha=1.0, num_bin_average=20, axs=None, ): @@ -643,6 +652,8 @@ def plot_performances_vs_firing_rate( Instead of the sigmoid an average by bins can be plotted. scatter_size : int, default 4 scatter size + scatter_alpha : float, default 1.0 + scatter alpha num_bin_average : int, default 2 Num bin for average axs : matplotlib.axes.Axes | None, default: None @@ -667,6 +678,7 @@ def plot_performances_vs_firing_rate( with_sigmoid_fit=with_sigmoid_fit, show_average_by_bin=show_average_by_bin, scatter_size=scatter_size, + scatter_alpha=scatter_alpha, num_bin_average=num_bin_average, axs=axs, ) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 6a0e043932..55762782f7 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -41,6 +41,8 @@ class UnitWaveformsWidget(BaseWidget): displayed per waveform, (matplotlib backend) scale : float, default: 1 Scale factor for the waveforms/templates (matplotlib backend) + abs_y_scale : None | float, default None + Absolut y_scale that override the auto y scale mechanism widen_narrow_scale : float, default: 1 Scale factor for the x-axis of the waveforms/templates (matplotlib backend) axis_equal : bool, default: False @@ -93,6 +95,7 @@ def __init__( sparsity=None, ncols=5, scale=1, + abs_y_scale=None, widen_narrow_scale=1, lw_waveforms=1, lw_templates=2, @@ -207,6 +210,7 @@ def __init__( unit_colors=unit_colors, channel_locations=channel_locations, scale=scale, + abs_y_scale=abs_y_scale, widen_narrow_scale=widen_narrow_scale, templates=templates, templates_shading=templates_shading, @@ -256,6 +260,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( dp.templates, dp.channel_locations, dp.nbefore, dp.x_offset_units, dp.widen_narrow_scale ) + if dp.abs_y_scale is not None: + y_scale = dp.abs_y_scale for i, unit_id in enumerate(dp.unit_ids): if dp.same_axis: