Skip to content
This repository was archived by the owner on Jun 27, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ env:
# Try all python versions with the latest numpy
- SETUP_CMD='test'


matrix:
include:

Expand All @@ -73,7 +72,6 @@ matrix:
- python: 3.5
env: NUMPY_VERSION=1.10


install:

# We now use the ci-helpers package to set up our testing environment.
Expand Down
1 change: 0 additions & 1 deletion docs/gen_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from astropy.modeling.fitting import SherpaFitter
from astropy.modeling.models import Gaussian1D, Gaussian2D

import numpy as np
import matplotlib.pyplot as plt

Expand Down
7 changes: 6 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ To make use of the entry points plugin registry which automatically makes the |S
Otherwise one can just use the latest stable ``astropy`` via::
conda install astropy


Next install Sherpa_ using the conda ``sherpa`` channel. Note that Sherpa
currently needs to be installed after astropy on Mac OSX.

Expand Down Expand Up @@ -236,4 +237,8 @@ API/Reference
Credit
------

The development of this package was made possible by the generous support of the `Google Summer of Code <https://summerofcode.withgoogle.com/>`_ program in 2016 under the `OpenAstronomy <http://openastronomy.org/>`_ by `Michele Costa <https://github.com/nocturnalastro>`_ with the support and advice of mentors `Tom Aldcroft <https://github.com/taldcroft>`_, `Omar Laurino <https://github.com/olaurino>`_, `Moritz Guenther <https://github.com/hamogu>`_, and `Doug Burke <https://github.com/DougBurke>`_.
The development of this package was made possible by the generous support of the `Google Summer of Code <https://summerofcode.withgoogle.com/>`_ program in 2016
under the `OpenAstronomy <http://openastronomy.org/>`_
by `Michele Costa <https://github.com/nocturnalastro>`_ with the support and advice of mentors
`Tom Aldcroft <https://github.com/taldcroft>`_, `Omar Laurino <https://github.com/olaurino>`_,
`Moritz Guenther <https://github.com/hamogu>`_, and `Doug Burke <https://github.com/DougBurke>`_.
89 changes: 79 additions & 10 deletions saba/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from __future__ import (absolute_import, unicode_literals, division,
print_function)
import numpy as np
from collections import OrderedDict
import numpy as np
import warnings
import copy

from sherpa.fit import Fit
from sherpa.data import Data1D, Data1DInt, Data2D, Data2DInt, DataSimulFit
from sherpa.data import BaseData
from sherpa.models import UserModel, Parameter, SimulFitModel
from sherpa.instrument import PSFModel
from sherpa.stats import Chi2, Chi2ConstVar, Chi2DataVar, Chi2Gehrels
from sherpa.stats import Chi2ModVar, Chi2XspecVar, LeastSq
from sherpa.stats import CStat, WStat, Cash
from sherpa.optmethods import GridSearch, LevMar, MonCar, NelderMead
from sherpa.estmethods import Confidence, Covariance, Projection
from sherpa.sim import MCMC
import warnings

from astropy.utils import format_doc
from astropy.utils.exceptions import AstropyUserWarning
Expand All @@ -27,8 +30,6 @@
if "SherpaFitter" not in w.message.message:
warnings.warn(w)

# from astropy.modeling

__all__ = ('SherpaFitter', 'SherpaMCMC')


Expand Down Expand Up @@ -268,6 +269,50 @@ def stat_values(self):
return self._stat_values


def make_rsp(data,rsp):
"""
Take an array as a response which is then convolved with the model output.
Parameters
----------
data: a sherpa dataset
rsp : an array which represets rsp
"""
def wrap_rsp(data, rsp):
rsp = np.asarray(rsp)
rdata = copy.deepcopy(data)
rdata.y = rsp
psf = PSFModel("user_rsp", rdata)
psf.fold(data)
return psf

try:
ndims = len(data.data.datasets[0].get_dims())
except AttributeError:
ndims = len(data.data.get_dims())

if ndims > 1:
return None
else:
rsp = np.asarray(rsp)

if data.ndata > 1:
if rsp.ndim > 1 or rsp.dtype == np.object:
if rsp.shape[0] == data.ndata:
zipped = zip(data.data.datasets, rsp)
else:
raise AstropyUserWarning("There is more than 1 but not"
" ndata responses")
else:
zipped = zip(data.data.datasets,
[rsp for _ in xrange(data.ndata)])

rsp = []
for da, rr in zipped:
rsp.append(wrap_rsp(da, rr))
else:
return wrap_rsp(data.data, rsp)


