Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,9 @@ def plot_unit_counts(
ncol = len(columns)
width = 1 / (ncol + 2)

colors = get_some_colors(columns, color_engine="auto", map_name="hot")
colors["num_well_detected"] = "green"
if colors is None:
colors = get_some_colors(columns, color_engine="auto", map_name="hot")
colors["num_well_detected"] = "green"

case_colors = study.get_colors(levels_to_group_by=levels_to_group_by)

Expand Down Expand Up @@ -320,6 +321,8 @@ def plot_unit_counts(
ymax = max(ymax, y + yerr[0])

if with_rectangle:
if revert_bad:
ymin = 0
spacing = width * 0.3
for i, key in enumerate(keys_mapping):
rect = plt.Rectangle(
Expand Down Expand Up @@ -413,6 +416,7 @@ def _plot_performances_vs_metric(
with_sigmoid_fit=False,
show_average_by_bin=True,
scatter_size=4,
scatter_alpha=1.0,
num_bin_average=20,
axs=None,
):
Expand Down Expand Up @@ -517,7 +521,7 @@ def _plot_performances_vs_metric(
all_xs = np.concatenate(all_xs)
all_ys = np.concatenate(all_ys)

ax.scatter(all_xs, all_ys, marker=".", label=label, color=color, s=scatter_size)
ax.scatter(all_xs, all_ys, marker=".", label=label, color=color, s=scatter_size, alpha=scatter_alpha)
ax.set_ylabel(performance_name)

ax.set_ylim(-0.05, 1.05)
Expand All @@ -542,6 +546,7 @@ def plot_performances_vs_snr(
with_sigmoid_fit=False,
show_average_by_bin=True,
scatter_size=4,
scatter_alpha=1.0,
num_bin_average=20,
axs=None,
):
Expand Down Expand Up @@ -572,6 +577,8 @@ def plot_performances_vs_snr(
Instead of the sigmoid an average by bins can be plotted.
scatter_size : int, default 4
scatter size
scatter_alpha : float, default 1.0
scatter alpha
num_bin_average : int, default 2
Num bin for average
axs : matplotlib.axes.Axes | None, default: None
Expand All @@ -596,6 +603,7 @@ def plot_performances_vs_snr(
with_sigmoid_fit=with_sigmoid_fit,
show_average_by_bin=show_average_by_bin,
scatter_size=scatter_size,
scatter_alpha=scatter_alpha,
num_bin_average=num_bin_average,
axs=axs,
)
Expand All @@ -613,6 +621,7 @@ def plot_performances_vs_firing_rate(
with_sigmoid_fit=False,
show_average_by_bin=True,
scatter_size=4,
scatter_alpha=1.0,
num_bin_average=20,
axs=None,
):
Expand Down Expand Up @@ -643,6 +652,8 @@ def plot_performances_vs_firing_rate(
Instead of the sigmoid an average by bins can be plotted.
scatter_size : int, default 4
scatter size
scatter_alpha : float, default 1.0
scatter alpha
num_bin_average : int, default 2
Num bin for average
axs : matplotlib.axes.Axes | None, default: None
Expand All @@ -667,6 +678,7 @@ def plot_performances_vs_firing_rate(
with_sigmoid_fit=with_sigmoid_fit,
show_average_by_bin=show_average_by_bin,
scatter_size=scatter_size,
scatter_alpha=scatter_alpha,
num_bin_average=num_bin_average,
axs=axs,
)
Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/widgets/unit_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class UnitWaveformsWidget(BaseWidget):
displayed per waveform, (matplotlib backend)
scale : float, default: 1
Scale factor for the waveforms/templates (matplotlib backend)
abs_y_scale : None | float, default None
Absolut y_scale that override the auto y scale mechanism
widen_narrow_scale : float, default: 1
Scale factor for the x-axis of the waveforms/templates (matplotlib backend)
axis_equal : bool, default: False
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(
sparsity=None,
ncols=5,
scale=1,
abs_y_scale=None,
widen_narrow_scale=1,
lw_waveforms=1,
lw_templates=2,
Expand Down Expand Up @@ -207,6 +210,7 @@ def __init__(
unit_colors=unit_colors,
channel_locations=channel_locations,
scale=scale,
abs_y_scale=abs_y_scale,
widen_narrow_scale=widen_narrow_scale,
templates=templates,
templates_shading=templates_shading,
Expand Down Expand Up @@ -256,6 +260,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
xvectors, y_scale, y_offset, delta_x = get_waveforms_scales(
dp.templates, dp.channel_locations, dp.nbefore, dp.x_offset_units, dp.widen_narrow_scale
)
if dp.abs_y_scale is not None:
y_scale = dp.abs_y_scale

for i, unit_id in enumerate(dp.unit_ids):
if dp.same_axis:
Expand Down