Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction

:globalemu: Robust Global 21-cm Signal Emulation
:Author: Harry Thomas Jones Bevins
:Version: 1.9.1
:Version: 1.10.0
:Homepage: https://github.com/htjb/globalemu
:Documentation: https://globalemu.readthedocs.io/

Expand Down
8 changes: 6 additions & 2 deletions globalemu/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class nn():
large for every sample to have influenced an update of the
network hyperparameters.

save_after **int / default: 10**
| After how many epochs do you want to save the trained model?

activation: **string / default: 'tanh'**
| The type of activation function used in the neural networks
hidden layers. The activation function effects the way that the
Expand Down Expand Up @@ -140,7 +143,7 @@ def __init__(self, **kwargs):

for key, values in kwargs.items():
if key not in set(
['batch_size', 'activation', 'epochs',
['batch_size', 'activation', 'epochs','save_after',
'lr', 'dropout', 'input_shape',
'output_shape', 'layer_sizes', 'base_dir',
'early_stop', 'xHI', 'resume',
Expand All @@ -166,6 +169,7 @@ def __init__(self, **kwargs):
if type(self.activation) is not str:
raise TypeError("'activation' must be a string.")
self.epochs = kwargs.pop('epochs', 10)
self.save_after = kwargs.pop('save_after',10)
self.lr = kwargs.pop('lr', 1e-3)
self.drop_val = kwargs.pop('dropout', 0)
self.input_shape = kwargs.pop('input_shape', 8)
Expand Down Expand Up @@ -321,7 +325,7 @@ def grad(model, inputs, targets):
' Epochs used = ' + str(epoch))
break

if (epoch + 1) % 10 == 0:
if (epoch + 1) % self.save_after == 0:
model.save(self.base_dir + 'model.h5')
np.savetxt(
self.base_dir + 'loss_history.txt', train_loss_results)
Expand Down
22 changes: 21 additions & 1 deletion globalemu/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from globalemu.losses import loss_functions
import matplotlib.pyplot as plt

import os

class signal_plot():

Expand Down Expand Up @@ -209,3 +209,23 @@ def __init__(self, parameters, labels, loss_type,
plt.subplots_adjust(hspace=0, wspace=0)
plt.savefig(self.base_dir + 'eval_plot.pdf')
plt.close()

#---------------------------------------------
#Shikhar: saving the data for the quality figure for customisation purpose.
print('\nSaving the predicted and true signals.\n')
save_dir = self.base_dir + '/checks/'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
np.save(save_dir+'mean_true',mean_label)
np.save(save_dir+'mean_pred',mean_pred)
np.save(save_dir+'95_true',limit_label)
np.save(save_dir+'95_pred',limit_pred)
np.save(save_dir+'worst_true',worst_label)
np.save(save_dir+'worst_pred',worst_pred)


filenm = open(save_dir+'loss_quality.txt', "w")
filenm.write(f'{loss[np.where(np.isclose(loss, loss.mean(), rtol=self.rtol, atol=self.atol))[0][0]]}\n')
filenm.write(f'{limit95}\n')
filenm.write(f'{loss.max()}\n')
filenm.close()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "globalemu"
version = "1.9.1"
version = "1.10.0"
description = "globalemu: Robust and Fast Global 21-cm Signal Emulation."
readme = "README.rst"
authors = [
Expand Down
Loading