@@ -149,14 +149,14 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
149149 virial_hat = label_dict ["virial" ]
150150 atom_ener_hat = label_dict ["atom_ener" ]
151151 atom_pref = label_dict ["atom_pref" ]
152- find_energy = label_dict ["find_energy" ]
153- find_force = label_dict ["find_force" ]
154- find_virial = label_dict ["find_virial" ]
155- find_atom_ener = label_dict ["find_atom_ener" ]
156- find_atom_pref = label_dict ["find_atom_pref" ]
152+ find_energy = global_cvt_2_ener_float ( label_dict ["find_energy" ])
153+ find_force = global_cvt_2_tf_float ( label_dict ["find_force" ])
154+ find_virial = global_cvt_2_tf_float ( label_dict ["find_virial" ])
155+ find_atom_ener = global_cvt_2_tf_float ( label_dict ["find_atom_ener" ])
156+ find_atom_pref = global_cvt_2_tf_float ( label_dict ["find_atom_pref" ])
157157 if self .has_gf :
158158 drdq = label_dict ["drdq" ]
159- find_drdq = label_dict ["find_drdq" ]
159+ find_drdq = global_cvt_2_tf_float ( label_dict ["find_drdq" ])
160160 else :
161161 find_drdq = 0.0
162162
@@ -589,10 +589,10 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
589589 virial_label = label_dict ["virial" ]
590590 atom_ener_label = label_dict ["atom_ener" ]
591591 atom_pref = label_dict ["atom_pref" ]
592- find_energy = label_dict ["find_energy" ]
593- find_force = label_dict ["find_force" ]
594- find_virial = label_dict ["find_virial" ]
595- find_atom_ener = label_dict ["find_atom_ener" ]
592+ find_energy = global_cvt_2_ener_float ( label_dict ["find_energy" ])
593+ find_force = global_cvt_2_tf_float ( label_dict ["find_force" ])
594+ find_virial = global_cvt_2_tf_float ( label_dict ["find_virial" ])
595+ find_atom_ener = global_cvt_2_tf_float ( label_dict ["find_atom_ener" ])
596596
597597 if self .enable_atom_ener_coeff :
598598 # when ener_coeff (\nu) is defined, the energy is defined as
@@ -932,8 +932,8 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix):
932932
933933 energy_hat = label_dict ["energy" ]
934934 ener_dipole_hat = label_dict ["energy_dipole" ]
935- find_energy = label_dict ["find_energy" ]
936- find_ener_dipole = label_dict ["find_energy_dipole" ]
935+ find_energy = global_cvt_2_ener_float ( label_dict ["find_energy" ])
936+ find_ener_dipole = global_cvt_2_ener_float ( label_dict ["find_energy_dipole" ])
937937
938938 l2_ener_loss = tf .reduce_mean (
939939 tf .square (energy - energy_hat ), name = "l2_" + suffix
0 commit comments