Skip to content

Commit 0685e0d

Browse files
add ebound support and tests
1 parent 80a596f commit 0685e0d

File tree

3 files changed

+82
-10
lines changed

3 files changed

+82
-10
lines changed

stingray/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def read(cls, filename, fmt=None, rmf_file=None, **kwargs):
631631
additional_columns = kwargs.pop("additional_columns", None)
632632

633633
evt = FITSTimeseriesReader(
634-
filename, output_class=EventList, additional_columns=additional_columns
634+
filename, output_class=EventList, additional_columns=additional_columns, **kwargs
635635
)[:]
636636

637637
if rmf_file is not None:

stingray/io.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
from astropy.io import fits
11+
import astropy.units as u
1112
from astropy.table import Table
1213
from astropy.logger import AstropyUserWarning
1314
import matplotlib.pyplot as plt
@@ -741,6 +742,7 @@ def __init__(
741742
gtistring=None,
742743
additional_columns=None,
743744
data_kind="events",
745+
**kwargs,
744746
):
745747
self.fname = fname
746748
self._data = fits.open(self.fname)
@@ -757,6 +759,8 @@ def __init__(
757759
f"{data_kind} is an unknown data kind."
758760
)
759761
self.data_kind = data_kind
762+
self.high_precision = kwargs.pop("high_precision", None)
763+
760764
if additional_columns is None and self.detector_key != "NONE":
761765
additional_columns = [self.detector_key]
762766
elif self.detector_key != "NONE":
@@ -765,8 +769,12 @@ def __init__(
765769

766770
if self.energy_column == "EBOUNDS":
767771
self.edata_hdu = self._data["EBOUNDS"]
772+
self.emin = kwargs.pop("emin", None)
773+
self.emax = kwargs.pop("emax", None)
768774
self.gti_file = gti_file
769775
self._read_gtis(self.gti_file)
776+
if kwargs != {}:
777+
warnings.warn(f"Unrecognized keywords: {list(kwargs.keys())}")
770778

771779
@property
772780
def time(self):
@@ -843,12 +851,6 @@ def _transform_slice_into_events(self, data):
843851
if self._mission_specific_processing is not None:
844852
data = self._mission_specific_processing(data, header=self.header, hduname=self.hduname)
845853

846-
# Set the times
847-
setattr(
848-
new_ts,
849-
self.main_array_attr,
850-
data[self.time_column][:] + self.timezero,
851-
)
852854
# Get conversion function PI->Energy
853855
try:
854856
pi_energy_func = get_rough_conversion_function(
@@ -865,15 +867,32 @@ def _transform_slice_into_events(self, data):
865867
ehigher = self.edata_hdu.data["E_MAX"]
866868
emid = elower + (ehigher - elower) / 2.0
867869

868-
energy = np.array([emid[c] for c in channels])
870+
self.emin = np.min(elower) if self.emin is None else self.emin
871+
self.emax = np.max(ehigher) if self.emax is None else self.emax
869872

873+
energy = np.array([emid[c] for c in channels])
870874
if (
871875
hasattr(self.edata_hdu.columns["E_MIN"], "unit")
872876
and (unit := self.edata_hdu.columns["E_MIN"].unit) is not None
873877
):
874878
conversion = (1 * u.Unit(unit)).to(u.keV).value
875-
new_ts.energy = energy * conversion
876-
new_ts.pi = channels
879+
880+
if isinstance(self.emin, u.Quantity):
881+
self.emin = self.emin.to(u.keV).value
882+
if isinstance(self.emax, u.Quantity):
883+
self.emax = self.emax.to(u.keV).value
884+
885+
if self.emin > self.emax:
886+
self.emin, self.emax = self.emax, self.emin
887+
888+
mask = (energy > self.emin) & (energy < self.emax)
889+
energy = energy[mask]
890+
channels = channels[mask]
891+
892+
time_dtype = np.float128 if self.high_precision is True else np.float64
893+
894+
new_ts.energy = np.asanyarray(energy * conversion, dtype=np.float64)
895+
new_ts.pi = np.asanyarray(channels, dtype=time_dtype)
877896
else:
878897
if self.energy_column in data.dtype.names:
879898
conversion = 1
@@ -888,6 +907,16 @@ def _transform_slice_into_events(self, data):
888907
if pi_energy_func is not None:
889908
new_ts.energy = pi_energy_func(new_ts.pi)
890909

910+
if "mask" not in locals():
911+
mask = np.ones(data[self.time_column].shape, dtype=bool)
912+
913+
# Set the times
914+
setattr(
915+
new_ts,
916+
self.main_array_attr,
917+
data[self.time_column][mask] + self.timezero,
918+
)
919+
891920
det_numbers = None
892921
if self.detector_key is not None and self.detector_key in data.dtype.names:
893922
new_ts.detector_id = data[self.detector_key]

stingray/tests/test_events.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
77

88
from astropy.io import fits
9+
import astropy.units as u
910
from astropy.time import Time
1011

1112
from ..events import EventList
@@ -375,6 +376,7 @@ def setup_class(self):
375376
elower = hdul[1].data["E_MIN"]
376377
ehigher = hdul[1].data["E_MAX"]
377378
emid = elower + (ehigher - elower) / 2.0
379+
self.emin, self.emax = np.min(elower), np.max(ehigher)
378380
self.energy = np.array([emid[c] for c in self.pi])
379381

380382
def test_read_fermi(self):
@@ -386,6 +388,43 @@ def test_check_energy_pi(self):
386388
assert_allclose(evt.energy, self.energy, atol=1e-8)
387389
assert_array_equal(evt.pi, self.pi)
388390

391+
def test_high_precision(self):
392+
evt = EventList.read(self.fermi_file, high_precision=True)
393+
394+
assert np.issubdtype(evt.time.dtype, np.float128)
395+
assert np.issubdtype(evt.pi.dtype, np.float128)
396+
397+
@pytest.mark.parametrize(
398+
"emin, emax",
399+
[
400+
(2, 45.2),
401+
(2 * u.keV, 45.2 * u.keV),
402+
(2000 * u.eV, 45.2 * u.keV),
403+
(None, 45.2 * u.keV),
404+
(2000 * u.eV, None),
405+
(None, None),
406+
],
407+
)
408+
def test_mask_ebound(self, emin, emax):
409+
evt = EventList.read(self.fermi_file, emin=emin, emax=emax)
410+
411+
emin = emin if emin is not None else self.emin
412+
emax = emax if emax is not None else self.emax
413+
414+
if isinstance(emin, u.Quantity):
415+
emin = emin.to(u.keV).value
416+
if isinstance(emax, u.Quantity):
417+
emax = emax.to(u.keV).value
418+
419+
mask = (self.energy > emin) & (self.energy < emax)
420+
energy = self.energy[mask]
421+
time = self.time[mask]
422+
pi = self.pi[mask]
423+
424+
assert_array_equal(evt.time, time)
425+
assert_array_equal(evt.pi, pi)
426+
assert_array_equal(evt.energy, energy)
427+
389428
def test_check_time_gti(self):
390429
evt = EventList.read(self.fermi_file)
391430

@@ -394,6 +433,10 @@ def test_check_time_gti(self):
394433
assert_equal(evt.ncounts, self.time.shape[0])
395434
assert_allclose(evt.mjdref, self.header["MJDREFI"])
396435

436+
def test_kwargs(self):
437+
with pytest.warns(UserWarning, match="Unrecognized keywords:"):
438+
EventList.read(self.fermi_file, alpha=1)
439+
397440

398441
class TestJoinEvents:
399442
def test_join_without_times_simulated(self):

0 commit comments

Comments
 (0)