-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdrive_model_and_save_outputs.py
98 lines (79 loc) · 4.28 KB
/
drive_model_and_save_outputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import yaml
import torch
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import find_peaks, butter, filtfilt
from zrnn.models import ZemlianovaRNN
from zrnn.datasets import generate_stimuli
N_RANDOM_PERIODS = 100
def drive_model_and_save_outputs(model, periods, device, time_steps=52000, discard_steps=2000, save_dir='model_outputs'):
# Create directory to save outputs if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Generate and process inputs
for period in periods:
t, I_stim, I_cc, z_t = generate_stimuli(period, 0.05, duration=52.0, I_stim_without_clicks=True)
input_signal = np.stack([I_stim, I_cc], axis=1)
# Convert to tensor and send to device
input_tensor = torch.tensor(input_signal, dtype=torch.float32).unsqueeze(0).to(device)
hidden = model.initHidden(1).to(device)
# Process the input through the model with no_grad for evaluation
outputs = []
hidden_states = []
with torch.no_grad():
for t in range(time_steps):
output, hidden = model(input_tensor[:, t, :], hidden)
outputs.append(output.detach().cpu().numpy())
if t >= discard_steps:
hidden_states.append(hidden.cpu().numpy())
outputs = np.array(outputs).squeeze()
valid_outputs = outputs[discard_steps:] # Discard the first 2000 steps
# Apply bandpass filter
fs = 1 / 0.001 # Sampling frequency (1000 Hz, since dt = 0.001 s)
low = 1 / (period + 0.1) # Low frequency of the bandpass filter
high = 1 / (period - 0.1) # High frequency of the bandpass filter
b, a = butter(N=2, Wn=[low, high], btype='band', fs=fs)
filtered_outputs = filtfilt(b, a, valid_outputs)
# Find peaks in the filtered outputs
peaks, _ = find_peaks(filtered_outputs, height=0)
# Determine window size for searching the highest peak in the original signal
window_size = int((period / 2) / 0.001) # Half period in terms of samples
# Find the highest peak in the original signal near each detected peak in the filtered output
true_peaks = []
for peak in peaks:
start = max(0, peak - window_size)
end = min(len(valid_outputs), peak + window_size)
true_peak = np.argmax(valid_outputs[start:end]) + start
true_peaks.append(true_peak)
# Plotting
plt.figure(figsize=(25, 5))
plt.plot(valid_outputs, label='Raw Output for Period: ' + str(period) + 's')
plt.plot(true_peaks, valid_outputs[true_peaks], "x", label='Highest Peaks in Raw Output')
plt.xlabel('Time Steps (after discard)')
plt.ylabel('Model Activity')
plt.title('Model Output for Period ' + str(period) + 's')
plt.legend()
plt.savefig(os.path.join(save_dir, f'plot_period_{period}.png')) # Save plot
plt.close()
# Save data to files
np.save(os.path.join(save_dir, f'valid_output_{period}.npy'), valid_outputs)
np.save(os.path.join(save_dir, f'peaks_{period}.npy'), true_peaks)
np.save(os.path.join(save_dir, f'hidden_states_{period}.npy'), np.array(hidden_states))
print(f'finished driving model and saving activity for period {period}')
def sample_exponential_skew(num_samples, lam=5):
u = np.random.rand(num_samples)
skewed_samples = np.exp(-lam * (1 - u))
return skewed_samples
def main(config_path='config.yaml', model_type=None):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)
model = ZemlianovaRNN(config['model']['input_dim'], config['model']['hidden_dim'], config['model']['output_dim'], config['model']['dt'], config['model']['tau'], config['model']['excit_percent'], sigma_rec=config['model']['sigma_rec']).to(device)
model.load_state_dict(torch.load(config['training']['save_path'], map_location=device))
model.eval()
# ADD a few extra randomly-generated periods
drive_model_and_save_outputs(model, config['training']['periods']+[round(v+0.1,3) for v in list(sample_exponential_skew(N_RANDOM_PERIODS))], device)
if __name__ == "__main__":
import fire
fire.Fire(main)