diff --git a/stixpy/calibration/detector.py b/stixpy/calibration/detector.py index 2fa6d38..0252d58 100644 --- a/stixpy/calibration/detector.py +++ b/stixpy/calibration/detector.py @@ -22,7 +22,9 @@ def get_srm(): ------- """ - drm_save = read_genx("/Users/shane/Projects/STIX/git/stix_drm_20220713.genx") + # drm_save = read_genx("/Users/shane/Projects/STIX/git/stix_drm_20220713.genx") + drm_save = np.load('/home/jmitchell/software/stixpy-dev/stixpy/config/data/detector/') + drm = drm_save["SAVEGEN0"]["SMATRIX"] * u.count / u.keV / u.photon energies_in = drm_save["SAVEGEN0"]["EDGES_IN"] * u.keV energies_in_width = np.diff(energies_in) diff --git a/stixpy/calibration/elut.py b/stixpy/calibration/elut.py new file mode 100644 index 0000000..3351754 --- /dev/null +++ b/stixpy/calibration/elut.py @@ -0,0 +1,107 @@ +from pathlib import Path + +import numpy as np + +from stixpy.calibration.detector import get_sci_channels +from stixpy.io.readers import read_elut, read_elut_index +import datetime + +__all__ = ["get_elut", "get_elut_correction"] + +def get_elut(date): + r""" + Get the energy lookup table (ELUT) for the given date + + Combines the ELUT with the science energy channels for the same date. + + Parameters + ---------- + date `astropy.time.Time` + Date to look up the ELUT. + + Returns + ------- + + """ + + # date = datetime.datetime(2023,11,15,0,0,0) + + root = Path(__file__).parent.parent + elut_index_file = Path(root, *["config", "data", "elut", "elut_index.csv"]) + + elut_index = read_elut_index(elut_index_file) + elut_info = elut_index.at(date) + if len(elut_info) == 0: + raise ValueError(f"No ELUT for for date {date}") + elif len(elut_info) > 1: + raise ValueError(f"Multiple ELUTs for for date {date}") + start_date, end_date, elut_file = list(elut_info)[0] + sci_channels = get_sci_channels(date) + + print('ELUT_FILENAME = ' , elut_file) + + elut_table = read_elut(elut_file, sci_channels) + + return elut_table + +def get_elut_correction(e_ind, pixel_data): + r""" + Get ELUT correction factors + + Only correct the low and high energy edges as internal bins are contiguous. + + Parameters + ---------- + e_ind : `np.ndarray` + Energy channel indices + pixel_data : `~stixpy.products.sources.CompressedPixelData` + Pixel data + + Returns + ------- + + """ + + energy_mask = pixel_data.energy_masks.energy_mask.astype(bool) + elut = get_elut(pixel_data.time_range.center) + ebin_edges_low = np.zeros((32, 12, 32), dtype=float) + ebin_edges_low[..., 1:] = elut.e_actual + ebin_edges_low = ebin_edges_low[..., energy_mask] + ebin_edges_high = np.zeros((32, 12, 32), dtype=float) + ebin_edges_high[..., 0:-1] = elut.e_actual + ebin_edges_high[..., -1] = np.nan + ebin_edges_high = ebin_edges_high[..., energy_mask] + ebin_widths = ebin_edges_high - ebin_edges_low + ebin_sci_edges_low = elut.e[..., 0:-1].value + ebin_sci_edges_low = ebin_sci_edges_low[..., energy_mask] + ebin_sci_edges_high = elut.e[..., 1:].value + ebin_sci_edges_high = ebin_sci_edges_high[..., energy_mask] + e_cor_low = (ebin_edges_high[..., e_ind[0]] - ebin_sci_edges_low[..., e_ind[0]]) / ebin_widths[..., e_ind[0]] + e_cor_high = (ebin_sci_edges_high[..., e_ind[-1]] - ebin_edges_low[..., e_ind[-1]]) / ebin_widths[..., e_ind[-1]] + + numbers = np.array([ + 0, 1, 2, 3, 4, 5, 6, 7, 13, 14, 15, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31 + ]) + + bins = ebin_sci_edges_high - ebin_sci_edges_low + + det_indices_top24 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 13, 14, 15, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + + det_indices_full = np.where(pixel_data.detector_masks.__dict__['masks'] == 1 )[1] + + det_indices = [d for i,d in enumerate(det_indices_top24) if d in det_indices_full] + + # det_indices = np.where(self.detector_masks.__dict__['masks'] == 1 )[1] + + pix_indices = np.where(pixel_data.pixel_masks.__dict__['masks'] == 1 )[1] + + bins_actual_1 = ebin_widths[det_indices, :, :] + bins_actual = bins_actual_1[:, pix_indices, :].mean(axis=1).mean(axis=0) + + + cor = bins / bins_actual + print('ELUT_COR = ',cor) + + return e_cor_high, e_cor_low, cor diff --git a/stixpy/calibration/energy.py b/stixpy/calibration/energy.py index 3a070d9..2f064fd 100644 --- a/stixpy/calibration/energy.py +++ b/stixpy/calibration/energy.py @@ -38,6 +38,10 @@ def get_elut(date): sci_channels = get_sci_channels(date) elut_table = read_elut(elut_file, sci_channels) + + print('science_channels = ',np.shape(sci_channels)) + print('elut_table_channels = ',np.shape(elut_table)) + return elut_table diff --git a/stixpy/calibration/flare_location.py b/stixpy/calibration/flare_location.py new file mode 100644 index 0000000..f7776d0 --- /dev/null +++ b/stixpy/calibration/flare_location.py @@ -0,0 +1,79 @@ + +import logging + +import astropy.units as u +import matplotlib.pyplot as plt +import numpy as np +from astropy.coordinates import SkyCoord +from sunpy.coordinates import HeliographicStonyhurst, Helioprojective, frames +from sunpy.map import Map, make_fitswcs_header +from sunpy.time import TimeRange +from xrayvision.clean import vis_clean +from xrayvision.imaging import vis_to_image, vis_to_map +from xrayvision.mem import mem, resistant_mean + +from stixpy.calibration.visibility import calibrate_visibility, create_meta_pixels, create_visibility +from stixpy.coordinates.frames import STIXImaging +from stixpy.coordinates.transforms import get_hpc_info +from stixpy.imaging.em import em +from stixpy.map.stix import STIXMap # noqa +from stixpy.product import Product + + + +__all__ = [ + "estimate_flare_location", +] + +def estimate_flare_location(sci_file,time_range_sci): + + energy_range = [6.,10.] * u.keV + imsize = [512,512] * u.pixel + pixel = [10, 10] * u.arcsec / u.pixel + + cpd_sci = Product(sci_file) + + meta_pixels_sci = create_meta_pixels( + cpd_sci, time_range=time_range_sci, energy_range=energy_range, flare_location=[0, 0] * u.arcsec, no_shadowing=True + ) + + vis = create_visibility(meta_pixels_sci) + + vis_tr = TimeRange(vis.meta["time_range"]) + roll, solo_xyz, pointing = get_hpc_info(vis_tr.start, vis_tr.end) + solo = HeliographicStonyhurst(*solo_xyz, obstime=vis_tr.center, representation_type="cartesian") + center_hpc = SkyCoord(0 * u.deg, 0 * u.deg, frame=Helioprojective(obstime=vis_tr.center, observer=solo)) + + cal_vis = calibrate_visibility(vis, flare_location=center_hpc) + + # order by sub-collimator e.g. 10a, 10b, 10c, 9a, 9b, 9c .... + isc_10_7 = [3, 20, 22, 16, 14, 32, 21, 26, 4, 24, 8, 28] + idx = np.argwhere(np.isin(cal_vis.meta["isc"], isc_10_7)).ravel() + + vis10_7 = cal_vis[idx] + + bp_image = vis_to_image(vis10_7, imsize, pixel_size=pixel) + + vis_tr = TimeRange(vis.meta["time_range"]) + roll, solo_xyz, pointing = get_hpc_info(vis_tr.start, vis_tr.end) + solo = HeliographicStonyhurst(*solo_xyz, obstime=vis_tr.center, representation_type="cartesian") + coord_stix = center_hpc.transform_to(STIXImaging(obstime=vis_tr.start, obstime_end=vis_tr.end, observer=solo)) + header = make_fitswcs_header( + bp_image, coord_stix, telescope="STIX", observatory="Solar Orbiter", scale=[10, 10] * u.arcsec / u.pix + ) + fd_bp_map = Map((bp_image, header)) + + max_pixel = np.argwhere(fd_bp_map.data == fd_bp_map.data.max()).ravel() * u.pixel + # because WCS axes and array are reversed + max_stix = fd_bp_map.pixel_to_world(max_pixel[1], max_pixel[0]) + + hpc_x = max_stix.transform_to(frames.Helioprojective).Tx.value + hpc_y = max_stix.transform_to(frames.Helioprojective).Ty.value + + stx_x = max_stix.Tx.value + stx_y = max_stix.Ty.value + + dictionary = {'stx':np.array([stx_x,stx_y]), + 'hpc':np.array([hpc_x,hpc_y])} + + return dictionary diff --git a/stixpy/calibration/grid.py b/stixpy/calibration/grid.py index 3fadaf4..a6d4598 100644 --- a/stixpy/calibration/grid.py +++ b/stixpy/calibration/grid.py @@ -7,13 +7,14 @@ import astropy.units as u import numpy as np from astropy.table import Table +import xraydb __all__ = ["get_grid_transmission", "_calculate_grid_transmission"] from stixpy.coordinates.frames import STIXImaging -def get_grid_transmission(flare_location: STIXImaging): +def get_grid_transmission(ph_energy, flare_location: STIXImaging): r""" Return the grid transmission for the 32 sub-collimators corrected for internal shadowing. @@ -32,18 +33,141 @@ def get_grid_transmission(flare_location: STIXImaging): front = Table.read(grid_info / "grid_param_front.txt", format="ascii", names=column_names) rear = Table.read(grid_info / "grid_param_rear.txt", format="ascii", names=column_names) - transmission_front = _calculate_grid_transmission(front, flare_location) - transmission_rear = _calculate_grid_transmission(rear, flare_location) - total_transmission = transmission_front * transmission_rear + # ;; Orientation of the slits of the grid as seen from the detector side + grid_orient_front_all = front['o'] + grid_orient_rear_all = rear['o'] + + pitch_front_all = front['p'] + pitch_rear_all = rear['p'] + + thickness_front_all = front['thick'] + thickness_rear_all = rear['thick'] + + sc = front['sc'] + + + # fpath = loc_file( 'CFL_subcoll_transmission.txt', path = getenv('STX_GRID') ) + column_names_cfl = ["subc_n", "subc_label", "intercept", "slope[1/deg]"] + + subcol_transmission = Table.read(grid_info / "CFL_subcoll_transmission.txt", format="ascii", names=column_names_cfl) + + subc_n_all = subcol_transmission['subc_n'] + subc_label = subcol_transmission['subc_label'] + intercept_all = subcol_transmission['intercept'] + slope_all = subcol_transmission['slope[1/deg]'] + + + muvals = xraydb.material_mu('W', ph_energy * 1e3, density=19.30, kind='total') / 10 # in units of mm^-1 + L = 1 / muvals + print('L = ', L) + # trans = np.exp(-0.4 / L) + subc_transm=L + + det_indices_top24 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 13, 14, 15, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + # det_all = np.arange(0,32,1) + + # det_indices_top24 = [0, 1, 2, 3, 4, 5, 6, 7, 13, 14, 15, 19, + # 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] + idx_full = det_indices_top24 + idx = [i for i,x in enumerate(sc-1) if x in det_indices_top24] + print('idx = ', idx) + print('lenidx = ',len(idx)) + # idx = det_all + + # for i,idx in enumerate(idx_full): + print(grid_orient_front_all) + + grid_orient_front = grid_orient_front_all[idx] + pitch_front = pitch_front_all[idx] + thickness_front = thickness_front_all[idx] + + # ;;------ Rear grid + grid_orient_rear = grid_orient_rear_all[idx] + pitch_rear = pitch_rear_all[idx] + thickness_rear = thickness_rear_all[idx] + + grid_orient_avg = (grid_orient_front + grid_orient_rear) / 2 + + flare_loc_deg = flare_location / 3600 #. ;; Convert coordinates to deg + theta = flare_loc_deg[0] * np.cos(np.deg2rad(grid_orient_avg)) + flare_loc_deg[1] * np.sin(np.deg2rad(grid_orient_avg)) + print('THETA = ', theta) + +# ;;------ Subcollimator tranmsission at low energies +# idx = np.where(subc_n_all eq (subc_n+1)) + + intercept = intercept_all[det_indices_top24] + slope = slope_all[det_indices_top24] + subc_transm_low_e = intercept + slope * theta + + # ;;------ Transmission of front and rear grid + slit_to_pitch = np.sqrt(subc_transm_low_e) + + slit_front = slit_to_pitch*pitch_front + slit_rear = slit_to_pitch*pitch_rear + + transm_front = stx_grid_transmission(pitch_front, slit_front, thickness_front, L) + transm_rear = stx_grid_transmission(pitch_rear, slit_rear, thickness_rear, L) + + # subc_transm.append(transm_front * transm_rear) + + subc_transm = transm_front * transm_rear + + # transmission_front = _calculate_grid_transmission(front, flare_location) + # transmission_rear = _calculate_grid_transmission(rear, flare_location) + # total_transmission = transmission_front * transmission_rear # The finest grids are made from multiple layers for the moment remove these and set 1 - final_transmission = np.ones(32) - sc = front["sc"] - finest_scs = [11, 12, 13, 17, 18, 19] # 1a, 2a, 1b, 2c, 1c, 2b - idx = np.argwhere(np.isin(sc, finest_scs, invert=True)).ravel() - final_transmission[sc[idx] - 1] = total_transmission[idx] - - return final_transmission + # final_transmission = np.ones(32) + # sc = front["sc"] + # finest_scs = [11, 12, 13, 17, 18, 19] # 1a, 2a, 1b, 2c, 1c, 2b + # idx = np.argwhere(np.isin(sc, finest_scs, invert=True)).ravel() + # final_transmission[sc[idx] - 1] = total_transmission[idx] + + print('subc = ',subc_transm) + return subc_transm + + +def stx_grid_transmission(pitch, slit, thickness, L): + + ds = 5e-3 + dh = 5e-2 + + # n_energies = np.shape(L)[0] + # n_subc = np.shape(pitch) + + slit_rep = slit.reshape(1, len(slit)) + pitch_rep = pitch.reshape(1, len(pitch)) + H_rep = thickness.reshape(1, len(thickness)) + L_rep = L.reshape(len(L), 1) + + print('slit = ',slit[0]) + print('pitch = ',pitch[0]) + print('h_rep = ',thickness[0]) + + # slit_rep = np.tile(slit, (n_energies, 1)) + # pitch_rep = np.tile(pitch, (n_energies, 1)) + # H_rep = np.tile(thickness, (n_energies, 1)) + # L_rep = np.tile(L, (n_subc, 1)).T + + print(np.shape(slit_rep)) + print(np.shape(pitch_rep)) + print(np.shape(H_rep)) + print(np.shape(L_rep)) + + # ;; Transmission for a wedge shape model for grid imperfections + g0 = slit_rep / pitch_rep + (pitch_rep - slit_rep) / pitch_rep * np.exp( - H_rep / L_rep ) + ttt = L_rep / dh * ( 1. - np.exp(- dh / L_rep ) ) + g1 = 2. * ds / pitch_rep * (ttt - np.exp( - H_rep / L_rep )) + + print('g0 = ',g0[:,0]) + print('g1 = ',g1[:,0]) + + g_transmission = g0 + g1 + + print('transmission = ',g_transmission) + + return g_transmission def _calculate_grid_transmission(grid_params, flare_location): diff --git a/stixpy/calibration/livetime.py b/stixpy/calibration/livetime.py index 556967b..1ac63c2 100644 --- a/stixpy/calibration/livetime.py +++ b/stixpy/calibration/livetime.py @@ -62,7 +62,7 @@ import astropy.units as u import numpy as np -__all__ = ["pileup_correction_factor", "get_livetime_fraction"] +__all__ = ["pileup_correction_factor", "get_livetime_fraction","livetime_counts_corr"] from stixpy.io.readers import read_subc_params @@ -93,7 +93,7 @@ def pileup_correction_factor(): return prob_diff_pix -def get_livetime_fraction(trigger_rate, *, eta=1.10 * u.us, tau=10.1 * u.us): +def get_livetime_fraction(trigger_rate,*, eta=1.1e-6 *u.s, tau=10.1e-6 * u.s): """ Return the live time fraction for the given trigger rate. @@ -113,7 +113,16 @@ def get_livetime_fraction(trigger_rate, *, eta=1.10 * u.us, tau=10.1 * u.us): """ beta = 0.94059104 # pileup_correction_factor() + # tau = tau / time_del + # eta = eta / time_del + photons_in = trigger_rate / (1.0 - trigger_rate * (tau + eta)) livetime_fraction = 1 / (1.0 + (tau + eta) * photons_in) two_photon = np.exp(-eta * beta * photons_in) * livetime_fraction + + return livetime_fraction, two_photon, photons_in + + + + diff --git a/stixpy/calibration/visibility.py b/stixpy/calibration/visibility.py index 652826b..c35ebee 100644 --- a/stixpy/calibration/visibility.py +++ b/stixpy/calibration/visibility.py @@ -177,12 +177,15 @@ def create_meta_pixels( t_ind = np.argwhere(t_mask).ravel() e_ind = np.argwhere(e_mask).ravel() + print('e_ind = ', e_ind) + time_range = TimeRange( pixel_data.times[t_ind[0]] - pixel_data.duration[t_ind[0]] / 2, pixel_data.times[t_ind[-1]] + pixel_data.duration[t_ind[-1]] / 2, ) changed = [] + for column in ["rcr", "pixel_masks", "detector_masks"]: if np.unique(pixel_data.data[column][t_ind], axis=0).shape[0] != 1: changed.append(column) @@ -201,21 +204,36 @@ def create_meta_pixels( pixel_data.data["livefrac"] = livefrac + print('e_ind = ',e_ind) + e_cor_high, e_cor_low = get_elut_correction(e_ind, pixel_data) + print('e_cor_high_shape = ',np.shape(e_cor_high)) + print('counts shape = ',np.shape(pixel_data.data["counts"])) + # Get counts and other data. idx_pix = _PIXEL_SLICES.get(pixels.lower(), None) + + print('pix = ',np.shape(e_cor_low[..., idx_pix])) + if idx_pix is None: raise ValueError(f"Unrecognised input for 'pixels': {pixels}. Supported values: {list(_PIXEL_SLICES.keys())}") counts = pixel_data.data["counts"].astype(float) count_errors = np.sqrt(pixel_data.data["counts_comp_err"].astype(float).value ** 2 + counts.value) * u.ct ct = counts[t_ind][..., idx_pix, e_ind] + # print('ct = ',np.shape(ct)) + ct_or = ct ct[..., 0] = ct[..., 0] * e_cor_low[..., idx_pix] ct[..., -1] = ct[..., -1] * e_cor_high[..., idx_pix] ct_error = count_errors[t_ind][..., idx_pix, e_ind] ct_error[..., 0] = ct_error[..., 0] * e_cor_low[..., idx_pix] ct_error[..., -1] = ct_error[..., -1] * e_cor_high[..., idx_pix] + # print(np.where( ((ct / ct_or) != 1) & (np.isnan(ct / ct_or) == False) )[0]) + + # indices = np.where(ct[...,0] != ct_or[...,0]) + # print(indices) + lt = (livefrac * pixel_data.data["timedel"].reshape(-1, 1).to("s"))[t_ind].sum(axis=0) ct_summed = ct.sum(axis=(0, 3)) # .astype(float) @@ -290,6 +308,8 @@ def get_elut_correction(e_ind, pixel_data): ebin_sci_edges_high = ebin_sci_edges_high[..., energy_mask] e_cor_low = (ebin_edges_high[..., e_ind[0]] - ebin_sci_edges_low[..., e_ind[0]]) / ebin_widths[..., e_ind[0]] e_cor_high = (ebin_sci_edges_high[..., e_ind[-1]] - ebin_edges_low[..., e_ind[-1]]) / ebin_widths[..., e_ind[-1]] + + return e_cor_high, e_cor_low diff --git a/stixpy/io/readers.py b/stixpy/io/readers.py index e5e3d00..2de2723 100644 --- a/stixpy/io/readers.py +++ b/stixpy/io/readers.py @@ -134,8 +134,19 @@ def read_elut(elut_file, sci_channels): elut = type("ELUT", (object,), dict()) elut.file = elut_file.name - elut.offset = elut_table["Offset"].reshape(32, 12) - elut.gain = elut_table["Gain keV/ADC"].reshape(32, 12) + try: + elut.offset = elut_table["Offset"].reshape(32, 12) + elut.gain = elut_table["Gain keV/ADC"].reshape(32, 12) + except KeyError: + try: + elut.offset = elut_table["Offset (ADC)"].reshape(32, 12) + elut.gain = elut_table["Gain (keV/ADC)"].reshape(32, 12) + + except KeyError: + elut.offset = elut_table["Offset (ADC)"].reshape(32, 12) + elut.gain = elut_table["Gain (ADC/keV)"].reshape(32, 12) + + elut.pixel = elut_table["Pixel"].reshape(32, 12) elut.detector = elut_table["Detector"].reshape(32, 12) adc = np.vstack(list(elut_table.columns[4:].values())).reshape(31, 32, 12) diff --git a/stixpy/product/sources/science.py b/stixpy/product/sources/science.py index fd90034..405ec00 100644 --- a/stixpy/product/sources/science.py +++ b/stixpy/product/sources/science.py @@ -15,9 +15,23 @@ from matplotlib.patches import Patch from matplotlib.widgets import Slider from sunpy.time.timerange import TimeRange +from datetime import timedelta + +from sunkit_spex.spectrum.spectrum import Spectrum, SpectralAxis +from sunkit_spex.spectrum.uncertainty import PoissonUncertainty +from ndcube import NDMeta +from ndcube.extra_coords import QuantityTableCoordinate, TimeTableCoordinate + +from stixpy.calibration.livetime import get_livetime_fraction +from stixpy.calibration.detector import get_srm from stixpy.io.readers import read_subc_params +from stixpy.calibration.elut import get_elut_correction from stixpy.product.product import L1Product +from stixpy.config.instrument import STIX_INSTRUMENT, _get_uv_points_data +from stixpy.calibration.transmission import Transmission +from stixpy.calibration.grid import get_grid_transmission +# from stixpy.calibration.flare_location import estimate_flare_location __all__ = [ "ScienceData", @@ -34,6 +48,7 @@ "DetectorMasks", "PixelMasks", "EnergyEdgeMasks", + "calc_count_rate" ] @@ -854,7 +869,14 @@ def duration(self): return self.data["timedel"] def get_data( - self, time_indices=None, energy_indices=None, detector_indices=None, pixel_indices=None, sum_all_times=False + self, + time_indices=None, + energy_indices=None, + detector_indices=None, + pixel_indices=None, + sum_all_times=False, + livetime_correction=True, + elut_correction=True ): """ Return the counts, errors, times, durations and energies for selected data. @@ -888,12 +910,69 @@ def get_data( Counts, errors, times, deltatimes, energies """ + + e_norm = self.dE + t_norm = self.data["timedel"] + + # print('time_dels = ',t_norm) + counts = self.data["counts"] + + # print('counts = ',np.shape(counts)) + try: counts_var = self.data["counts_comp_err"] ** 2 except KeyError: counts_var = self.data["counts_comp_comp_err"] ** 2 shape = counts.shape + + # print('shape_counts = ',np.shape(counts_var) ) + + if livetime_correction: + + trigger_to_detector = STIX_INSTRUMENT.subcol_adc_mapping + triggers = self.data["triggers"][:, trigger_to_detector].astype(float)[...] + + triggers_error = self.data["triggers_comp_err"][:, trigger_to_detector].astype(float)[...] + triggers_lower = triggers - triggers_error + triggers_upper = triggers + triggers_error + + _, livefrac, _ = get_livetime_fraction(triggers/ self.data["timedel"].to("s").reshape(-1, 1)) + _, livefrac_lower, _ = get_livetime_fraction(triggers_lower/ self.data["timedel"].to("s").reshape(-1, 1)) + _, livefrac_upper, _ = get_livetime_fraction(triggers_upper/ self.data["timedel"].to("s").reshape(-1, 1)) + + # if t_norm.size != 1: + + t_norm = t_norm.reshape(-1, 1, 1, 1) + livefrac = livefrac.reshape(livefrac.shape + (1, 1)) + livefrac_lower = livefrac_lower.reshape(livefrac_lower.shape + (1, 1)) + livefrac_upper = livefrac_upper.reshape(livefrac_upper.shape + (1, 1)) + + t_norm_original = t_norm + + t_norm = t_norm * livefrac + t_norm_lower = t_norm * livefrac_lower + t_norm_upper = t_norm * livefrac_upper + + # print('tnorm = ',t_norm) + + # print('td*lf = ',t_norm.mean(axis=1).squeeze()) + # print('lf = ',np.shape(livefrac)) + + if elut_correction: + + _, _, elut_cor_fac = get_elut_correction(np.array(self.energies['channel']), self) + + print('elut_corr_fac = ', elut_cor_fac) + e_norm_energies = e_norm + e_norm = e_norm / elut_cor_fac + + else: + e_norm_energies = e_norm + elut_cor_fac = 1 + + + if len(shape) < 4: counts = counts.reshape(shape[0], 1, 1, shape[-1]) counts_var = counts_var.reshape(shape[0], 1, 1, shape[-1]) @@ -901,6 +980,8 @@ def get_data( energies = self.energies[:] times = self.times + # print('TIMES = ',times) + if detector_indices is not None: detecor_indices = np.asarray(detector_indices) if detecor_indices.ndim == 1: @@ -908,6 +989,7 @@ def get_data( detector_mask[detecor_indices] = True counts = counts[:, detector_mask, ...] counts_var = counts_var[:, detector_mask, ...] + t_norm = t_norm[..., detector_mask] elif detecor_indices.ndim == 2: counts = np.hstack( [np.sum(counts[:, dl : dh + 1, ...], axis=1, keepdims=True) for dl, dh in detecor_indices] @@ -925,6 +1007,7 @@ def get_data( pixel_mask[pixel_indices] = True counts = counts[..., pixel_mask, :] counts_var = counts_var[..., pixel_mask, :] + t_norm = t_norm[...,pixel_mask,:] elif pixel_indices.ndim == 2: counts = np.concatenate( [np.sum(counts[..., pl : ph + 1, :], axis=2, keepdims=True) for pl, ph in pixel_indices], axis=2 @@ -934,7 +1017,7 @@ def get_data( [np.sum(counts_var[..., pl : ph + 1, :], axis=2, keepdims=True) for pl, ph in pixel_indices], axis=2 ) - e_norm = self.dE + if energy_indices is not None: energy_indices = np.asarray(energy_indices) if energy_indices.ndim == 1: @@ -944,6 +1027,7 @@ def get_data( counts_var = counts_var[..., energy_mask] e_norm = self.dE[energy_mask] energies = self.energies[energy_mask] + elif energy_indices.ndim == 2: counts = np.concatenate( [np.sum(counts[..., el : eh + 1], axis=-1, keepdims=True) for el, eh in energy_indices], axis=-1 @@ -963,7 +1047,7 @@ def get_data( ) energies = QTable(energies * u.keV, names=["e_low", "e_high"]) - t_norm = self.data["timedel"] + if time_indices is not None: time_indices = np.asarray(time_indices) if time_indices.ndim == 1: @@ -972,6 +1056,7 @@ def get_data( counts = counts[time_mask, ...] counts_var = counts_var[time_mask, ...] t_norm = self.data["timedel"][time_mask] + livefrac = livefrac[time_mask, ...] times = times[time_mask] # dT = self.data['timedel'][time_mask] elif time_indices.ndim == 2: @@ -999,15 +1084,528 @@ def get_data( t_norm = np.sum(dt) if e_norm.size != 1: + e_norm_energies = e_norm_energies.reshape(1, 1, 1, -1) e_norm = e_norm.reshape(1, 1, 1, -1) - if t_norm.size != 1: - t_norm = t_norm.reshape(-1, 1, 1, 1) + if np.isnan(np.array(e_norm.value)).any(): + + valid_mask = np.flatnonzero(~np.isnan(e_norm)) + + e_norm_energies = e_norm_energies[...,valid_mask] + e_norm = e_norm[...,valid_mask] + counts = counts[...,valid_mask] + counts_var = counts_var[...,valid_mask] + + if elut_correction: + elut_cor_fac = elut_cor_fac[valid_mask] + + energies = energies[valid_mask] + + counts_err = np.sqrt(counts*u.ct + counts_var) + + counts_corr = counts / (e_norm * t_norm) + + counts_lower = counts / (t_norm_lower) + counts_upper = counts / (t_norm_upper) + + livetime_error = (counts_upper - counts_lower) / 2 + + + + counts_err = np.sqrt(((counts_err/t_norm)**2) + (livetime_error**2)) / (e_norm) + + + return counts_corr, counts_err, times, t_norm, livefrac, energies, elut_cor_fac + + + def get_spectrum(self, bkg_prod=None): + + det_indices_top24 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 13, 14, 15, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + + det_indices_full = np.where(self.detector_masks.__dict__['masks'] == 1 )[1] + + det_indices = [d for i,d in enumerate(det_indices_top24) if d in det_indices_full] + + pix_indices = np.where(self.pixel_masks.__dict__['masks'] == 1 )[1] + + rate, rate_err, times, t_norm_cs, livefrac, energies, elut_cor_fac = self.get_data() + + # energies = self.energies + + de = np.array(energies['e_high'] - energies['e_low']) * u.keV + + + t_diff = t_norm_cs.to(u.s) + + counts_kev = rate * t_norm_cs + counts = counts_kev * de + + dt = (t_diff / livefrac) + dt = dt.squeeze().mean(axis=1) + + result_count_rate_full = counts + result_count_rate_det = result_count_rate_full[:, det_indices, :, :] + result_count_rate_det_pix = result_count_rate_det[:, :, pix_indices, :] + result_count_rate = result_count_rate_det_pix.sum(axis=(1,2)) + + result_count_rate_full_corr = counts / livefrac + result_count_rate_det_corr = result_count_rate_full_corr[:, det_indices, :, :] + result_count_rate_det_pix_corr = result_count_rate_det_corr[:, :, pix_indices, :] + result_count_rate_corr = result_count_rate_det_pix_corr.sum(axis=(1,2)) + + counts_err_kev = rate_err * t_norm_cs + counts_err = counts_err_kev * de + result_count_err_rate_full = counts_err / livefrac + result_count_err_rate_det =result_count_err_rate_full[:, det_indices, :, :] + result_count_err_rate_det_pix =result_count_err_rate_det[:, :, pix_indices, :] + result_count_err_rate = np.sqrt(((result_count_err_rate_det_pix**2).sum(axis=(1,2)) ) ) - counts_err = np.sqrt(counts * u.ct + counts_var) / (e_norm * t_norm) - counts = counts / (e_norm * t_norm) + if bkg_prod: + + rate_bkg, rate_err_bkg, times_bkg, t_norm_cs_bkg, livefrac_bkg, energies_bkg, _ = bkg_prod.get_data(elut_correction=False) + + # energies_bkg = bkg_prod.energies + + de_bkg = np.array(energies_bkg['e_high'] - energies_bkg['e_low']) * u.keV + + t_diff_bkg = t_norm_cs_bkg.to(u.s) + + dt_bkg = (t_diff_bkg / livefrac_bkg) + + _, _, indices_sub = np.intersect1d( + energies['e_low'], + energies_bkg['e_low'], + return_indices=True + ) + + print('ind_sub = ',indices_sub) + print('e = ',energies['e_low']) + print('ebkg = ',energies_bkg['e_low']) + print('ebkg_ind = ',energies_bkg['e_low'][indices_sub]) + + rate_bkg = rate_bkg[:,:,:,indices_sub] * elut_cor_fac + rate_err_bkg = rate_err_bkg[:,:,:,indices_sub]* elut_cor_fac + de_bkg = de_bkg[indices_sub] + + counts_kev_bkg = rate_bkg * t_norm_cs_bkg + counts_bkg = counts_kev_bkg * de_bkg + + # dt = dt.squeeze().mean(axis=1) + dt_bkg = dt_bkg.squeeze() + + print('dt_bkg =',dt_bkg.shape) + print('dt_shape =',dt.shape) + # result_count_rate_full_bkg = (counts_bkg / dt_bkg) * dt + result_count_rate_full_bkg = counts_bkg + result_count_rate_det_bkg = result_count_rate_full_bkg[:, det_indices, :, :] + result_count_rate_det_pix_bkg = result_count_rate_det_bkg[:, :, pix_indices, :] + result_count_rate_bkg = result_count_rate_det_pix_bkg.sum(axis=(1,2)) + step = result_count_rate_bkg.sum(axis=0) / dt_bkg.mean() + result_count_rate_bkg = dt.reshape(len(dt),1) * step.reshape(1,len(step)) + + # result_count_rate_full_corr_bkg = (counts_bkg / (livefrac_bkg * dt_bkg)) * dt + result_count_rate_full_corr_bkg = counts_bkg / livefrac_bkg + result_count_rate_det_corr_bkg = result_count_rate_full_corr_bkg[:, det_indices, :, :] + result_count_rate_det_pix_corr_bkg = result_count_rate_det_corr_bkg[:, :, pix_indices, :] + result_count_rate_corr_bkg = result_count_rate_det_pix_corr_bkg.sum(axis=(1,2)) + step_corr = result_count_rate_corr_bkg.sum(axis=0) / dt_bkg.mean() + result_count_rate_corr_bkg = dt.reshape(len(dt),1) * step_corr.reshape(1,len(step)) + + counts_err_kev_bkg = rate_err_bkg * t_norm_cs_bkg + counts_err_bkg = counts_err_kev_bkg * de_bkg + + result_count_err_rate_full_bkg = counts_err_bkg + result_count_err_rate_det_bkg =result_count_err_rate_full_bkg[:, det_indices, :, :] + result_count_err_rate_det_pix_bkg =result_count_err_rate_det_bkg[:, :, pix_indices, :] + result_count_err_rate_bkg = np.sqrt(((result_count_err_rate_det_pix_bkg**2).sum(axis=(1,2)) ) ) + step_err = np.sqrt((result_count_err_rate_bkg**2).sum(axis=0)) / dt_bkg.mean() + result_count_err_rate_bkg = dt.reshape(len(dt),1) * step_err.reshape(1,len(step)) + + result_count_err_rate_full_corr_bkg = counts_err_bkg / livefrac_bkg + result_count_err_rate_det_corr_bkg =result_count_err_rate_full_corr_bkg[:, det_indices, :, :] + result_count_err_rate_det_corr_pix_bkg =result_count_err_rate_det_corr_bkg[:, :, pix_indices, :] + result_count_err_rate_bkg_corr = np.sqrt(((result_count_err_rate_det_corr_pix_bkg**2).sum(axis=(1,2)) ) ) + step_err_corr = np.sqrt((result_count_err_rate_bkg_corr**2).sum(axis=0)) / dt_bkg.mean() + result_count_err_rate_bkg_corr = dt.reshape(len(dt),1) * step_err_corr.reshape(1,len(step)) + + spec_in_corr = result_count_rate_corr - result_count_rate_corr_bkg + spec_in = result_count_rate - result_count_rate_bkg + + spec_in_corr_err = np.sqrt(result_count_err_rate**2 + result_count_err_rate_bkg**2) + + spec_in_corr_lvt = result_count_rate_corr + spec_in_lvt = result_count_rate + spec_in_corr_err_lvt= result_count_err_rate + + if energies['e_low'][0].value == 0: + + spec_in = spec_in[:,1:] + spec_in_corr = spec_in_corr[:,1:] + spec_in_corr_err = spec_in_corr_err[:,1:] + energies = energies[1:] + + else: + + spec_in_corr = result_count_rate_corr + spec_in = result_count_rate + spec_in_corr_err= result_count_err_rate + + spec_in_corr_lvt = result_count_rate_corr + spec_in_lvt = result_count_rate + spec_in_corr_err_lvt= result_count_err_rate + + + if energies['e_low'][0].value == 0: + + spec_in = spec_in[:,1:] + spec_in_corr = spec_in_corr[:,1:] + spec_in_corr_err = spec_in_corr_err[:,1:] + energies = energies[1:] + + t_diff = t_diff[:,det_indices].mean(axis=1).squeeze() + + # eff_livefrac = result_count_rate.sum(axis=(1)) / result_count_rate_corr.sum(axis=(1)) + # + + eff_livefrac = spec_in_lvt.sum(axis=(1)) / spec_in_corr_lvt.sum(axis=(1)) + + # eff_livefrac = spec_in.sum(axis=(1) + + spec_in_final = spec_in_corr * eff_livefrac[:,None] + spec_in_corr_err_final = spec_in_corr_err * eff_livefrac[:,None] + + # dt = dt.squeeze().mean(axis=1) + + data_dictionary = {'rate':spec_in_final, + 'rate_err':spec_in_corr_err_final, + 'times':times, + 'time_bin':dt, + 'livefrac':eff_livefrac, + 'energies':energies} + + + return data_dictionary + + + def get_spec_obj(self,event_time_range,flare_angle,srm_dictionary=None,bkg_data=None,flare_location=None): + + + if not bkg_data: + + data_dict = self.get_spectrum() + counts_axis = np.concatenate([data_dict['energies']['e_low'],[data_dict['energies']['e_high'][-1]]]) + + de = np.diff(counts_axis)[np.newaxis,:] + dt = data_dict['time_bin'][:, np.newaxis] + + times_full = data_dict['times'] + + + counts = data_dict['rate'] + counts_uncertainity = data_dict['rate_err'] + + times_start = times_full - (data_dict['time_bin']/2) + times_end = times_full + (data_dict['time_bin']/2) + + inds = np.where( (times_start >= Time(event_time_range[0]) ) & (times_end <= Time(event_time_range[-1]) ) )[0] + + # print('len inds = ', len(inds)) + + counts_final = counts[inds].sum(axis=0) + counts_uncertainity_final = np.sqrt((counts_uncertainity[inds]**2).sum(axis=0)) + + counts_uncertainity_pu = PoissonUncertainty(counts_uncertainity_final) + counts_spectral_axis = SpectralAxis(counts_axis, bin_specification='edges') + + t_norm = data_dict['time_bin'][inds] * data_dict['livefrac'][inds] + + else: + + data_dict = self.get_spectrum(bkg_prod=bkg_data) + counts_axis = np.concatenate([data_dict['energies']['e_low'],[data_dict['energies']['e_high'][-1]]]) + + de = np.diff(counts_axis)[np.newaxis, :] + dt = data_dict['time_bin'][:, np.newaxis] + + times_full = data_dict['times'] + + counts = data_dict['rate'] + + print('ct_check_prev = ',len(counts[counts<0])) + + counts[np.nonzero(counts < 0)] = 0 + + print('ct_shape = ',counts.shape) + print('ct_check_post = ',len(counts[counts<0])) + + counts_uncertainity = data_dict['rate_err'] + + times_start = times_full - (data_dict['time_bin']/2) + times_end = times_full + (data_dict['time_bin']/2) + + inds = np.where( (times_start >= Time(event_time_range[0]) ) & (times_end <= Time(event_time_range[-1]) ) )[0] + + counts_final = counts[inds].sum(axis=0) + counts_uncertainity_final = np.sqrt((counts_uncertainity[inds]**2).sum(axis=0)) + + e_low = data_dict['energies']['e_low'].value + energy_conditions = [e_low < 7, (e_low < 10) & (e_low >= 7), e_low>= 10] + percentage = [0.07, 0.05, 0.03] + + systematic_err_percentage = np.select(energy_conditions, percentage) + + print(systematic_err_percentage) + + # Calculating systematic error + systematic_err = systematic_err_percentage * counts_final + + counts_err_final_final = np.sqrt(counts_uncertainity_final**2 + systematic_err**2) + + counts_uncertainity_pu = PoissonUncertainty(counts_err_final_final) + counts_spectral_axis = SpectralAxis(counts_axis, bin_specification='edges') + + t_norm = data_dict['time_bin'][inds] * data_dict['livefrac'][inds] + + + if not srm_dictionary: + srm_dict = self.get_masked_srm(flare_location=flare_location) + else: + srm_dict=srm_dictionary + + distance = (self.meta['DSUN_OBS'] * u.m).to(u.AU) + + meta = NDMeta() + + # meta.add("exposure_time", np.sum(t_norm)) + # meta.add("geo_area", srm_dict['geo_area']) + # meta.add("date-obs", data_dict['times']) + # meta.add("angle",flare_angle) + # meta.add("distance",distance) + # meta.add("srm",srm_dict['srm']) + # meta.add("ph_axis",srm_dict['ph_axis']) + + # spec_1d = Spectrum(data=counts_final,uncertainty=counts_uncertainity_pu, spectral_axis=counts_spectral_axis, meta=meta) + + # return spec_1d + + ct_de = np.diff(counts_axis.value) + + srm = srm_dict['srm'] * ct_de[None,:] + ph_ax_mids = srm_dict['ph_axis'][:-1] + 0.5*np.diff(srm_dict['ph_axis']) + + index = np.where(ph_ax_mids <= 3.7)[0] + + srm_trim = srm[index[-1]:] + + ph_ax_bins = np.column_stack((srm_dict['ph_axis'][:-1], srm_dict['ph_axis'][1:])) + + ph_ax_bins_trim = ph_ax_bins[index[-1]:] + + ph_energies_trim = np.concatenate([ph_ax_bins_trim[:,0], ph_ax_bins_trim[:,1][-1:]]) + # return srm_trim, ph_energies_trim + + print(ph_ax_bins_trim) + + + meta.add("exposure_time", np.sum(t_norm)) + meta.add("geo_area", srm_dict['geo_area']) + meta.add("date-obs", data_dict['times']) + meta.add("angle",flare_angle*u.deg) + meta.add("distance",distance) + meta.add("srm",srm_trim) + meta.add("ph_axis",ph_energies_trim*u.keV) + + spec_1d = Spectrum(data=counts_final,uncertainty=counts_uncertainity_pu, spectral_axis=counts_spectral_axis, meta=meta) + + return spec_1d + +# NEED TO ADD IN THE DIST AND ANGLE READING AND BKG_SUBTRACT, GETTING THERE THOUGH!!! + + def bkg_subtract(self, bkg_data): + + spec_unsub = self.get_spectrum() + + times = spec_unsub['times'] + time_bins = spec_unsub['time_bin'] + rate= spec_unsub['rate'] + rate_err = spec_unsub['rate_err'] + energies = spec_unsub['energies'] + + bkg_time_bin = bkg_data['time_bin'] + # print(bkg_data['rate']) + bkg_rate = (bkg_data['rate'] / bkg_time_bin) * time_bins[:,np.newaxis] + +# LIVETIME CORRECTED + EFFECTIVE LIVEWTIME MIGHT WORK + + # print(bkg_rate) + bkg_rate_err = (bkg_data['rate_err'] / bkg_time_bin) * time_bins[:,np.newaxis] + bkg_energies = bkg_data['energies'] + + _, _, indices_sub = np.intersect1d( + energies['e_low'], + bkg_energies['e_low'], + return_indices=True + ) + + # print(energies) + # print(bkg_energies) + + print('e_bkg = ',bkg_energies) + + print('e_dat = ',energies) + print('in_sub = ',indices_sub) + + bkg_rate_sub = bkg_rate[:,indices_sub] + bkg_rate_err_sub = bkg_rate_err[:,indices_sub] + + rate_sub = rate - bkg_rate_sub + rate_err_sub = np.sqrt((rate_err.value**2) + (bkg_rate_err_sub.value**2)) *u.ct + + spec_sub = {'rate_bkg_sub':rate_sub, + 'rate_err_bkg_sub':rate_err_sub, + 'rate':rate, + 'rate_err':rate_err, + 'rate_bkg':bkg_rate, + 'rate_err_bkg':bkg_rate_err, + 'times':times, + 'time_bin':time_bins, + 'energies':energies} + + return spec_sub + + def get_masked_srm(self,flare_location): + + PATH_DRM = '/home/jmitchell/software/stixpy-dev/stixpy/config/data/detector/' + drm = np.load(PATH_DRM+'stx_drm_energy.npy') + ph_energies = np.load(PATH_DRM+'stx_ph_edges.npy') + ct_energies = np.load(PATH_DRM+'stx_ct_edges.npy') + + # max_stix = estimate_flare_location(self,time_range) + + energies = self.energies + e_low = np.array(energies['e_low']) + + if e_low[0] == 0: + e_low = e_low[1:] + + e_high = np.array(energies['e_high']) + + e_high = e_high[~np.isnan(e_high)] + + # e_index = np.where((ph_energies >= e_low[0]) & + # (ph_energies <= e_high[-1]) )[0] + + if e_high[-1] == 150: + e_edges = e_low + ct_e_diff = np.diff(e_edges) + else: + e_edges = np.concatenate([e_low,[e_high[-1]]]) + ct_e_diff = np.diff(e_edges) + + # print('ct_shape = ',len(ct_e_diff)) + + # drm_clipped = drm[1:-1, 1:-1] + # ph_energies_clipped = ph_energies[1:-1] + + # drm_clipped = drm + # ph_energies_clipped = ph_energies + + epsilon = 1e-4 + + mask_not_in_e = ~np.isclose( + ct_energies[:, None], + e_edges[None, :], + atol=epsilon + ).any(axis=1) + + values_to_remove = ct_energies[mask_not_in_e] + + indices_to_remove = np.where( + np.isclose( + ph_energies[:, None], + values_to_remove[None, :], + atol=epsilon + ).any(axis=1) + )[0] + + + + + drm_clipped = np.delete(drm, indices_to_remove, axis=0) + drm_clipped = np.delete(drm_clipped, indices_to_remove, axis=1) + + ph_energies_clipped = np.delete(ph_energies, indices_to_remove) + + ph_e_diff = np.diff(ph_energies_clipped) + + + pixel_areas = STIX_INSTRUMENT.pixel_config["Area"].to("cm2") + + det_indices_top24 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 13, 14, 15, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + + det_indices_full = np.where(self.detector_masks.__dict__['masks'] == 1 )[1] + + det_indices = [d for i,d in enumerate(det_indices_top24) if d in det_indices_full] + + pix_indices = np.where(self.pixel_masks.__dict__['masks'] == 1 )[1] + + pixel_areas = pixel_areas[pix_indices].value + + area_scale = len(det_indices)*np.sum(pixel_areas) + + energy_widths = np.diff(ph_energies_clipped) + + e_mids = ph_energies_clipped[:-1] + (energy_widths / 2) + + trans = Transmission() + + tot_trans = trans.get_transmission(energies=e_mids * u.keV) + + attenuation = np.zeros(len(tot_trans["det-1"])) + + for i,det in enumerate(det_indices): + attenuation += tot_trans[f'det-{det}'] + + attenuation = attenuation / len(det_indices) + + # drm_clipped = ((drm_clipped * attenuation[:,None] ) * area_scale) + drm_clipped = ((drm_clipped * ph_e_diff[None,:] * attenuation[:,None] )) + + drm_new = [] + + for j in range(np.shape(drm_clipped)[0]): + + working = [] + + for i in range(len(e_edges)-1): + + indices_sum = np.where((ph_energies_clipped >= e_edges[i]) & + (ph_energies_clipped < e_edges[i+1]))[0] + + tot = drm_clipped[j,indices_sum].sum(axis=0) + + working.append(tot) + + drm_new.append(working) + + drm_new = np.array(drm_new) + + # tr = TimeRange(vis.meta.time_range) + # roll, solo_heeq, stix_pointing = get_hpc_info(vis.meta.time_range[0], vis.meta.time_range[1]) + # solo_coord = HeliographicStonyhurst(solo_heeq, representation_type="cartesian", obstime=tr.center) + # flare_location = flare_location.transform_to(STIXImaging(obstime=tr.center, observer=solo_coord)) + + grid_transmission = get_grid_transmission(e_mids, flare_location) + + grid_transmission = grid_transmission.mean(axis=1) + + srm = (drm_new * grid_transmission[:,None]) / ct_e_diff[None,:] + # srm = (drm_new * grid_transmission[:,None]) + + return {'srm':srm,'ph_axis':ph_energies_clipped,'geo_area':area_scale} - return counts, counts_err, times, t_norm, energies def concatenate(self, others): """ @@ -1355,3 +1953,71 @@ def __init__(self, *args, format_func=None, **kwargs): if format_func is not None: self._format = format_func super().__init__(*args, **kwargs) + + +def calc_count_rate(dat): + + rate, rate_err, _, t_norm_cs, energies, _, cor = dat + + de = np.array(energies['e_high'] - energies['e_low']) + + rate = np.array(rate) + rate_err = np.array(rate_err) + + t_norm = t_norm_cs.to(u.s).value + + counts_kev = rate * t_norm_cs + counts_err_kev = rate_err * t_norm_cs + + counts = counts_kev * de + counts_err = counts_err_kev * de + + result_count_rate = counts / t_norm + result_count_rate_err = counts_err / t_norm + + result_count_rate = result_count_rate[:, :, :8, :].sum(axis=(1,2)) * cor + result_count_rate_err = result_count_rate_err[:, :, :8, :].sum(axis=(1,2)) * cor + + return result_count_rate, result_count_rate_err + + + +# pixel_areas = STIX_INSTRUMENT.pixel_config["Area"].to("cm2") + +# print(pixel_areas) + +# det_indices_top24 = np.array([0, 1, 2, 3, 4, 5, 6, 7, 13, 14, 15, 19, +# 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + +# det_indices_full = np.where(self.detector_masks.__dict__['masks'] == 1 )[1] + +# det_indices = [d for i,d in enumerate(det_indices_top24) if d in det_indices_full] + +# # pix_indices = np.where(self.pixel_masks.__dict__['masks'] == 1 )[1] +# pix_indices = [0,1,2,3,4,5,6,7] + +# pixel_areas = pixel_areas[pix_indices].value + +# area_scale = len(det_indices)*np.sum(pixel_areas) + +# energy_widths = np.diff(ph_energies) + +# e_mids = ph_energies[:-1] + (energy_widths / 2) + +# trans = Transmission() + +# print('e_mids = ',e_mids) + +# tot_trans = trans.get_transmission(energies=e_mids * u.keV) + +# # attenuation = tot_trans["det-0"] + +# attenuation = np.zeros(len(tot_trans["det-1"])) + +# for i,det in enumerate(det_indices): +# attenuation += tot_trans[f'det-{det}'] + +# attenuation = attenuation / len(det_indices) + + +# srm = ((drm_new * attenuation[:, None] ) * area_scale) / 4