Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scopesim/effects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .spectral_trace_list import *
from .spectral_efficiency import *
from .metis_lms_trace_list import *
from .mosaic_trace_list import *
from .surface_list import *
from .ter_curves import *
from . import ter_curves_utils
Expand Down
1 change: 1 addition & 0 deletions scopesim/effects/metis_lms_trace_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def apply_to(self, obj, **kwargs):

if isinstance(obj, FieldOfView):
# Application to field of view
logger.debug("Executing %s, FoV", self.meta['name'])
if obj.hdu is not None and obj.hdu.header["NAXIS"] == 3:
obj.cube = obj.hdu
elif obj.hdu is None and obj.cube is None:
Expand Down
252 changes: 252 additions & 0 deletions scopesim/effects/mosaic_trace_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
"""SpectralTraceList and SpectralTrace for MOSAIC"""
from tqdm.auto import tqdm
from typing import ClassVar

import numpy as np
from astropy.table import Table
from astropy import units as u
from astropy.io import fits
from astropy.wcs import WCS
from astropy.modeling import fitting
from astropy.modeling.models import Polynomial1D
from synphot import SourceSpectrum, Empirical1D
from .spectral_trace_list import SpectralTraceList
from .spectral_trace_list_utils import SpectralTrace

from ..utils import get_logger, quantify, power_vector
from ..optics.fov import FieldOfView
from ..optics.fov_volume_list import FovVolumeList
from ..detector import Detector

logger = get_logger(__name__)




