Skip to content

Commit

Permalink
Merge pull request #115 from amcadmus/devel
Browse files Browse the repository at this point in the history
polarizability: add option to only fit the diag part
  • Loading branch information
amcadmus authored Nov 3, 2019
2 parents ca61f69 + 9a61ff9 commit 7f8afdd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
5 changes: 3 additions & 2 deletions examples/water/train/polar_se_a.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
},
"fitting_net": {
"type": "polar",
"pol_type": [0],
"sel_type": [0],
"fit_diag": true,
"neuron": [100, 100, 100],
"resnet_dt": true,
"seed": 1,
Expand All @@ -28,7 +29,7 @@

"learning_rate" :{
"type": "exp",
"start_lr": 0.001,
"start_lr": 0.01,
"decay_steps": 5000,
"decay_rate": 0.95,
"_comment": "that's all"
Expand Down
1 change: 1 addition & 0 deletions source/tests/polar_se_a.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"fitting_net": {
"type": "polar",
"pol_type": [0],
"fit_diag": false,
"neuron": [100, 100, 100],
"resnet_dt": true,
"seed": 1,
Expand Down
22 changes: 16 additions & 6 deletions source/train/Fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,14 @@ def __init__ (self, jdata, descrpt) :
args = ClassArg()\
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
.add('resnet_dt', bool, default = True)\
.add('fit_diag', bool, default = True)\
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
.add('seed', int)
class_data = args.parse(jdata)
self.n_neuron = class_data['neuron']
self.resnet_dt = class_data['resnet_dt']
self.sel_type = class_data['sel_type']
self.fit_diag = class_data['fit_diag']
self.seed = class_data['seed']
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
self.dim_rot_mat = self.dim_rot_mat_1 * 3
Expand Down Expand Up @@ -400,12 +402,20 @@ def build (self,
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
else :
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
# (nframes x natoms) x (naxis x naxis)
final_layer = one_layer(layer, self.dim_rot_mat_1*self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
# (nframes x natoms) x naxis x naxis
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1, self.dim_rot_mat_1])
# (nframes x natoms) x naxis x naxis
final_layer = final_layer + tf.transpose(final_layer, perm = [0,2,1])
if self.fit_diag :
# (nframes x natoms) x naxis
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
# (nframes x natoms) x naxis
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1])
# (nframes x natoms) x naxis x naxis
final_layer = tf.matrix_diag(final_layer)
else :
# (nframes x natoms) x (naxis x naxis)
final_layer = one_layer(layer, self.dim_rot_mat_1*self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
# (nframes x natoms) x naxis x naxis
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1, self.dim_rot_mat_1])
# (nframes x natoms) x naxis x naxis
final_layer = final_layer + tf.transpose(final_layer, perm = [0,2,1])
# (nframes x natoms) x naxis x 3(coord)
final_layer = tf.matmul(final_layer, rot_mat_i)
# (nframes x natoms) x 3(coord) x 3(coord)
Expand Down

0 comments on commit 7f8afdd

Please sign in to comment.