Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ widgets = [
"matplotlib",
"ipympl",
"ipywidgets",
"sortingview>=0.12.0",
"figpack",
"figpack-spike-sorting"
]

metrics = [
Expand Down Expand Up @@ -183,6 +184,9 @@ test = [
"skops",
"huggingface_hub",

# widgets
"sortingview",

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
Expand Down
32 changes: 22 additions & 10 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AmplitudesWidget(BaseRasterWidget):
If equal to n, each nth spike is kept for plotting.
hide_unit_selector : bool, default: False
If True the unit selector is not displayed
(sortingview backend)
(figpack backend)
plot_histogram : bool, default: False
If True, an histogram of the amplitudes is plotted on the right axis
(matplotlib backend)
Expand Down Expand Up @@ -94,10 +94,10 @@ def __init__(
segment_indices = validate_segment_indices(segment_indices, sorting)

# Check for SortingView backend
is_sortingview = backend == "sortingview"
is_figpack = backend in ("figpack", "sortingview")

# For SortingView, ensure we're only using a single segment
if is_sortingview and len(segment_indices) > 1:
if is_figpack and len(segment_indices) > 1:
warn("SortingView backend currently supports only single segment. Using first segment.")
segment_indices = [segment_indices[0]]

Expand Down Expand Up @@ -158,8 +158,8 @@ def __init__(
scatter_decimate=scatter_decimate,
)

# If using SortingView, extract just the first segment's data as flat dicts
if is_sortingview:
# If using Figpack, extract just the first segment's data as flat dicts
if is_figpack:
first_segment = segment_indices[0]
plot_data["spike_train_data"] = {first_segment: spiketrains_by_segment[first_segment]}
plot_data["y_axis_data"] = {first_segment: amplitudes_by_segment[first_segment]}
Expand All @@ -172,27 +172,39 @@ def __init__(
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

if not use_sortingview:
data_plot["hide_unit_selector"] = True # force hide unit selector for figpack
dp = to_attr(data_plot)

unit_ids = make_serializable(dp.unit_ids)

sa_items = [
vv.SpikeAmplitudesItem(
vv_views.SpikeAmplitudesItem(
unit_id=u,
spike_times_sec=dp.spike_train_data[u].astype("float32"),
spike_amplitudes=dp.y_axis_data[u].astype("float32"),
)
for u in unit_ids
]

self.view = vv.SpikeAmplitudes(
self.view = vv_views.SpikeAmplitudes(
start_time_sec=0,
end_time_sec=np.sum(dp.durations),
plots=sa_items,
hide_unit_selector=dp.hide_unit_selector,
# hide_unit_selector=dp.hide_unit_selector,
)

self.url = handle_display_and_url(self, self.view, **backend_kwargs)
18 changes: 14 additions & 4 deletions src/spikeinterface/widgets/autocorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax.set_title(str(unit_id))

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

dp = to_attr(data_plot)
unit_ids = make_serializable(dp.unit_ids)
Expand All @@ -52,14 +62,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
for j in range(i, len(unit_ids)):
if i == j:
ac_items.append(
vv.AutocorrelogramItem(
vv_views.AutocorrelogramItem(
unit_id=unit_ids[i],
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
bin_counts=dp.correlograms[i, j].astype("int32"),
)
)

self.view = vv.Autocorrelograms(autocorrelograms=ac_items)
self.view = vv_views.Autocorrelograms(autocorrelograms=ac_items)

self.url = handle_display_and_url(self, self.view, **backend_kwargs)

Expand Down
16 changes: 13 additions & 3 deletions src/spikeinterface/widgets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def set_default_plotter_backend(backend):
"figsize": "Size of matplotlib figure, default: None",
"figtitle": "The figure title, default: None",
},
"sortingview": {
"figpack": {
"generate_url": "If True, the figurl URL is generated and printed, default: True",
"display": "If True and in jupyter notebook/lab, the widget is displayed in the cell, default: True.",
"figlabel": "The figurl figure label, default: None",
"height": "The height of the sortingview View in jupyter, default: None",
"inline": "If True, the widget is displayed inline in the cell, default: None",
"height": "The height of the figpack View in jupyter, default: None",
},
"ipywidgets": {
"width_cm": "Width of the figure in cm, default: 10",
Expand All @@ -47,14 +48,23 @@ def set_default_plotter_backend(backend):
},
"ephyviewer": {},
"spikeinterface_gui": {},
# deprecated
"sortingview": {
"generate_url": "If True, the figurl URL is generated and printed, default: True",
"display": "If True and in jupyter notebook/lab, the widget is displayed in the cell, default: True.",
"figlabel": "The figurl figure label, default: None",
"height": "The height of the sortingview View in jupyter, default: None",
},
}

default_backend_kwargs = {
"matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None},
"sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None},
"figpack": {"generate_url": True, "display": True, "figlabel": None, "inline": None, "height": None},
"ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None},
"ephyviewer": {},
"spikeinterface_gui": {},
# deprecated
"sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None},
}


Expand Down
24 changes: 16 additions & 8 deletions src/spikeinterface/widgets/crosscorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CrossCorrelogramsWidget(BaseWidget):
unit_ids list or None, default: None
List of unit ids
min_similarity_for_correlograms : float, default: 0.2
For sortingview backend. Threshold for computing pair-wise cross-correlograms.
For figpack backend. Threshold for computing pair-wise cross-correlograms.
If template similarity between two units is below this threshold, the cross-correlogram is not displayed.
For auto-correlograms plot, this is automatically set to None.
window_ms : float, default: 100.0
Expand All @@ -30,7 +30,7 @@ class CrossCorrelogramsWidget(BaseWidget):
Bin size in ms. If correlograms are already computed (e.g. with SortingAnalyzer),
this argument is ignored
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
unit_colors : dict | None, default: None
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
Expand Down Expand Up @@ -127,8 +127,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.axes[-1, i].set_xlabel("CCG (ms)")

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

dp = to_attr(data_plot)

Expand All @@ -144,14 +154,12 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
for j in range(i, len(unit_ids)):
if similarity[i, j] >= dp.min_similarity_for_correlograms:
cc_items.append(
vv.CrossCorrelogramItem(
vv_views.CrossCorrelogramItem(
unit_id1=unit_ids[i],
unit_id2=unit_ids[j],
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
bin_counts=dp.correlograms[i, j].astype("int32"),
)
)

self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector)

self.view = vv_views.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector)
self.url = handle_display_and_url(self, self.view, **backend_kwargs)
32 changes: 22 additions & 10 deletions src/spikeinterface/widgets/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MetricsBaseWidget(BaseWidget):
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
include_metrics_data : bool, default: True
If True, metrics data are included in unit table
"""
Expand Down Expand Up @@ -258,8 +258,18 @@ def _update_ipywidget(self, change):
self.figure.canvas.flush_events()

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

dp = to_attr(data_plot)

Expand All @@ -275,7 +285,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
metrics_sv = []
for col in metric_names:
dtype = np.array(metrics.iloc[0][col]).dtype
metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str)
metric = vv_views.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str)
metrics_sv.append(metric)

units_m = []
Expand All @@ -290,8 +300,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
continue
values_skip_nans[k] = v

units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans))
v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv)
units_m.append(vv_views.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans))
v_metrics = vv_views.UnitMetricsGraph(units=units_m, metrics=metrics_sv)

if not dp.hide_unit_selector:
if dp.include_metrics_data:
Expand All @@ -301,12 +311,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
if col not in sorting_copy.get_property_keys():
sorting_copy.set_property(col, metrics[col].values)
# generate table with properties
v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names)
v_units_table = generate_unit_table_view(
sorting_copy, unit_properties=metric_names, use_sortingview=use_sortingview
)
else:
v_units_table = generate_unit_table_view(dp.sorting)
v_units_table = generate_unit_table_view(dp.sorting, use_sortingview=use_sortingview)

self.view = vv.Splitter(
direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics)
self.view = vv_base.Splitter(
direction="horizontal", item1=vv_base.LayoutItem(v_units_table), item2=vv_base.LayoutItem(v_metrics)
)
else:
self.view = v_metrics
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/quality_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class QualityMetricsWidget(MetricsBaseWidget):
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class BaseRasterWidget(BaseWidget):
y_ticks : dict | None, default: None
Ticks on y-axis, passed to `set_yticks`. If None, default ticks are used.
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
segment_boundary_kwargs : dict | None, default: None
Additional arguments for the segment boundary lines, passed to `matplotlib.axvline`
backend : str | None, default None
Expand Down
Loading
Loading