Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/76.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `~xrayvision.coordinates.frames.Projective` coordinate frame. This is intended to represent the projective coordinate system of images with unknown pointing.
1 change: 1 addition & 0 deletions changelog/77.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Enable ``phase_center`` inputs to functions that calculate visibilities to be given as `astropy.coordinates.SkyCoord`.
3 changes: 3 additions & 0 deletions examples/stix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import astropy.units as apu
import matplotlib.pyplot as plt
import numpy as np
from astropy.coordinates import SkyCoord

from xrayvision.clean import vis_clean
from xrayvision.coordinates.frames import Projective
from xrayvision.imaging import vis_psf_map, vis_to_map
from xrayvision.mem import mem, resistant_mean

Expand All @@ -29,6 +31,7 @@
time_range = stix_data["time_range"]
energy_range = stix_data["energy_range"]
stix_vis = stix_data["stix_visibilities"]
stix_vis.phase_center = SkyCoord(Tx=stix_vis.phase_center[1], Ty=stix_vis.phase_center[0], frame=Projective)

Comment on lines +34 to 35
Copy link

Copilot AI Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing a SkyCoord object as if it were a sequence may result in an error since phase_center is expected to be a scalar SkyCoord. Use the existing attributes (e.g. phase_center.Tx and phase_center.Ty) instead of indexing.

Suggested change
stix_vis.phase_center = SkyCoord(Tx=stix_vis.phase_center[1], Ty=stix_vis.phase_center[0], frame=Projective)
stix_vis.phase_center = SkyCoord(Tx=stix_vis.phase_center.Tx, Ty=stix_vis.phase_center.Ty, frame=Projective)

Copilot uses AI. Check for mistakes.
###############################################################################
# Lets have a look at the point spread function (PSF) or dirty beam
Expand Down
Empty file.
122 changes: 122 additions & 0 deletions xrayvision/coordinates/frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import astropy.coordinates
import astropy.units as u
from astropy.wcs import WCS
from sunpy.coordinates.frameattributes import ObserverCoordinateAttribute
from sunpy.coordinates.frames import HeliographicStonyhurst, SunPyBaseCoordinateFrame
from sunpy.sun.constants import radius as _RSUN

__all__ = ["Projective"]


X_CTYPE = "PJLN"
Y_CTYPE = "PJLT"


class Projective(SunPyBaseCoordinateFrame):
"""A generic projective coordinate frame for an image taken by an arbitrary imager."""

observer = ObserverCoordinateAttribute(HeliographicStonyhurst)
rsun = astropy.coordinates.QuantityAttribute(default=_RSUN, unit=u.km)
frame_specific_representation_info = {
astropy.coordinates.SphericalRepresentation: [
astropy.coordinates.RepresentationMapping("lon", u.arcsec),
astropy.coordinates.RepresentationMapping("lat", u.arcsec),
astropy.coordinates.RepresentationMapping("distance", "distance"),
],
astropy.coordinates.UnitSphericalRepresentation: [
astropy.coordinates.RepresentationMapping("lon", u.arcsec),
astropy.coordinates.RepresentationMapping("lat", u.arcsec),
],
}


def projective_wcs_to_frame(wcs):
r"""
This function registers the coordinate frames to their FITS-WCS coordinate
type values in the `astropy.wcs.utils.wcs_to_celestial_frame` registry.

Parameters
----------
wcs : `astropy.wcs.WCS`

Returns
-------
: `Projective`
"""
if hasattr(wcs, "coordinate_frame"):
return wcs.coordinate_frame

# Not a lat,lon coordinate system bail out early
if set(wcs.wcs.ctype) != {X_CTYPE, Y_CTYPE}:
return None

dateavg = wcs.wcs.dateobs

rsun = wcs.wcs.aux.rsun_ref
if rsun is not None:
rsun *= u.m

hgs_longitude = wcs.wcs.aux.hgln_obs
hgs_latitude = wcs.wcs.aux.hglt_obs
hgs_distance = wcs.wcs.aux.dsun_obs

observer = HeliographicStonyhurst(
lat=hgs_latitude * u.deg, lon=hgs_longitude * u.deg, radius=hgs_distance * u.m, obstime=dateavg, rsun=rsun
)

