Skip to content

Commit 51c6c8b

Browse files
committed
generate.py - add extra outputs.
1 parent 132dbe4 commit 51c6c8b

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/spikeinterface/core/generate.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def generate_sorting(
9595
add_spikes_on_borders=False,
9696
num_spikes_per_border=3,
9797
border_size_samples=20,
98+
extra_outputs=False,
9899
seed=None,
99100
):
100101
"""
@@ -135,10 +136,14 @@ def generate_sorting(
135136
num_segments = len(durations)
136137
unit_ids = np.arange(num_units)
137138

139+
extra_outputs_dict = {
140+
"firing_rates": [],
141+
}
142+
138143
spikes = []
139144
for segment_index in range(num_segments):
140145
num_samples = int(sampling_frequency * durations[segment_index])
141-
samples, labels = synthesize_poisson_spike_vector(
146+
samples, labels, firing_rates_array = synthesize_poisson_spike_vector(
142147
num_units=num_units,
143148
sampling_frequency=sampling_frequency,
144149
duration=durations[segment_index],
@@ -172,12 +177,17 @@ def generate_sorting(
172177
)
173178
spikes.append(spikes_on_borders)
174179

180+
extra_outputs_dict["firing_rates"].append(firing_rates_array)
181+
175182
spikes = np.concatenate(spikes)
176183
spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))]
177184

178185
sorting = NumpySorting(spikes, sampling_frequency, unit_ids)
179186

180-
return sorting
187+
if extra_outputs:
188+
return sorting, extra_outputs_dict
189+
else:
190+
return sorting
181191

182192

183193
def add_synchrony_to_sorting(sorting, sync_event_ratio=0.3, seed=None):
@@ -776,7 +786,7 @@ def synthesize_poisson_spike_vector(
776786
unit_indices = unit_indices[sort_indices]
777787
spike_frames = spike_frames[sort_indices]
778788

779-
return spike_frames, unit_indices
789+
return spike_frames, unit_indices, firing_rates
780790

781791

782792
def synthesize_random_firings(

0 commit comments

Comments
 (0)