From a28ef3bb35b59a83809b3de30a784a3dbe255fc8 Mon Sep 17 00:00:00 2001 From: mittal Date: Wed, 11 Mar 2026 09:07:16 +0000 Subject: [PATCH 1/4] SM: adding the feature to save the model after user choice of epochs. Also, save the data for quality checks. --- globalemu/network.py | 8 ++++++-- globalemu/plotter.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/globalemu/network.py b/globalemu/network.py index c54167f..c9d046c 100644 --- a/globalemu/network.py +++ b/globalemu/network.py @@ -35,6 +35,9 @@ class nn(): large for every sample to have influenced an update of the network hyperparameters. + save_after **int / default: 10** + | After how many epochs do you want to save the trained model? + activation: **string / default: 'tanh'** | The type of activation function used in the neural networks hidden layers. The activation function effects the way that the @@ -140,7 +143,7 @@ def __init__(self, **kwargs): for key, values in kwargs.items(): if key not in set( - ['batch_size', 'activation', 'epochs', + ['batch_size', 'activation', 'epochs','save_after', 'lr', 'dropout', 'input_shape', 'output_shape', 'layer_sizes', 'base_dir', 'early_stop', 'xHI', 'resume', @@ -166,6 +169,7 @@ def __init__(self, **kwargs): if type(self.activation) is not str: raise TypeError("'activation' must be a string.") self.epochs = kwargs.pop('epochs', 10) + self.save_after = kwargs.pop('save_after',10) self.lr = kwargs.pop('lr', 1e-3) self.drop_val = kwargs.pop('dropout', 0) self.input_shape = kwargs.pop('input_shape', 8) @@ -321,7 +325,7 @@ def grad(model, inputs, targets): ' Epochs used = ' + str(epoch)) break - if (epoch + 1) % 10 == 0: + if (epoch + self.save_after) % 1 == 0: model.save(self.base_dir + 'model.h5') np.savetxt( self.base_dir + 'loss_history.txt', train_loss_results) diff --git a/globalemu/plotter.py b/globalemu/plotter.py index 0d9dc91..63962e9 100644 --- a/globalemu/plotter.py +++ b/globalemu/plotter.py @@ -209,3 +209,21 @@ def __init__(self, parameters, labels, loss_type, plt.subplots_adjust(hspace=0, wspace=0) plt.savefig(self.base_dir + 'eval_plot.pdf') plt.close() + + #--------------------------------------------- + #Shikhar: saving the data for the quality figure for customisation purpose. + print('\nSaving the predicted and true signals.\n') + save_dir = self.base_dir + '/checks/' + np.save(save_dir+'mean_true',mean_label) + np.save(save_dir+'mean_pred',mean_pred) + np.save(save_dir+'95_true',limit_label) + np.save(save_dir+'95_pred',limit_pred) + np.save(save_dir+'worst_true',worst_label) + np.save(save_dir+'worst_pred',worst_pred) + + + filenm = open(save_dir+'loss_quality.txt', "w") + filenm.write(f'{loss[np.where(np.isclose(loss, loss.mean(), rtol=self.rtol, atol=self.atol))[0][0]]}\n') + filenm.write(f'{limit95}\n') + filenm.write(f'{loss.max()}\n') + filenm.close() From 3c66db07195107dadb293795ec0d5cf24b4eba64 Mon Sep 17 00:00:00 2001 From: mittal Date: Wed, 11 Mar 2026 09:16:59 +0000 Subject: [PATCH 2/4] SM: correcting the save_after feature --- globalemu/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/globalemu/network.py b/globalemu/network.py index c9d046c..652c5cc 100644 --- a/globalemu/network.py +++ b/globalemu/network.py @@ -325,7 +325,7 @@ def grad(model, inputs, targets): ' Epochs used = ' + str(epoch)) break - if (epoch + self.save_after) % 1 == 0: + if (epoch + 1) % self.save_after == 0: model.save(self.base_dir + 'model.h5') np.savetxt( self.base_dir + 'loss_history.txt', train_loss_results) From 3c9a2f1c551c9666b974e83d5042c4365005db7e Mon Sep 17 00:00:00 2001 From: mittal Date: Thu, 12 Mar 2026 10:39:24 +0000 Subject: [PATCH 3/4] update version number --- README.rst | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index cccdad8..7f613f0 100644 --- a/README.rst +++ b/README.rst @@ -7,7 +7,7 @@ Introduction :globalemu: Robust Global 21-cm Signal Emulation :Author: Harry Thomas Jones Bevins -:Version: 1.9.1 +:Version: 1.10.0 :Homepage: https://github.com/htjb/globalemu :Documentation: https://globalemu.readthedocs.io/ diff --git a/pyproject.toml b/pyproject.toml index 326149b..7ce9f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "globalemu" -version = "1.9.1" +version = "1.10.0" description = "globalemu: Robust and Fast Global 21-cm Signal Emulation." readme = "README.rst" authors = [ From 374eaf88efc634b0d848a56c6b3f02ec6e482716 Mon Sep 17 00:00:00 2001 From: mittal Date: Thu, 12 Mar 2026 11:08:15 +0000 Subject: [PATCH 4/4] Create checks directory --- globalemu/plotter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/globalemu/plotter.py b/globalemu/plotter.py index 63962e9..00980e5 100644 --- a/globalemu/plotter.py +++ b/globalemu/plotter.py @@ -11,7 +11,7 @@ import numpy as np from globalemu.losses import loss_functions import matplotlib.pyplot as plt - +import os class signal_plot(): @@ -214,6 +214,8 @@ def __init__(self, parameters, labels, loss_type, #Shikhar: saving the data for the quality figure for customisation purpose. print('\nSaving the predicted and true signals.\n') save_dir = self.base_dir + '/checks/' + if not os.path.isdir(save_dir): + os.makedirs(save_dir) np.save(save_dir+'mean_true',mean_label) np.save(save_dir+'mean_pred',mean_pred) np.save(save_dir+'95_true',limit_label)