Skip to content

Conversation

@JoeZiminski
Copy link
Collaborator

@JoeZiminski JoeZiminski commented Jul 18, 2024

This PR adds an 'inter-session displacement' ground-truth recording generator. This is to act as test data for inter-session alignment e.g. #2626 #3126. The reason for this is to support inter-session alignment (e.g. #2626, #3126 ). The idea is to create separate recordings with the same templates but shifted unit locations across recordings. There are options to model:

  • rigid shift in (x, y)
  • nonrigid shift in (x, y)
  • neurons dropping out across recordings, recording_amplitude_scalings allows scaling the injected template amplitudes to different sizes across recordings.
  • additional neurons being introduced into the recording. By default, when the templates shift, the space previously occupied on the probe by those recordings is left empty. In reality, new neurons should be shifted into the recording. shift_units_outside_probe=True` introduces new neurons into the recording after probe shift.

This PR tries to use the existing motion machinery where possible, and makes some refactorings with this aim. The main changes are the refactorings, introduction of a new function for generating multi-session drifting recordings, and associated tests. Below are some sections to highlight usage, and example scripts for quick-running or deeper debugging.

I'm not sure about the name 'session displacement'. Maybe just 'generate_multi_session_recording()` is easier to understand. Also, please let me know if any other variable names are unclear.

Examples

Example 1

Set the amplitude of 2 units to zero between sessions, also with a 100 um shift. The units are shifted across sessions, and the top / bottom units are removed (e.g. simulating these neurons disappearing between sessions)_.

Screenshot 2024-09-03 at 17 00 57

Example 2

Set shift_units_outside_probe to True which introduces new units into the recording due to shift, alongside a 250 um shift. Note the top unit has been shifted out of the probe, the middle units are shifted up, and 2 new units are introduced at the bottom of the probe.

Screenshot 2024-09-03 at 17 02 53

Note

When running with a set of n num_units, it is initially suprising that you do not always see exactly n units clearly (e.g. below, which was run with num_units=5. The main driver of this is that the default generate_unit_locations_kwargs is between 5 and 45. For low-amplitude neurons, a far z axis will not generate enough signal to reach the probe (this is of course more realistic for simulation purposes). If you want to be sure you see n units in the raster, set margin_um and maximum_z lower in generate_unit_locations_kwargs e.g.:

image

generate_unit_locations_kwargs=dict(
    margin_um=0.0,
    minimum_z=5.0,
    maximum_z=10.0,
    minimum_distance=18.0,
    max_iteration=100,
    distance_strict=False,
)
Quick Run Code
import spikeinterface.full as si
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
import matplotlib.pyplot as plt
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

rec_list, _ = generate_session_displacement_recordings(
    non_rigid_gradient=None,  # Note this will set nonlinearity to both x and y (the same)
    num_units=5,
    recording_durations=[25, 25, 25],
    recording_shifts=[
        (0, 0),
        (0, 75),
        (0, -150),
    ],
    shift_units_outside_probe=False,
    seed=None,
)

# Plot the raster maps.

for rec in rec_list:

    peaks = detect_peaks(rec, method="locally_exclusive")
    peak_locs = localize_peaks(rec, peaks, method="grid_convolution")

    si.plot_drift_raster_map(
        peaks=peaks,
        peak_locations=peak_locs,
        recording=rec,
        clim=(-300, 0)  # fix clim for comparability across plots
    )
    plt.show()
Full Debugging Code
import spikeinterface.full as si
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
import matplotlib.pyplot as plt
import numpy as np
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

# Generate a ground truth recording where every unit is firing a lot,
# with high amplitude, and is close to the spike, so all picked up.
# This just makes it easy to play around with the units, e.g.
# if specifying 5 units, 5 unit peaks are clearly visible, none are lost
# because their position is too far from probe.

default_unit_params_range = dict(
    alpha=(100.0, 500.0),
    depolarization_ms=(0.09, 0.14),
    repolarization_ms=(0.5, 0.8),
    recovery_ms=(1.0, 1.5),
    positive_amplitude=(0.1, 0.25),
    smooth_ms=(0.03, 0.07),
    spatial_decay=(20, 40),
    propagation_speed=(250.0, 350.0),
    b=(0.1, 1),
    c=(0.1, 1),
    x_angle=(0, np.pi),
    y_angle=(0, np.pi),
    z_angle=(0, np.pi),
)

default_unit_params_range["alpha"] = (500, 500)  # do this or change the margin on generate_unit_locations_kwargs
default_unit_params_range["b"] = (0.5, 1)        # and make the units fatter, easier to receive signal!
default_unit_params_range["c"] = (0.5, 1)

scale_ = [np.array([0.25, 0.5, 1, 1, 0])] * 2
scale_ = [np.ones(5)] + scale_

rec_list, _ = generate_session_displacement_recordings(
    non_rigid_gradient=None,  # 0.05, TODO: note this will set nonlinearity to both x and y (the same)
    num_units=5,
    recording_durations=(25, 25, 25),  # TODO: checks on inputs
    recording_shifts=(
        (0, 0),
        (0, 0),
        (0, 0),
    ),
    recording_amplitude_scalings= {
        "method": "by_amplitude_and_firing_rate",
        "scalings": scale_,
    },
    shift_units_outside_probe=False,
    generate_sorting_kwargs=dict(firing_rates=(0, 200), refractory_period_ms=4.0),
    generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3),
    seed=None,
    generate_unit_locations_kwargs=dict(
        margin_um=0.0,  # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
        minimum_z=5.0,
        maximum_z=45.0,
        minimum_distance=18.0,
        max_iteration=100,
        distance_strict=False,
    ),
)

