Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions finn/track_application_menus/main_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

import finn
from finn.track_application_menus.menu_widget import MenuWidget
from finn.track_data_views.views.layers.track_labels import TrackLabels
from finn.track_data_views.views.layers.track_points import TrackPoints
from finn.track_data_views.views.ortho_views import (
_get_manager,
paint_event_hook,
point_data_hook,
sync_filters,
track_layers_hook,
)
from finn.track_data_views.views.tree_view.tree_widget import TreeWidget


Expand All @@ -22,4 +31,12 @@ def __init__(self, viewer: finn.Viewer):
layout = QVBoxLayout()
layout.addWidget(self.menu_widget)

orth_view_manager = _get_manager(viewer)
orth_view_manager.register_layer_hook(
(TrackLabels, TrackPoints), track_layers_hook
)
orth_view_manager.register_layer_hook((TrackLabels), paint_event_hook)
orth_view_manager.register_layer_hook((TrackPoints), point_data_hook)
orth_view_manager.set_sync_filters(sync_filters)

self.setLayout(layout)
86 changes: 58 additions & 28 deletions finn/track_data_views/views/layers/track_labels.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import random
import time
from typing import TYPE_CHECKING

import numpy as np

import finn
from finn.layers import Labels
from finn.utils import DirectLabelColormap
from finn.utils.action_manager import action_manager
from finn.utils.notifications import show_info, show_warning
Expand Down Expand Up @@ -99,34 +101,56 @@ def __init__(
self.bind_key("z")(self.tracks_viewer.undo)
self.bind_key("r")(self.tracks_viewer.redo)

# Connect click events to node selection
@self.mouse_drag_callbacks.append
def click(_, event):
if (
event.type == "mouse_press"
and self.mode == "pan_zoom"
and not (
self.tracks_viewer.mode == "lineage"
and self.viewer.dims.ndisplay == 3
)
): # disable selecting in lineage mode in 3D
# Listen to click, paint events and changing the selected label
self.mouse_drag_callbacks.append(self.click)
self.events.paint.connect(self._on_paint)
self.tracks_viewer.selected_nodes.list_updated.connect(self.update_selected_label)
self.events.selected_label.connect(self._ensure_valid_label)
self.events.mode.connect(self._check_mode)
self.viewer.dims.events.current_step.connect(self._ensure_valid_label)

# Connect click events to node selection
def click(self, _, event):
if (
event.type == "mouse_press"
and self.mode == "pan_zoom"
and not (
self.tracks_viewer.mode == "lineage" and self.viewer.dims.ndisplay == 3
)
): # disable selecting in lineage mode in 3D
# differentiate between click and drag
mouse_press_time = time.time()
dragged = False
yield
# on move
while event.type == "mouse_move":
dragged = True
yield
if dragged and time.time() - mouse_press_time < 0.5:
dragged = False # suppress micro drag events and treat them as click
# on release
if not dragged:
label = self.get_value(
event.position,
view_direction=event.view_direction,
dims_displayed=event.dims_displayed,
world=True,
)
# check opacity (=visibility) in the colormap
if label is not None and label != 0 and self.colormap.map(label)[-1] != 0:
append = "Shift" in event.modifiers
self.tracks_viewer.selected_nodes.add(label, append)
self.process_click(event, label)

# Listen to paint events and changing the selected label
self.events.paint.connect(self._on_paint)
self.tracks_viewer.selected_nodes.list_updated.connect(self.update_selected_label)
self.events.selected_label.connect(self._ensure_valid_label)
self.events.mode.connect(self._check_mode)
self.viewer.dims.events.current_step.connect(self._ensure_valid_label)
def assign_new_label(self, event):
"""Function for orthoviews to connect to so the 'm' event can be processed here"""

new_label(self)

def process_click(self, event: Event, label: int):
"""Process the click event to update the selected nodes"""

if (
label is not None and label != 0 and self.colormap.map(label)[-1] != 0
): # check opacity (=visibility) in the colormap
append = "Shift" in event.modifiers
self.tracks_viewer.selected_nodes.add(label, append)

def _get_colormap(self) -> DirectLabelColormap:
"""Get a DirectLabelColormap that maps node ids to their track ids, and then
Expand Down Expand Up @@ -208,20 +232,26 @@ def _parse_paint_event(self, event_val):
mask = concatenated_values == old_value
indices = tuple(concatenated_indices[dim][mask] for dim in range(ndim))
time_points = np.unique(indices[0])
for time in time_points:
time_mask = indices[0] == time
for time_point in time_points:
time_mask = indices[0] == time_point
actions.append(
(tuple(indices[dim][time_mask] for dim in range(ndim)), old_value)
)
return new_value, actions

def _revert_paint(self, event):
def _revert_paint(self, _, source_layer: Labels | None = None):
"""Revert a paint event after it fails validation (no motile tracker Actions have
been created). This keeps the view synced with the backend data.
been created). If a source_layer is provided, the paint event will be reverted on
this layer (this is necessary for orthoviews). This keeps the view synced with
the backend data.
"""
super().undo()

def _on_paint(self, event):
if source_layer is not None:
source_layer.undo() # revert on the orthoview
else:
super().undo()

def _on_paint(self, event, source_layer: Labels | None = None):
"""Listen to the paint event and check which track_ids have changed"""

with self.events.selected_label.blocker():
Expand Down Expand Up @@ -267,7 +297,7 @@ def _on_paint(self, event):
" If you want to update the track id of the node, please edit the "
"edges directly instead."
)
self._revert_paint(event)
self._revert_paint(event, source_layer)
self.refresh()
return
self.tracks_viewer.tracks_controller.update_segmentations(
Expand Down
71 changes: 55 additions & 16 deletions finn/track_data_views/views/layers/track_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
import time
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -13,12 +14,18 @@
if TYPE_CHECKING:
from finn.track_data_views.views_coordinator.tracks_viewer import TracksViewer

from psygnal import Signal

from finn.utils.events import Event


class TrackPoints(finn.layers.Points):
"""Extended points layer that holds the track information and emits and
responds to dynamics visualization signals
"""

data_updated = Signal()

@property
def _type_string(self) -> str:
return (
Expand Down Expand Up @@ -77,17 +84,25 @@ def __init__(
@self.mouse_drag_callbacks.append
def click(layer, event):
if event.type == "mouse_press":
# is the value passed from the click event?
point_index = layer.get_value(
event.position,
view_direction=event.view_direction,
dims_displayed=event.dims_displayed,
world=True,
)
if point_index is not None:
node_id = self.nodes[point_index]
append = "Shift" in event.modifiers
self.tracks_viewer.selected_nodes.add(node_id, append)
# differentiate between click and drag
mouse_press_time = time.time()
dragged = False
yield
# on move
while event.type == "mouse_move":
dragged = True
yield
if dragged and time.time() - mouse_press_time < 0.5:
dragged = False # suppress micro drag events and treat them as click
if not dragged:
# is the value passed from the click event?
point_index = layer.get_value(
event.position,
view_direction=event.view_direction,
dims_displayed=event.dims_displayed,
world=True,
)
self.process_point_click(point_index, event)

# listen to updates of the data
self.events.data.connect(self._update_data)
Expand All @@ -101,11 +116,22 @@ def click(layer, event):
# to update the nodes in self.tracks_viewer.selected_nodes
self.selected_data.events.items_changed.connect(self._update_selection)

def process_point_click(self, point_index: int | None, event: Event):
"""Select the clicked point(s)"""

if point_index is None:
self.tracks_viewer.selected_nodes.reset()
else:
node_id = self.nodes[point_index]
append = "Shift" in event.modifiers
self.tracks_viewer.selected_nodes.add(node_id, append)

def set_point_size(self, size: int) -> None:
"""Sets a new default point size"""

self.default_size = size
self._refresh()
self.size = self.default_size
self.border_color = self.border_color # emits border color event which triggers updating the sizes as well (size does not have its own event)

def _refresh(self):
"""Refresh the data in the points layer"""
Expand All @@ -121,16 +147,21 @@ def _refresh(self):
self.tracks_viewer.tracks.graph.nodes[node][NodeAttr.TRACK_ID.value]
for node in self.nodes
]
# this submits two events one where the action is 'ongoing' and one when it is finished
self.data = self.tracks_viewer.tracks.get_positions(self.nodes, incl_time=True)
self.data_updated.emit() # emit update signal for the orthogonal views to connect to

self.symbol = self.get_symbols(
self.tracks_viewer.tracks, self.tracks_viewer.symbolmap
)
self.face_color = [
self.tracks_viewer.colormap.map(track_id) for track_id in track_ids
]
self.properties = {"node_id": self.nodes, "track_id": track_ids}
self.size = self.default_size
self.border_color = [1, 1, 1, 1]

with self.events.border_color.blocker(): # no need to submit events for this
self.size = self.default_size
self.border_color = [1, 1, 1, 1]

self.events.data.connect(
self._update_data
Expand Down Expand Up @@ -227,8 +258,12 @@ def update_point_outline(self, visible: list[int] | str) -> None:
self.shown[indices] = True

# set border color for selected item
self.border_color = [1, 1, 1, 1]
self.size = self.default_size
with (
self.events.border_color.blocker()
): # block the event emitter here to not trigger update in orthogonal views
self.border_color = [1, 1, 1, 1]
with self.events.size.blocker():
self.size = self.default_size
for node in self.tracks_viewer.selected_nodes:
index = self.node_index_dict[node]
self.border_color[index] = (
Expand All @@ -238,4 +273,8 @@ def update_point_outline(self, visible: list[int] | str) -> None:
1,
)
self.size[index] = math.ceil(self.default_size + 0.3 * self.default_size)

# emit the event to trigger update in orthogonal views
self.border_color = self.border_color
self.size = self.size
self.refresh()
4 changes: 4 additions & 0 deletions finn/track_data_views/views/layers/tracks_layer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def center_view(self, node):
location[dim] + 0.5
) # use the world location, since the 'step' in viewer.dims.range
# already in world units
# Also update the step for the dims that are displayed, in order to sync with
# the orthogonal views
for dim in self.viewer.dims.displayed:
step[dim] = int(location[dim] + 0.5)
self.viewer.dims.current_step = step

# check whether the new coordinates are inside or outside the field of view,
Expand Down
Loading
Loading