-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsave_progress.py
55 lines (37 loc) · 1.23 KB
/
save_progress.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from ppo import PPO
from params import config
def save_progress(training_log, validation_log, validation_result, record, model: PPO):
path = config.progress_config.path_to_save_progress
validation_log.append(validation_result)
saved = {
"model_state_dict": model.policy.state_dict(),
"optimizer_state_dict": model.optimizer.state_dict(),
"validation_log": validation_log,
"training_log": training_log,
"best_record": record
}
with open(f'{path}/validation_log.txt', 'w') as logfile:
logfile.write(str(validation_log))
with open(f'{path}/training_log.txt', 'w') as logfile:
logfile.write(str(training_log))
if validation_result < record:
saved['best_record'] = validation_result
torch.save(model.policy.state_dict(), f'{path}/best_weight.pth')
if config.progress_config.save_training:
torch.save(saved, f'{path}/saved.pth')
'''
Notes:
The folder name will be MK01 / MK02
The subfolder name you can determine yourself
keep a fixed format
Suggestion:
<Experiment_ID>
Things to save
1. a serialized config object in the folder.
2. training log
3. validation log
4. best_weight
5. last_weight
6. last_optimizer_weights
'''