-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathself_play.py
158 lines (142 loc) · 5.75 KB
/
self_play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from concurrent.futures import ProcessPoolExecutor
import random
import nltk
from time import time, sleep
import _pickle as cPickle
from env import TextEnv
from player import Agent
from api import APIServer
import numpy as np
import h5py
from filelock import FileLock
lock = FileLock("my_lock")
class Utils:
def __init__(self, config): self.config = config
self.translate, _ = cPickle.load(open('save/vocab_cotra.pkl', 'rb'))
# TODO: do something to protect this writing
def translate_store(self, iter, sentence):
with open(self.config.sample_file, "a")as f:
f.write('iter: %d\n' % iter)
f.write(
' '.join([self.translate[item] if item != self.config.start_token else '' for item in sentence]) + '\n')
f.write('\n')
def read_game(self):
while True:
try:
with lock:
with h5py.File("buffer", "r") as f:
cur_row = f['/cur_row'][0]
cur_row_tmp = cur_row % self.config.buffer_size
s = f['/s'][cur_row_tmp, :self.config.max_length-1]
with h5py.File("buffer", "a") as f:
f['/cur_row'][0] = cur_row + 1
return s, cur_row
except KeyboardInterrupt:
print("Caught KeyboardInterrupt")
break
except:
sleep(0.00001)
def save_game(self, s, pi, iter, cur_row):
pi += [[-1] * self.config.simulation_num_per_move]
pi = np.array(pi) # [bs, length, sims] for mult, [bs, sims] for single
#TODO: padding for s is 0, which is problematic for non-progressive mode
# writes (s,pi,z)'s to the buffer
# t1 = time()
while True:
try:
with lock:
with h5py.File("buffer", "a") as f:
s_buffer = f['/s']
pi_buffer = f['/pi']
# z_buffer = f['/z']
cur_row_tmp = cur_row % self.config.buffer_size
s_buffer[cur_row_tmp, self.config.max_length-1] = s
pi_buffer[cur_row_tmp, self.config.max_length-2:self.config.max_length,...] = pi
# z_buffer[cur_row_tmo, ...] = z
if (iter + 1) % 100 == 0:
with h5py.File("buffer", "r") as f:
s = f['/s'][cur_row_tmp, :self.config.max_length]
s = [int(elm) for elm in s]
break
except KeyboardInterrupt:
print("Caught KeyboardInterrupt")
break
except:
sleep(0.00001)
# print("iteration: %s seconds" % (time() - start_time))
if (iter + 1) % 100 == 0:
self.translate_store(iter, s) # print #(self.config.batch_size//2) of generated sentences
def start(config, process_ind):
api_server = APIServer(config)
utils = Utils(config)
process_num = config.multi_process_num
api_server.start_serve()
with ProcessPoolExecutor(max_workers=process_num) as executor:
try:
futures = []
for i in range(process_num):
play_worker = SelfPlayWorker(config, api=api_server.get_api_client(),
utils=utils)
futures.append(executor.submit(play_worker.start))
except KeyboardInterrupt:
print("Caught KeyboardInterrupt, terminating workers")
executor.terminate()
executor.join()
class SelfPlayWorker:
def __init__(self, config, api, utils):
"""
:param config:
:param TextEnv|None env:
:param APIServer|None api:
"""
self.config = config
self.env = None
self.agent = None
self.api = api
self.false_positive_count_of_resign = 0
self.resign_test_game_count = 0
self.utils = utils
def start(self):
game_idx = 0
while True:
try:
# play game
t = time()
period = False
s, self.cur_row = self.utils.read_game()
s_tmp = []
for i, elm in enumerate(s):
if elm != self.config.period_token and elm != self.config.blank_token:
s_tmp.append(int(elm))
else:
period = True
break
if not period:
self.ini_state = s_tmp
self.start_game()
s = self.env.string
pi = [row[1] for row in self.agent.moves]
v = [row[2] for row in self.agent.moves]
length = len(s)
self.utils.save_game(s[-1], pi, game_idx, self.cur_row)
print(game_idx, " done: %.3f" % (time() - t), length, v)
game_idx += 1
except KeyboardInterrupt:
print("Caught KeyboardInterrupt")
break
def start_game(self):
# enable_resign = self.config.disable_resignation_rate <= random()
self.agent = Agent(self.config, api=self.api)
self.env = TextEnv(self.config)
state = self.env.state()
# game loop
while not self.env.done:
action = self.agent.action(state)
if action is None:
break
self.env.add(action)
state = self.env.state()
#policy_compressed = [-1] * self.config.simulation_num_per_move
#self.agent.moves.append([self.env.string, policy_compressed])
#self.agent.finish_game(-1)
# self.finish_game(resign_enabled=enable_resign)