diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 82b68d86b0..bd81bc74a7 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -309,7 +309,9 @@ def generate_drifting_recording( duration=600.0, sampling_frequency=30000.0, probe_name="Neuropixels1-128", + probe=None, generate_probe_kwargs=None, + unit_locations=None, generate_unit_locations_kwargs=dict( margin_um=20.0, minimum_z=5.0, @@ -321,6 +323,7 @@ def generate_drifting_recording( # distribution="multimodal", # num_modes=2, ), + displacement_data=None, generate_displacement_vector_kwargs=dict( displacement_sampling_frequency=5.0, drift_start_um=[0, 20], @@ -347,7 +350,9 @@ def generate_drifting_recording( ellipse_angle=(0, np.pi * 2), ), ), + sorting=None, generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0), + noise=None, generate_noise_kwargs=dict(noise_levels=(6.0, 8.0), spatial_decay=25.0), extra_outputs=False, seed=None, @@ -364,20 +369,30 @@ def generate_drifting_recording( The duration in seconds. sampling_frequency : float, dfault: 30000. The sampling frequency. + probe: Probe object, default None + If provided, the Probe geometry to consider probe_name : str, default: "Neuropixels1-128" - The probe type if generate_probe_kwargs is None. + The probe type if generate_probe_kwargs is None and probe is None. generate_probe_kwargs : None or dict A dict to generate the probe, this supersede probe_name when not None. + unit_locations: array, default None + The unit locations of the cells generate_unit_locations_kwargs : dict - Parameters given to generate_unit_locations(). + Parameters given to generate_unit_locations() if unit_locations is None + displacement_data: tuple of arrays, default None + The output of generate_displacement_vector(), if precomputed by the user generate_displacement_vector_kwargs : dict - Parameters given to generate_displacement_vector(). + Parameters given to generate_displacement_vector() if displacement_data is None generate_templates_kwargs : dict Parameters given to generate_templates() + sorting: NumpySorting, default None + The sorting to generate data from generate_sorting_kwargs : dict - Parameters given to generate_sorting(). + Parameters given to generate_sorting() if sorting is None + noise: NoiseGenerator, default None + Noise generator used to generate background noise generate_noise_kwargs : dict - Parameters given to generate_noise(). + Parameters given to generate_noise() if no noise is None extra_outputs : bool, default False Return optionaly a dict with more variables. seed : None ot int @@ -406,12 +421,31 @@ def generate_drifting_recording( seed = _ensure_seed(seed) + if sorting is None: + sorting = generate_sorting( + num_units=num_units, + sampling_frequency=sampling_frequency, + durations=[ + duration, + ], + **generate_sorting_kwargs, + seed=seed, + ) + else: + num_units = sorting.get_num_units() + sampling_frequency = sorting.sampling_frequency + if sorting._recording is not None: + duration = sorting.get_total_duration() + # probe - if generate_probe_kwargs is None: - generate_probe_kwargs = _toy_probes[probe_name] - probe = generate_multi_columns_probe(**generate_probe_kwargs) - num_channels = probe.get_contact_count() - probe.set_device_channel_indices(np.arange(num_channels)) + if probe is None: + if generate_probe_kwargs is None: + generate_probe_kwargs = _toy_probes[probe_name] + + probe = generate_multi_columns_probe(**generate_probe_kwargs) + num_channels = probe.get_contact_count() + probe.set_device_channel_indices(np.arange(num_channels)) + channel_locations = probe.contact_positions # import matplotlib.pyplot as plt # import probeinterface.plotting @@ -420,20 +454,34 @@ def generate_drifting_recording( # plt.show() # unit locations - unit_locations = generate_unit_locations( - num_units, - channel_locations, - seed=seed, - **generate_unit_locations_kwargs, - ) - - ( - unit_displacements, - displacement_vectors, - displacement_unit_factor, - displacement_sampling_frequency, - displacements_steps, - ) = generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs) + if unit_locations is None: + unit_locations = generate_unit_locations( + num_units, + channel_locations, + seed=seed, + **generate_unit_locations_kwargs, + ) + else: + assert len(unit_locations) == num_units, "We should have num_units unit locations" + + if displacement_data is None: + ( + unit_displacements, + displacement_vectors, + displacement_unit_factor, + displacement_sampling_frequency, + displacements_steps, + ) = generate_displacement_vector( + duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs + ) + else: + ( + unit_displacements, + displacement_vectors, + displacement_unit_factor, + displacement_sampling_frequency, + displacements_steps, + ) = displacement_data # unit_params need to be fixed before the displacement steps generate_templates_kwargs = generate_templates_kwargs.copy() @@ -470,16 +518,6 @@ def generate_drifting_recording( drifting_templates = DriftingTemplates.from_static_templates(templates) - sorting = generate_sorting( - num_units=num_units, - sampling_frequency=sampling_frequency, - durations=[ - duration, - ], - **generate_sorting_kwargs, - seed=seed, - ) - sorting.set_property("gt_unit_locations", unit_locations) distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2) @@ -493,13 +531,18 @@ def generate_drifting_recording( drifting_templates.templates_array_moved = templates_array_moved drifting_templates.displacements = displacements_steps - noise = generate_noise( - probe=probe, - sampling_frequency=sampling_frequency, - durations=[duration], - seed=seed, - **generate_noise_kwargs, - ) + if noise is None: + noise = generate_noise( + probe=probe, + sampling_frequency=sampling_frequency, + durations=[duration], + seed=seed, + **generate_noise_kwargs, + ) + else: + assert noise.sampling_frequency == sampling_frequency, "Noise sampling frequency mismatch" + assert noise.probe.get_contact_count() == probe.get_contact_count(), "Noise num channels mismatch" + assert noise.get_total_duration() == duration, "Noise duration should be the same as the recording duration" static_recording = InjectDriftingTemplatesRecording( sorting=sorting, @@ -531,6 +574,7 @@ def generate_drifting_recording( displacement_unit_factor=displacement_unit_factor, unit_displacements=unit_displacements, templates=templates, + generate_templates_kwargs=generate_templates_kwargs, ) return static_recording, drifting_recording, sorting, extra_infos else: