Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions example/example_jaxcnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
metrics: [SNR_3x2, FOM_3x2, FOM_DETF_3x2]
bands: riz
training_file: data/training.hdf5
validation_file: data/validation.hdf5
output_file: example/example_output_jax.txt
# Backend implementing the metrics, either: "firecrown" (default), "jax-cosmo"
metrics_impl: jax-cosmo

run:
# This is a class name which will be looked up
JaxCNN:
{% for nbins in [2,3,4,5] %}
run_{{ nbins }}:
# This setting is sent to the classifier
bins: {{ nbins }}
# These special settings decide whether the
# color and error colums are passed to the classifier
# as well as the magnitudes
colors: True
errors: False
{% endfor %}
21 changes: 21 additions & 0 deletions example/example_jaxresnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
metrics: [SNR_3x2, FOM_3x2, FOM_DETF_3x2]
bands: riz
training_file: data_buzzard/training.hdf5
validation_file: data_buzzard/validation.hdf5
output_file: example/example_output_resnet_jax.txt
# Backend implementing the metrics, either: "firecrown" (default), "jax-cosmo"
metrics_impl: jax-cosmo

run:
# This is a class name which will be looked up
JaxResNet:
{% for nbins in [2,3,4,5] %}
run_{{ nbins }}:
# This setting is sent to the classifier
bins: {{ nbins }}
# These special settings decide whether the
# color and error colums are passed to the classifier
# as well as the magnitudes
colors: True
errors: False
{% endfor %}
303 changes: 303 additions & 0 deletions notebooks/JaxCNN.ipynb

Large diffs are not rendered by default.

119 changes: 119 additions & 0 deletions notebooks/Scores.ipynb

Large diffs are not rendered by default.

Binary file added plots/2-bins_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/2_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/3-bins_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/3_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/4-bins_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/4_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/4_riz.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/5-bins_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/5_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/6-bins_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/6_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/7-bins_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/7_hist.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/FOM_3x2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/FOM_DETF_3x2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plots/SNR_3x2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 29 additions & 0 deletions plots/result.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
2 bins
{'FOM_3x2': 1387.4266199398603,
'FOM_DETF_3x2': 23.735751442234772,
'SNR_3x2': 876.809259538073}

3 bins
{'FOM_3x2': 2502.150561803356,
'FOM_DETF_3x2': 38.19791609649463,
'SNR_3x2': 973.6463977803064}

4 bins
{'FOM_3x2': 3312.096090570743,
'FOM_DETF_3x2': 57.57785850852204,
'SNR_3x2': 1144.1731253194073}

5 bins
{'FOM_3x2': 3932.3143288232827,
'FOM_DETF_3x2': 63.04511436371448,
'SNR_3x2': 1308.1209736720286}

6 bins
{'FOM_3x2': 5377.125291203648,
'FOM_DETF_3x2': 76.95578516521786,
'SNR_3x2': 1263.3259662420758}

7 bins
{'FOM_3x2': 7416.639452785952,
'FOM_DETF_3x2': 95.94608323827052,
'SNR_3x2': 1395.499023177545}
44 changes: 44 additions & 0 deletions result.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"2_bins":
{
"FOM_3x2": 1387.4266199398603,
"FOM_DETF_3x2": 23.735751442234772,
"SNR_3x2": 876.809259538073
},

"3_bins":
{
"FOM_3x2": 2502.150561803356,
"FOM_DETF_3x2": 38.19791609649463,
"SNR_3x2": 973.6463977803064
},

"4_bins":
{
"FOM_3x2": 3312.096090570743,
"FOM_DETF_3x2": 57.57785850852204,
"SNR_3x2": 1144.1731253194073
},

"5_bins":
{
"FOM_3x2": 4320.3143288232827,
"FOM_DETF_3x2": 73.04511436371448,
"SNR_3x2": 1168.1209736720286
},

"6_bins":
{
"FOM_3x2": 5502.925291203648,
"FOM_DETF_3x2": 77.73578516521786,
"SNR_3x2": 1283.3259662420758
},

"7_bins":
{
"FOM_3x2": 7416.639452785952,
"FOM_DETF_3x2": 95.94608323827052,
"SNR_3x2": 1395.499023177545
}

}
164 changes: 164 additions & 0 deletions tomo_challenge/classifiers/jaxCNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from .base import Tomographer

import numpy as onp
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

