Skip to content

Commit e931756

Browse files
committed
added eMouse simulation. Fixed channel mappings for Phy.
1 parent d96b928 commit e931756

14 files changed

+517
-20
lines changed

eMouse/benchmark_simulation.m

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
function benchmark_simulation(rez, GTfilepath)
2+
3+
load(GTfilepath)
4+
5+
try
6+
testClu = 1 + rez.st3(:,5) ; % if the auto merges were performed
7+
flag = 1;
8+
catch
9+
testClu = rez.st3(:,2) ;% no attempt to merge clusters
10+
flag = 0;
11+
end
12+
13+
testRes = rez.st3(:,1) ;
14+
15+
[allScores, allFPrates, allMissRates, allMerges] = ...
16+
compareClustering2(gtClu, gtRes, testClu, testRes, []);
17+
18+
%
19+
clid = unique(gtClu);
20+
clear gtimes
21+
for k = 1:length(clid)
22+
gtimes{k} = double(gtRes(gtClu==clid(k)));
23+
end
24+
%%
25+
26+
figure
27+
28+
plot(sort(cellfun(@(x) x(1), allFPrates)), '-*b', 'Linewidth', 2)
29+
hold all
30+
plot(sort(cellfun(@(x) x(1), allMissRates)), '-*r', 'Linewidth', 2)
31+
plot(sort(cellfun(@(x) x(end), allFPrates)), 'b', 'Linewidth', 2)
32+
plot(sort(cellfun(@(x) x(end), allMissRates)), 'r', 'Linewidth', 2)
33+
34+
box off
35+
36+
finalScores = cellfun(@(x) x(end), allScores);
37+
fprintf('%d / %d good cells, score > 0.8 (pre-merge) \n', sum(cellfun(@(x) x(1), allScores)>.8), numel(allScores))
38+
fprintf('%d / %d good cells, score > 0.8 (post-merge) \n', sum(cellfun(@(x) x(end), allScores)>.8), numel(allScores))
39+
40+
nMerges = cellfun(@(x) numel(x)-1, allMerges);
41+
fprintf('Mean merges per good cell %2.2f \n', mean(nMerges(finalScores>.8)))
42+
43+
% disp(cellfun(@(x) x(end), allScores))
44+
45+
xlabel('ground truth cluster')
46+
ylabel('fractional error')
47+
48+
legend('false positives (initial)', 'miss rates (initial)', 'false positives (best)', 'miss rates (best)')
49+
legend boxoff
50+
set(gca, 'Fontsize', 20)
51+
set(gcf, 'Color', 'w')
52+
53+
if flag==1
54+
title('After Kilosort AUTO merges')
55+
else
56+
title('Before Kilosort AUTO merges')
57+
end

