@@ -95,6 +95,7 @@ def generate_sorting(
9595 add_spikes_on_borders = False ,
9696 num_spikes_per_border = 3 ,
9797 border_size_samples = 20 ,
98+ extra_outputs = False ,
9899 seed = None ,
99100):
100101 """
@@ -135,10 +136,14 @@ def generate_sorting(
135136 num_segments = len (durations )
136137 unit_ids = np .arange (num_units )
137138
139+ extra_outputs_dict = {
140+ "firing_rates" : [],
141+ }
142+
138143 spikes = []
139144 for segment_index in range (num_segments ):
140145 num_samples = int (sampling_frequency * durations [segment_index ])
141- samples , labels = synthesize_poisson_spike_vector (
146+ samples , labels , firing_rates_array = synthesize_poisson_spike_vector (
142147 num_units = num_units ,
143148 sampling_frequency = sampling_frequency ,
144149 duration = durations [segment_index ],
@@ -172,12 +177,17 @@ def generate_sorting(
172177 )
173178 spikes .append (spikes_on_borders )
174179
180+ extra_outputs_dict ["firing_rates" ].append (firing_rates_array )
181+
175182 spikes = np .concatenate (spikes )
176183 spikes = spikes [np .lexsort ((spikes ["sample_index" ], spikes ["segment_index" ]))]
177184
178185 sorting = NumpySorting (spikes , sampling_frequency , unit_ids )
179186
180- return sorting
187+ if extra_outputs :
188+ return sorting , extra_outputs_dict
189+ else :
190+ return sorting
181191
182192
183193def add_synchrony_to_sorting (sorting , sync_event_ratio = 0.3 , seed = None ):
@@ -776,7 +786,7 @@ def synthesize_poisson_spike_vector(
776786 unit_indices = unit_indices [sort_indices ]
777787 spike_frames = spike_frames [sort_indices ]
778788
779- return spike_frames , unit_indices
789+ return spike_frames , unit_indices , firing_rates
780790
781791
782792def synthesize_random_firings (
0 commit comments