class SherpaFitter(Fitter):
__doc__ = """
Sherpa Fitter for astropy models.
Expand Down Expand Up @@ -321,7 +366,7 @@ def __init__(self, optimizer="levmar", statistic="leastsq", estmethod="covarianc
setattr(self.__class__, 'est_config', property(lambda s: s._est_config, doc=self._est_method.__doc__))


def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, **kwargs):
def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, rsp=None, **kwargs):
"""
Fit the astropy model with a the sherpa fit routines.

Expand All @@ -347,6 +392,9 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,
bkg_sale : float or list of floats (optional)
the scaling factor for the dataset if a single value
is supplied it will be copied for each dataset
rsp: array or list of arrays
this is convolved with the model output when fitting the model
N.B only 1D is currently supported.
**kwargs :
keyword arguments will be passed on to sherpa fit routine

Expand All @@ -364,10 +412,15 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,

self._data = Dataset(n_inputs, x, y, z, xbinsize, ybinsize, err, bkg, bkg_scale)

if rsp is not None:
self._rsp = make_rsp(self._data, rsp)
else:
self._rsp = None

if self._data.ndata > 1:

if len(models) == 1:
self._fitmodel = ConvertedModel([models.copy() for _ in xrange(self._data.ndata)], tie_list)
self._fitmodel = ConvertedModel([models.copy() for _ in xrange(self._data.ndata)], tie_list, rsp=self._rsp)
# Copy the model so each data set has the same model!
elif len(models) == self._data.ndata:
self._fitmodel = ConvertedModel(models, tie_list)
Expand All @@ -377,9 +430,10 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,
else:
if len(models) > 1:
self._data.make_simfit(len(models))
self._fitmodel = ConvertedModel(models, tie_list)
self._fitmodel = ConvertedModel(models, tie_list,
rsp=self._rsp)
else:
self._fitmodel = ConvertedModel(models)
self._fitmodel = ConvertedModel(models, rsp=self._rsp)

self._fitter = Fit(self._data.data, self._fitmodel.sherpa_model, self._stat_method, self._opt_method, self._est_method, **kwargs)
self.fit_info = self._fitter.fit()
Expand Down Expand Up @@ -633,16 +687,31 @@ class ConvertedModel(object):
e.g. [(modelB.y, modelA.x)] will mean that y in modelB will be tied to x of modelA
"""

def __init__(self, models, tie_list=None):
def __init__(self, models, tie_list=None, rsp=None):
self.model_dict = OrderedDict()
try:
models.parameters # does it quack
self.sherpa_model = self._astropy_to_sherpa_model(models)
self.rsp = rsp
if rsp is not None:
self.sherpa_model = rsp(self.sherpa_model)

self.model_dict[models] = self.sherpa_model
except AttributeError:
for mod in models:
try:
n_rsp = len(rsp)
assert len(models) == n_rsp, AstropyUserWarning("The number of responses must be either 1 or the numeber of models %i" % len(models))
zipped = zip(models, rsp)

except TypeError:
zipped = zip(models, [rsp for _ in range(len(models))])

for mod, rsp in zipped:
self.model_dict[mod] = self._astropy_to_sherpa_model(mod)

if rsp is not None:
self.sherpa_model[mod] = rsp(self.sherpa_model[mod])

if tie_list is not None:
for par1, par2 in tie_list:
getattr(self.model_dict[par1._model], par1.name).link = getattr(self.model_dict[par2._model], par2.name)
Expand Down
2 changes: 1 addition & 1 deletion saba/tests/coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ exclude_lines =
pragma: no cover

# Don't complain about packages we have installed
# except ImportError
except ImportError

# Don't complain if tests don't hit assertions
raise AssertionError
Expand Down
15 changes: 15 additions & 0 deletions saba/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,21 @@ def test_bkg_doesnt_explode(self):
sfit(m, x, y, bkg=bkg)
# TODO: Make this better!


def test_rsp1d_doesnt_explode(self):
"""
Check this goes through the motions
"""

self.fitter(self.model1d.copy(), self.x1, self.y1, err=self.dy1, rsp=self.rsp1)

def test_rsp1d_multi_doesnt_explode(self):
"""
Check this goes through the motions
"""

self.fitter([self.model1d.copy(), self.model1d_2.copy()], [self.x1, self.x2], [self.y1, self.y2], err=[self.dy1, self.dy2], rsp=[self.rsp1, self.rsp2])

def test_entry_points(self):
# a little to test that entry points can be loaded!
from pkg_resources import iter_entry_points
Expand Down