frame_args = {"obstime": dateavg, "observer": observer, "rsun": rsun}

return Projective(**frame_args)


def projective_frame_to_wcs(frame, projection="TAN"):
r"""
For a given frame, this function returns the corresponding WCS object.

It registers the WCS coordinates types from their associated frame in the
`astropy.wcs.utils.celestial_frame_to_wcs` registry.

Parameters
----------
frame : `Projective`
projection : `str`, optional

Returns
-------
`astropy.wcs.WCS`
"""
# Bail out early if not STIXImaging frame
if not isinstance(frame, Projective):
return None

wcs = WCS(naxis=2)
wcs.wcs.aux.rsun_ref = frame.rsun.to_value(u.m)

# Sometimes obs_coord can be a SkyCoord, so convert down to a frame
obs_frame = frame.observer
if hasattr(obs_frame, "frame"):
obs_frame = frame.observer.frame

if obs_frame:
wcs.wcs.aux.hgln_obs = obs_frame.lon.to_value(u.deg)
wcs.wcs.aux.hglt_obs = obs_frame.lat.to_value(u.deg)
wcs.wcs.aux.dsun_obs = obs_frame.radius.to_value(u.m)

if frame.obstime:
wcs.wcs.dateobs = frame.obstime.utc.iso
wcs.wcs.cunit = ["arcsec", "arcsec"]
wcs.wcs.ctype = [X_CTYPE, Y_CTYPE]

return wcs


# Remove once min version of sunpy has https://github.com/sunpy/sunpy/pull/7594
astropy.wcs.utils.WCS_FRAME_MAPPINGS.insert(-1, [projective_wcs_to_frame])
astropy.wcs.utils.FRAME_WCS_MAPPINGS.insert(-1, [projective_frame_to_wcs])

PROJECTIVE_CTYPE_TO_UCD1 = {
"PJLT": "custom:pos.projective.lat",
"PJLN": "custom:pos.projective.lon",
"PJRZ": "custom:pos.projective.z",
}
astropy.wcs.wcsapi.fitswcs.CTYPE_TO_UCD1.update(PROJECTIVE_CTYPE_TO_UCD1)
Empty file.
73 changes: 73 additions & 0 deletions xrayvision/coordinates/tests/test_frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import astropy.units as u
import numpy as np
import pytest
from astropy.wcs import WCS
from sunpy.coordinates import HeliographicStonyhurst

from xrayvision.coordinates.frames import Projective, projective_frame_to_wcs, projective_wcs_to_frame


@pytest.fixture
def projective_wcs():
w = WCS(naxis=2)

w.wcs.dateobs = "2024-01-01 00:00:00.000"
w.wcs.crpix = [10, 20]
w.wcs.cdelt = np.array([2, 2])
w.wcs.crval = [0, 0]
w.wcs.ctype = ["PJLN", "PJLT"]

w.wcs.aux.hgln_obs = 10
w.wcs.aux.hglt_obs = 20
w.wcs.aux.dsun_obs = 1.5e11

return w


@pytest.fixture
def projective_frame():
obstime = "2024-01-01"
observer = HeliographicStonyhurst(10 * u.deg, 20 * u.deg, 1.5e11 * u.m, obstime=obstime)

frame_args = {"obstime": obstime, "observer": observer, "rsun": 695_700_000 * u.m}

frame = Projective(**frame_args)
return frame


def test_projective_wcs_to_frame(projective_wcs):
frame = projective_wcs_to_frame(projective_wcs)
assert isinstance(frame, Projective)

assert frame.obstime.isot == "2024-01-01T00:00:00.000"
assert frame.rsun == 695700 * u.km
assert frame.observer == HeliographicStonyhurst(
10 * u.deg, 20 * u.deg, 1.5e11 * u.m, obstime="2024-01-01T00:00:00.000"
)


def test_projective_wcs_to_frame_none():
w = WCS(naxis=2)
w.wcs.ctype = ["ham", "cheese"]
frame = projective_wcs_to_frame(w)

assert frame is None


def test_projective_frame_to_wcs(projective_frame):
wcs = projective_frame_to_wcs(projective_frame)

