Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ widgets = [
"matplotlib",
"ipympl",
"ipywidgets",
"sortingview>=0.12.0",
"figpack"
]

metrics = [
Expand Down Expand Up @@ -183,6 +183,9 @@ test = [
"skops",
"huggingface_hub",

# widgets
"sortingview",

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
Expand Down
32 changes: 22 additions & 10 deletions src/spikeinterface/widgets/amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class AmplitudesWidget(BaseRasterWidget):
If equal to n, each nth spike is kept for plotting.
hide_unit_selector : bool, default: False
If True the unit selector is not displayed
(sortingview backend)
(figpack backend)
plot_histogram : bool, default: False
If True, an histogram of the amplitudes is plotted on the right axis
(matplotlib backend)
Expand Down Expand Up @@ -94,10 +94,10 @@ def __init__(
segment_indices = validate_segment_indices(segment_indices, sorting)

# Check for SortingView backend
is_sortingview = backend == "sortingview"
is_figpack = backend in ("figpack", "sortingview")

# For SortingView, ensure we're only using a single segment
if is_sortingview and len(segment_indices) > 1:
if is_figpack and len(segment_indices) > 1:
warn("SortingView backend currently supports only single segment. Using first segment.")
segment_indices = [segment_indices[0]]

Expand Down Expand Up @@ -158,8 +158,8 @@ def __init__(
scatter_decimate=scatter_decimate,
)

# If using SortingView, extract just the first segment's data as flat dicts
if is_sortingview:
# If using Figpack, extract just the first segment's data as flat dicts
if is_figpack:
first_segment = segment_indices[0]
plot_data["spike_train_data"] = {first_segment: spiketrains_by_segment[first_segment]}
plot_data["y_axis_data"] = {first_segment: amplitudes_by_segment[first_segment]}
Expand All @@ -172,27 +172,39 @@ def __init__(
BaseRasterWidget.__init__(self, **plot_data, backend=backend, **backend_kwargs)

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

if not use_sortingview:
data_plot["hide_unit_selector"] = True # force hide unit selector for figpack
dp = to_attr(data_plot)

unit_ids = make_serializable(dp.unit_ids)

sa_items = [
vv.SpikeAmplitudesItem(
vv_views.SpikeAmplitudesItem(
unit_id=u,
spike_times_sec=dp.spike_train_data[u].astype("float32"),
spike_amplitudes=dp.y_axis_data[u].astype("float32"),
)
for u in unit_ids
]

self.view = vv.SpikeAmplitudes(
self.view = vv_views.SpikeAmplitudes(
start_time_sec=0,
end_time_sec=np.sum(dp.durations),
plots=sa_items,
hide_unit_selector=dp.hide_unit_selector,
# hide_unit_selector=dp.hide_unit_selector,
)

self.url = handle_display_and_url(self, self.view, **backend_kwargs)
18 changes: 14 additions & 4 deletions src/spikeinterface/widgets/autocorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax.set_title(str(unit_id))

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

dp = to_attr(data_plot)
unit_ids = make_serializable(dp.unit_ids)
Expand All @@ -52,14 +62,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
for j in range(i, len(unit_ids)):
if i == j:
ac_items.append(
vv.AutocorrelogramItem(
vv_views.AutocorrelogramItem(
unit_id=unit_ids[i],
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
bin_counts=dp.correlograms[i, j].astype("int32"),
)
)

self.view = vv.Autocorrelograms(autocorrelograms=ac_items)
self.view = vv_views.Autocorrelograms(autocorrelograms=ac_items)

self.url = handle_display_and_url(self, self.view, **backend_kwargs)

Expand Down
16 changes: 13 additions & 3 deletions src/spikeinterface/widgets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def set_default_plotter_backend(backend):
"figsize": "Size of matplotlib figure, default: None",
"figtitle": "The figure title, default: None",
},
"sortingview": {
"figpack": {
"generate_url": "If True, the figurl URL is generated and printed, default: True",
"display": "If True and in jupyter notebook/lab, the widget is displayed in the cell, default: True.",
"figlabel": "The figurl figure label, default: None",
"height": "The height of the sortingview View in jupyter, default: None",
"inline": "If True, the widget is displayed inline in the cell, default: None",
"height": "The height of the figpack View in jupyter, default: None",
},
"ipywidgets": {
"width_cm": "Width of the figure in cm, default: 10",
Expand All @@ -47,14 +48,23 @@ def set_default_plotter_backend(backend):
},
"ephyviewer": {},
"spikeinterface_gui": {},
# deprecated
"sortingview": {
"generate_url": "If True, the figurl URL is generated and printed, default: True",
"display": "If True and in jupyter notebook/lab, the widget is displayed in the cell, default: True.",
"figlabel": "The figurl figure label, default: None",
"height": "The height of the sortingview View in jupyter, default: None",
},
}

default_backend_kwargs = {
"matplotlib": {"figure": None, "ax": None, "axes": None, "ncols": 5, "figsize": None, "figtitle": None},
"sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None},
"figpack": {"generate_url": True, "display": True, "figlabel": None, "inline": None, "height": None},
"ipywidgets": {"width_cm": 25, "height_cm": 10, "display": True, "controllers": None},
"ephyviewer": {},
"spikeinterface_gui": {},
# deprecated
"sortingview": {"generate_url": True, "display": True, "figlabel": None, "height": None},
}


Expand Down
24 changes: 16 additions & 8 deletions src/spikeinterface/widgets/crosscorrelograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CrossCorrelogramsWidget(BaseWidget):
unit_ids list or None, default: None
List of unit ids
min_similarity_for_correlograms : float, default: 0.2
For sortingview backend. Threshold for computing pair-wise cross-correlograms.
For figpack backend. Threshold for computing pair-wise cross-correlograms.
If template similarity between two units is below this threshold, the cross-correlogram is not displayed.
For auto-correlograms plot, this is automatically set to None.
window_ms : float, default: 100.0
Expand All @@ -30,7 +30,7 @@ class CrossCorrelogramsWidget(BaseWidget):
Bin size in ms. If correlograms are already computed (e.g. with SortingAnalyzer),
this argument is ignored
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
unit_colors : dict | None, default: None
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
Expand Down Expand Up @@ -127,8 +127,18 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.axes[-1, i].set_xlabel("CCG (ms)")

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

dp = to_attr(data_plot)

Expand All @@ -144,14 +154,12 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
for j in range(i, len(unit_ids)):
if similarity[i, j] >= dp.min_similarity_for_correlograms:
cc_items.append(
vv.CrossCorrelogramItem(
vv_views.CrossCorrelogramItem(
unit_id1=unit_ids[i],
unit_id2=unit_ids[j],
bin_edges_sec=(dp.bins / 1000.0).astype("float32"),
bin_counts=dp.correlograms[i, j].astype("int32"),
)
)

self.view = vv.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector)

self.view = vv_views.CrossCorrelograms(cross_correlograms=cc_items, hide_unit_selector=dp.hide_unit_selector)
self.url = handle_display_and_url(self, self.view, **backend_kwargs)
32 changes: 22 additions & 10 deletions src/spikeinterface/widgets/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MetricsBaseWidget(BaseWidget):
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
include_metrics_data : bool, default: True
If True, metrics data are included in unit table
"""
Expand Down Expand Up @@ -258,8 +258,18 @@ def _update_ipywidget(self, change):
self.figure.canvas.flush_events()

def plot_sortingview(self, data_plot, **backend_kwargs):
import sortingview.views as vv
from .utils_sortingview import generate_unit_table_view, make_serializable, handle_display_and_url
self.plot_figpack(data_plot, use_sortingview=True, **backend_kwargs)

def plot_figpack(self, data_plot, **backend_kwargs):
from .utils_figpack import (
make_serializable,
handle_display_and_url,
import_figpack_or_sortingview,
generate_unit_table_view,
)

use_sortingview = backend_kwargs.get("use_sortingview", False)
vv_base, vv_views = import_figpack_or_sortingview(use_sortingview)

dp = to_attr(data_plot)

Expand All @@ -275,7 +285,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
metrics_sv = []
for col in metric_names:
dtype = np.array(metrics.iloc[0][col]).dtype
metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str)
metric = vv_views.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str)
metrics_sv.append(metric)

units_m = []
Expand All @@ -290,8 +300,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
continue
values_skip_nans[k] = v

units_m.append(vv.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans))
v_metrics = vv.UnitMetricsGraph(units=units_m, metrics=metrics_sv)
units_m.append(vv_views.UnitMetricsGraphUnit(unit_id=unit_id, values=values_skip_nans))
v_metrics = vv_views.UnitMetricsGraph(units=units_m, metrics=metrics_sv)

if not dp.hide_unit_selector:
if dp.include_metrics_data:
Expand All @@ -301,12 +311,14 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
if col not in sorting_copy.get_property_keys():
sorting_copy.set_property(col, metrics[col].values)
# generate table with properties
v_units_table = generate_unit_table_view(sorting_copy, unit_properties=metric_names)
v_units_table = generate_unit_table_view(
sorting_copy, unit_properties=metric_names, use_sortingview=use_sortingview
)
else:
v_units_table = generate_unit_table_view(dp.sorting)
v_units_table = generate_unit_table_view(dp.sorting, use_sortingview=use_sortingview)

self.view = vv.Splitter(
direction="horizontal", item1=vv.LayoutItem(v_units_table), item2=vv.LayoutItem(v_metrics)
self.view = vv_views.Splitter(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be using vv_base instead of vv_views?

direction="horizontal", item1=vv_views.LayoutItem(v_units_table), item2=vv_views.LayoutItem(v_metrics)
)
else:
self.view = v_metrics
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/quality_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class QualityMetricsWidget(MetricsBaseWidget):
Dict of colors with unit ids as keys and colors as values. Colors can be any type accepted
by matplotlib. If None, default colors are chosen using the `get_some_colors` function.
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class BaseRasterWidget(BaseWidget):
y_ticks : dict | None, default: None
Ticks on y-axis, passed to `set_yticks`. If None, default ticks are used.
hide_unit_selector : bool, default: False
For sortingview backend, if True the unit selector is not displayed
For figpack backend, if True the unit selector is not displayed
segment_boundary_kwargs : dict | None, default: None
Additional arguments for the segment boundary lines, passed to `matplotlib.axvline`
backend : str | None, default None
Expand Down
Loading
Loading