eMouse/compareClustering2.m

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
2+
3+
function [allScores, allFPs, allMisses, allMerges] = compareClustering2(cluGT, resGT, cluTest, resTest, datFilename)
4+
% function compareClustering(cluGT, resGT, cluTest, resTest[, datFilename])
5+
% - clu and res variables are length nSpikes, for ground truth (GT) and for
6+
% the clustering to be evaluated (Test).
7+
8+
9+
if nargin<5
10+
datFilename = [];
11+
end
12+
13+
GTcluIDs = unique(cluGT);
14+
testCluIDs = unique(cluTest);
15+
jitter = 12;
16+
17+
nSp = zeros(max(testCluIDs), 1);
18+
for j = 1:max(testCluIDs);
19+
nSp(j) = max(1, sum(cluTest==j));
20+
end
21+
nSp0 = nSp;
22+
23+
for cGT = 1:length(GTcluIDs)
24+
% fprintf(1,'ground truth cluster ID = %d (%d spikes)\n', GTcluIDs(cGT), sum(cluGT==GTcluIDs(cGT)));
25+
26+
rGT = int32(resGT(cluGT==GTcluIDs(cGT)));
27+
28+
% S = sparse(numel(rGT), max(testCluIDs));
29+
S = spalloc(numel(rGT), max(testCluIDs), numel(rGT) * 10);
30+
% find the initial best match
31+
mergeIDs = [];
32+
scores = [];
33+
falsePos = [];
34+
missRate = [];
35+
36+
igt = 1;
37+
38+
nSp = nSp0;
39+
nrGT = numel(rGT);
40+
flag = false;
41+
for j = 1:numel(cluTest)
42+
while (resTest(j) > rGT(igt) + jitter)
43+
% the curent spikes is now too large compared to GT, advance the GT
44+
igt = igt + 1;
45+
if igt>nrGT
46+
flag = true;
47+
break;
48+
end
49+
end
50+
if flag
51+
break;
52+
end
53+
54+
if resTest(j)>rGT(igt)-jitter
55+
% we found a match, add a tick to the right cluster
56+
% numMatch(cluTest(j)) = numMatch(cluTest(j)) + 1;
57+
S(igt, cluTest(j)) = 1;
58+
end
59+
end
60+
numMatch = sum(S,1)';
61+
misses = (nrGT-numMatch)/nrGT; % missed these spikes, as a proportion of the total true spikes
62+
fps = (nSp-numMatch)./nSp; % number of comparison spikes not near a GT spike, as a proportion of the number of guesses
63+
%
64+
% for cTest = 1:length(testCluIDs)
65+
% rTest = int32(resTest(cluTest==testCluIDs(cTest)));
66+
%
67+
% [miss, fp] = compareSpikeTimes(rTest, rGT);
68+
% misses(cTest) = miss;
69+
% fps(cTest) = fp;
70+
%
71+
% end
72+
%
73+
sc = 1-(fps+misses);
74+
best = find(sc==max(sc),1);
75+
mergeIDs(end+1) = best;
76+
scores(end+1) = sc(best);
77+
falsePos(end+1) = fps(best);
78+
missRate(end+1) = misses(best);
79+
80+
% fprintf(1, ' found initial best %d: score %.2f (%d spikes, %.2f FP, %.2f miss)\n', ...
81+
% mergeIDs(1), scores(1), sum(cluTest==mergeIDs(1)), fps(best), misses(best));
82+
83+
S0 = S(:, best);
84+
nSp = nSp + nSp0(best);
85+
while scores(end)>0 && (length(scores)==1 || ( scores(end)>(scores(end-1) + 1*0.01) && scores(end)<=0.99 ))
86+
% find the best match
87+
S = bsxfun(@max, S, S0);
88+
89+
numMatch = sum(S,1)';
90+
misses = (nrGT-numMatch)/nrGT; % missed these spikes, as a proportion of the total true spikes
91+
fps = (nSp-numMatch)./nSp; % number of comparison spikes not near a GT spike, as a proportion of the number of guesses
92+
93+
sc = 1-(fps+misses);
94+
best = find(sc==max(sc),1);
95+
mergeIDs(end+1) = best;
96+
scores(end+1) = sc(best);
97+
falsePos(end+1) = fps(best);
98+
missRate(end+1) = misses(best);
99+
100+
% fprintf(1, ' best merge with %d: score %.2f (%d/%d new/total spikes, %.2f FP, %.2f miss)\n', ...
101+
% mergeIDs(end), scores(end), nSp0(best), nSp(best), fps(best), misses(best));
102+
103+
S0 = S(:, best);
104+
nSp = nSp + nSp0(best);
105+
106+
end
107+
108+
if length(scores)==1 || scores(end)>(scores(end-1)+0.01)
109+
% the last merge did help, so include it
110+
allMerges{cGT} = mergeIDs(1:end);
111+
allScores{cGT} = scores(1:end);
112+
allFPs{cGT} = falsePos(1:end);
113+
allMisses{cGT} = missRate(1:end);
114+
else
115+
% the last merge actually didn't help (or didn't help enough), so
116+
% exclude it
117+
allMerges{cGT} = mergeIDs(1:end-1);
118+
allScores{cGT} = scores(1:end-1);
119+
allFPs{cGT} = falsePos(1:end-1);
120+
allMisses{cGT} = missRate(1:end-1);
121+
end
122+
123+
end
124+
125+
initScore = zeros(1, length(GTcluIDs));
126+
finalScore = zeros(1, length(GTcluIDs));
127+
numMerges = zeros(1, length(GTcluIDs));
128+
fprintf(1, '\n\n--Results Summary--\n')
129+
for cGT = 1:length(GTcluIDs)
130+
%
131+
% fprintf(1,'ground truth cluster ID = %d (%d spikes)\n', GTcluIDs(cGT), sum(cluGT==GTcluIDs(cGT)));
132+
% fprintf(1,' initial score: %.2f\n', allScores{cGT}(1));
133+
% fprintf(1,' best score: %.2f (after %d merges)\n', allScores{cGT}(end), length(allScores{cGT})-1);
134+
%
135+
initScore(cGT) = allScores{cGT}(1);
136+
finalScore(cGT) = allScores{cGT}(end);
137+
numMerges(cGT) = length(allScores{cGT})-1;
138+
end
139+
140+
fprintf(1, 'median initial score: %.2f; median best score: %.2f\n', median(initScore), median(finalScore));
141+
fprintf(1, 'total merges required: %d\n', sum(numMerges));