assert isinstance(wcs, WCS)
assert wcs.wcs.ctype[0] == "PJLN"
assert wcs.wcs.cunit[0] == "arcsec"
assert wcs.wcs.dateobs == "2024-01-01 00:00:00.000"

assert wcs.wcs.aux.rsun_ref == projective_frame.rsun.to_value(u.m)
assert wcs.wcs.aux.dsun_obs == 1.5e11
assert wcs.wcs.aux.hgln_obs == 10
assert wcs.wcs.aux.hglt_obs == 20


def test_projective_frame_to_wcs_none():
wcs = projective_frame_to_wcs(HeliographicStonyhurst())
assert wcs is None
53 changes: 31 additions & 22 deletions xrayvision/imaging.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Optional
from typing import Union, Optional

import astropy.units as apu
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.units import Quantity
from astropy.wcs.utils import celestial_frame_to_wcs
from sunpy.map import GenericMap, Map

from xrayvision.coordinates.frames import Projective
from xrayvision.transform import dft_map, idft_map
from xrayvision.visibility import Visibilities

Expand Down Expand Up @@ -91,13 +94,14 @@ def validate_and_expand_kwarg(q: Quantity, name: Optional[str] = "") -> Quantity
return q


@apu.quantity_input
def image_to_vis(
image: Quantity,
*,
u: Quantity[apu.arcsec**-1],
v: Quantity[apu.arcsec**-1],
phase_center: Optional[Quantity[apu.arcsec]] = (0.0, 0.0) * apu.arcsec,
phase_center: Union[SkyCoord, Quantity[apu.arcsec]] = SkyCoord(
Tx=0.0 * apu.arcsec, Ty=0.0 * apu.arcsec, frame=Projective
),
pixel_size: Optional[Quantity[apu.arcsec / apu.pix]] = 1.0 * apu.arcsec / apu.pix,
) -> Visibilities:
r"""
Expand All @@ -111,7 +115,7 @@ def image_to_vis(
Array of u coordinates where the visibilities will be evaluated
v :
Array of v coordinates where the visibilities will be evaluated
phase_center :
phase_center : `astropy.coordinates.SkyCoord`
The coordinates the phase_center.
pixel_size :
Size of pixels, if only one value is passed, assume square pixels (repeating the value).
Expand All @@ -126,7 +130,11 @@ def image_to_vis(
if not (apu.get_physical_type((1 / u).unit) == ANGLE and apu.get_physical_type((1 / v).unit) == ANGLE):
raise ValueError("u and v must be inverse angle (e.g. 1/deg or 1/arcsec")
vis = dft_map(
image, u=u, v=v, phase_center=[0.0, 0.0] * apu.arcsec, pixel_size=pixel_size
image,
u=u,
v=v,
phase_center=SkyCoord(Tx=0.0 * apu.arcsec, Ty=0.0 * apu.arcsec, frame=Projective),
pixel_size=pixel_size,
) # TODO: adapt to generic map center
return Visibilities(vis, u=u, v=v, phase_center=phase_center)

Expand Down Expand Up @@ -169,7 +177,9 @@ def vis_to_image(
shape=shape,
weights=weights,
pixel_size=pixel_size,
phase_center=[0.0, 0.0] * apu.arcsec, # TODO update to have generic image center
phase_center=SkyCoord(
Tx=0.0 * apu.arcsec, Ty=0.0 * apu.arcsec, frame=Projective
), # TODO update to have generic image center
)

return bp_arr
Expand Down Expand Up @@ -308,19 +318,19 @@ def generate_header(vis: Visibilities, *, shape: Quantity[apu.pix], pixel_size:
-------
:
"""
header = {
"crval1": (vis.phase_center[1]).to_value(apu.arcsec),
"crval2": (vis.phase_center[0]).to_value(apu.arcsec),
"cdelt1": (pixel_size[1] * apu.pix).to_value(apu.arcsec),
"cdelt2": (pixel_size[0] * apu.pix).to_value(apu.arcsec),
"ctype1": "HPLN-TAN",
"ctype2": "HPLT-TAN",
"naxis": 2,
"naxis1": shape[1].value,
"naxis2": shape[0].value,
"cunit1": "arcsec",
"cunit2": "arcsec",
}
cunit = apu.arcsec
phase_center = vis.phase_center
header = celestial_frame_to_wcs(phase_center.frame).to_header()
header["crval1"] = (phase_center.Tx).to_value(cunit)
header["crval2"] = (phase_center.Ty).to_value(cunit)
header["crpix1"] = shape[1].to_value(apu.pix) / 2
header["crpix2"] = shape[0].to_value(apu.pix) / 2
header["cdelt1"] = (pixel_size[1] * apu.pix).to_value(cunit)
header["cdelt2"] = (pixel_size[0] * apu.pix).to_value(cunit)
header["naxis1"] = shape[1].value
header["naxis2"] = shape[0].value
header["cunit1"] = str(cunit)
header["cunit2"] = str(cunit)
return header