class MosaicSpectralTraceList(SpectralTraceList):
"""SpectralTraceList for MOSAIC"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.aplist = self._file["Aperture List"].data
# TODO: check units or normalise to arcsec
self.view = np.array([(self.aplist["right"].max() -
self.aplist["left"].min()),
(self.aplist["top"].max() -
self.aplist["bottom"].min())])

def apply_to(self, obj, **kwargs):
"""See parent docstring."""
### This is copied from MetisSpectralTraceList, make less redundant?
if isinstance(obj, FovVolumeList):
logger.debug("Executing %s, FoV setup", self.meta['name'])
# Create a single volume that covers the aperture and
# the maximum wavelength range of the grating
volumes = [spectral_trace.fov_grid()
for spectral_trace in self.spectral_traces.values()]
wave_min = min(vol["wave_min"] for vol in volumes)
wave_max = max(vol["wave_max"] for vol in volumes)
extracted_vols = obj.extract(axes=["wave"],
edges=([[wave_min, wave_max]]))
obj.volumes = extracted_vols

if isinstance(obj, FieldOfView):
# Application to field of view
logger.debug("Executing %s, FoV", self.meta['name'])
if obj.hdu is not None and obj.hdu.header["NAXIS"] == 3:
obj.cube = obj.hdu
elif obj.hdu is None and obj.cube is None:
obj.cube = obj.make_cube_hdu()

fovcube = obj.cube.data
n_z = fovcube.shape[0]
fovwcs = WCS(obj.cube.header)
# Make this linear to avoid jump at RA 0 deg
fovwcs.wcs.ctype = ["LINEAR", "LINEAR", fovwcs.wcs.ctype[2]]
fovwcs_spat = fovwcs.sub(2)
fovwcs_spec = fovwcs.spectral
fovlam = fovwcs_spec.all_pix2world(np.arange(n_z), 0)[0]
fovlam <<= u.Unit(fovwcs_spec.wcs.cunit[0])

det_header = obj.detector_header
detwcs = WCS(det_header, key='D')
naxis1d, naxis2d = det_header["NAXIS1"], det_header["NAXIS2"]

## This is the place where we need to look at the apertures
## - collapse each aperture to 1D spectrum by integrating spatially
## - map each 1D spectrum to detector/fov

image = np.zeros((naxis2d, naxis1d), dtype=np.float32)

for sptid, spt in tqdm(self.spectral_traces.items(),
desc="Fiber traces", position=2):
theap = self.aplist[self.aplist['id'] == sptid]

# solid angle in arcsec**2
solid_angle = ((theap["right"] - theap["left"]) *
(theap["top"] - theap["bottom"]))

# apertures are defined in arcsec. fovwcs is in degrees
xmin, xmax, ymin, ymax = (theap["left"]/3600, theap["right"]/3600,
theap["bottom"]/3600, theap["top"]/3600)

imin = max(0, int(np.round(fovwcs_spat.all_world2pix(xmin, 0, 0)[0][0])))
imax = int(np.round(fovwcs_spat.all_world2pix(xmax, 0, 0)[0][0]))
jmin = max(0, int(np.round(fovwcs_spat.all_world2pix(0, ymin, 0)[1][0])))
jmax = int(np.round(fovwcs_spat.all_world2pix(0, ymax, 0)[1][0]))

# Average over the spatial dimensions of the aperture (still per arcsec2)
fovflux = fovcube[:, jmin:jmax, imin:imax].mean(axis=(1,2)) * solid_angle
spec = SourceSpectrum(Empirical1D, points=fovlam.to(u.um),
lookup_table=fovflux)

# Need to interpolate this to the output wavelength grid
detlam = spt.x2lam(detwcs.all_pix2world(np.arange(naxis1d), 0, 0)[0])
detlam <<= u.um
yvals = spt.lam2y(detlam.value)
jfib = detwcs.all_world2pix(0, yvals.mean(), 0)[1].astype(int)
logger.debug("Flux from %s: %.4g", spt.trace_id, spec(detlam).value.sum())

detdisp = np.diff(detlam, prepend=detlam[0])
image[jfib,] += (spec(detlam) * detdisp).value

image_hdr = detwcs.to_header()
image_hdr["BUNIT"] = "ph s-1"
image_hdr.extend(det_header)
obj.hdu = fits.ImageHDU(data=image, header=image_hdr)
return obj



def make_spectral_traces(self):
"""Return a dictionary of spectral traces read in from a file."""
self.ext_data = self._file[0].header["EDATA"]
self.ext_cat = self._file[0].header["ECAT"]
self.catalog = Table(self._file[self.ext_cat].data)
spec_traces = {}
for row in self.catalog:
# image_plane_id = -1 marks rows that should not be read,
# e.g. the aperture list. Although not necessary if the catalogue
# is formatted in a way that only traces are listed, this provides
# a possibility to "mask" traces.
if row["image_plane_id"] == -1:
continue
params = {col: row[col] for col in row.colnames}
params.update(self.meta)
hdu = self._file[row["extension_id"]]
spec_traces[row["description"]] = MosaicSpectralTrace(hdu, **params)

self.spectral_traces = spec_traces


class MosaicSpectralTrace(SpectralTrace):
"""A single spectral trace for MOSAIC"""

def __init__(self, trace_tbl, **kwargs):
super().__init__(trace_tbl, **kwargs)

def compute_interpolation_functions(self):
x_arr = self.table[self.meta["x_colname"]]
y_arr = self.table[self.meta["y_colname"]]
#xi_arr = self.table[self.meta["s_colname"]]
lam_arr = self.table[self.meta["wave_colname"]]

self.wave_min = quantify(np.min(lam_arr), u.um).value
self.wave_max = quantify(np.max(lam_arr), u.um).value

self.lam2x = Transform1D.fit(lam_arr, x_arr, degree=2)
self.x2lam = Transform1D.fit(x_arr, lam_arr, degree=2)
self.lam2y = Transform1D.fit(lam_arr, y_arr, degree=2)

class Transform1D():
"""
1-dimensional polynomial transform.
"""

def __init__(self, coeffs, pretransform=None,
posttransform=None):
self.coeffs = np.asarray(coeffs)
self.nx = self.coeffs.shape[0]
self.pretransform = self._repackage(pretransform)
self.posttransform = self._repackage(posttransform)

def _repackage(self, trafo):
"""Make sure `trafo` is a tuple."""
if trafo is not None and not isinstance(trafo, tuple):
trafo = (trafo, {})
return trafo

def __call__(self, x, **kwargs):
"""
Apply the polynomial transform.

The transformation is a polynomial based on the simple monomials x^i.
"""

if "pretransform" in kwargs:
self.pretransform = self._repackage(kwargs["pretransform"])
if "postransform" in kwargs:
self.posttransform = self._repackage(kwargs["posttransform"])

x = np.array(x)

# Apply pre transform
if self.pretransform is not None:
x = self.pretransform[0](x, **self.pretransform[1])

xvec = power_vector(x, self.nx - 1)

result = self.coeffs @ xvec

# Apply posttransform
if self.posttransform is not None:
result = self.posttransform[0](result, **self.posttransform[1])

return result

@classmethod
def fit(cls, xin, xout, degree=4):
"""Determine polynomial fit"""
pinit = Polynomial1D(degree=degree)
fitter = fitting.LinearLSQFitter()
fit = fitter(pinit, xin, xout)
return Transform1D(fit.parameters)

def gradient(self):
"""Compute the gradient of a 1d polynomial transformation"""
coeffs = self.coeffs

dcoeffs = (coeffs * np.arange(self.nx))[1:]
return Transform1D(dcoeffs)


class MosaicCollapseSpectralTraces(MosaicSpectralTraceList):
"""Collapse SpectralTraces to 1D spectrum"""
required_keys = {"filename"}
z_order: ClassVar[tuple[int, ...]] = (899,)

def __init__(self, **kwargs):
super().__init__(**kwargs)

def apply_to(self, det, **kwargs):
"""Apply to detector readout"""
if not isinstance(det, Detector):
return det

image = det._hdu.data
detwcs = WCS(det._hdu.header, key='D')
spec = np.zeros(image.shape[1], dtype=np.float32)
for sptid, spt in tqdm(self.spectral_traces.items(),
desc="Fiber traces", position=2):
y_mm = spt.table['y'][0]
jfib = int(detwcs.all_world2pix(0, y_mm, 0)[1])
spec += image[jfib,]

x_mm = detwcs.all_pix2world(np.arange(image.shape[1]), 1, 0)[0]
lam = spt.x2lam(x_mm)
det._hdu = fits.BinTableHDU.from_columns([
fits.Column(name='wavelength', format='D', array=lam, unit='um'),
fits.Column(name='spectrum', format='D', array=spec, unit='ADU')])
return det
64 changes: 64 additions & 0 deletions scopesim/tests/tests_effects/test_mosaic_trace_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Unit tests for mosaic_trace_list.py"""

# pylint: disable=missing-function-docstring
# pylint: disable=invalid-name
# pylint: disable=too-few-public-methods
import pytest

import numpy as np

from astropy.io import fits

from scopesim.utils import power_vector
from scopesim.effects.mosaic_trace_list import Transform1D

@pytest.fixture(name="tf1d", scope="class")
def fixture_tf1d():
"""Instantiate a Transform1D"""
coeffs = np.array([2, -1, 1])
return Transform1D(coeffs)

@pytest.fixture(name="quadratic", scope="class")
def fixture_quadratic():
"""Quadratic model, analytic and coeffients"""
coeffs = np.array([1, -1, 2])

def quadfunc(x):
z_a = 1 - 1 * x + 2 * x**2
return z_a

def dquad_dx(x):
return -1 + 4 * x

return {'coeffs': coeffs,
'function': quadfunc,
'gradient': dquad_dx}

class TestTransform1D:
"""Tests for Transform1D()"""
def test_initialises_with_coeffs(self, tf1d):
assert isinstance(tf1d, Transform1D)

def test_call_gives_correct_result(self, quadratic):
x = np.random.randn()

# coefficients and explicit function
tf1d = Transform1D(quadratic['coeffs'])
assert tf1d(x) == quadratic['function'](x)

def test_gradient_gives_correct_result(self, quadratic):
x = np.random.randn()

tf2d = Transform1D(quadratic['coeffs'])
tf2d_grad = tf2d.gradient()

assert tf2d_grad(x) == quadratic['gradient'](x)

def test_fit_gives_correct_coeffs(self):
x = np.linspace(0, 1, 10)
y = 1 - 0.5 * x + 2.3 * x**2 - 3 * x**3

coeffs = np.array([1, -0.5, 2.3, -3])
tf1d = Transform1D.fit(x, y, degree=3)

assert tf1d.coeffs == pytest.approx(coeffs)
Loading