Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
kraglik committed Nov 8, 2018
0 parents commit 37185cb
Show file tree
Hide file tree
Showing 34 changed files with 2,231 additions and 0 deletions.
109 changes: 109 additions & 0 deletions autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pickle
import random
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()

self.ec1 = nn.Conv2d(1, 50, kernel_size=(3, 5))
self.ep1 = nn.MaxPool2d(kernel_size=(1, 2), return_indices=True)
self.ec2 = nn.Conv2d(50, 100, kernel_size=(2, 5))
self.ep2 = nn.MaxPool2d(kernel_size=(1, 2), return_indices=True)
self.ec3 = nn.Conv2d(100, 200, kernel_size=(1, 5))
self.ep3 = nn.MaxPool2d(kernel_size=(1, 2), return_indices=True)
self.ec4 = nn.Conv2d(200, 100, kernel_size=(1, 3))

self.el1 = nn.Linear(400, 200)
self.el2 = nn.Linear(200, 100)

self.dl1 = nn.Linear(100, 200)
self.dl2 = nn.Linear(200, 400)

self.dc1 = nn.ConvTranspose2d(100, 200, kernel_size=(1, 3))
self.dp1 = nn.MaxUnpool2d(kernel_size=(1, 2))
self.dc2 = nn.ConvTranspose2d(200, 100, kernel_size=(1, 5))
self.dp2 = nn.MaxUnpool2d(kernel_size=(1, 2))
self.dc3 = nn.ConvTranspose2d(100, 50, kernel_size=(2, 5))
self.dp3 = nn.MaxUnpool2d(kernel_size=(1, 2))
self.dc4 = nn.ConvTranspose2d(50, 1, kernel_size=(3, 5))

self.train()

def encode(self, readouts, return_indices=True):
x = F.relu(self.ec1(readouts))
x, p1 = self.ep1(x)
x = F.relu(self.ec2(x))
x, p2 = self.ep2(x)
x = F.relu(self.ec3(x))
x, p3 = self.ep3(x)
x = F.relu(self.ec4(x))

x = F.tanh(self.el1(x.view(-1)))
x = F.tanh(self.el2(x))

if return_indices:
return x, (p3, p2, p1)
else:
return x

def decode(self, encoding, indices):
x = F.tanh(self.dl1(encoding))
x = F.tanh(self.dl2(x))

x = x.view(1, 100, 2, 2)

x = F.relu(self.dc1(x))
x = self.dp1(x, indices[0])
x = F.relu(self.dc2(x))
x = self.dp2(x, indices[1])
x = F.relu(self.dc3(x))
x = self.dp3(x, indices[2])
x = self.dc4(x)

return x

def forward(self, readouts):
encoding, indices = self.encode(readouts)
return self.decode(encoding, indices)


def main():
net = Autoencoder().cuda()
net.optim = torch.optim.Adam(net.parameters())

with open('readouts.pickle', 'rb') as f:
readouts = pickle.load(f)

train = [torch.stack(readouts[i - 5: i]) for i in range(5, 10000)]
test = [torch.stack(readouts[i - 5: i]) for i in range(10005, len(readouts), 5)]

for epoch in range(5):
print('epoch %d' % epoch)

random.shuffle(train)

loss_f = nn.MSELoss()

for i, sample in enumerate(train):
net.optim.zero_grad()

sample = sample.cuda().view(1, 1, 5, 60)
restored = net.forward(sample)

loss = loss_f(restored, sample)
loss.backward()

print('sample %d' % i, ', loss %f' % float(loss))

net.optim.step()

torch.save(net, 'net.torch')


if __name__ == '__main__':
main()
113 changes: 113 additions & 0 deletions liquid_state_machine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import ast

import torch
import pickle
import numpy as np

from simulator.core import Network, poisson_spike_train
from simulator.model.custom import LiquidStateMachine
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pickle
import pandas


load = False
use_cuda = True and torch.cuda.is_available()


def main():

if load:
net = torch.load('net.spiking_torch')
lsm = net.lsm
else:
net = Network(log_limit=1000)
lsm = LiquidStateMachine(
net, "lsm",
input_size=662,
output_size=150,
cuda=use_cuda
)
net.lsm = lsm

vectors = pandas.read_csv("result.csv")

net.time = 0
# vectors = [vec for vec in vectors if len(vec) > 3]

toggled = False
toggle_time = 0
readouts = []

for s, vector in vectors.iterrows():
spikes = torch.FloatTensor(poisson_spike_train(np.array(ast.literal_eval(vector[0])), 10, 50)).cuda()

for i in range(spikes.shape[0]):
input = spikes[i]

net.step({'lsm': input})

print(net.time / 1000, "seconds of simulation")

if net.time > toggle_time and not toggled:
lsm.toggle_learning(False)

