Skip to content

Commit f9063db

Browse files
committed
Update to add global norm
1 parent 2d37935 commit f9063db

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

controller.py

+17
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def __init__(self, policy_session, num_layers, state_space,
187187
exploration=0.8,
188188
controller_cells=32,
189189
embedding_dim=20,
190+
clip_norm=0.0,
190191
restore_controller=False):
191192
self.policy_session = policy_session # type: tf.Session
192193

@@ -200,6 +201,7 @@ def __init__(self, policy_session, num_layers, state_space,
200201
self.discount_factor = discount_factor
201202
self.exploration = exploration
202203
self.restore_controller = restore_controller
204+
self.clip_norm = clip_norm
203205

204206
self.reward_buffer = []
205207
self.state_buffer = []
@@ -372,7 +374,15 @@ def build_policy_network(self):
372374
tf.summary.scalar('total_loss', self.total_loss)
373375

374376
self.gradients = self.optimizer.compute_gradients(self.total_loss)
377+
375378
with tf.name_scope('policy_gradients'):
379+
# normalize gradients so that they dont explode if argument passed
380+
if self.clip_norm is not None and self.clip_norm != 0.0:
381+
norm = tf.constant(self.clip_norm, dtype=tf.float32)
382+
gradients, vars = zip(*self.gradients) # unpack the two lists of gradients and the variables
383+
gradients, _ = tf.clip_by_global_norm(gradients, norm) # clip by the norm
384+
self.gradients = list(zip(gradients, vars)) # we need to set values later, convert to list
385+
376386
# compute policy gradients
377387
for i, (grad, var) in enumerate(self.gradients):
378388
if grad is not None:
@@ -489,3 +499,10 @@ def train_step(self):
489499
self.exploration *= 0.99
490500

491501
return loss
502+
503+
def remove_files(self):
504+
files = ['train_history.csv', 'buffers.txt']
505+
506+
for file in files:
507+
if os.path.exists(file):
508+
os.remove(file)

manager.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class NetworkManager:
99
'''
1010
Helper class to manage the generation of subnetwork training given a dataset
1111
'''
12-
def __init__(self, dataset, epochs=5, child_batchsize=128, acc_beta=0.8, clip_rewards=False):
12+
def __init__(self, dataset, epochs=5, child_batchsize=128, acc_beta=0.8, clip_rewards=0.0):
1313
'''
1414
Manager which is tasked with creating subnetworks, training them on a dataset, and retrieving
1515
rewards in the term of accuracy, which is passed to the controller RNN.
@@ -19,7 +19,7 @@ def __init__(self, dataset, epochs=5, child_batchsize=128, acc_beta=0.8, clip_re
1919
epochs: number of epochs to train the subnetworks
2020
child_batchsize: batchsize of training the subnetworks
2121
acc_beta: exponential weight for the accuracy
22-
clip_rewards: whether to clip rewards in [-0.05, 0.05] range to prevent
22+
clip_rewards: float - to clip rewards in [-range, range] to prevent
2323
large weight updates. Use when training is highly unstable.
2424
'''
2525
self.dataset = dataset
@@ -89,9 +89,12 @@ def get_rewards(self, model_fn, actions):
8989
reward = np.clip(reward, -0.05, 0.05)
9090

9191
# update moving accuracy with bias correction for 1st update
92-
self.moving_acc = self.beta * self.moving_acc + (1 - self.beta) * acc
93-
self.moving_acc = self.moving_acc / (1 - self.beta_bias)
94-
self.beta_bias = 0
92+
if self.beta > 0.0 and self.beta < 1.0:
93+
self.moving_acc = self.beta * self.moving_acc + (1 - self.beta) * acc
94+
self.moving_acc = self.moving_acc / (1 - self.beta_bias)
95+
self.beta_bias = 0
96+
97+
reward = np.clip(reward, -0.1, 0.1)
9598

9699
print()
97100
print("Manager: EWA Accuracy = ", self.moving_acc)

train.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717
NUM_LAYERS = 4 # number of layers of the state space
1818
MAX_TRIALS = 250 # maximum number of models generated
1919

20-
MAX_EPOCHS = 10 # maximum number of epochs to train
20+
MAX_EPOCHS = 1 # maximum number of epochs to train
2121
CHILD_BATCHSIZE = 128 # batchsize of the child models
22-
EXPLORATION = 0.8 # high exploration for the first 1000 steps
22+
EXPLORATION = 0.9 # high exploration for the first 1000 steps
2323
REGULARIZATION = 1e-3 # regularization strength
2424
CONTROLLER_CELLS = 32 # number of cells in RNN controller
25-
CLIP_REWARDS = False # clip rewards in the [-0.05, 0.05] range
25+
EMBEDDING_DIM = 20 # dimension of the embeddings for each state
26+
ACCURACY_BETA = 0.8 # beta value for the moving average of the accuracy
27+
CLIP_REWARDS = 0.0 # clip rewards in the [-0.05, 0.05] range
2628
RESTORE_CONTROLLER = True # restore controller to continue training
2729

2830
# construct a state space
@@ -53,17 +55,22 @@
5355
reg_param=REGULARIZATION,
5456
exploration=EXPLORATION,
5557
controller_cells=CONTROLLER_CELLS,
58+
embedding_dim=EMBEDDING_DIM,
5659
restore_controller=RESTORE_CONTROLLER)
5760

5861
# create the Network Manager
59-
manager = NetworkManager(dataset, epochs=MAX_EPOCHS, child_batchsize=CHILD_BATCHSIZE, clip_rewards=CLIP_REWARDS)
62+
manager = NetworkManager(dataset, epochs=MAX_EPOCHS, child_batchsize=CHILD_BATCHSIZE, clip_rewards=CLIP_REWARDS,
63+
acc_beta=ACCURACY_BETA)
6064

6165
# get an initial random state space if controller needs to predict an
6266
# action from the initial state
6367
state = state_space.get_random_state_space(NUM_LAYERS)
6468
print("Initial Random State : ", state_space.parse_state_space_list(state))
6569
print()
6670

71+
# clear the previous files
72+
controller.remove_files()
73+
6774
# train for number of trails
6875
for trial in range(MAX_TRIALS):
6976
with policy_sess.as_default():

0 commit comments

Comments
 (0)