Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 161 additions & 40 deletions spikeinterface_gui/basescatterview.py

Large diffs are not rendered by default.

93 changes: 87 additions & 6 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,14 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save

spike_vector2 = self.analyzer.sorting.to_spike_vector(concatenated=False)
# this is dict of list because per segment spike_indices[segment_index][unit_id]
spike_indices_abs = spike_vector_to_indices(spike_vector2, unit_ids, absolute_index=True)
spike_indices = spike_vector_to_indices(spike_vector2, unit_ids)
# this is flatten
spike_per_seg = [s.size for s in spike_vector2]
# dict[unit_id] -> all indices for this unit across segments
self._spike_index_by_units = {}
# dict[seg_index][unit_id] -> all indices for this unit for one segment
self._spike_index_by_segment_and_units = spike_indices
self._spike_index_by_segment_and_units = spike_indices_abs
for unit_id in unit_ids:
inds = []
for seg_ind in range(num_seg):
Expand Down Expand Up @@ -335,10 +336,17 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
raise ValueError("Curation data format version is missing and is required in the curation data.")
try:
validate_curation_dict(curation_data)
self.curation_data = curation_data
except Exception as e:
raise ValueError(f"Invalid curation data.\nError: {e}")

if curation_data.get("merges") is None:
curation_data["merges"] = []
if curation_data.get("splits") is None:
curation_data["splits"] = []
if curation_data.get("removed") is None:
curation_data["removed"] = []
self.curation_data = curation_data

self.has_default_quality_labels = False
if "label_definitions" not in self.curation_data:
if label_definitions is not None:
Expand All @@ -355,6 +363,8 @@ def __init__(self, analyzer=None, backend="qt", parent=None, verbose=False, save
print('Curation quality labels are the default ones')
self.has_default_quality_labels = True

# this is used to store the active split unit
self.active_split = None

def check_is_view_possible(self, view_name):
from .viewlist import possible_class_views
Expand Down Expand Up @@ -460,6 +470,13 @@ def set_visible_unit_ids(self, visible_unit_ids):
if len(visible_unit_ids) > lim:
visible_unit_ids = visible_unit_ids[:lim]
self._visible_unit_ids = list(visible_unit_ids)
self.active_split = None
if len(visible_unit_ids) == 1 and self.curation:
# check if unit is split
for split in self.curation_data['splits']:
if visible_unit_ids[0] == split['unit_id']:
self.active_split = split
break

def get_visible_unit_ids(self):
"""Get list of visible unit_ids"""
Expand Down Expand Up @@ -524,10 +541,21 @@ def get_indices_spike_visible(self):
return self._spike_visible_indices

def get_indices_spike_selected(self):
if self.active_split is not None:
# select the splitted spikes in the active split
split_unit_id = self.active_split['unit_id']
spike_inds = self.get_spike_indices(split_unit_id, seg_index=None)
split_indices = self.active_split['indices']
self._spike_selected_indices = np.array(spike_inds[split_indices], dtype='int64')
return self._spike_selected_indices

def set_indices_spike_selected(self, inds):
self._spike_selected_indices = np.array(inds)
if len(self._spike_selected_indices) == 1:
# set time info
segment_index = self.spikes['segment_index'][self._spike_selected_indices[0]]
sample_index = self.spikes['sample_index'][self._spike_selected_indices[0]]
self.set_time(time=sample_index / self.sampling_frequency, segment_index=segment_index)

def get_spike_indices(self, unit_id, seg_index=None):
if seg_index is None:
Expand Down Expand Up @@ -767,15 +795,68 @@ def make_manual_merge_if_possible(self, merge_unit_ids):
if self.verbose:
print(f"Merged unit group: {[str(u) for u in merge_unit_ids]}")
return True

def make_manual_split_if_possible(self, unit_id, indices):
"""
Check if the a unit_id can be split into a new split in the curation_data.

If unit_id is already in the removed list then the split is skipped.
If unit_id is already in some other split then the split is skipped.
"""
if not self.curation:
return False

if unit_id in self.curation_data["removed"]:
return False

# check if unit_id is already in a split
for split in self.curation_data["splits"]:
if split["unit_id"] == unit_id:
return False

new_split = {
"unit_id": unit_id,
"mode": "indices",
"indices": indices
}
self.curation_data["splits"].append(new_split)
if self.verbose:
print(f"Split unit {unit_id} with {len(indices)} spikes")
return True

def make_manual_restore_merge(self, merge_group_indices):
def make_manual_restore_merge(self, merge_indices):
if not self.curation:
return
for merge_index in merge_group_indices:
for merge_index in merge_indices:
if self.verbose:
print(f"Unmerged merge group {self.curation_data['merge_unit_groups'][merge_index]['unit_ids']}")
print(f"Unmerged {self.curation_data['merges'][merge_index]['unit_ids']}")
self.curation_data["merges"].pop(merge_index)

def make_manual_restore_split(self, split_indices):
if not self.curation:
return
for split_index in split_indices:
if self.verbose:
print(f"Unsplitting {self.curation_data['splits'][split_index]['unit_id']}")
self.curation_data["splits"].pop(split_index)

def set_active_split_unit(self, unit_id):
"""
Set the active split unit_id.
This is used to set the label for the split unit.
"""
if not self.curation:
return
if unit_id is None:
self.active_split = None
else:
if unit_id in self.curation_data["removed"]:
print(f"Unit {unit_id} is removed, cannot set as active split unit")
return
active_split = [s for s in self.curation_data["splits"] if s["unit_id"] == unit_id]
if len(active_split) == 1:
self.active_split = active_split[0]

def get_curation_label_definitions(self):
# give only label definition with exclusive
label_definitions = {}
Expand Down
130 changes: 92 additions & 38 deletions spikeinterface_gui/crosscorrelogramview.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ class CrossCorrelogramView(ViewBase):
_supported_backend = ['qt', 'panel']
_depend_on = ["correlograms"]
_settings = [
{'name': 'window_ms', 'type': 'float', 'value' : 50. },
{'name': 'bin_ms', 'type': 'float', 'value' : 1.0 },
{'name': 'display_axis', 'type': 'bool', 'value' : True },
{'name': 'max_visible', 'type': 'int', 'value' : 8 },
]
{'name': 'window_ms', 'type': 'float', 'value' : 50. },
{'name': 'bin_ms', 'type': 'float', 'value' : 1.0 },
{'name': 'display_axis', 'type': 'bool', 'value' : True },
]
_need_compute = True

