diff --git a/README.rst b/README.rst index cccdad8..7f613f0 100644 --- a/README.rst +++ b/README.rst @@ -7,7 +7,7 @@ Introduction :globalemu: Robust Global 21-cm Signal Emulation :Author: Harry Thomas Jones Bevins -:Version: 1.9.1 +:Version: 1.10.0 :Homepage: https://github.com/htjb/globalemu :Documentation: https://globalemu.readthedocs.io/ diff --git a/globalemu/network.py b/globalemu/network.py index c54167f..652c5cc 100644 --- a/globalemu/network.py +++ b/globalemu/network.py @@ -35,6 +35,9 @@ class nn(): large for every sample to have influenced an update of the network hyperparameters. + save_after **int / default: 10** + | After how many epochs do you want to save the trained model? + activation: **string / default: 'tanh'** | The type of activation function used in the neural networks hidden layers. The activation function effects the way that the @@ -140,7 +143,7 @@ def __init__(self, **kwargs): for key, values in kwargs.items(): if key not in set( - ['batch_size', 'activation', 'epochs', + ['batch_size', 'activation', 'epochs','save_after', 'lr', 'dropout', 'input_shape', 'output_shape', 'layer_sizes', 'base_dir', 'early_stop', 'xHI', 'resume', @@ -166,6 +169,7 @@ def __init__(self, **kwargs): if type(self.activation) is not str: raise TypeError("'activation' must be a string.") self.epochs = kwargs.pop('epochs', 10) + self.save_after = kwargs.pop('save_after',10) self.lr = kwargs.pop('lr', 1e-3) self.drop_val = kwargs.pop('dropout', 0) self.input_shape = kwargs.pop('input_shape', 8) @@ -321,7 +325,7 @@ def grad(model, inputs, targets): ' Epochs used = ' + str(epoch)) break - if (epoch + 1) % 10 == 0: + if (epoch + 1) % self.save_after == 0: model.save(self.base_dir + 'model.h5') np.savetxt( self.base_dir + 'loss_history.txt', train_loss_results) diff --git a/globalemu/plotter.py b/globalemu/plotter.py index 0d9dc91..00980e5 100644 --- a/globalemu/plotter.py +++ b/globalemu/plotter.py @@ -11,7 +11,7 @@ import numpy as np from globalemu.losses import loss_functions import matplotlib.pyplot as plt - +import os class signal_plot(): @@ -209,3 +209,23 @@ def __init__(self, parameters, labels, loss_type, plt.subplots_adjust(hspace=0, wspace=0) plt.savefig(self.base_dir + 'eval_plot.pdf') plt.close() + + #--------------------------------------------- + #Shikhar: saving the data for the quality figure for customisation purpose. + print('\nSaving the predicted and true signals.\n') + save_dir = self.base_dir + '/checks/' + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + np.save(save_dir+'mean_true',mean_label) + np.save(save_dir+'mean_pred',mean_pred) + np.save(save_dir+'95_true',limit_label) + np.save(save_dir+'95_pred',limit_pred) + np.save(save_dir+'worst_true',worst_label) + np.save(save_dir+'worst_pred',worst_pred) + + + filenm = open(save_dir+'loss_quality.txt', "w") + filenm.write(f'{loss[np.where(np.isclose(loss, loss.mean(), rtol=self.rtol, atol=self.atol))[0][0]]}\n') + filenm.write(f'{limit95}\n') + filenm.write(f'{loss.max()}\n') + filenm.close() diff --git a/pyproject.toml b/pyproject.toml index 326149b..7ce9f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "globalemu" -version = "1.9.1" +version = "1.10.0" description = "globalemu: Robust and Fast Global 21-cm Signal Emulation." readme = "README.rst" authors = [