Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,29 @@ on:
branches: [ master ]
pull_request:
branches: [ master ]
schedule:
- cron: "0 0 * * 1" # every monday at 00:00

jobs:
build:

runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.8', '3.9', '3.10']
os: [ubuntu-latest, macos-latest]
python-version: ['3.10', '3.11', '3.12', '3.13']

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest coverage
pip install -r requirements.txt
python -m pip install .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
build/
globalemu.egg-info/
env*
__pycache__/
venv*
.coverage
1 change: 0 additions & 1 deletion MANIFEST.in

This file was deleted.

2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction

:globalemu: Robust Global 21-cm Signal Emulation
:Author: Harry Thomas Jones Bevins
:Version: 1.8.2
:Version: 1.9.1
:Homepage: https://github.com/htjb/globalemu
:Documentation: https://globalemu.readthedocs.io/

Expand Down
Binary file modified docs/images/gui.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 10 additions & 2 deletions globalemu/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensorflow.keras import backend as K
import gc
import pickle
import warnings


class evaluate():
Expand Down Expand Up @@ -48,7 +49,7 @@ class evaluate():
base_dir + 'model.h5',
compile=False)

logs: **list / default: [0, 1, 2]**
logs: **list / default: []**
| The indices corresponding to the astrophysical
parameters that
were logged during training. The default assumes
Expand Down Expand Up @@ -140,7 +141,14 @@ def __init__(self, **kwargs):
self.base_dir + 'model.h5',
compile=False)

self.logs = kwargs.pop('logs', [0, 1, 2])
self.logs = kwargs.pop('logs', [])
if self.logs == []:
warnings.warn("logs has defaulted to [] i.e. " +
"log no input parameters. Older versions " +
"assumed logs=[0, 1, 2]. If logs is not in " +
"{base_dir}/kwargs.txt and the network " +
"was trained with globalemu < 1.9 there is a " +
"good chance logs should be [0, 1, 2]!")
if type(self.logs) is not list:
raise TypeError("'logs' must be a list.")
self.garbage_collection = kwargs.pop('gc', False)
Expand Down
104 changes: 104 additions & 0 deletions globalemu/gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
import numpy as np
import sys
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from globalemu.eval import evaluate

def main():
base_dir = sys.argv[1]
if base_dir[-1] != '/':
base_dir += '/'

config = np.genfromtxt(
base_dir + 'gui_configuration.csv',
delimiter=',',
names=True,
dtype='U100,f8,f8,f8,f8,U100,U100', # Explicitly specify types
encoding='utf-8'
)

logs = config['logs'].tolist()
logs = [int(x) for x in logs if x != '--']
label_min = config['label_min'][0]
label_max = config['label_max'][0]
ylabel = config['ylabel'][0]

predictor = evaluate(base_dir=base_dir, logs=logs)

# Calculate initial values (center of range)
center = []
for i in range(len(config['names'])):
center.append(config['mins'][i] + (config['maxs'][i] - config['mins'][i])/2)
center = np.array(center)

# Create figure with plot on left, sliders on right
n_sliders = len(config['names'])
fig = plt.figure(figsize=(14, 8))

# Main plot on left side - takes full height
ax_plot = plt.subplot2grid((n_sliders, 2), (0, 0), rowspan=n_sliders)

# Initial signal
def get_params(slider_vals):
params = []
for i in range(len(slider_vals)):
if i in set(logs):
params.append(10**slider_vals[i])
else:
params.append(slider_vals[i])
return params

signal, z = predictor(get_params(center))
line, = ax_plot.plot(z, signal, c='k', lw=2)
ax_plot.set_xlabel('z')
ax_plot.set_ylabel(ylabel)
ax_plot.set_ylim([label_min, label_max])
ax_plot.grid(True, alpha=0.3)

# Create sliders on right side
sliders = []

for i in range(n_sliders):
# Create axis for this slider on right column
ax_slider = plt.subplot2grid((n_sliders, 2), (i, 1))

# Create slider
slider = Slider(
ax_slider,
config['names'][i],
config['mins'][i],
config['maxs'][i],
valinit=center[i],
valstep=(config['maxs'][i] - config['mins'][i])/100
)
sliders.append(slider)

# Update function
def update(val):
slider_vals = [s.val for s in sliders]
params = get_params(slider_vals)
signal, z = predictor(params)
line.set_ydata(signal)
fig.canvas.draw_idle()

# Connect all sliders to update function
for slider in sliders:
slider.on_changed(update)

# Reset button
ax_reset = plt.axes([0.8, 0.025, 0.1, 0.04])
btn_reset = Button(ax_reset, 'Reset')

def reset(event):
for i, slider in enumerate(sliders):
slider.set_val(center[i])

btn_reset.on_clicked(reset)

plt.tight_layout()
plt.show()

if __name__ == '__main__':
main()
42 changes: 30 additions & 12 deletions globalemu/gui_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

import numpy as np
import pandas as pd
import pickle


Expand Down Expand Up @@ -115,14 +114,33 @@ def __init__(self, base_dir, paramnames, data_dir, **kwargs):
else:
full_logs.append('--')

df = pd.DataFrame({'names': self.paramnames,
'mins': data_mins,
'maxs': data_maxs,
'label_min':
[test_labels.min()] + ['']*(len(data_maxs)-1),
'label_max':
[test_labels.max()] + ['']*(len(data_maxs)-1),
'logs': full_logs,
'ylabel': self.ylabel})

df.to_csv(base_dir + 'gui_configuration.csv', index=False)
# Create the data arrays
n = len(data_maxs)
names = np.array(self.paramnames, dtype='U100') # Unicode strings
mins = np.array(data_mins, dtype=float)
maxs = np.array(data_maxs, dtype=float)

# label_min: first value is test_labels.min(), rest are empty strings
label_min = np.array([test_labels.min()] + ['']*(n-1), dtype='U100')

# label_max: first value is test_labels.max(), rest are empty strings
label_max = np.array([test_labels.max()] + ['']*(n-1), dtype='U100')

logs = np.array(full_logs, dtype='U100')
ylabel = np.array([self.ylabel]*n, dtype='U100')

# Stack into 2D array (transpose to get columns)
data = np.column_stack([names, mins, maxs, label_min, label_max, logs, ylabel])

# Write header
header = 'names,mins,maxs,label_min,label_max,logs,ylabel'

# Save to CSV
np.savetxt(
base_dir + 'gui_configuration.csv',
data,
delimiter=',',
header=header,
comments='', # Prevents '#' being added to header
fmt='%s' # String format for all columns
)
9 changes: 5 additions & 4 deletions globalemu/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import numpy as np
import os
import pandas as pd
import pickle
from globalemu.cmSim import calc_signal
from globalemu.resample import sampling
Expand Down Expand Up @@ -146,10 +145,12 @@ def __init__(self, num, z, **kwargs):

np.savetxt(self.base_dir + 'z.txt', self.z)

def load_data(file):
return pd.read_csv(
def load_data(file: str) -> np.ndarray:
"""Helper function to load the data from the provided directory."""
return np.loadtxt(
self.data_location + file,
delim_whitespace=True, header=None).values
dtype=float
)

full_train_data = load_data('train_data.txt')
full_train_labels = load_data('train_labels.txt')
Expand Down
50 changes: 50 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
[build-system]
requires = ["setuptools>=77.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "globalemu"
version = "1.9.1"
description = "globalemu: Robust and Fast Global 21-cm Signal Emulation."
readme = "README.rst"
authors = [
{name = "Harry T. J. Bevins", email = "htjb2@cam.ac.uk"}]
license = {text = "MIT"}
requires-python = ">=3.10"
dependencies = [
'numpy',
'tensorflow',
'matplotlib',
'Pillow'
]

classifiers=[
'Intended Audience :: Science/Research',
'Natural Language :: English',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: 3.13',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Astronomy',
'Topic :: Scientific/Engineering :: Physics',
]

[project.scripts]
globalemu = "globalemu.gui:main"


[project.optional-dependencies]
dev = [
"pytest",
]
docs = [
"sphinx",
"sphinx_rtd_theme",
"numpydoc",
"packaging"
]

[tool.setuptools.packages.find]
where = ["."]
include = ["globalemu*"]
6 changes: 0 additions & 6 deletions requirements.txt

This file was deleted.

Loading
Loading