-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
32 lines (25 loc) · 877 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import tensorflow as tf
import numpy as np
import json
from input import make_input_fn
from base_model import base_model
#from din_model import din_model
MAX_HIST_LEN = 500
def train(filename,batch_size,num_epochs,model_dir):
with open('dictionary.json') as f:
dictionary=json.load(f)
category=np.load('category_list.npy').tolist()
item_len=len(dictionary['product_id'])
assert item_len == len(category)
train_input_fn=make_input_fn('train.tfrecord',MAX_HIST_LEN,batch_size,num_epochs)
# din_model or base_model
estimator=tf.estimator.Estimator(base_model,model_dir=model_dir,
params={
'item_len':item_len,
'categories':category,
'num_hidden':128,
'learning_rate':0.001
})
estimator.train(train_input_fn)
if __name__=='__main__':
train('train.tfrecord',32,1,'./model_log')