def __init__(self, controller=None, parent=None, backend="qt"):
Expand All @@ -26,6 +25,36 @@ def _on_settings_changed(self):
def _compute(self):
self.ccg, self.bins = self.controller.compute_correlograms(
self.settings['window_ms'], self.settings['bin_ms'])

def _compute_split_ccg(self):
"""
This method is used to compute the cross-correlogram for a split unit.
It is called when the user selects a split unit in the controller.
"""
from spikeinterface import NumpySorting
from spikeinterface.postprocessing import compute_correlograms

if self.controller.active_split is None:
raise ValueError("No active split unit selected.")

split_unit_id = self.controller.active_split["unit_id"]
spike_inds = self.controller.get_spike_indices(split_unit_id, seg_index=None)
split_indices = self.controller.active_split['indices']
spikes_split_unit = self.controller.spikes[spike_inds]
unit_index = spikes_split_unit[0]["unit_index"]
# change unit_index for split indices
spikes_split_unit["unit_index"][split_indices] = unit_index + 1
split_sorting = NumpySorting(
spikes=spikes_split_unit,
sampling_frequency=self.controller.sampling_frequency,
unit_ids=[f"{split_unit_id}-0", f"{split_unit_id}-1"]
)
ccg, bins = compute_correlograms(
split_sorting,
window_ms=self.settings['window_ms'],
bin_ms=self.settings['bin_ms']
)
return ccg, bins

## Qt ##

Expand All @@ -51,18 +80,32 @@ def _qt_refresh(self):
return

visible_unit_ids = self.controller.get_visible_unit_ids()
visible_unit_ids = visible_unit_ids[:self.settings['max_visible']]

n = len(visible_unit_ids)