# if int(net.time) % 1000 == 0 and net.time > 1:
# group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose()
# plt.matshow(group_1_activity)
# plt.title('Group 1 activity')
# plt.savefig('plots/lsm/%d.png' % int(net.time))
# plt.close()
#
# del group_1_activity

if int(net.time) % 50000 == 0:
torch.save(net, 'net.spiking_torch')

print("step ", s, " done")

if net.time >= 400 and int(net.time) % 400 == 0:
readout = lsm.get_readout_vector(400)

readouts.append(readout)

# if 10000 >= net.time > 200:
# readout = lsm.get_readout_vector(200)
# half_readout = lsm.get_readout_vector(100)
# avg_readout += torch.cat((readout, half_readout), dim=0)
#
# readouts.append(torch.cat((readout, half_readout), dim=0))
#
# if net.time > 10000 and int(net.time) % 50 == 0:
# readout = lsm.get_readout_vector(100)
# half_readout = lsm.get_readout_vector(50)
#
# readouts.append(torch.cat((readout, half_readout), dim=0))

# distances.append((
# int(net.time / 25),
# distance(avg_readout, torch.cat((readout, half_readout), dim=0))
# ))

if int(net.time) % 50000 == 0 and net.time > 1:
with open('readouts.pickle', 'wb') as f:
pickle.dump(readouts, f)

group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose()
# plt.figure(figsize=(40, 10))
plt.matshow(group_1_activity)
plt.title('Group 1 activity')
plt.savefig('plots/liquid_state_machine_activity.png')
plt.close()

net.lsm = lsm
torch.save(net, 'net.spiking_torch')

with open('readouts.pickle', 'wb') as f:
pickle.dump(readouts, f)


if __name__ == '__main__':
main()
117 changes: 117 additions & 0 deletions lsm_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import ast

import torch
import pickle
import numpy as np

from simulator.core import Network, poisson_spike_train
from simulator.model.custom import LiquidStateMachine, DimensionReductor
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pickle
import pandas


load = False
use_cuda = True and torch.cuda.is_available()


def main():

if load:
net = torch.load('net.spiking_torch')

lsm = net.lsm
else:
net = Network(log_limit=500)

lsm = DimensionReductor(
net, "lsm",
input_size=2475,
hidden_size=625,
output_size=50,
cuda=use_cuda
)

net.lsm = lsm

vectors = pandas.read_csv("ac_mse_inputs.csv")

net.time = 0
readouts = []
vectors = vectors.as_matrix()
lsm.toggle_learning(False)

for s in range(vectors.shape[0]):
spikes = torch.cuda.FloatTensor(poisson_spike_train(vectors[s], 1.1, 100))

for i in range(spikes.shape[0]):
input = spikes[i]

net.step({'lsm_input': input})

print(net.time / 1000, "seconds of simulation")

# if int(net.time) % 1000 == 0 and net.time > 1:
# group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose()
# plt.matshow(group_1_activity)
# plt.title('Group 1 activity')
# plt.savefig('plots/lsm/%d.png' % int(net.time))
# plt.close()
#
# del group_1_activity

if int(net.time) % 50000 == 0:
torch.save(net, 'net.spiking_torch')

print("step ", s, " done")

if net.time >= 200 and int(net.time) % 200 == 0:
readout = lsm.get_readout_vector(200)

readouts.append(readout.cpu())

# if 10000 >= net.time > 200:
# readout = lsm.get_readout_vector(200)
# half_readout = lsm.get_readout_vector(100)
# avg_readout += torch.cat((readout, half_readout), dim=0)
#
# readouts.append(torch.cat((readout, half_readout), dim=0))
#
# if net.time > 10000 and int(net.time) % 50 == 0:
# readout = lsm.get_readout_vector(100)
# half_readout = lsm.get_readout_vector(50)
#
# readouts.append(torch.cat((readout, half_readout), dim=0))

# distances.append((
# int(net.time / 25),
# distance(avg_readout, torch.cat((readout, half_readout), dim=0))
# ))

if int(net.time) % 5000 == 0 and net.time > 1:
with open('readouts.pickle', 'wb') as f:
pickle.dump(readouts, f)

with open('readouts.pickle', 'wb') as f:
pickle.dump(readouts, f)

group_1_activity = torch.stack(lsm.get_cache()).cpu().numpy().transpose()
# plt.figure(figsize=(40, 10))
plt.matshow(group_1_activity)
plt.title('Group 1 activity')
plt.xlabel("Время")
plt.ylabel("Номер нейрона")
plt.savefig('plots/liquid_state_machine_activity.png')
plt.close()

net.lsm = lsm
torch.save(net, 'net.spiking_torch')

with open('readouts.pickle', 'wb') as f:
pickle.dump(readouts, f)


if __name__ == '__main__':
main()
Loading

0 comments on commit 37185cb

Please sign in to comment.