-
Notifications
You must be signed in to change notification settings - Fork 2
Neuron dropout #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Neuron dropout #165
Changes from 15 commits
4fdc0b8
05ad857
0dc9ff0
a0ab11b
00a81b6
180f8a6
497c220
97ab280
965860e
faeeebf
8b3459c
2cda3a4
b3e28e3
c374c76
675c183
db97b49
ca21277
abc25e3
1bf6524
61e3283
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,76 @@ | |
import time | ||
import numpy as np | ||
from riglib.experiment import traits, experiment | ||
from riglib.bmi import clda | ||
|
||
###### CONSTANTS | ||
sec_per_min = 60 | ||
|
||
######################################################################################################## | ||
# Decoder/BMISystem add-ons | ||
######################################################################################################## | ||
class RandomUnitDropout(traits.HasTraits): | ||
''' | ||
Randomly removes units from the decoder. Does not work with CLDA turned on. Units are removed at the | ||
end of the delay period on each trial and replaced when the trial ends (either in reward or penalty). | ||
The same units will be dropped on repeated trials. The units to drop are specified in the | ||
`unit_drop_groups` attribute by a list of lists of unit indices. The `unit_drop_targets` attribute | ||
specifies the target indices on which to drop each group of units. | ||
''' | ||
|
||
unit_drop_prob = traits.Float(0, desc="Probability of dropping a group of units from the decoder") | ||
unit_drop_groups = traits.Array(value=[[0, 1], [2]], desc="Groups of unit indices to drop from the decoder one at a time") | ||
unit_drop_targets = traits.Array(value=[1, 2], desc="Target indices on which to drop groups of units from the decoder") | ||
|
||
def init(self): | ||
super().init() | ||
self.decoder_units_dropped = np.ones((len(self.decoder.units),), dtype='bool') | ||
new_dtype = np.dtype(self.trial_dtype.descr + [('decoder_units_dropped', '?', self.decoder_units_dropped.shape)]) | ||
self.trial_dtype = new_dtype | ||
self.unit_drop_group_idx = 0 | ||
|
||
# Save a copy of the decoder | ||
self.decoder_orig = self.decoder.copy() | ||
|
||
def _start_wait(self): | ||
|
||
# Decide which units to drop in this trial but don't actually drop them yet | ||
if (self.gen_indices[self.target_index] == self.unit_drop_targets[self.unit_drop_group_idx] and | ||
|
||
np.random.rand() < self.unit_drop_prob): | ||
self.decoder_units_dropped = np.isin(range(len(self.decoder.units)), self.unit_drop_groups[self.unit_drop_group_idx]) | ||
|
||
# Update the group for next trial | ||
self.unit_drop_group_idx = (self.unit_drop_group_idx + 1) % len(self.unit_drop_groups) | ||
else: | ||
self.decoder_units_dropped = np.zeros((len(self.decoder.units),), dtype='bool') | ||
|
||
# Update the trial record | ||
self.trial_record['decoder_units_dropped'] = self.decoder_units_dropped | ||
super()._start_wait() | ||
|
||
def _start_targ_transition(self): | ||
''' | ||
Override the decoder to drop random units. Keep a record of what's going on in the `trial` data. | ||
''' | ||
super()._start_targ_transition() | ||
if self.target_index + 1 < self.chain_length and np.any(self.decoder_units_dropped): | ||
|
||
if hasattr(self.decoder.filt, 'C'): | ||
self.decoder.filt.C[self.decoder_units_dropped, :] = 0 | ||
elif hasattr(self.decoder.filt, 'unit_to_state'): | ||
self.decoder.filt.unit_to_state[:, self.decoder_units_dropped] = 0 | ||
|
||
def _reset_decoder(self): | ||
self.decoder = self.decoder_orig.copy() | ||
|
||
def _increment_tries(self): | ||
super()._increment_tries() | ||
self._reset_decoder() | ||
|
||
def _start_reward(self): | ||
super()._start_reward() | ||
self._reset_decoder() | ||
|
||
|
||
class NormFiringRates(traits.HasTraits): | ||
''' Docstring ''' | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of being the unit indices to drop I think it should be based on the unit name/number.