-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 37185cb
Showing
34 changed files
with
2,231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import pickle | ||
import random | ||
import torch | ||
from torch.autograd import Variable | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class Autoencoder(nn.Module): | ||
def __init__(self): | ||
super(Autoencoder, self).__init__() | ||
|
||
self.ec1 = nn.Conv2d(1, 50, kernel_size=(3, 5)) | ||
self.ep1 = nn.MaxPool2d(kernel_size=(1, 2), return_indices=True) | ||
self.ec2 = nn.Conv2d(50, 100, kernel_size=(2, 5)) | ||
self.ep2 = nn.MaxPool2d(kernel_size=(1, 2), return_indices=True) | ||
self.ec3 = nn.Conv2d(100, 200, kernel_size=(1, 5)) | ||
self.ep3 = nn.MaxPool2d(kernel_size=(1, 2), return_indices=True) | ||
self.ec4 = nn.Conv2d(200, 100, kernel_size=(1, 3)) | ||
|
||
self.el1 = nn.Linear(400, 200) | ||
self.el2 = nn.Linear(200, 100) | ||
|
||
self.dl1 = nn.Linear(100, 200) | ||
self.dl2 = nn.Linear(200, 400) | ||
|
||
self.dc1 = nn.ConvTranspose2d(100, 200, kernel_size=(1, 3)) | ||
self.dp1 = nn.MaxUnpool2d(kernel_size=(1, 2)) | ||
self.dc2 = nn.ConvTranspose2d(200, 100, kernel_size=(1, 5)) | ||
self.dp2 = nn.MaxUnpool2d(kernel_size=(1, 2)) | ||
self.dc3 = nn.ConvTranspose2d(100, 50, kernel_size=(2, 5)) | ||
self.dp3 = nn.MaxUnpool2d(kernel_size=(1, 2)) | ||
self.dc4 = nn.ConvTranspose2d(50, 1, kernel_size=(3, 5)) | ||
|
||
self.train() | ||
|
||
def encode(self, readouts, return_indices=True): | ||
x = F.relu(self.ec1(readouts)) | ||
x, p1 = self.ep1(x) | ||
x = F.relu(self.ec2(x)) | ||
x, p2 = self.ep2(x) | ||
x = F.relu(self.ec3(x)) | ||
x, p3 = self.ep3(x) | ||
x = F.relu(self.ec4(x)) | ||
|
||
x = F.tanh(self.el1(x.view(-1))) | ||
x = F.tanh(self.el2(x)) | ||
|
||
if return_indices: | ||
return x, (p3, p2, p1) | ||
else: | ||
return x | ||
|
||
def decode(self, encoding, indices): | ||
x = F.tanh(self.dl1(encoding)) | ||
x = F.tanh(self.dl2(x)) | ||
|
||
x = x.view(1, 100, 2, 2) | ||
|
||
x = F.relu(self.dc1(x)) | ||
x = self.dp1(x, indices[0]) | ||
x = F.relu(self.dc2(x)) | ||
x = self.dp2(x, indices[1]) | ||
x = F.relu(self.dc3(x)) | ||
x = self.dp3(x, indices[2]) | ||
x = self.dc4(x) | ||
|
||
return x | ||
|
||
def forward(self, readouts): | ||
encoding, indices = self.encode(readouts) | ||
return self.decode(encoding, indices) | ||
|
||
|
||
def main(): | ||
net = Autoencoder().cuda() | ||
net.optim = torch.optim.Adam(net.parameters()) | ||
|
||
with open('readouts.pickle', 'rb') as f: | ||
readouts = pickle.load(f) | ||
|
||
train = [torch.stack(readouts[i - 5: i]) for i in range(5, 10000)] | ||
test = [torch.stack(readouts[i - 5: i]) for i in range(10005, len(readouts), 5)] | ||
|
||
for epoch in range(5): | ||
print('epoch %d' % epoch) | ||
|
||
random.shuffle(train) | ||
|
||
loss_f = nn.MSELoss() | ||
|
||
for i, sample in enumerate(train): | ||
net.optim.zero_grad() | ||
|
||
sample = sample.cuda().view(1, 1, 5, 60) | ||
restored = net.forward(sample) | ||
|
||
loss = loss_f(restored, sample) | ||
loss.backward() | ||
|
||
print('sample %d' % i, ', loss %f' % float(loss)) | ||
|
||
net.optim.step() | ||
|
||
torch.save(net, 'net.torch') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import ast | ||
|
||
import torch | ||
import pickle | ||
import numpy as np | ||
|
||
from simulator.core import Network, poisson_spike_train | ||
from simulator.model.custom import LiquidStateMachine | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
import pickle | ||
import pandas | ||
|
||
|
||
load = False | ||
use_cuda = True and torch.cuda.is_available() | ||
|
||
|
||
def main(): | ||
|
||
if load: | ||
net = torch.load('net.spiking_torch') | ||
lsm = net.lsm | ||
else: | ||
net = Network(log_limit=1000) | ||
lsm = LiquidStateMachine( | ||
net, "lsm", | ||
input_size=662, | ||
output_size=150, | ||
cuda=use_cuda | ||
) | ||
net.lsm = lsm | ||
|
||
vectors = pandas.read_csv("result.csv") | ||
|
||
net.time = 0 | ||
# vectors = [vec for vec in vectors if len(vec) > 3] | ||
|
||
toggled = False | ||
toggle_time = 0 | ||
readouts = [] | ||
|
||
for s, vector in vectors.iterrows(): | ||
spikes = torch.FloatTensor(poisson_spike_train(np.array(ast.literal_eval(vector[0])), 10, 50)).cuda() | ||
|
||
for i in range(spikes.shape[0]): | ||
input = spikes[i] | ||
|
||
net.step({'lsm': input}) | ||
|
||
print(net.time / 1000, "seconds of simulation") | ||
|
||
if net.time > toggle_time and not toggled: | ||
lsm.toggle_learning(False) | ||
|
||
# if int(net.time) % 1000 == 0 and net.time > 1: | ||
# group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose() | ||
# plt.matshow(group_1_activity) | ||
# plt.title('Group 1 activity') | ||
# plt.savefig('plots/lsm/%d.png' % int(net.time)) | ||
# plt.close() | ||
# | ||
# del group_1_activity | ||
|
||
if int(net.time) % 50000 == 0: | ||
torch.save(net, 'net.spiking_torch') | ||
|
||
print("step ", s, " done") | ||
|
||
if net.time >= 400 and int(net.time) % 400 == 0: | ||
readout = lsm.get_readout_vector(400) | ||
|
||
readouts.append(readout) | ||
|
||
# if 10000 >= net.time > 200: | ||
# readout = lsm.get_readout_vector(200) | ||
# half_readout = lsm.get_readout_vector(100) | ||
# avg_readout += torch.cat((readout, half_readout), dim=0) | ||
# | ||
# readouts.append(torch.cat((readout, half_readout), dim=0)) | ||
# | ||
# if net.time > 10000 and int(net.time) % 50 == 0: | ||
# readout = lsm.get_readout_vector(100) | ||
# half_readout = lsm.get_readout_vector(50) | ||
# | ||
# readouts.append(torch.cat((readout, half_readout), dim=0)) | ||
|
||
# distances.append(( | ||
# int(net.time / 25), | ||
# distance(avg_readout, torch.cat((readout, half_readout), dim=0)) | ||
# )) | ||
|
||
if int(net.time) % 50000 == 0 and net.time > 1: | ||
with open('readouts.pickle', 'wb') as f: | ||
pickle.dump(readouts, f) | ||
|
||
group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose() | ||
# plt.figure(figsize=(40, 10)) | ||
plt.matshow(group_1_activity) | ||
plt.title('Group 1 activity') | ||
plt.savefig('plots/liquid_state_machine_activity.png') | ||
plt.close() | ||
|
||
net.lsm = lsm | ||
torch.save(net, 'net.spiking_torch') | ||
|
||
with open('readouts.pickle', 'wb') as f: | ||
pickle.dump(readouts, f) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import ast | ||
|
||
import torch | ||
import pickle | ||
import numpy as np | ||
|
||
from simulator.core import Network, poisson_spike_train | ||
from simulator.model.custom import LiquidStateMachine, DimensionReductor | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
import pickle | ||
import pandas | ||
|
||
|
||
load = False | ||
use_cuda = True and torch.cuda.is_available() | ||
|
||
|
||
def main(): | ||
|
||
if load: | ||
net = torch.load('net.spiking_torch') | ||
|
||
lsm = net.lsm | ||
else: | ||
net = Network(log_limit=500) | ||
|
||
lsm = DimensionReductor( | ||
net, "lsm", | ||
input_size=2475, | ||
hidden_size=625, | ||
output_size=50, | ||
cuda=use_cuda | ||
) | ||
|
||
net.lsm = lsm | ||
|
||
vectors = pandas.read_csv("ac_mse_inputs.csv") | ||
|
||
net.time = 0 | ||
readouts = [] | ||
vectors = vectors.as_matrix() | ||
lsm.toggle_learning(False) | ||
|
||
for s in range(vectors.shape[0]): | ||
spikes = torch.cuda.FloatTensor(poisson_spike_train(vectors[s], 1.1, 100)) | ||
|
||
for i in range(spikes.shape[0]): | ||
input = spikes[i] | ||
|
||
net.step({'lsm_input': input}) | ||
|
||
print(net.time / 1000, "seconds of simulation") | ||
|
||
# if int(net.time) % 1000 == 0 and net.time > 1: | ||
# group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose() | ||
# plt.matshow(group_1_activity) | ||
# plt.title('Group 1 activity') | ||
# plt.savefig('plots/lsm/%d.png' % int(net.time)) | ||
# plt.close() | ||
# | ||
# del group_1_activity | ||
|
||
if int(net.time) % 50000 == 0: | ||
torch.save(net, 'net.spiking_torch') | ||
|
||
print("step ", s, " done") | ||
|
||
if net.time >= 200 and int(net.time) % 200 == 0: | ||
readout = lsm.get_readout_vector(200) | ||
|
||
readouts.append(readout.cpu()) | ||
|
||
# if 10000 >= net.time > 200: | ||
# readout = lsm.get_readout_vector(200) | ||
# half_readout = lsm.get_readout_vector(100) | ||
# avg_readout += torch.cat((readout, half_readout), dim=0) | ||
# | ||
# readouts.append(torch.cat((readout, half_readout), dim=0)) | ||
# | ||
# if net.time > 10000 and int(net.time) % 50 == 0: | ||
# readout = lsm.get_readout_vector(100) | ||
# half_readout = lsm.get_readout_vector(50) | ||
# | ||
# readouts.append(torch.cat((readout, half_readout), dim=0)) | ||
|
||
# distances.append(( | ||
# int(net.time / 25), | ||
# distance(avg_readout, torch.cat((readout, half_readout), dim=0)) | ||
# )) | ||
|
||
if int(net.time) % 5000 == 0 and net.time > 1: | ||
with open('readouts.pickle', 'wb') as f: | ||
pickle.dump(readouts, f) | ||
|
||
with open('readouts.pickle', 'wb') as f: | ||
pickle.dump(readouts, f) | ||
|
||
group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose() | ||
# plt.figure(figsize=(40, 10)) | ||
plt.matshow(group_1_activity) | ||
plt.title('Group 1 activity') | ||
plt.xlabel("Время") | ||
plt.ylabel("Номер нейрона") | ||
plt.savefig('plots/liquid_state_machine_activity.png') | ||
plt.close() | ||
|
||
net.lsm = lsm | ||
torch.save(net, 'net.spiking_torch') | ||
|
||
with open('readouts.pickle', 'wb') as f: | ||
pickle.dump(readouts, f) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.