Expand Down Expand Up @@ -352,15 +362,14 @@ def map_to_vis(amap: GenericMap, *, u: Quantity[1 / apu.arcsec], v: Quantity[1 /
new_pos[1] = float(meta["crval1"])
if "crval2" in meta:
new_pos[0] = float(meta["crval2"])
new_pos = SkyCoord(Tx=new_pos[1] * apu.arcsec, Ty=new_pos[0] * apu.arcsec, frame=Projective)

new_psize = np.array([1.0, 1.0])
if "cdelt1" in meta:
new_psize[1] = float(meta["cdelt1"])
if "cdelt2" in meta:
new_psize[0] = float(meta["cdelt2"])

vis = image_to_vis(
amap.quantity, u=u, v=v, pixel_size=new_psize * apu.arcsec / apu.pix, phase_center=new_pos * apu.arcsec
)
vis = image_to_vis(amap.quantity, u=u, v=v, pixel_size=new_psize * apu.arcsec / apu.pix, phase_center=new_pos)

return vis
13 changes: 8 additions & 5 deletions xrayvision/tests/test_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import numpy as np
import pytest
from astropy.convolution.kernels import Gaussian2DKernel
from numpy.testing import assert_allclose, assert_array_equal
from astropy.coordinates import SkyCoord
from numpy.testing import assert_allclose
from sunpy.map import Map

from xrayvision.coordinates.frames import Projective
from xrayvision.imaging import image_to_vis, map_to_vis, vis_psf_image, vis_to_image, vis_to_map
from xrayvision.transform import dft_map, generate_uv, idft_map
from xrayvision.visibility import Visibilities
Expand Down Expand Up @@ -127,7 +129,7 @@ def test_image_to_vis():

# For an empty map visibilities should all be zero (0+0j)
empty_vis = image_to_vis(image, u=v, v=v)
assert np.array_equal(empty_vis.phase_center, (0.0, 0.0) * apu.arcsec)
assert empty_vis.phase_center == SkyCoord(Tx=0.0 * apu.arcsec, Ty=0.0 * apu.arcsec, frame=Projective)
assert np.array_equal(empty_vis.visibilities, np.zeros(n * m, dtype=complex))


Expand All @@ -144,8 +146,9 @@ def test_image_to_vis_center():
u, v = np.array([u, v]).reshape(2, size) / apu.arcsec

# For an empty map visibilities should all be zero (0+0j)
empty_vis = image_to_vis(image, u=u, v=v, phase_center=(2.0, -3.0) * apu.arcsec)
assert np.array_equal(empty_vis.phase_center, (2.0, -3.0) * apu.arcsec)
phase_center = SkyCoord(Tx=2 * apu.arcsec, Ty=-3.0 * apu.arcsec, frame=Projective)
empty_vis = image_to_vis(image, u=u, v=v, phase_center=phase_center)
assert empty_vis.phase_center == phase_center
assert np.array_equal(empty_vis.visibilities, np.zeros(n * m, dtype=complex))


Expand Down Expand Up @@ -181,7 +184,7 @@ def test_map_to_vis(pos, pixel):
mp = Map((data, header))
vis = map_to_vis(mp, u=u, v=v)

assert_array_equal(vis.phase_center, pos)
assert vis.phase_center == SkyCoord(Tx=pos[1], Ty=pos[0], frame=Projective)

res = vis_to_image(vis, shape=(m, n) * apu.pixel, pixel_size=pixel)
assert_allclose(res, data)
Expand Down
Loading