|
| 1 | +""" |
| 2 | +TODO: some notes on this debugging script. |
| 3 | +""" |
| 4 | +import spikeinterface.full as si |
| 5 | +from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
| 8 | +from spikeinterface.sortingcomponents.peak_detection import detect_peaks |
| 9 | +from spikeinterface.sortingcomponents.peak_localization import localize_peaks |
| 10 | + |
| 11 | +# Generate a ground truth recording where every unit is firing a lot, |
| 12 | +# with high amplitude, and is close to the spike, so all picked up. |
| 13 | +# This just makes it easy to play around with the units, e.g. |
| 14 | +# if specifying 5 units, 5 unit peaks are clearly visible, none are lost |
| 15 | +# because their position is too far from probe. |
| 16 | + |
| 17 | +default_unit_params_range = dict( |
| 18 | + alpha=(100.0, 500.0), |
| 19 | + depolarization_ms=(0.09, 0.14), |
| 20 | + repolarization_ms=(0.5, 0.8), |
| 21 | + recovery_ms=(1.0, 1.5), |
| 22 | + positive_amplitude=(0.1, 0.25), |
| 23 | + smooth_ms=(0.03, 0.07), |
| 24 | + spatial_decay=(20, 40), |
| 25 | + propagation_speed=(250.0, 350.0), |
| 26 | + b=(0.1, 1), |
| 27 | + c=(0.1, 1), |
| 28 | + x_angle=(0, np.pi), |
| 29 | + y_angle=(0, np.pi), |
| 30 | + z_angle=(0, np.pi), |
| 31 | +) |
| 32 | + |
| 33 | +default_unit_params_range["alpha"] = (500, 500) # do this or change the margin on generate_unit_locations_kwargs |
| 34 | +default_unit_params_range["b"] = (0.5, 1) # and make the units fatter, easier to receive signal! |
| 35 | +default_unit_params_range["c"] = (0.5, 1) |
| 36 | + |
| 37 | +scale_ = [np.array([0.25, 0.5, 1, 1, 0])] * 2 |
| 38 | +scale_ = [np.ones(5)] + scale_ |
| 39 | + |
| 40 | +rec_list, _ = generate_session_displacement_recordings( |
| 41 | + non_rigid_gradient=None, # 0.05, TODO: note this will set nonlinearity to both x and y (the same) |
| 42 | + num_units=5, |
| 43 | + recording_durations=(25, 25, 25), # TODO: checks on inputs |
| 44 | + recording_shifts=( |
| 45 | + (0, 0), |
| 46 | + (0, 0), |
| 47 | + (0, 0), |
| 48 | + ), |
| 49 | + recording_amplitude_scalings= { |
| 50 | + "method": "by_amplitude_and_firing_rate", |
| 51 | + "scalings": scale_, |
| 52 | + }, |
| 53 | + generate_sorting_kwargs=dict(firing_rates=(0, 200), refractory_period_ms=4.0), |
| 54 | + generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3), |
| 55 | + seed=None, |
| 56 | + generate_unit_locations_kwargs=dict( |
| 57 | + margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up. |
| 58 | + minimum_z=5.0, |
| 59 | + maximum_z=45.0, |
| 60 | + minimum_distance=18.0, |
| 61 | + max_iteration=100, |
| 62 | + distance_strict=False, |
| 63 | + ), |
| 64 | +) |
| 65 | + |
| 66 | +# Iterate through each recording, plotting the raw traces then |
| 67 | +# detecting and plotting the peaks. |
| 68 | + |
| 69 | +for rec in rec_list: |
| 70 | + |
| 71 | + si.plot_traces(rec, time_range=(0, 1)) |
| 72 | + plt.show() |
| 73 | + |
| 74 | + peaks = detect_peaks(rec, method="locally_exclusive") |
| 75 | + peak_locs = localize_peaks(rec, peaks, method="grid_convolution") |
| 76 | + |
| 77 | + si.plot_drift_raster_map( |
| 78 | + peaks=peaks, |
| 79 | + peak_locations=peak_locs, |
| 80 | + recording=rec, |
| 81 | + clim=(-300, 0) # fix clim for comparability across plots |
| 82 | + ) |
| 83 | + plt.show() |
0 commit comments