-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
127 lines (98 loc) · 4.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
import yaml
import os
from zrnn.models import ZemlianovaRNN, RNN
from zrnn.datasets import PulseStimuliDataset
def train(model, dataloader, optimizer, criterion, config, device):
model.train() # Set the model to training mode
min_loss = float('inf') # Initialize the minimum loss to a large value
for epoch in range(config['training']['epochs']):
total_loss = 0
for inputs, targets, _ in dataloader:
inputs, targets = inputs.to(device), targets.to(device) # Move data to GPU
batch_size = inputs.shape[0]
hidden = model.initHidden(batch_size).to(device)
outputs = []
for t in range(inputs.size(1)): # process each time step
output, hidden = model(inputs[:, t], hidden)
outputs.append(output)
outputs = torch.stack(outputs, dim=1)
loss = criterion(outputs[..., 0], targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{config['training']['epochs']}, Loss: {avg_loss}")
if avg_loss < min_loss:
min_loss = avg_loss
torch.save(model.state_dict(), config['training']['save_path'])
print(f"Model saved with improvement at epoch {epoch+1} with loss {min_loss}")
if avg_loss <= config['training']['early_stopping_loss']:
print("Early stopping threshold reached.")
break
return model
def plot_results(model, dataloader, config, device):
if not config['plotting']['enable']:
return
print('plotting some examples ...')
# Ensure directory for plots exists
plot_dir = 'plots'
os.makedirs(plot_dir, exist_ok=True)
model.eval()
with torch.no_grad():
for inputs, targets, periods in dataloader:
# Ensure we're handling the device placement
inputs, targets = inputs.to(device), targets.to(device)
# Initialize hidden states for the batch
hidden = model.initHidden(inputs.size(0)).to(device)
outputs = []
# Process each timestep in the sequence
for t in range(inputs.size(1)):
output, hidden = model(inputs[:, t, :], hidden)
outputs.append(output)
outputs = torch.stack(outputs, dim=1) # [batch_size, seq_len, features]
outputs = outputs.squeeze(-1).cpu().numpy() # Simplify if output features == 1
targets = targets.squeeze(-1).cpu().numpy()
inputs = inputs.squeeze(-1).cpu().numpy()
plotted_periods = set()
# Iterate over each sequence in the batch
for i, period in enumerate(periods.cpu().numpy()):
# Plot each period only once if they are the same in a batch
if period in plotted_periods:
continue
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(outputs[i], label='Predictions')
ax.plot(targets[i], label='Targets')
ax.plot(inputs[i], label='Inputs')
ax.legend()
ax.set_title(f"Responses for Period {period:.3f} Seconds")
plt.savefig(os.path.join(plot_dir, f"Period_{period:.3f}_seconds.png"))
plt.close(fig)
plotted_periods.add(period)
print('done plotting')
def main(config_path='config.yaml', model_type=None):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
if model_type:
# Overriding model type from the command line
config['model']['type'] = model_type
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Training with device:', device)
dataset = PulseStimuliDataset(config['training']['periods'], size=config['training']['batch_size'], dt=config['model']['dt'])
dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True, num_workers=0)
if config['model']['type'] == "RNN":
model = RNN(config['model']['input_dim'], config['model']['hidden_dim'], config['model']['output_dim'], config['model']['dt'], config['model']['tau']).to(device)
else:
model = ZemlianovaRNN(config['model']['input_dim'], config['model']['hidden_dim'], config['model']['output_dim'], config['model']['dt'], config['model']['tau'], config['model']['excit_percent'], sigma_rec=config['model']['sigma_rec']).to(device)
optimizer = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])
criterion = nn.MSELoss()
model = train(model, dataloader, optimizer, criterion, config, device)
plot_results(model, dataloader, config, device)
if __name__ == "__main__":
import fire
fire.Fire(main)