diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a9f3fa4..bd5d097 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -8,6 +8,8 @@ on: branches: [ master ] pull_request: branches: [ master ] + schedule: + - cron: "0 0 * * 1" # every monday at 00:00 jobs: build: @@ -15,8 +17,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ['3.8', '3.9', '3.10'] + os: [ubuntu-latest, macos-latest] + python-version: ['3.10', '3.11', '3.12', '3.13'] steps: - uses: actions/checkout@v2 @@ -24,11 +26,11 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Install dependencies + - name: Install run: | python -m pip install --upgrade pip python -m pip install flake8 pytest coverage - pip install -r requirements.txt + python -m pip install . - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/.gitignore b/.gitignore index 48ce8c8..3290b96 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ build/ globalemu.egg-info/ +env* +__pycache__/ +venv* +.coverage diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index f9bd145..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -include requirements.txt diff --git a/README.rst b/README.rst index 04d4a2c..cccdad8 100644 --- a/README.rst +++ b/README.rst @@ -7,7 +7,7 @@ Introduction :globalemu: Robust Global 21-cm Signal Emulation :Author: Harry Thomas Jones Bevins -:Version: 1.8.2 +:Version: 1.9.1 :Homepage: https://github.com/htjb/globalemu :Documentation: https://globalemu.readthedocs.io/ diff --git a/docs/images/gui.png b/docs/images/gui.png index 37e499b..e779233 100644 Binary files a/docs/images/gui.png and b/docs/images/gui.png differ diff --git a/globalemu/eval.py b/globalemu/eval.py index e9a557e..7b696da 100644 --- a/globalemu/eval.py +++ b/globalemu/eval.py @@ -15,6 +15,7 @@ from tensorflow.keras import backend as K import gc import pickle +import warnings class evaluate(): @@ -48,7 +49,7 @@ class evaluate(): base_dir + 'model.h5', compile=False) - logs: **list / default: [0, 1, 2]** + logs: **list / default: []** | The indices corresponding to the astrophysical parameters that were logged during training. The default assumes @@ -140,7 +141,14 @@ def __init__(self, **kwargs): self.base_dir + 'model.h5', compile=False) - self.logs = kwargs.pop('logs', [0, 1, 2]) + self.logs = kwargs.pop('logs', []) + if self.logs == []: + warnings.warn("logs has defaulted to [] i.e. " + + "log no input parameters. Older versions " + + "assumed logs=[0, 1, 2]. If logs is not in " + + "{base_dir}/kwargs.txt and the network " + + "was trained with globalemu < 1.9 there is a " + + "good chance logs should be [0, 1, 2]!") if type(self.logs) is not list: raise TypeError("'logs' must be a list.") self.garbage_collection = kwargs.pop('gc', False) diff --git a/globalemu/gui.py b/globalemu/gui.py new file mode 100644 index 0000000..76228da --- /dev/null +++ b/globalemu/gui.py @@ -0,0 +1,104 @@ +import matplotlib.pyplot as plt +from matplotlib.widgets import Slider, Button +import numpy as np +import sys +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +from globalemu.eval import evaluate + +def main(): + base_dir = sys.argv[1] + if base_dir[-1] != '/': + base_dir += '/' + + config = np.genfromtxt( + base_dir + 'gui_configuration.csv', + delimiter=',', + names=True, + dtype='U100,f8,f8,f8,f8,U100,U100', # Explicitly specify types + encoding='utf-8' + ) + + logs = config['logs'].tolist() + logs = [int(x) for x in logs if x != '--'] + label_min = config['label_min'][0] + label_max = config['label_max'][0] + ylabel = config['ylabel'][0] + + predictor = evaluate(base_dir=base_dir, logs=logs) + + # Calculate initial values (center of range) + center = [] + for i in range(len(config['names'])): + center.append(config['mins'][i] + (config['maxs'][i] - config['mins'][i])/2) + center = np.array(center) + + # Create figure with plot on left, sliders on right + n_sliders = len(config['names']) + fig = plt.figure(figsize=(14, 8)) + + # Main plot on left side - takes full height + ax_plot = plt.subplot2grid((n_sliders, 2), (0, 0), rowspan=n_sliders) + + # Initial signal + def get_params(slider_vals): + params = [] + for i in range(len(slider_vals)): + if i in set(logs): + params.append(10**slider_vals[i]) + else: + params.append(slider_vals[i]) + return params + + signal, z = predictor(get_params(center)) + line, = ax_plot.plot(z, signal, c='k', lw=2) + ax_plot.set_xlabel('z') + ax_plot.set_ylabel(ylabel) + ax_plot.set_ylim([label_min, label_max]) + ax_plot.grid(True, alpha=0.3) + + # Create sliders on right side + sliders = [] + + for i in range(n_sliders): + # Create axis for this slider on right column + ax_slider = plt.subplot2grid((n_sliders, 2), (i, 1)) + + # Create slider + slider = Slider( + ax_slider, + config['names'][i], + config['mins'][i], + config['maxs'][i], + valinit=center[i], + valstep=(config['maxs'][i] - config['mins'][i])/100 + ) + sliders.append(slider) + + # Update function + def update(val): + slider_vals = [s.val for s in sliders] + params = get_params(slider_vals) + signal, z = predictor(params) + line.set_ydata(signal) + fig.canvas.draw_idle() + + # Connect all sliders to update function + for slider in sliders: + slider.on_changed(update) + + # Reset button + ax_reset = plt.axes([0.8, 0.025, 0.1, 0.04]) + btn_reset = Button(ax_reset, 'Reset') + + def reset(event): + for i, slider in enumerate(sliders): + slider.set_val(center[i]) + + btn_reset.on_clicked(reset) + + plt.tight_layout() + plt.show() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/globalemu/gui_config.py b/globalemu/gui_config.py index 1368242..706f32f 100644 --- a/globalemu/gui_config.py +++ b/globalemu/gui_config.py @@ -15,7 +15,6 @@ """ import numpy as np -import pandas as pd import pickle @@ -115,14 +114,33 @@ def __init__(self, base_dir, paramnames, data_dir, **kwargs): else: full_logs.append('--') - df = pd.DataFrame({'names': self.paramnames, - 'mins': data_mins, - 'maxs': data_maxs, - 'label_min': - [test_labels.min()] + ['']*(len(data_maxs)-1), - 'label_max': - [test_labels.max()] + ['']*(len(data_maxs)-1), - 'logs': full_logs, - 'ylabel': self.ylabel}) - - df.to_csv(base_dir + 'gui_configuration.csv', index=False) + # Create the data arrays + n = len(data_maxs) + names = np.array(self.paramnames, dtype='U100') # Unicode strings + mins = np.array(data_mins, dtype=float) + maxs = np.array(data_maxs, dtype=float) + + # label_min: first value is test_labels.min(), rest are empty strings + label_min = np.array([test_labels.min()] + ['']*(n-1), dtype='U100') + + # label_max: first value is test_labels.max(), rest are empty strings + label_max = np.array([test_labels.max()] + ['']*(n-1), dtype='U100') + + logs = np.array(full_logs, dtype='U100') + ylabel = np.array([self.ylabel]*n, dtype='U100') + + # Stack into 2D array (transpose to get columns) + data = np.column_stack([names, mins, maxs, label_min, label_max, logs, ylabel]) + + # Write header + header = 'names,mins,maxs,label_min,label_max,logs,ylabel' + + # Save to CSV + np.savetxt( + base_dir + 'gui_configuration.csv', + data, + delimiter=',', + header=header, + comments='', # Prevents '#' being added to header + fmt='%s' # String format for all columns + ) diff --git a/globalemu/preprocess.py b/globalemu/preprocess.py index eb5b046..c6b677c 100644 --- a/globalemu/preprocess.py +++ b/globalemu/preprocess.py @@ -13,7 +13,6 @@ import numpy as np import os -import pandas as pd import pickle from globalemu.cmSim import calc_signal from globalemu.resample import sampling @@ -146,10 +145,12 @@ def __init__(self, num, z, **kwargs): np.savetxt(self.base_dir + 'z.txt', self.z) - def load_data(file): - return pd.read_csv( + def load_data(file: str) -> np.ndarray: + """Helper function to load the data from the provided directory.""" + return np.loadtxt( self.data_location + file, - delim_whitespace=True, header=None).values + dtype=float + ) full_train_data = load_data('train_data.txt') full_train_labels = load_data('train_labels.txt') diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..326149b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["setuptools>=77.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "globalemu" +version = "1.9.1" +description = "globalemu: Robust and Fast Global 21-cm Signal Emulation." +readme = "README.rst" +authors = [ + {name = "Harry T. J. Bevins", email = "htjb2@cam.ac.uk"}] +license = {text = "MIT"} +requires-python = ">=3.10" +dependencies = [ + 'numpy', + 'tensorflow', + 'matplotlib', + 'Pillow' +] + +classifiers=[ + 'Intended Audience :: Science/Research', + 'Natural Language :: English', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Astronomy', + 'Topic :: Scientific/Engineering :: Physics', +] + +[project.scripts] +globalemu = "globalemu.gui:main" + + +[project.optional-dependencies] +dev = [ + "pytest", +] +docs = [ + "sphinx", + "sphinx_rtd_theme", + "numpydoc", + "packaging" +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["globalemu*"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0ab477e..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -numpy -tensorflow-macos; sys_platform == 'darwin' -tensorflow; sys_platform != 'darwin' -pandas -matplotlib -Pillow diff --git a/scripts/globalemu b/scripts/globalemu deleted file mode 100644 index fffa3d3..0000000 --- a/scripts/globalemu +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python3 -from tkinter import * -import matplotlib.pyplot as plt -from matplotlib.mathtext import math_to_image -from io import BytesIO -from PIL import ImageTk, Image -import numpy as np -import shutil -import pandas as pd -import os -import sys -import locale -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -from globalemu.eval import evaluate - -l = '.'.join(locale.getlocale()) -if l[:2] != 'en': - try: - locale.setlocale(locale.LC_ALL, 'en_GB.UTF-8') - except OSError: - print("OSError: system locale is not set to 'en_' and " + - "'en_GB.UTF-8' is not available. This will cause " + - "issues with tkinter. To install 'en_GB.UTF-8' and run the " + - "GUI run 'sudo apt-get install language-pack-en' in " + - "the terminal.") - -def slider_var(name, lower, upper, position_y, - initial, img_x_resize, img_y_resize, tickinterval, - resolution=0.1): - """Creates the GUI sliders and associated labels""" - position_x = 410 - buffer = BytesIO() - math_to_image(name, buffer, dpi=200, format='png') - buffer.seek(0) - - img_var = ImageTk.PhotoImage(Image.open(buffer).resize( - (img_x_resize, img_y_resize))) - entry = Scale( - window, from_=lower, to=upper, - orient=HORIZONTAL, length=325, - resolution=resolution, - tickinterval=tickinterval, - background='white', command=signal) - entry.place(x=position_x, y=position_y) - entry.set(initial) - return entry, img_var - -def signal(_, parameters=None): - """Creates the new signals when the user moves the sliders.""" - if parameters is not None: - params = [] - for i in range(len(center)): - if i in set(logs): - params.append(10**parameters[i]) - else: - params.append(parameters[i]) - else: - params = [] - for i in range(len(entries)): - if i in set(logs): - params.append(10**float(entries[i].get())) - else: - params.append(float(entries[i].get())) - signal, z = predictor(params) - plt.figure(figsize=(4, 3)) - plt.plot(z, signal, c='k') - plt.xlabel('z') - plt.ylabel(ylabel) - plt.ylim([label_min, label_max]) - plt.tight_layout() - plt.savefig('img/img.png', dpi=100) - plt.close() - if parameters is None: - new_img = ImageTk.PhotoImage(Image.open("img/img.png")) - panel.configure(image=new_img) - panel.image = new_img - - -def reset(parameters=None): - """Resets the GUI when the reset button is pressed.""" - if parameters is not None: - signal('', parameters) - else: - for i in range(len(center)): - if i in set(logs): - entries[i].set(center[i]) - else: - entries[i].set(center[i]) - signal('') - -base_dir = sys.argv[1] - -config = pd.read_csv(base_dir + 'gui_configuration.csv') - -logs = config['logs'].tolist() -logs = [int(x) for x in logs if x != '--'] -label_min = config['label_min'][0] -label_max = config['label_max'][0] -ylabel = config['ylabel'][0] - -predictor = evaluate(base_dir=base_dir, logs=logs) - -window = Tk() -window.geometry("800x450") -window.configure(background='white') - -window.title('globalemu GUI') - -if os.path.exists('img/'): - shutil.rmtree('img/') -os.mkdir('img/') - -center = [] -for i in range(len(config['names'])): - center.append(config['mins'][i] + (config['maxs'][i] - config['mins'][i])/2) -center = np.array(center) - -reset(parameters=center) - -img = ImageTk.PhotoImage(Image.open("img/img.png")) -panel = Label(window, image=img) -panel.place(x=10, y=10) - -bad_characters = ['$', '\mathrm', '{', '}', '_', '^', '\\'] -common_greek = ['tau', 'alpha', 'nu', 'beta'] -entries, labels = [], [] -for i in range(len(config['names'])): - for j in range(len(bad_characters)): - if j == 0: - reduced_name = config['names'][i].replace(bad_characters[j], '') - else: - reduced_name = reduced_name.replace(bad_characters[j], '') - if reduced_name in set(common_greek): - reduced_name = 'o' - e, l = slider_var(config['names'][i], config['mins'][i], - config['maxs'][i], 10+60*i, - center[i], - 8*len(reduced_name), 15, - (config['maxs'][i] - config['mins'][i])/5, - resolution=(config['maxs'][i] - config['mins'][i])/100) - labels.append(l) - entries.append(e) - -for i in range(len(labels)): - Label(window, image=labels[i]).place(x=740, y=10+i*60) - -btn = Button(window, text='Reset', command=reset) -btn.place(x=180, y=360) - -window.mainloop() - -shutil.rmtree('img/') diff --git a/setup.py b/setup.py deleted file mode 100644 index 23e4b04..0000000 --- a/setup.py +++ /dev/null @@ -1,46 +0,0 @@ -from setuptools import setup, find_packages - -def readme(short=False): - with open('README.rst', encoding='utf-8') as f: - if short: - return f.readlines()[1].strip() - else: - return f.read() - -def get_version_from_readme() -> str: - """Get current version of package from README.rst""" - readme_text = readme() - for line in readme_text.splitlines(): - if ":Version:" in line: - current_version = line.split(":")[-1].strip() - return current_version - raise ValueError("Could not find version in README.rst") - -setup( - name='globalemu', - version=get_version_from_readme(), - description='globalemu: Robust and Fast Global 21-cm Signal Emulation', - long_description=readme(), - author='Harry T. J. Bevins', - author_email='htjb2@cam.ac.uk', - url='https://github.com/htjb/globalemu', - packages=find_packages(), - install_requires=open('requirements.txt').read().splitlines(), - license='MIT', - scripts=['scripts/globalemu'], - extras_require={ - 'docs': ['sphinx', 'sphinx_rtd_theme', 'numpydoc', 'packaging'], - }, - tests_require=['pytest'], - classifiers=[ - 'Intended Audience :: Science/Research', - 'Natural Language :: English', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Astronomy', - 'Topic :: Scientific/Engineering :: Physics', - ], -) diff --git a/tests/test_download.py b/tests/test_download.py index 440ed82..b76ee22 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -2,7 +2,6 @@ import os import pytest import requests -import pandas as pd import numpy as np @@ -25,12 +24,8 @@ def download_21cmGEM_data(): with open(data_dir + saves[i], 'wb') as f: f.write(requests.get(url).content) - td = pd.read_csv( - data_dir + 'train_data.txt', - delim_whitespace=True, header=None).values - tl = pd.read_csv( - data_dir + 'train_labels.txt', - delim_whitespace=True, header=None).values + td = np.loadtxt(data_dir + 'train_data.txt', dtype=float) + tl = np.loadtxt(data_dir + 'train_labels.txt', dtype=float) np.savetxt(data_dir + 'train_data.txt', td[:500, :]) np.savetxt(data_dir + 'train_labels.txt', tl[:500, :]) diff --git a/tests/test_gui_config.py b/tests/test_gui_config.py index 0bc1f28..a2cb755 100644 --- a/tests/test_gui_config.py +++ b/tests/test_gui_config.py @@ -1,4 +1,4 @@ -import pandas as pd +import numpy as np from globalemu.gui_config import config from globalemu.downloads import download import shutil @@ -26,7 +26,14 @@ def test_config(): assert(os.path.exists('xHI_release/gui_configuration.csv') is True) - res = pd.read_csv('xHI_release/gui_configuration.csv') + res = np.genfromtxt( + 'xHI_release/gui_configuration.csv', + delimiter=',', + names=True, + dtype='U100,f8,f8,f8,f8,U100,U100', # Explicitly specify types + encoding='utf-8' + ) + logs = res['logs'].tolist() logs = [int(x) for x in logs if x != '--'] assert(logs == [0, 1, 2]) diff --git a/tests/test_network.py b/tests/test_network.py index 60410b9..f297de0 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -14,61 +14,66 @@ def custom_loss(y, y_, x): z = np.arange(5, 50.1, 0.1) - process(10, z, data_location='21cmGEM_data/') + process(10, z, data_location="21cmGEM_data/") nn(batch_size=451, layer_sizes=[8], epochs=10, loss_function=custom_loss) # results of below will not make sense as it is being run on the # global signal data but it will test the code (xHI data not public) - process(10, z, data_location='21cmGEM_data/', xHI=True) + process(10, z, data_location="21cmGEM_data/", xHI=True) nn(batch_size=451, layer_sizes=[8], epochs=5, xHI=True) - nn(batch_size=451, layer_sizes=[8], epochs=5, output_activation='linear') + nn(batch_size=451, layer_sizes=[8], epochs=5, output_activation="linear") # test early_stop code nn(batch_size=451, layer_sizes=[], epochs=20, early_stop=True) + process(10, z, data_location="21cmGEM_data/", base_dir="base_dir/") + nn( + batch_size=451, + layer_sizes=[], + random_seed=10, + epochs=30, + base_dir="base_dir/", + early_stop=True, + ) + + dir = ["model_dir/", "base_dir/"] + for i in range(len(dir)): + if os.path.exists(dir[i]): + shutil.rmtree(dir[i]) + + +def test_process_nn_keyword_errors(): + z = np.arange(5, 50.1, 0.1) with pytest.raises(KeyError): - process(10, z, datalocation='21cmGEM_data/') + process(10, z, datalocation="21cmGEM_data/") with pytest.raises(KeyError): nn(batch_size=451, layersizes=[8], epochs=10) - with pytest.raises(TypeError): - nn(batch_size='foo') - with pytest.raises(TypeError): - nn(activation=10) - with pytest.raises(TypeError): - nn(epochs=False) - with pytest.raises(TypeError): - nn(lr='bar') - with pytest.raises(TypeError): - nn(dropout=True) - with pytest.raises(TypeError): - nn(input_shape='foo') - with pytest.raises(TypeError): - nn(output_shape='foobar') - with pytest.raises(TypeError): - nn(layer_sizes=10) - with pytest.raises(TypeError): - nn(base_dir=50) - with pytest.raises(KeyError): - nn(base_dir='dir') - with pytest.raises(TypeError): - nn(early_stop='foo') - with pytest.raises(TypeError): - nn(xHI='false') - with pytest.raises(TypeError): - nn(resume=10) - with pytest.raises(TypeError): - nn(output_activation=2) - with pytest.raises(TypeError): - nn(loss_function='foobar') - - process(10, z, data_location='21cmGEM_data/', base_dir='base_dir/') - nn(batch_size=451, layer_sizes=[], random_seed=10, epochs=30, - base_dir='base_dir/', early_stop=True) - - dir = ['model_dir/', 'base_dir/'] - for i in range(len(dir)): - if os.path.exists(dir[i]): - shutil.rmtree(dir[i]) + +@pytest.mark.parametrize( + "keyword, value", + [ + ("batch_size", "foo"), + ("activation", 10), + ("epochs", False), + ("lr", "bar"), + ("dropout", True), + ("input_shape", "foo"), + ("output_shape", "foobar"), + ("layer_sizes", 10), + ("base_dir", 50), + ("early_stop", "foo"), + ("xHI", "false"), + ("resume", 10), + ("output_activation", 2), + ("loss_function", "foobar"), + ], +) +def test_process_nn_type_errors(keyword, value): + z = np.arange(5, 50.1, 0.1) + + process(10, z, data_location="21cmGEM_data/") + with pytest.raises((TypeError, ValueError)): + nn(**{keyword: value}) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 5e63bf6..ed294da 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -2,7 +2,6 @@ from globalemu.preprocess import process import os import pytest -import pandas as pd import shutil @@ -29,8 +28,9 @@ def test_preprocess(): for i in range(len(files)): assert(os.path.exists('model_dir/' + files[i]) is True) - full_train_data = pd.read_csv( - 'model_dir/train_dataset.csv', header=None).values + full_train_data = np.loadtxt( + 'model_dir/train_dataset.csv', dtype=float, + delimiter=',') for i in range(full_train_data.shape[1]): if i < full_train_data.shape[1] - 1: