-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathDistanceCellModel.m
More file actions
138 lines (120 loc) · 6.64 KB
/
DistanceCellModel.m
File metadata and controls
138 lines (120 loc) · 6.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
function[Starts Goals Decoded Vec_l RateDiff Error] = DistanceCellModel(GC_cpm,DrawFig)
%% Distance cell / number line model of vector navigation with grid cells
% Daniel Bush, UCL Institute of Cognitive Neuroscience
% Reference: Using Grid Cells for Navigation (2015) Neuron (in press)
% Contact: drdanielbush@gmail.com
%
% Inputs:
% GC_cpm = Number of grid cells per unique phase, in each module
% DrawFig = Plot figure of errors (0 / 1)
%
% Outputs:
% Starts = Random 2D start locations (m)
% Goals = Random 2D goal locations (m)
% Decoded = 2D translation vector decoded from grid cell activity (m)
% Vec_l = Length of decoded 2D translation vector (m)
% RateDiff = Difference in firing rate between read-out cells on each axis
% (Hz) - should be proportional to translation vector length
% Error = Error in decoded translation vector (m)
% Provide some parameters for the simulation
iterations = 1000; % How many iterations to run?
Range = 500; % Range of distance cells (m)
GC_mps = 20; % Unique grid cell phases on each axis, per module
GC_scales = 0.25.*1.4.^(0:9); % Grid cell scales (m)
GC_r = 30; % Peak grid cell firing rate (Hz)
dt = 0.1; % Time window of grid cell firing
Emax_k = 0.01; % Emax WTA k parameter (see de Almeida et al., 2009)
Dist_step = 0.04; % Displacement increment for distance cells (m)
Normalise = 1; % Normalise distance cell firing rates? (0 / 1)
Distances = Dist_step : Dist_step : Range; % Assign the locations along each axis that distance cells code for
N_distance = length(Distances); % Total number of distance cells
N_grid = length(GC_scales)*GC_mps; % Total number of grid cells scale / phase values per directional axis
clear scale
% Generate grid cell to distance cell synaptic weights
Grid_Dist_w = zeros(N_grid,N_distance);
for scale = 1 : length(GC_scales)
for offset = 0 : GC_mps-1
Grid_Dist_w((scale-1)*GC_mps + offset + 1, 1 : N_distance) = (cos((mod(Distances-(offset/GC_mps)*GC_scales(scale),GC_scales(scale))/GC_scales(scale))*2*pi)+1)/2;
end
end
clear scale offset
% Generate distance cell to read out cell synaptic weights
Dist_Out_w(1,:) = linspace(0,100,N_distance);
Dist_Out_w(2,:) = linspace(100,0,N_distance);
% Assign some memory
Starts = nan(iterations,2); % Log start positions
Goals = nan(iterations,2); % Log goal positions
RateDiff = nan(iterations,2); % Log decoded vectors
% For each iteration...
for i = 1 : iterations
% Update the user
if mod(i,iterations/10)==0
disp([int2str(i/iterations*100) '% complete...']);
drawnow
end
% Randomly assign start and goal locations
Starts(i,:) = [rand*Range rand*Range];
Goals(i,:) = [rand*Range rand*Range];
% For each axis...
for ax = 1 : 2
% Identify the mean firing rates of grid cells at the start and goal locations
StartRates = (1+cos((repmat(((mod(Starts(i,ax),GC_scales)./GC_scales)*GC_mps)',1,GC_mps) - (meshgrid(1:GC_mps,1:length(GC_scales))-1))/GC_mps*2*pi))/2*GC_r*dt;
GoalRates = (1+cos((repmat(((mod(Goals(i,ax), GC_scales)./GC_scales)*GC_mps)',1,GC_mps) - (meshgrid(1:GC_mps,1:length(GC_scales))-1))/GC_mps*2*pi))/2*GC_r*dt;
% Convert that to Poisson firing in each of the grid cells encoding
% each phase offset
StartRates = sum(poissrnd(repmat(StartRates,[1 1 GC_cpm])),3);
GoalRates = sum(poissrnd(repmat(GoalRates,[1 1 GC_cpm])),3);
StartRates = reshape(StartRates',length(GC_scales)*GC_mps,1);
GoalRates = reshape(GoalRates',length(GC_scales)*GC_mps,1);
% Compute the firing rate of distance cells
StartDist = StartRates' * Grid_Dist_w;
GoalDist = GoalRates' * Grid_Dist_w;
clear StartRates GoalRates
% Apply the E-max WTA algorithm
StartDist = StartDist .* (StartDist>((1-Emax_k).*max(StartDist)));
GoalDist = GoalDist .* (GoalDist>((1-Emax_k).* max(GoalDist)));
% Normalise the distance cell firing rates, if required
if Normalise
StartDist = StartDist ./ sum(StartDist) * 10;
GoalDist = GoalDist ./ sum(GoalDist) * 10;
end
% Compute the firing rate of output cells
RateDiff(i,ax) = (sum(StartDist.*Dist_Out_w(2,:)) + sum(GoalDist.*Dist_Out_w(1,:)) - sum(StartDist.*Dist_Out_w(1,:)) - sum(GoalDist.*Dist_Out_w(2,:)))/100;
clear StartDist GoalDist
end
end
% Compute the error
Vec_l = sqrt(sum((Goals - Starts).^2,2));
Actual = Goals - Starts;
b = regress(RateDiff(:),Actual(:));
Decoded = RateDiff./b;
Error = sqrt(sum((Decoded - Actual).^2,2));
% Plot the firing rate difference against true vector, if required
if DrawFig
figure
subplot(2,2,1)
scatter(Actual(:),RateDiff(:),'k.')
set(gca,'FontSize',14)
xlabel('True Translation Vector (m)','FontSize',14)
ylabel('Read-Out Firing Rate Difference (Hz)','FontSize',14)
axis square
subplot(2,2,2)
temp = histc(Error,linspace(0,ceil(max(Error)*100)/100,100)) ./ iterations;
bar(linspace(0,ceil(max(Error)*100),100),temp,'FaceColor','k','EdgeColor','k')
set(gca,'FontSize',14)
xlabel('Error in Decoded Translation Vector (cm)','FontSize',14)
ylabel('Relative Frequency','FontSize',14)
axis square
subplot(2,2,3)
scatter(Vec_l,Error*100,'k.')
set(gca,'FontSize',14)
xlabel('Decoded Translation Vector Length (m)','FontSize',14)
ylabel('Error in Decoded Translation Vector (cm)','FontSize',14)
hold on
b2 = regress(Error*100,[Vec_l ones(size(Vec_l,1),1)]);
plot(linspace(0,max(Vec_l),10),b2(2) + b2(1).*linspace(0,max(Vec_l),10),'r','LineWidth',3)
hold off
axis square
clear b2
end
clear i ax Dist_Out_w Distances Grid_Dist_w iterations b N_grid N_distance Dist_step