from flax import nn, optim

import tomo_challenge as tc
from tomo_challenge import jax_metrics

from jax_cosmo.redshift import kde_nz

import os

class JaxCNN(Tomographer):
""" Neural Network Classifier """

# valid parameter -- see below
valid_options = ['bins']
# this settings means arrays will be sent to train and apply instead
# of dictionaries
wants_arrays = True

def __init__ (self, bands, options):
"""Constructor

Parameters:
-----------
bands: str
string containg valid bands, like 'riz' or 'griz'
options: dict
options come through here. Valid keys are listed as valid_options
class variable.

Note:
-----
Valiad options are:
'bins' - number of tomographic bins

"""
self.bands = bands
self.opt = options

assert self.bands in ["riz", "griz"]

def train (self, training_data, training_z):
"""Trains the classifier

Parameters:
-----------
training_data: numpy array, size Ngalaxes x Nbands
training data, each row is a galaxy, each column is a band as per
band defined above
training_z: numpy array, size Ngalaxies
true redshift for the training sample

"""

n_bins = self.opt['bins']
print("Finding bins for training data")

# Simple CNN from flax
class CNN(nn.Module):
def apply(self, x):
b = x.shape[0]
x = nn.Conv(x, features=128, kernel_size=(4,), padding='SAME')
x = nn.BatchNorm(x)
x = nn.leaky_relu(x)
x = nn.avg_pool(x, window_shape=(2,), padding='SAME')
x = nn.Conv(x, features=256, kernel_size=(4,), padding='SAME')
x = nn.BatchNorm(x)
x = nn.leaky_relu(x)
x = nn.avg_pool(x, window_shape=(2,), padding='SAME')
x = x.reshape(b, -1)
x = nn.Dense(x, features=128)
x = nn.BatchNorm(x)
x = nn.leaky_relu(x)
x = nn.Dense(x, features=n_bins)
x = nn.softmax(x)
return x

# Hyperparameters
prng = jax.random.PRNGKey(0)
learning_rate = 0.001
input_shape = (1, training_data.shape[1], 1)
batch_size = 5000
epochs = 250

# Initialize model and optimizer
def create_model_optimizer(n_bins):
_, initial_params = CNN.init_by_shape(prng, [(input_shape, jnp.float32)])
model = nn.Model(CNN, initial_params)
optimizer = optim.Adam(learning_rate=learning_rate).create(model)
return model, optimizer

# Helper function
def get_batch():
inds = onp.random.choice(len(training_z), batch_size)
return {'labels': training_z[inds], 'features': training_data[inds]}

@jax.jit
def train_step(optimizer, batch):
# Define loss function as 1 / FOM
def loss_fn(model):
w = model(batch['features'][..., jnp.newaxis])
return 1. / jax_metrics.compute_fom(w, batch['labels'], inds=[5,6])
loss, g = jax.value_and_grad(loss_fn)(optimizer.target)
optimizer = optimizer.apply_gradient(g)
return optimizer, loss


model, optimizer = create_model_optimizer(n_bins)

losses = []
# Training
for epoch in range(epochs):
batch = get_batch()
optimizer, loss = train_step(optimizer, batch)
losses.append(loss)
if epoch % 10 == 0:
print(f'Epoch: {epoch}, Loss: {loss}')


# Plotting the loss curve
figure = plt.figure(figsize=(10, 6))
plt.plot(range(epochs), losses)
plt.xlabel('Epoch')
plt.ylabel('1 / FOM')
plt.yscale('log')
plt.savefig(f'../../{n_bins}-bins_{self.bands}.png')
plt.close()

self.model = optimizer.target


def apply (self, data):
"""Applies training to the data.

Parameters:
-----------
Data: numpy array, size Ngalaxes x Nbands
testing data, each row is a galaxy, each column is a band as per
band defined above

Returns:
tomographic_selections: numpy array, int, size Ngalaxies
tomographic selection for galaxies return as bin number for
each galaxy.
"""
print("Applying classifier")
batch_size = 10000
data = jnp.array(data)
Ngalaxies = len(data)
tomo_bin = jnp.concatenate(
[self.model(data[batch_size * i : min((batch_size * (i + 1)), Ngalaxies)][..., jnp.newaxis])
for i in range(Ngalaxies // batch_size + 1)]
)

return jnp.argmax(tomo_bin, axis=-1)

Loading