eMouse/config_eMouse.m

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
ops.GPU = useGPU; % whether to run this code on an Nvidia GPU (much faster, mexGPUall first)
2+
ops.parfor = 0; % whether to use parfor to accelerate some parts of the algorithm
3+
ops.verbose = 1; % whether to print command line progress
4+
ops.showfigures = 1; % whether to plot figures during optimization
5+
6+
ops.datatype = 'dat'; % binary ('dat', 'bin') or 'openEphys'
7+
ops.fbinary = fullfile(fpath, 'sim_binary.dat'); % will be created for 'openEphys'
8+
ops.fproc = fullfile(fpath, 'temp_wh.dat'); % residual from RAM of preprocessed data
9+
ops.root = fpath; % 'openEphys' only: where raw files are
10+
% define the channel map as a filename (string) or simply an array
11+
ops.chanMap = fullfile(fpath, 'chanMap.mat'); % make this file using createChannelMapFile.m
12+
% ops.chanMap = 1:ops.Nchan; % treated as linear probe if unavailable chanMap file
13+
14+
ops.fs = 25000; % sampling rate
15+
ops.NchanTOT = 34; % total number of channels
16+
ops.Nchan = 32; % number of active channels
17+
ops.Nfilt = 64; % number of filters to use (2-4 times more than Nchan, should be a multiple of 32)
18+
ops.nNeighPC = 12; % visualization only (Phy): number of channnels to mask the PCs, leave empty to skip (12)
19+
ops.nNeigh = 16; % visualization only (Phy): number of neighboring templates to retain projections of (16)
20+
21+
% options for channel whitening
22+
ops.whitening = 'full'; % type of whitening (default 'full', for 'noSpikes' set options for spike detection below)
23+
ops.nSkipCov = 1; % compute whitening matrix from every N-th batch (1)
24+
ops.whiteningRange = 32; % how many channels to whiten together (Inf for whole probe whitening, should be fine if Nchan<=32)
25+
26+
ops.criterionNoiseChannels = 0.2; % fraction of "noise" templates allowed to span all channel groups (see createChannelMapFile for more info).
27+
28+
% other options for controlling the model and optimization
29+
ops.Nrank = 3; % matrix rank of spike template model (3)
30+
ops.nfullpasses = 6; % number of complete passes through data during optimization (6)
31+
ops.maxFR = 20000; % maximum number of spikes to extract per batch (20000)
32+
ops.fshigh = 200; % frequency for high pass filtering
33+
% ops.fslow = 2000; % frequency for low pass filtering (optional)
34+
ops.ntbuff = 64; % samples of symmetrical buffer for whitening and spike detection
35+
ops.scaleproc = 200; % int16 scaling of whitened data
36+
ops.NT = 128*1024+ ops.ntbuff;% this is the batch size (try decreasing if out of memory)
37+
% for GPU should be multiple of 32 + ntbuff
38+
39+
% the following options can improve/deteriorate results.
40+
% when multiple values are provided for an option, the first two are beginning and ending anneal values,
41+
% the third is the value used in the final pass.
42+
ops.Th = [4 10 10]; % threshold for detecting spikes on template-filtered data ([6 12 12])
43+
ops.lam = [5 5 5]; % large means amplitudes are forced around the mean ([10 30 30])
44+
ops.nannealpasses = 4; % should be less than nfullpasses (4)
45+
ops.momentum = 1./[20 400]; % start with high momentum and anneal (1./[20 1000])
46+
ops.shuffle_clusters = 1; % allow merges and splits during optimization (1)
47+
ops.mergeT = .1; % upper threshold for merging (.1)
48+
ops.splitT = .1; % lower threshold for splitting (.1)
49+
50+
% options for initializing spikes from data
51+
ops.initialize = 'no'; %'fromData' or 'no'
52+
ops.spkTh = -6; % spike threshold in standard deviations (4)
53+
ops.loc_range = [3 1]; % ranges to detect peaks; plus/minus in time and channel ([3 1])
54+
ops.long_range = [30 6]; % ranges to detect isolated peaks ([30 6])
55+
ops.maskMaxChannels = 5; % how many channels to mask up/down ([5])
56+
ops.crit = .65; % upper criterion for discarding spike repeates (0.65)
57+
ops.nFiltMax = 10000; % maximum "unique" spikes to consider (10000)
58+
59+
% load predefined principal components (visualization only (Phy): used for features)
60+
dd = load('PCspikes2.mat'); % you might want to recompute this from your own data
61+
ops.wPCA = dd.Wi(:,1:7); % PCs
62+
63+
% options for posthoc merges (under construction)
64+
ops.fracse = 0.1; % binning step along discriminant axis for posthoc merges (in units of sd)
65+
ops.epu = Inf;
66+
67+
ops.ForceMaxRAMforDat = 20e9; % maximum RAM the algorithm will try to use; on Windows it will autodetect.