unit_ids = list(self.controller.unit_ids)
if self.controller.active_split is None:
n = len(visible_unit_ids)
unit_ids = list(self.controller.unit_ids)
colors = {
unit_id: self.get_unit_color(unit_id) for unit_id in visible_unit_ids
}
ccg = self.ccg
bins = self.bins
else:
split_unit_id = visible_unit_ids[0]
n = 2
unit_ids = [f"{split_unit_id}-0", f"{split_unit_id}-1"]
visible_unit_ids = unit_ids
ccg, bins = self._compute_split_ccg()
split_unit_color = self.get_unit_color(split_unit_id)
colors = {
f"{split_unit_id}-0": split_unit_color,
f"{split_unit_id}-1": split_unit_color,
}

for r in range(n):
for c in range(r, n):

i = unit_ids.index(visible_unit_ids[r])
j = unit_ids.index(visible_unit_ids[c])
count = self.ccg[i, j, :]
count = ccg[i, j, :]

plot = pg.PlotItem()
if not self.settings['display_axis']:
Expand All @@ -71,16 +114,15 @@ def _qt_refresh(self):

if r==c:
unit_id = visible_unit_ids[r]
color = self.get_unit_color(unit_id)
color = colors[unit_id]
else:
color = (120,120,120,120)

curve = pg.PlotCurveItem(self.bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
curve = pg.PlotCurveItem(bins, count, stepMode='center', fillLevel=0, brush=color, pen=color)
plot.addItem(curve)
self.grid.addItem(plot, row=r, col=c)

## panel ##

def _panel_make_layout(self):
import panel as pn
import bokeh.plotting as bpl
Expand Down Expand Up @@ -115,31 +157,40 @@ def _panel_refresh(self):
return

visible_unit_ids = self.controller.get_visible_unit_ids()

# Show warning above the plot if too many visible units
if len(visible_unit_ids) > self.settings['max_visible']:
warning_msg = f"Only showing first {self.settings['max_visible']} units out of {len(visible_unit_ids)} visible units"
insert_warning(self, warning_msg)
self.is_warning_active = True
return
if self.is_warning_active:
clear_warning(self)
self.is_warning_active = False

visible_unit_ids = visible_unit_ids[:self.settings['max_visible']]

n = len(visible_unit_ids)
unit_ids = list(self.controller.unit_ids)
if self.controller.active_split is None:
n = len(visible_unit_ids)
unit_ids = list(self.controller.unit_ids)
colors = {
unit_id: self.get_unit_color(unit_id) for unit_id in visible_unit_ids
}
ccg = self.ccg
bins = self.bins
else:
split_unit_id = visible_unit_ids[0]
n = 2
unit_ids = [f"{split_unit_id}-0", f"{split_unit_id}-1"]
visible_unit_ids = unit_ids
ccg, bins = self._compute_split_ccg()
split_unit_color = self.get_unit_color(split_unit_id)
colors = {
f"{split_unit_id}-0": split_unit_color,
f"{split_unit_id}-1": split_unit_color,
}

first_fig = None
for r in range(n):
row_plots = []
for c in range(r, n):

i = unit_ids.index(visible_unit_ids[r])
j = unit_ids.index(visible_unit_ids[c])
count = self.ccg[i, j, :]
count = ccg[i, j, :]

# Create Bokeh figure
p = bpl.figure(
if first_fig is not None:
extra_kwargs = dict(x_range=first_fig.x_range)
else:
extra_kwargs = dict()
fig = bpl.figure(
width=250,
height=250,
tools="pan,wheel_zoom,reset",
Expand All @@ -148,29 +199,32 @@ def _panel_refresh(self):
background_fill_color=_bg_color,
border_fill_color=_bg_color,
outline_line_color="white",
**extra_kwargs,
)
p.toolbar.logo = None
fig.toolbar.logo = None

# Get color from controller
if r == c:
unit_id = visible_unit_ids[r]
color = self.get_unit_color(unit_id)
color = colors[unit_id]
fill_alpha = 0.7
else:
color = "lightgray"
fill_alpha = 0.4

p.quad(
fig.quad(
top=count,
bottom=0,
left=self.bins[:-1],
right=self.bins[1:],
left=bins[:-1],
right=bins[1:],
fill_color=color,
line_color=color,
alpha=fill_alpha,
)
if first_fig is None:
first_fig = fig

row_plots.append(p)
row_plots.append(fig)
# Fill row with None for proper spacing
full_row = [None] * r + row_plots + [None] * (n - len(row_plots))
self.plots.append(full_row)
Expand Down
Loading