# Iterate through each recording, plotting the raw traces then
# detecting and plotting the peaks.

for rec in rec_list:

    si.plot_traces(rec, time_range=(0, 1))
    plt.show()

    peaks = detect_peaks(rec, method="locally_exclusive")
    peak_locs = localize_peaks(rec, peaks, method="grid_convolution")

    si.plot_drift_raster_map(
        peaks=peaks,
        peak_locations=peak_locs,
        recording=rec,
        clim=(-300, 0)  # fix clim for comparability across plots
    )
    plt.show()

TODO:

  • tests are randomly failing every now and that only on macOS

@JoeZiminski JoeZiminski force-pushed the add_session_displacement_generation branch from 089cbc0 to 60c8e5e Compare July 29, 2024 15:39
@alejoe91 alejoe91 added motion correction Questions related to motion correction generators Related to generator tools labels Aug 27, 2024
@JoeZiminski JoeZiminski force-pushed the add_session_displacement_generation branch from 72ca7fa to 33254b9 Compare September 3, 2024 14:13
@JoeZiminski JoeZiminski marked this pull request as ready for review September 3, 2024 16:08
@JoeZiminski JoeZiminski force-pushed the add_session_displacement_generation branch from 33254b9 to e996dee Compare September 3, 2024 16:35
Copy link
Collaborator

@cwindolf cwindolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! This seems really nice. I think my only question is whether there can be more logic sharing with the single-session generate_drifting_recording, since some of the changes you've made seem like they could be helpful there. But, not sure how much sense that makes.

return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps


def calculate_displacement_unit_factor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could be called something like "simulate_linear_gradient_drift"? i was a bit confused reading it, but it seems to be generating drift which is 0 at the top of the probe and something not zero at the bottom?

maybe someone can help explain what exactly

displacement_unit_factor = factors * (1 - f) + f

ends up producing... is it like there is some global drift plus per-unit linear drift?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @cwindolf thanks a lot for this review. This function is a refactoring of this code. I like simulate_linear_gradient_drift, for this PR I will keep the current naming for consistency with the old code. However I'll make an issue based on some of the points you raise in this PR (e.g. including some things in the within-session drift) and add a note on this there.

I agree I got quite confused the first (few) times looking through the use of non_rigid_gradient. I think the easiest way to see it is with some example values. In the first part of the function, the dot-product of the displacement vector and the unit location (expressed as a vector from the probe origin, I am not sure where it is, maybe bottom-left). In the y-displacement only case, this is just the unit y position. These unit positions are scaled to [0, 1] and called factors.

The f and expression you show ensure that the 'largest' unit location (e.g. near the top of the probe if the origin is bottom left) is scaled by 1 (no change). The scaling is linear across all unit positions, its kind of like a linspace where the max value is 1 and f sets the min value. e.g. looking at the smallest, largest and a middle location unit (i.e. factors 0, 1, and 0.5)

non_rigid_gradient=0.8

0 * 0.2 + 0.8= 0.8
0.5 * 0.2 + 0.8 = 0.9
1 * 0.8 + 0.2 = 1

non_rigid_gradient=0.2

0 * 0.8 + 0.2 = 0.2
0.5 * 0.8 + 0.2 = 0.6
1 * 0.8 + 0.2 = 1

So in the first case, the scaling of the units is only in the range [0.8, 1] of the normalised position of the unit. But for the smaller non_rigid_gradient=0.2, the scaling is between [0.2, 1].

I re-wrote the docstring of the function, let me know if its any clearer, I think there is still room for improvement, I am also not sure how much depth to go into.

)


def _update_kwargs_for_extended_units(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a cool feature! in the single-session case, i have previously dealt with this by adding extra units that are off the probe to start. i'm wondering how this fits into the single session case... would this function be useful there too?

Copy link
Collaborator Author

@JoeZiminski JoeZiminski Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers! The approach here is exactly the same, some additional simulated unit locations are generated that are off the probe. As the displacement is applied, these out-of-probe units are moved into the probe. I think this would also be useful in the single session case, as the probe drifts within a session more and more 'new' signal (i.e. signal that was not detected in the probes original position) will be introduced and could affect the correction.

I guess this is a difficult problem, as the nature of the 'new' units introduced into the region in which the probe is measuring will be random and presumably highly variable across preparations.

@JoeZiminski
Copy link
Collaborator Author

Hey @samuelgarcia @alejoe91 hope you both are good! I know things are very busy at the moment, but I was wondering if anyone might be available to give this a quick review? It would be useful for #3231 to merge this and it is (relatively) orthogonal from existing code.

spike_frames = spike_frames[sort_indices]

return spike_frames, unit_indices
return spike_frames, unit_indices, firing_rates
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's worth changing the output of this public function? Alternatively, you could add

firing_rates_array = _ensure_firing_rates(firing_rates, num_units, seed)

at line 154 and keep it as before.

I know @h-mayorquin uses this function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree on this one.

Plus, the firing rates are already passed to synthesize_poisson_spike_vector, why return them? You can make an array of the input in one line if necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes great points that's a big oversight, thanks both!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this Joe!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chrishalcrow @h-mayorquin! I think this is resolved now. In the end I didn't need to change anything at all related to firing_rates in generate.py 😄

@JoeZiminski JoeZiminski force-pushed the add_session_displacement_generation branch from dfe5ac9 to 8bceeaa Compare March 14, 2025 18:04
JoeZiminski and others added 8 commits March 14, 2025 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

generators Related to generator tools motion correction Questions related to motion correction

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants