Skip to content
122 changes: 81 additions & 41 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def generate_drifting_recording(
duration=600.0,
sampling_frequency=30000.0,
probe_name="Neuropixels1-128",
probe=None,
generate_probe_kwargs=None,
unit_locations=None,
generate_unit_locations_kwargs=dict(
margin_um=20.0,
minimum_z=5.0,
Expand All @@ -321,6 +323,7 @@ def generate_drifting_recording(
# distribution="multimodal",
# num_modes=2,
),
displacement_data=None,
generate_displacement_vector_kwargs=dict(
displacement_sampling_frequency=5.0,
drift_start_um=[0, 20],
Expand All @@ -347,7 +350,9 @@ def generate_drifting_recording(
ellipse_angle=(0, np.pi * 2),
),
),
sorting=None,
generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0),
noise=None,
generate_noise_kwargs=dict(noise_levels=(6.0, 8.0), spatial_decay=25.0),
extra_outputs=False,
seed=None,
Expand All @@ -364,20 +369,30 @@ def generate_drifting_recording(
The duration in seconds.
sampling_frequency : float, dfault: 30000.
The sampling frequency.
probe: Probe object, default None
If provided, the Probe geometry to consider
probe_name : str, default: "Neuropixels1-128"
The probe type if generate_probe_kwargs is None.
The probe type if generate_probe_kwargs is None and probe is None.
generate_probe_kwargs : None or dict
A dict to generate the probe, this supersede probe_name when not None.
unit_locations: array, default None
The unit locations of the cells
generate_unit_locations_kwargs : dict
Parameters given to generate_unit_locations().
Parameters given to generate_unit_locations() if unit_locations is None
displacement_data: tuple of arrays, default None
The output of generate_displacement_vector(), if precomputed by the user
generate_displacement_vector_kwargs : dict
Parameters given to generate_displacement_vector().
Parameters given to generate_displacement_vector() if displacement_data is None
generate_templates_kwargs : dict
Parameters given to generate_templates()
sorting: NumpySorting, default None
The sorting to generate data from
generate_sorting_kwargs : dict
Parameters given to generate_sorting().
Parameters given to generate_sorting() if sorting is None
noise: NoiseGenerator, default None
Noise generator used to generate background noise
generate_noise_kwargs : dict
Parameters given to generate_noise().
Parameters given to generate_noise() if no noise is None
extra_outputs : bool, default False
Return optionaly a dict with more variables.
seed : None ot int
Expand Down Expand Up @@ -406,12 +421,31 @@ def generate_drifting_recording(

seed = _ensure_seed(seed)

if sorting is None:
sorting = generate_sorting(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=[
duration,
],
**generate_sorting_kwargs,
seed=seed,
)
else:
num_units = sorting.get_num_units()
sampling_frequency = sorting.sampling_frequency
if sorting._recording is not None:
duration = sorting.get_total_duration()

# probe
if generate_probe_kwargs is None:
generate_probe_kwargs = _toy_probes[probe_name]
probe = generate_multi_columns_probe(**generate_probe_kwargs)
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))
if probe is None:
if generate_probe_kwargs is None:
generate_probe_kwargs = _toy_probes[probe_name]

probe = generate_multi_columns_probe(**generate_probe_kwargs)
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))

channel_locations = probe.contact_positions
# import matplotlib.pyplot as plt
# import probeinterface.plotting
Expand All @@ -420,20 +454,34 @@ def generate_drifting_recording(
# plt.show()

# unit locations
unit_locations = generate_unit_locations(
num_units,
channel_locations,
seed=seed,
**generate_unit_locations_kwargs,
)

(
unit_displacements,
displacement_vectors,
displacement_unit_factor,
displacement_sampling_frequency,
displacements_steps,
) = generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs)
if unit_locations is None:
unit_locations = generate_unit_locations(
num_units,
channel_locations,
seed=seed,
**generate_unit_locations_kwargs,
)
else:
assert len(unit_locations) == num_units, "We should have num_units unit locations"

if displacement_data is None:
(
unit_displacements,
displacement_vectors,
displacement_unit_factor,
displacement_sampling_frequency,
displacements_steps,
) = generate_displacement_vector(
duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs
)
else:
(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make a dict no ? Or at least a named tuple.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, this was to be consistent, but I can make it a dict if you prefer

unit_displacements,
displacement_vectors,
displacement_unit_factor,
displacement_sampling_frequency,
displacements_steps,
) = displacement_data

# unit_params need to be fixed before the displacement steps
generate_templates_kwargs = generate_templates_kwargs.copy()
Expand Down Expand Up @@ -470,16 +518,6 @@ def generate_drifting_recording(

drifting_templates = DriftingTemplates.from_static_templates(templates)

sorting = generate_sorting(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=[
duration,
],
**generate_sorting_kwargs,
seed=seed,
)

sorting.set_property("gt_unit_locations", unit_locations)

distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2)
Expand All @@ -493,13 +531,14 @@ def generate_drifting_recording(
drifting_templates.templates_array_moved = templates_array_moved
drifting_templates.displacements = displacements_steps

noise = generate_noise(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
seed=seed,
**generate_noise_kwargs,
)
if noise is None:
noise = generate_noise(
probe=probe,
sampling_frequency=sampling_frequency,
durations=[duration],
seed=seed,
**generate_noise_kwargs,
)

static_recording = InjectDriftingTemplatesRecording(
sorting=sorting,
Expand Down Expand Up @@ -531,6 +570,7 @@ def generate_drifting_recording(
displacement_unit_factor=displacement_unit_factor,
unit_displacements=unit_displacements,
templates=templates,
generate_templates_kwargs=generate_templates_kwargs,
)
return static_recording, drifting_recording, sorting, extra_infos
else:
Expand Down