-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
117 lines (83 loc) · 3.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import os
import sys
import numpy as np
import tensorflow as tf
from dataset import UserInteractionDataset, load_dataset, load_item_embeddings
from framework.logger import TensorBoardLogger
from model import PreferenceElicitationModel
def create_feed_map(
batch_item_indices: np.ndarray,
batch_targets: np.ndarray,
model: PreferenceElicitationModel):
"""
:param batch_item_indices:
:param batch_targets:
:param model:
:return:
"""
return {
model.item_indices: batch_item_indices,
model.targets: batch_targets,
}
def train(
model: PreferenceElicitationModel,
dataset: UserInteractionDataset,
path: str,
batch_size: int = 64,
num_epochs: int = 100,
learning_rate: float = 1e-3,
momentum: float = 0.99,
early_stopping_threshold: int = 30) -> None:
"""
:param model:
:param dataset:
:param path:
:param batch_size:
:param num_epochs:
:param learning_rate:
:param momentum:
:param early_stopping_threshold:
:return:
"""
optimizer = model.create_optimizer(learning_rate, momentum)
saver = tf.train.Saver()
with tf.Session() as session, TensorBoardLogger(session, path) as logger:
session.run(tf.global_variables_initializer())
min_validation_loss, non_improvement_times = np.inf, 0
for epoch in range(num_epochs):
for inputs, targets in dataset.get_train_batches(batch_size):
loss_, _ = session.run([model.loss, optimizer], feed_dict=create_feed_map(inputs, targets, model))
print(f'\rEpoch: {epoch}, loss: {loss_}', sep=' ', end='')
sys.stdout.flush()
logger.add_training_metadata(loss_)
for inputs, targets in dataset.get_validation_batches(batch_size):
loss_ = session.run(model.loss, feed_dict=create_feed_map(inputs, targets, model))
logger.add_validation_metadata(loss_)
validation_loss = logger.get_validation_cost()
logger.flush()
print(f'\nValidation set epoch: {epoch}, loss: {validation_loss}')
if validation_loss < min_validation_loss:
min_validation_loss = validation_loss
non_improvement_times = 0
saver.save(session, os.path.join(path, 'preference_elicitation.ckpt'))
elif non_improvement_times < early_stopping_threshold:
non_improvement_times += 1
else:
print('Stopping after no improvement.')
return
def main(path: str, train_ratio: float = 0.9):
"""
:param path:
:param train_ratio:
:return:
"""
dataset = load_dataset(path, train_ratio)
item_embeddings = load_item_embeddings(path, dataset.num_items)
model = PreferenceElicitationModel(item_embeddings, episode_length=dataset.num_items_per_user)
train(model, dataset, path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str)
args = parser.parse_args()
main(args.path)