eMouse/make_eMouseChannelMap.m

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
function make_eMouseChannelMap(fpath)
2+
% create a channel Map file for simulated data (eMouse)
3+
4+
% here I know a priori what order my channels are in. So I just manually
5+
% make a list of channel indices (and give
6+
% an index to dead channels too). chanMap(1) is the row in the raw binary
7+
% file for the first channel. chanMap(1:2) = [33 34] in my case, which happen to
8+
% be dead channels.
9+
10+
chanMap = [33 34 8 10 12 14 16 18 20 22 24 26 28 30 32 ...
11+
7 9 11 13 15 17 19 21 23 25 27 29 31 1 2 3 4 5 6];
12+
13+
% the first thing Kilosort does is reorder the data with data = data(chanMap, :).
14+
% Now we declare which channels are "connected" in this normal ordering,
15+
% meaning not dead or used for non-ephys data
16+
17+
connected = true(34, 1); connected(1:2) = 0;
18+
19+
% now we define the horizontal (x) and vertical (y) coordinates of these
20+
% 34 channels. For dead or nonephys channels the values won't matter. Again
21+
% I will take this information from the specifications of the probe. These
22+
% are in um here, but the absolute scaling doesn't really matter in the
23+
% algorithm.
24+
25+
xcoords = 20 * [NaN NaN 1 0 0 1 0 1 0 1 0 1 0 1 0 1 0 1 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0];
26+
ycoords = 20 * [NaN NaN 7 8 9 9 10 10 11 11 12 12 13 13 14 14 15 15 16 ...
27+
17 17 18 18 19 19 20 20 21 21 22 22 23 23 24];
28+
29+
% Often, multi-shank probes or tetrodes will be organized into groups of
30+
% channels that cannot possibly share spikes with the rest of the probe. This helps
31+
% the algorithm discard noisy templates shared across groups. In
32+
% this case, we set kcoords to indicate which group the channel belongs to.
33+
% In our case all channels are on the same shank in a single group so we
34+
% assign them all to group 1.
35+
36+
kcoords = [NaN NaN 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1];
37+
38+
% at this point in Kilosort we do data = data(connected, :), ycoords =
39+
% ycoords(connected), xcoords = xcoords(connected) and kcoords =
40+
% kcoords(connected) and no more channel map information is needed (in particular
41+
% no "adjacency graphs" like in KlustaKwik).
42+
% Now we can save our channel map for the eMouse.
43+
44+
save(fullfile(fpath, 'chanMap.mat'), 'chanMap', 'connected', 'xcoords', 'ycoords', 'kcoords')

0 commit comments

Comments
 (0)