diff --git a/.editorconfig b/.editorconfig index a2c6683..593ac6f 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,6 @@ trim_trailing_whitespace=true indent_style=space indent_size=4 -[*.yml] +[{*.yml,*.toml}] indent_style=space indent_size=2 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c0842b2..9fa3f2a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,15 +2,15 @@ name: CI on: push: -# branches: -# - 'main' -# - '*.*' -# - '!*backport*' -# tags: -# - 'v*' -# - '!*dev*' -# - '!*pre*' -# - '!*post*' + branches: + - 'main' + - '*.*' + - '!*backport*' + tags: + - 'v*' + - '!*dev*' + - '!*pre*' + - '!*post*' pull_request: # Allow manual runs through the web UI workflow_dispatch: diff --git a/changelog/83.feature.rst b/changelog/83.feature.rst new file mode 100644 index 0000000..79be450 --- /dev/null +++ b/changelog/83.feature.rst @@ -0,0 +1 @@ +Add a basic visibility forward fitting method (`xrayvision.vis_forward_fit.forward_fit.vis_forward_fit`). diff --git a/docs/reference/forward_fit.rst b/docs/reference/forward_fit.rst new file mode 100644 index 0000000..bdb0041 --- /dev/null +++ b/docs/reference/forward_fit.rst @@ -0,0 +1,14 @@ +.. vis_forward_fit: + +Vis Forward Fit ('xrayvision.vis_forward_fit') +********************************************** + +The ``vis_forward_fit`` submodule contains the visibility forward fitting methods + +.. automodapi:: xrayvision.vis_forward_fit + +.. automodapi:: xrayvision.vis_forward_fit.forward_fit + :include-all-objects: + +.. automodapi:: xrayvision.vis_forward_fit.sources + :include-all-objects: diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 389d57b..cf05ae7 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -11,5 +11,6 @@ Reference transform utils visibility + forward_fit ../whatsnew/index diff --git a/examples/rhessi.py b/examples/rhessi.py index b7232c7..572bb2d 100644 --- a/examples/rhessi.py +++ b/examples/rhessi.py @@ -4,6 +4,7 @@ ====================================== Create images from RHESSI visibility data + """ import astropy.units as apu diff --git a/examples/stix.py b/examples/stix.py index e4a36a3..d2f5aa4 100644 --- a/examples/stix.py +++ b/examples/stix.py @@ -16,6 +16,8 @@ from xrayvision.clean import vis_clean from xrayvision.imaging import vis_psf_map, vis_to_map from xrayvision.mem import mem, resistant_mean +from xrayvision.vis_forward_fit.forward_fit import vis_forward_fit +from xrayvision.vis_forward_fit.sources import Source, SourceList ############################################################################### # Create images from STIX visibility data. @@ -33,13 +35,13 @@ ############################################################################### # Lets have a look at the point spread function (PSF) or dirty beam -psf_map = vis_psf_map(stix_vis, shape=(129, 129) * apu.pixel, pixel_size=2 * apu.arcsec / apu.pix, scheme="uniform") +psf_map = vis_psf_map(stix_vis, shape=(129, 129) * apu.pixel, pixel_size=1 * apu.arcsec / apu.pix, scheme="uniform") psf_map.plot() ############################################################################### # Back projection -backproj_map = vis_to_map(stix_vis, shape=(129, 129) * apu.pixel, pixel_size=2 * apu.arcsec / apu.pix, scheme="uniform") +backproj_map = vis_to_map(stix_vis, shape=(129, 129) * apu.pixel, pixel_size=1 * apu.arcsec / apu.pix, scheme="uniform") backproj_map.plot() ############################################################################### @@ -48,7 +50,7 @@ clean_map, model_map, resid_map = vis_clean( stix_vis, shape=[129, 129] * apu.pixel, - pixel_size=[2, 2] * apu.arcsec / apu.pix, + pixel_size=[1, 1] * apu.arcsec / apu.pix, clean_beam_width=20 * apu.arcsec, niter=100, ) @@ -62,17 +64,42 @@ percent_lambda = 2 / (snr_value**2 + 90) mem_map = mem( - stix_vis, shape=[129, 129] * apu.pixel, pixel_size=[2, 2] * apu.arcsec / apu.pix, percent_lambda=percent_lambda + stix_vis, shape=[129, 129] * apu.pixel, pixel_size=[1, 1] * apu.arcsec / apu.pix, percent_lambda=percent_lambda +) + +############################################################################### +# VIS_FWD_FIT + +sources = SourceList( + [ + Source( + "elliptical", + 15 * stix_vis.visibilities.unit, + 1 * apu.arcsec, + 2 * apu.arcsec, + 5 * apu.arcsec, + 2 * apu.arcsec, + 1, + ) + ] +) + +vis_fwd_map = vis_forward_fit(stix_vis, sources, shape=[129, 129] * apu.pixel, pixel_size=[1, 1] * apu.arcsec / apu.pix) + +vis_fwd_pso_map = vis_forward_fit( + stix_vis, sources, method="PSO", shape=[129, 129] * apu.pixel, pixel_size=[1, 1] * apu.arcsec / apu.pix ) -mem_map.plot() ############################################################################### # Comparison fig = plt.figure(figsize=(10, 10)) -fig.add_subplot(221, projection=psf_map) -fig.add_subplot(222, projection=backproj_map) -fig.add_subplot(223, projection=clean_map) -fig.add_subplot(224, projection=mem_map) +fig.add_subplot(231, projection=psf_map) +fig.add_subplot(232, projection=backproj_map) +fig.add_subplot(233, projection=clean_map) +fig.add_subplot(234, projection=mem_map) +fig.add_subplot(235, projection=mem_map) +fig.add_subplot(236, projection=mem_map) + axs = fig.get_axes() psf_map.plot(axes=axs[0]) axs[0].set_title("PSF") @@ -82,4 +109,9 @@ axs[2].set_title("Clean") mem_map.plot(axes=axs[3]) axs[3].set_title("MEM") +vis_fwd_map.plot(axes=axs[4]) +axs[4].set_title("VIS_FWRDFIT") +vis_fwd_pso_map.plot(axes=axs[5]) +axs[5].set_title("VIS_FWRDFIT_PSO") + plt.show() diff --git a/pyproject.toml b/pyproject.toml index 8c9b86c..d369512 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "numpy>=1.24.0", "packaging>=23.0", "scipy>=1.13", - "xarray>=2023.5.0" + "xarray>=2023.5.0", ] dynamic = ["version"] keywords = ["solar", "physics", "solar", "sun", "x-rays"] @@ -42,7 +42,10 @@ classifiers = [ map = [ "sunpy[map]>=5.1.0" ] -all = ["xrayvisim[map]"] +pso = [ + "pymoo>=0.6.1.3" +] +all = ["xrayvisim[map,pso]"] tests = [ "matplotlib>=3.8.0", "pytest-astropy>=0.11.0", diff --git a/pytest.ini b/pytest.ini index bb3bc1a..b18e67d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -40,3 +40,6 @@ filterwarnings = # Until update code need to ignore missing WCS ignore:.*:sunpy.util.exceptions.SunpyMetadataWarning ignore:.*divide by zero.*:RuntimeWarning + ignore:The.*feasible.*:DeprecationWarning + # Can't use ipython in PyCharm debugger without this + ignore::DeprecationWarning diff --git a/ruff.toml b/ruff.toml index 6b03df6..4eb1d8b 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,7 +8,7 @@ exclude = [ ] [lint] -select = ["E", "F", "W", "UP", "PT"] +select = ["E", "F", "W", "UP", "PT", "C"] extend-ignore = [ # pycodestyle (E, W) "E501", # LineTooLong # TODO! fix @@ -27,7 +27,12 @@ extend-ignore = [ "docs/conf.py" = ["E402"] "docs/*.py" = [ "INP001", # Implicit-namespace-package. The examples are not a package. + "D100" ] +"examples/*.py" = [ + "D" +] + "__init__.py" = ["E402", "F401", "F403"] "test_*.py" = ["B011", "D", "E402", "PGH001", "S101"] # Need to import clients to register them, but don't use them in file diff --git a/xrayvision/clean.py b/xrayvision/clean.py index a07f152..ef3e871 100644 --- a/xrayvision/clean.py +++ b/xrayvision/clean.py @@ -198,7 +198,6 @@ def vis_clean( map : Return a `sunpy.map.Map` by default or array only if `False` """ - dirty_map = vis_to_map(vis, shape=shape, pixel_size=pixel_size, **kwargs) dirty_beam_shape = [x.value * 3 + 1 if x.value * 3 % 2 == 0 else x.value * 3 for x in shape] * shape.unit dirty_beam = vis_psf_image(vis, shape=dirty_beam_shape, pixel_size=pixel_size, **kwargs) diff --git a/xrayvision/conftest.py b/xrayvision/conftest.py index 25f4943..d77c6df 100644 --- a/xrayvision/conftest.py +++ b/xrayvision/conftest.py @@ -1,9 +1,8 @@ # Force MPL to use non-gui backends for testing. -import matplotlib try: pass except ImportError: pass -else: - matplotlib.use("Agg") +# else: +# matplotlib.use("Agg") diff --git a/xrayvision/coordinates/frames.py b/xrayvision/coordinates/frames.py index c61d248..d42b66c 100644 --- a/xrayvision/coordinates/frames.py +++ b/xrayvision/coordinates/frames.py @@ -71,7 +71,7 @@ def projective_wcs_to_frame(wcs): observer = None for frame, attr_names in required_attrs.items(): attrs = [getattr(wcs.wcs.aux, attr_name) for attr_name in attr_names] - if all([attr is not None for attr in attrs]): + if all(attr is not None for attr in attrs): kwargs = {"obstime": dateavg} if rsun is not None: kwargs["rsun"] = rsun diff --git a/xrayvision/imaging.py b/xrayvision/imaging.py index a185d52..017ad08 100644 --- a/xrayvision/imaging.py +++ b/xrayvision/imaging.py @@ -245,7 +245,7 @@ def vis_psf_image( # Make sure psf is always odd so power is in exactly one pixel shape = [s // 2 * 2 + 1 for s in shape.to_value(apu.pix)] * shape.unit psf_arr = idft_map( - np.ones(vis.visibilities.shape) * vis.visibilities.unit, + np.ones(vis.visibilities.shape) * np.prod(pixel_size.value) * vis.visibilities.unit, u=vis.u, v=vis.v, shape=shape, diff --git a/xrayvision/mem.py b/xrayvision/mem.py index d0353a1..bfee69d 100644 --- a/xrayvision/mem.py +++ b/xrayvision/mem.py @@ -133,7 +133,6 @@ def _estimate_flux(vis, shape, pixel, maxiter=1000, tol=1e-3): Estimated total flux """ - Hv, Lip, Visib = _prepare_for_optimise(pixel, shape, vis) # PROJECTED LANDWEBER @@ -237,7 +236,6 @@ def _get_mean_visibilities(vis, shape, pixel): ------- Mean Visibilities """ - if vis.amplitude_uncertainty is None: amplitude_uncertainty = np.ones_like(vis.visibilities) else: @@ -374,7 +372,6 @@ def _proximal_operator(z, f, m, lamb, Lip, niter=250): ------- """ - # INITIALIZATION OF THE DYKSTRA - LIKE SPLITTING x = z[:] p = np.zeros_like(x) @@ -434,11 +431,11 @@ def _optimise_fb(Hv, Visib, Lip, flux, lambd, shape, pixel, maxiter, tol): Maximum number of iterations tol : Tolerance value used in the stopping rule ( || x - x_old || <= tol || x_old ||) + Returns ------- MEM Image """ - # 'f': value of the total flux of the image (taking into account the area of the pixel) f = flux / (pixel[0] * pixel[1]) # 'm': total flux divided by the number of pixels of the image diff --git a/xrayvision/tests/test_transform.py b/xrayvision/tests/test_transform.py index 52da619..298d83d 100644 --- a/xrayvision/tests/test_transform.py +++ b/xrayvision/tests/test_transform.py @@ -1,7 +1,7 @@ import astropy.units as apu import numpy as np import pytest -from numpy.fft import fft2, fftshift, ifft2, ifftshift +from numpy.fft import fft2, ifft2 from numpy.testing import assert_allclose from scipy import signal @@ -377,34 +377,36 @@ def test_phase_center_equivalence(): assert np.allclose(data, img2) -def test_fft_equivalence(): - # Odd (3, 3) so symmetric and chose shape and pixel size so xy values will run from 0 to 2 the same as in fft - # TODO: add same kind of test for even for fft2 then A[n/2] has both pos and negative nyquist frequencies - # e.g shape (2, 2), (3, 2), (2, 3) - shape = (3, 3) - pixel = (1, 1) - center = (1, 1) +@pytest.mark.parametrize("dim", ((2, 3, 4, 5, 6))) +def test_fft_equivalence(dim): + shape = np.array((dim, dim)) + pixel = np.array((1, 1)) - data = np.arange(np.prod(shape)).reshape(shape) + # In order to replicate the FFT need to make the product of x*u + y*v match mk/M + nl/N where M, N and the array + # dimensions so neex x, y to go from 0 to M-1, N-1 and u, v 0 to m-1/M, and n-1/N so need to chose parameters so + # this happens namely the center for u, v and the dft + uv_center = -1 / ((0 - shape / 2 + 0.5) / (shape * pixel)) + xy_center = -1 * (0 - shape / 2 + 0.5) vv = generate_uv( - shape[0] * apu.pix, phase_center=center[0] * apu.arcsec, pixel_size=pixel[0] * apu.arcsec / apu.pix + shape[0] * apu.pix, phase_center=uv_center[0] * apu.arcsec, pixel_size=pixel[0] * apu.arcsec / apu.pix ) uu = generate_uv( - shape[1] * apu.pix, phase_center=center[1] * apu.arcsec, pixel_size=pixel[1] * apu.arcsec / apu.pix + shape[1] * apu.pix, phase_center=uv_center[1] * apu.arcsec, pixel_size=pixel[1] * apu.arcsec / apu.pix ) u, v = np.meshgrid(uu, vv) u = u.flatten() v = v.flatten() - vis = dft_map(data, u=u, v=v, pixel_size=pixel * apu.arcsec / apu.pix, phase_center=center * apu.arcsec) + data = np.arange(np.prod(shape)).reshape(shape) + vis = dft_map(data, u=u, v=v, pixel_size=pixel * apu.arcsec / apu.pix, phase_center=xy_center * apu.arcsec) ft = fft2(data) - fts = fftshift(ft) vis = vis.reshape(shape) - # Convention in xrayvison has the minus sign on the forward transform but numpy has it on reverse + # Convention in xrayvison has the minus sign on the forward transform but numpy has it on reverse and the first + # + # The rtol seems to vary as the size increase I'm assuming poor round off error handling in the naive DFT/IDFT vis_conj = np.conjugate(vis) - assert_allclose(fts, vis_conj, atol=1e-13) + assert_allclose(ft, vis_conj, rtol=1e-10, atol=1e-10) - vis_ft = ifftshift(vis_conj) - img = ifft2(vis_ft) - assert_allclose(np.real(img), data, atol=1e-14) + img = ifft2(vis_conj) + assert_allclose(np.real(img), data, rtol=1e-10, atol=1e-10) diff --git a/xrayvision/tests/test_visibility.py b/xrayvision/tests/test_visibility.py index 037388b..1d7542b 100644 --- a/xrayvision/tests/test_visibility.py +++ b/xrayvision/tests/test_visibility.py @@ -109,7 +109,7 @@ def test_vis_eq(visibilities): def test_meta_eq(vis_meta): meta = vis_meta assert meta == meta - meta = vm.VisMeta(dict()) + meta = vm.VisMeta({}) assert meta == meta diff --git a/xrayvision/transform.py b/xrayvision/transform.py index 749067b..ca01845 100644 --- a/xrayvision/transform.py +++ b/xrayvision/transform.py @@ -161,10 +161,8 @@ def dft_map( x, y = np.meshgrid(x, y) uv = np.vstack([u, v]) # Check units are correct for exp need to be dimensionless and then remove units for speed - if (uv[0, :] * x[0, 0]).unit == apu.dimensionless_unscaled and ( - uv[1, :] * y[0, 0] - ).unit == apu.dimensionless_unscaled: - uv = uv.value # type: ignore + if (uv.unit * x.unit) == apu.dimensionless_unscaled and (uv.unit * y.unit) == apu.dimensionless_unscaled: + uv = uv.value # src_type: ignore x = x.value y = y.value @@ -174,7 +172,7 @@ def dft_map( 2j * np.pi * (x[..., np.newaxis] * uv[np.newaxis, 0, :] + y[..., np.newaxis] * uv[np.newaxis, 1, :]) ), axis=(0, 1), - ) + ) * np.prod(pixel_size.value) return vis else: @@ -245,7 +243,7 @@ def idft_map( axis=2, ) - return np.real(image) + return np.real(image) / np.prod(pixel_size.value) else: raise UnitsError("Incompatible units on uv {uv.unit} should cancel with xy to leave a dimensionless quantity") diff --git a/xrayvision/vis_forward_fit/__init__.py b/xrayvision/vis_forward_fit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xrayvision/vis_forward_fit/forward_fit.py b/xrayvision/vis_forward_fit/forward_fit.py new file mode 100644 index 0000000..098b39c --- /dev/null +++ b/xrayvision/vis_forward_fit/forward_fit.py @@ -0,0 +1,241 @@ +from typing import Union, Callable, Optional + +import astropy.units as apu +import numpy as np +from astropy.units import Quantity, quantity_input +from numpy.typing import NDArray +from pymoo.algorithms.soo.nonconvex.pso import PSO +from pymoo.core.problem import ElementwiseProblem +from pymoo.optimize import minimize as moo_minimize +from scipy.optimize import OptimizeResult, minimize +from sunpy.map import Map + +from xrayvision.imaging import generate_header +from xrayvision.transform import generate_xy +from xrayvision.vis_forward_fit.sources import ( + Circular, + Elliptical, + Loop, + SourceList, + circular_gaussian_img, + circular_gaussian_vis, + elliptical_gaussian_img, + elliptical_gaussian_vis, + loop_img, + loop_vis, +) +from xrayvision.visibility import Visibilities + +__all__ = ["SOURCE_TO_IMAGE", "SOURCE_TO_VIS", "sources_to_image", "sources_to_vis", "vis_forward_fit"] + +#: Mapping of sources to image generation functions +SOURCE_TO_IMAGE: dict[str, Callable] = { + Circular.__name__: circular_gaussian_img, + Elliptical.__name__: elliptical_gaussian_img, + Loop.__name__: loop_img, +} + +#: Mapping of sources to visibility generation functions +SOURCE_TO_VIS: dict[str, Callable] = { + Circular.__name__: circular_gaussian_vis, + Elliptical.__name__: elliptical_gaussian_vis, + Loop.__name__: loop_vis, +} + + +def sources_to_image( + source_list: SourceList, + shape: Quantity[apu.pix], + pixel_size: Quantity[apu.arcsec / apu.pix], + center=(0, 0) * apu.arcsec, +) -> np.ndarray[float]: + r""" + Create an image from a list of sources. + + Parameters + ---------- + source_list : + List of sources and their parameters + shape : + Shape of the image create + pixel_size : + Size + + Returns + ------- + + """ + image = None + x = generate_xy(shape[1], pixel_size=pixel_size[1], phase_center=center[1]) + y = generate_xy(shape[0], pixel_size=pixel_size[0], phase_center=center[0]) + x, y = np.meshgrid(x, y) + for source in source_list: + try: + if image is None: + image = SOURCE_TO_IMAGE[source.__class__.__name__]( + *[source.param_list[0], x, y, *source.param_list[1:]] + ) + else: + image += SOURCE_TO_IMAGE[source.__class__.__name__]( + *[source.param_list[0], x, y, *source.param_list[1:]] + ) + except KeyError: + raise KeyError(f"Unknown source type: {source.__class__.__name__}") + return image + + +def sources_to_vis(source_list: SourceList, u, v) -> np.ndarray[np.complex128]: + r""" + Create visibilities from a list of sources. + + Parameters + ---------- + source_list + u : + u coordinates to evaluate sources + v : + u coordinates to evaluate sources + + Returns + ------- + vis : + Complex visibilities + """ + vis = np.zeros(u.shape, dtype=np.complex128) + for source in source_list: + try: + vis += SOURCE_TO_VIS[source.__class__.__name__](*[source.param_list[0], u, v, *source.param_list[1:]]) + except KeyError: + raise KeyError(f"Unknown source type: {source.__class__.__name__}") + return vis + + +def _vis_forward_fit_minimise( + visobs: Visibilities, sources: SourceList, method: str +) -> tuple[SourceList, OptimizeResult]: + r""" + Internal minimisation function + + Parameters + ---------- + visobs : + Input Visibilities + sources : + List of sources + method : + Method to use for the minimisation + + """ + if method.casefold() == "pso": + problem = VisForwardFitProblem(visobs.u, visobs.v, visobs, sources) + algo = PSO(pop=100) + res = moo_minimize(problem, algo) + sources_fit = sources.from_params(sources, res.X) + else: + visobs_ri = np.hstack([visobs.visibilities.real, visobs.visibilities.imag]) + + def objective(x, u, v, visobs_ri, sources): + cur_sources = sources.from_params(sources, x) + vispred = sources_to_vis(cur_sources, u.value, v.value) + vispred_ri = np.hstack([vispred.real, vispred.imag]) + return np.sum(np.abs(visobs_ri.value - vispred_ri) ** 2) + + res = minimize( + objective, + [getattr(p, "value", p) for p in sources.params], + (visobs.u, visobs.v, visobs_ri, sources), + method=method, + bounds=[(x, y) for x, y in zip(*sources.bounds)], + ) + sources_fit = sources.from_params(sources, res.x) + return sources_fit, res + + +@quantity_input() +def vis_forward_fit( + vis: Visibilities, + sources: SourceList, + shape: Quantity[apu.pix], + pixel_size: Quantity[apu.arcsec / apu.pix], + map: Optional[bool] = True, + method: Optional[str] = "Nelder-Mead", +) -> Union[Quantity, NDArray[np.float64]]: + r""" + Visibility forward fit method. + + Parameters + ---------- + vis : + Input visibilities + sources : + List of sources and their initial parameters + shape : + Shape of the image create + pixel_size : + Pixel size + map : + Return a `Map` + method : + Method to use any of those supported methods by `scipy.optimize.minimize` or 'PSO' for particle swarm optimization + """ + if method is None: + method = "Nelder-Mead" + sources_out, res = _vis_forward_fit_minimise(vis, sources, method=method) + # add units back + sources_out = SourceList.from_params( + sources, [pout * getattr(pin, "unit", 1) for pin, pout in zip(sources.params, sources_out.params)] + ) + image = sources_to_image(sources_out, shape, pixel_size) + if map: + header = generate_header(vis, shape=shape, pixel_size=pixel_size) + return Map((image, header)) + return image, sources_out + + +class VisForwardFitProblem(ElementwiseProblem): + def __init__(self, u, v, visobs: Visibilities, sources: SourceList): + self.u = u + self.v = v + self.visobs = visobs + self.visobs_ri = np.hstack([visobs.visibilities.real, visobs.visibilities.imag]) + self.sources = sources + n_var = len(sources.params) + xl, xu = sources.bounds + super().__init__(n_var=n_var, n_obj=1, n_ieq_constr=1, xl=xl, xu=xu) + + def _evaluate(self, x, out, *args, **kwargs): + out["G"] = x[3] - x[4] + cur_sources = self.sources.from_params(self.sources, x) + vispred = sources_to_vis(cur_sources, self.u.value, self.v.value) + vispred_ri = np.hstack([vispred.real, vispred.imag]) + out["F"] = np.sum(np.abs(self.visobs_ri.value - vispred_ri) ** 2) + + # def _evaluate(self, x, out, *args, **kwargs): + # # sigmamin < sigmamax -> sigmamin - sigmamax <= 0 + # out["G"] = x[3] - x[4] + # out["F"] = 1e10 + # + # # wrap angles + # eps = 1e-7 + # # Alpha (Rotation): Wrap between -pi/2 and pi/2 + # half_pi = np.pi / 2 + # x[-2] = ((x[-2] + half_pi) % np.pi) - half_pi + # # Beta ("Length"): Clip between 0 and pi + # x[-1] = np.clip(x[-1], eps, np.pi) + # + # s_min = min(x[3], x[4]) + # s_max = max(x[3], x[4]) + # + # x[3] = s_min + # x[4] = s_max + # + # cur_sources = self.sources.from_params(self.sources, x) + # + # try: + # vispred = sources_to_vis(cur_sources, self.u.value, self.v.value) + # vispred_ri = np.hstack([vispred.real, vispred.imag]) + # if not np.isfinite(np.sum(np.abs(self.visobs_ri.value - vispred_ri) ** 2)): + # return + # out["F"] = np.sum(np.abs(self.visobs_ri.value - vispred_ri) ** 2) + # except ValueError: + # out["F"] = 1e10 diff --git a/xrayvision/vis_forward_fit/sources.py b/xrayvision/vis_forward_fit/sources.py new file mode 100644 index 0000000..2a67f05 --- /dev/null +++ b/xrayvision/vis_forward_fit/sources.py @@ -0,0 +1,1030 @@ +from abc import ABC, abstractmethod +from typing import Callable, Optional +from itertools import chain +from collections import UserList +from dataclasses import dataclass + +import numpy as np +from scipy.special import binom, factorial + +__all__ = [ + "circular_gaussian_img", + "circular_gaussian_vis", + "elliptical_gaussian_img", + "elliptical_gaussian_vis", + "GenericSource", + "Circular", + "Elliptical", + "SourceList", + "SourceFactory", + "Source", +] + + +def circular_gaussian_img(amp, x, y, x0, y0, sigma): + r""" + Circular gaussian function sampled at x, y. + + .. math:: + + F(x, y) = A \exp{\left(-\frac{(x0-x)^2 + (y0 - y)^2}{2\sigma^2}\right)} + + + Parameters + ---------- + amp : + Amplitude + x : + x coordinates + y : + y coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + sigma : + Sigma + + See Also + -------- + circular_gaussian_vis + """ + return amp / (2 * np.pi * sigma**2) * np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2)) + + +def circular_gaussian_vis(amp, u, v, x0, y0, sigma): + r""" + Circular gaussian in Fourier space sampled at u, v. + + .. math:: + + F(u, v) = A \exp{\left( -2\pi^2 \sigma^2 (u^2 +v^2 \right)}) \exp( 2\pi i(x0u + y0v)) + + + Parameters + ---------- + amp : + Amplitude + u : + u coordinates + v : + v coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + sigma : + Sigma + + See Also + -------- + circular_gaussian + """ + return amp * np.exp(-2 * np.pi**2 * sigma**2 * (u**2 + v**2)) * np.exp(2j * np.pi * (x0 * u + y0 * v)) + + +def elliptical_gaussian_img(amp, x, y, x0, y0, sigmax, sigmay, theta): + r""" + Elliptical gaussian sampled at x, y. + + .. math:: + + x' &= ((x0 - x) \cos(\theta) + ((y0 - y) \sin(\theta)) \\ + y' &= -((x0 - x) \sin(\theta) + ((y0 - y) \cos(\theta)) \\ + F(x, y) &= \frac{A}{(2 \pi \sigma_x \sigma_y)} \exp \left( \frac{x'^2}{2\sigma_x^2} + \frac{y'^2}{\sigma_y^2} \right) + + + Parameters + ---------- + amp : + Amplitude + x : + x coordinates + y : + y coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + sigmax : + Sigma in x direction + sigmay : + Sigma in y direction + theta : + Rotation angle in anticlockwise + + See Also + -------- + elliptical_gaussian_vis + """ + sint = np.sin(theta) + cost = np.cos(theta) + xp = ((x0 - x) * cost) + ((y0 - y) * sint) + yp = -((x0 - x) * sint) + ((y0 - y) * cost) + return amp / (2 * np.pi * sigmax * sigmay) * np.exp(-((xp**2 / (2 * sigmax**2)) + (yp**2 / (2 * sigmay**2)))) + + +def elliptical_gaussian_vis(amp, u, v, x0, y0, sigmax, sigmay, theta): + r""" + Elliptical gaussian in Fourier space sampled at u, v. + + .. math:: + + x' &= u\cos(\theta) +v \sin(\theta) \\ + y' &= -u \sin(\theta) + v \cos(\theta) \\ + F(x, y) &= A \exp \left( -2\pi^2 ((u'^2\sigma_x^2) + (v'^2\sigma_y^2) \right) \exp( 2\pi i(x0u + y0v)) + + Parameters + ---------- + amp : + Amplitude + u : + u coordinates + v : + v coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + sigmax : + Sigma in x direction + sigmay : + Sigma in y direction + theta : + Rotation angle in anticlockwise + + See Also + -------- + elliptical_gaussian + """ + sint = np.sin(theta) + cost = np.cos(theta) + up = cost * u + sint * v + vp = -sint * u + cost * v + return ( + amp + * np.exp(-2 * np.pi**2 * ((up**2 * sigmax**2) + (vp**2 * sigmay**2))) + * np.exp(2j * np.pi * (x0 * u + y0 * v)) + ) + + +def loop_img_old(amp, x, y, x0, y0, fwhm_min, fwhm_max, rotation, loopwidth, max_comps=21): + r""" + Loop source in image space sampled at x, y. + + The loop source is approximated with a series of equispaced circular Gaussians. + + Parameters + ---------- + amp : + Total flux + x : + x coordinates + y : + y coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + fwhm_min : + FWHM of the semiminor axis + fwhm_max : + FWHM of the semimajor axis + rotation : + Position angle of the loop major axis in radians + loopwidth : + Arc extent parameter (related to opening angle) + max_comps : int + Upper limit on the number of equispaced circular Gaussians used to approximate the loop + + Returns + ------- + image : ndarray + 2D image of the loop source + + See Also + -------- + loop_img + """ + sig2fwhm = np.sqrt(8 * np.log(2)) + + # Calculate the relative strengths of the sources to reproduce a gaussian and their collective stddev. + iseq0 = np.arange(max_comps) + relflux0 = factorial(max_comps - 1) / (factorial(iseq0) * factorial(max_comps - 1 - iseq0)) / 2 ** (max_comps - 1) + ok = np.flatnonzero(relflux0 > 0.01) # Just keep; circles that contain; at least 1 % of flux + ncirc = ok.size + relflux = relflux0[ok] / relflux0[ok].sum() + iseq = np.arange(ncirc) + reltheta = iseq / (ncirc - 1.0) - 0.5 # locations of circles for arclength=1 + factor = np.sqrt((reltheta**2 * relflux).sum()) * sig2fwhm # FWHM of binomial distribution for arclength=1 + + loopangle = loopwidth / factor + if np.abs(loopangle) >= 2 * np.pi: + raise ValueError(f"Internal parameterization error - Loop arc {loopangle} exceeds 2 pi.") + + if loopangle == 0.0: + loopangle = 0.01 # Avoid problems if loopangle = 0 + + theta = np.abs(loopangle) * (iseq / (ncirc - 1.0) - 0.5) # equispaced between + - loopangle / 2 + xloop = np.sin(theta) # for unit radius of curvature, R + yloop = np.cos(theta) # relative to center of curvature + + if loopangle < 0: + yloop = -yloop # Sign of loopangle determines sense of loop curvature + + # Determine the size and location of the equivalent separated components in a coord system where x is an axis + # parallel to the line joining the footpoints. Note that there are combinations of loop angle, sigminor and + # sigmajor that cannot occur with radius > 1arcsec. In such a case circle radius is set to 1. Such cases will lead + # to bad solutions and be flagged as such at the end. + + sigminor = fwhm_min / sig2fwhm + sigmajor = fwhm_max / sig2fwhm + fsumx2 = (xloop**2 * relflux).sum() # scale - free factors describing loop moments for endpoint separation=1 + fsumy = (yloop * relflux).sum() + fsumy2 = (yloop**2 * relflux).sum() + loopradius = np.sqrt((sigmajor**2 - sigminor**2) / (fsumx2 - fsumy2 + fsumy**2)) + sgm_unti = getattr(sigmajor, "unit", 1) + term = max((sigmajor**2 - loopradius**2 * fsumx2), 1 * sgm_unti**2) # > 0 condition avoids problems in next step. + circfwhm = max(sig2fwhm * np.sqrt(term), 1 * sgm_unti) # Set minimum to avoid display problems + + cgshift = loopradius * fsumy # will enable emission centroid location to be unchanged + relx = xloop * loopradius # x is axis joining 'footpoints' + rely = yloop * loopradius - cgshift + + # Calculate source structures for each circle. + pasep = rotation + sinus = np.sin(pasep) + cosinus = np.cos(pasep) + + image = None + pixel = [1, 1] + for i in range(iseq.size): + flux_new = amp * relflux[i] # Split the flux between components. + + x_loc_new = x0 - relx[i] * sinus + rely[i] * cosinus + y_loc_new = y0 + relx[i] * cosinus + rely[i] * sinus + + x_tmp = ((x - x_loc_new) * cosinus) + ((y - y_loc_new) * sinus) + y_tmp = -((x - x_loc_new) * sinus) + ((y - y_loc_new) * cosinus) + x_tmp = 2.0 * np.sqrt(2.0 * np.log(2.0)) * x_tmp / circfwhm + y_tmp = 2.0 * np.sqrt(2.0 * np.log(2.0)) * y_tmp / circfwhm + im_tmp = np.exp(-(x_tmp**2.0 + y_tmp**2.0) / 2.0) + if image is None: + image = im_tmp / (im_tmp.sum() * pixel[0] * pixel[1]) * flux_new + else: + image += im_tmp / (im_tmp.sum() * pixel[0] * pixel[1]) * flux_new + + return image + + +def loop_vis_old(amp, u, v, x0, y0, fwhm_minor, fwhm_major, rotation, loopwidth, max_comps=21): + r""" + Loop source in Fourier space sampled at u, v. + + Parameters + ---------- + amp : + Total flux + u : + u coordinates + v : + v coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + fwhm_minor : + FWHM of the semiminor axis + fwhm_major : + FWHM of the semimajor axis + rotation : + Position angle of the loop major axis in radians + loopwidth : + Arc extent parameter (related to opening angle) + max_comps : int + Upper limit on the number of equispaced circular Gaussians used to approximate the loop + + Returns + ------- + vis : ndarray (complex128) + Complex visibilities evaluated at (u, v) + + See Also + -------- + loop_vis + """ + + sig2fwhm = np.sqrt(8 * np.log(2.0)) + + # Calculate the relative strengths of the sources to reproduce a gaussian and their collective stddev. + iseq0 = np.arange(max_comps) + relflux0 = ( + factorial(max_comps - 1) / (factorial(iseq0) * factorial(max_comps - 1 - iseq0)) / 2 ** (max_comps - 1) + ) # TOTAL(relflux)=1 + ok = np.flatnonzero(relflux0 > 0.01) # Just keep circles that contain at least 1% of flux + ncirc = ok.size + relflux = relflux0[ok] / (relflux0[ok]).sum() + iseq = np.arange(ncirc) + reltheta = iseq / (ncirc - 1.0) - 0.5 # locations of circles for arclength=1 + factor = np.sqrt((reltheta**2 * relflux).sum()) * sig2fwhm # FWHM of binomial distribution for arclength=1 + + loopangle = loopwidth / factor + if np.abs(loopangle).sum() >= 2 * np.pi: + raise ValueError(f"Internal parameterization error - Loop arc {loopangle} exceeds 2pi.") + + if loopangle == 0: + loopangle = 0.01 # Avoids problems if loopangle = 0 + + theta = np.abs(loopangle) * (iseq / (ncirc - 1.0) - 0.5) # equispaced between +- loopangle/2 + xloop = np.sin(theta) # for unit radius of curvature, R + yloop = np.cos(theta) # relative to center of curvature + + if loopangle < 0: + # Sign of loopangle determines sense of loop curvature # Sign of loopangle determines sense of loop curvature + yloop = -yloop + + # Determine the size and location of the equivalent separated components in a coord system where... + # x is an axis parallel to the line joining the footpoints + # Note that there are combinations of loop angle, sigminor and sigmajor that cannot occur with radius>1arcsec. + # In such a case circle radius is set to 1. Such cases will lead to bad solutions and be flagged as such at the end. + + # eccen = np.sqrt(1 - (sigma_min**2 / sigma_max**2)) + # sigminor = sigma_min * (1 - eccen ** 2) ** 0.25 / sig2fwhm + # sigmajor = sigma_max / (1 - eccen ** 2) ** 0.25 / sig2fwhm + + sigminor = fwhm_minor / sig2fwhm + sigmajor = fwhm_major / sig2fwhm + fsumx2 = (xloop**2 * relflux).sum() # scale-free factors describing loop moments for endpoint separation=1 + fsumy = (yloop * relflux).sum() + fsumy2 = (yloop**2 * relflux).sum() + loopradius = np.sqrt((sigmajor**2 - sigminor**2) / (fsumx2 - fsumy2 + fsumy**2)) + sgm_unti = getattr(sigmajor, "unit", 1) + term = max((sigmajor**2 - loopradius**2 * fsumx2), 1 * sgm_unti**2) # >0 condition avoids problems in next step. + circfwhm = max(sig2fwhm * np.sqrt(term), 1 * sgm_unti) # Set minimum to avoid display problems + + cgshift = loopradius * fsumy + relx = xloop * loopradius # x is axis joining 'footpoints' + rely = yloop * loopradius - cgshift # will enable emission centroid location to be unchanged + + # Calculate source structures for each circle. + pasep = rotation # position angle of line joining arc endpoints + x_loc_new = x0 - relx * np.sin(pasep) + rely * np.cos(pasep) + y_loc_new = y0 + relx * np.cos(pasep) + rely * np.sin(pasep) + + flux_new = amp * relflux # Split the flux between components. + + arg = (-(np.pi**2) * circfwhm**2) / (4 * np.log(2)) * (u**2 + v**2) + relvis = np.exp(arg) + + for j in range(ncirc): + if j == 0: + vis = flux_new[j] * relvis * np.exp(2j * np.pi * (x_loc_new[j] * u + y_loc_new[j] * v)) + else: + vis += flux_new[j] * relvis * np.exp(2j * np.pi * (x_loc_new[j] * u + y_loc_new[j] * v)) + return vis + + +def loop_img(flux, x, y, x0, y0, sigma_minor, sigma_major, rotation, loopwidth, min_fraction=0.01, max_comps=21): + r""" + Loop source in image space sampled at x, y. + + The loop is approximated as a series of circular Gaussians with binomially-weighted + fluxes arranged along a circular arc. + + Parameters + ---------- + flux : + Total integrated flux + x : + x coordinates + y : + y coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + sigma_minor : + Standard deviation of the loop width perpendicular to the arc + sigma_major : + Standard deviation of the loop extent along the arc + rotation : + Position angle of the loop major axis in radians + loopwidth : + Arc extent parameter (related to opening angle) + min_fraction : float + Minimum relative flux to retain a component + max_comps : int + Upper limit on the number of Gaussian components + + Returns + ------- + image : ndarray + Loop brightness distribution evaluated at (x, y) + + See Also + -------- + loop_vis, loop_img_old + """ + component_fluxes, n_components = _compute_binomial_weights(max_comps, min_fraction) + + loop_params = _compute_loop_geometry(sigma_minor, sigma_major, loopwidth, component_fluxes, n_components) + + x_components, y_components = _transform_to_image_coords( + loop_params["rel_x"], loop_params["rel_y"], x0, y0, rotation + ) + + data = _evaluate_gaussians_on_grid( + x, y, x_components, y_components, loop_params["component_sigma"], flux * component_fluxes, rotation + ) + + return data + + +def loop_vis(flux, u, v, x0, y0, sigma_minor, sigma_major, rotation, loopwidth, min_fraction=0.01, max_comps=21): + r""" + Loop source in Fourier space sampled at u, v. + + Parameters + ---------- + flux : + Total integrated flux + u : + u coordinates + v : + v coordinates + x0 : + Center x coordinate + y0 : + Center y coordinate + sigma_minor : + Standard deviation of the loop width perpendicular to the arc + sigma_major : + Standard deviation of the loop extent along the arc + rotation : + Position angle of the loop major axis in radians + loopwidth : + Arc extent parameter (related to opening angle) + min_fraction : float + Minimum relative flux to retain a component + max_comps : int + Upper limit on the number of Gaussian components + + Returns + ------- + vis : ndarray (complex128) + Complex visibilities evaluated at (u, v) + + See Also + -------- + loop_img, loop_vis_old + """ + component_fluxes, n_components = _compute_binomial_weights(max_comps, min_fraction) + + loop_params = _compute_loop_geometry(sigma_minor, sigma_major, loopwidth, component_fluxes, n_components) + + x_components, y_components = _transform_to_image_coords( + loop_params["rel_x"], loop_params["rel_y"], x0, y0, rotation + ) + + vis = _evaluate_visibility_analytical( + u, v, x_components, y_components, loop_params["component_sigma"], flux * component_fluxes + ) + + return vis + + +def _compute_binomial_weights(max_comps, min_fraction=0.01): + """ + Compute normalized binomial distribution weights for loop Gaussian components. + + Parameters + ---------- + max_comps : int + Maximum number of components to consider + min_fraction : float + Minimum relative flux to retain a component + + Returns + ------- + weights : ndarray + Normalized flux weights summing to 1 + n_kept : int + Number of components retained + """ + indices = np.arange(max_comps) + + # Binomial coefficients: C(n-1, k) / 2^(n-1) + # Using scipy.special.binom is more efficient and numerically stable than factorial + binomial_coeffs = binom(max_comps - 1, indices) / 2.0 ** (max_comps - 1) + + # Keep only significant components + significant = binomial_coeffs > min_fraction + weights_kept = binomial_coeffs[significant] + + # Normalize + weights_normalized = weights_kept / weights_kept.sum() + + return weights_normalized, len(weights_normalized) + + +def _compute_loop_geometry(sigma_minor, sigma_major, arc_param, flux_weights, n_comps): + """ + Compute loop geometry: component positions and size. + + Parameters + ---------- + sigma_minor : + Minor axis standard deviation (perpendicular to arc) + sigma_major : + Major axis standard deviation (along arc) + arc_param : + Arc extent parameter + flux_weights : ndarray + Normalized flux weights for each component + n_comps : int + Number of components + + Returns + ------- + params : dict + Dictionary with keys ``rel_x``, ``rel_y``, ``component_sigma``, ``radius`` + """ + # Component positions (normalized to [−0.5, +0.5]) + comp_indices = np.arange(n_comps) + normalized_positions = comp_indices / (n_comps - 1.0) - 0.5 + + # Binomial spatial distribution factor + # This relates the arc extent to the actual opening angle + binomial_width = np.sqrt((normalized_positions**2 * flux_weights).sum()) * 2 * np.sqrt(2 * np.log(2)) + + # Opening angle + loop_angle = arc_param / binomial_width + + # Validate opening angle + if np.abs(loop_angle) >= 2.0 * np.pi: + raise ValueError( + f"Loop arc parameter {arc_param} produces opening angle {loop_angle:.3f} rad " + f"(>= 2π). Reduce arc_param or adjust sigma_max." + ) + + # Handle zero angle + if loop_angle == 0.0: + loop_angle = 0.01 + + # Angular positions along arc + theta = np.abs(loop_angle) * (comp_indices / (n_comps - 1.0) - 0.5) + + # Positions on unit circle (radius = 1) + x_unit = np.sin(theta) + y_unit = np.cos(theta) + + # Flip y for negative angles (curvature sense) + if loop_angle < 0: + y_unit = -y_unit + + # Compute loop radius from ellipse parameters + # Statistical moments of the distribution + moment_x2 = (x_unit**2 * flux_weights).sum() + moment_y = (y_unit * flux_weights).sum() + moment_y2 = (y_unit**2 * flux_weights).sum() + + # Radius of curvature + denominator = moment_x2 - moment_y2 + moment_y**2 + if denominator <= 0: + # Degenerate case - use fallback + radius = sigma_major + else: + radius = np.sqrt((sigma_major**2 - sigma_minor**2) / denominator) + + # Component sigma (circular Gaussians) + # Need to account for the spread along the arc + variance_residual = sigma_major**2 - radius**2 * moment_x2 + + # Ensure non-negative variance + if hasattr(sigma_major, "unit"): + unit = sigma_major.unit + variance_residual = max(variance_residual, 0 * unit**2) + component_sigma = np.sqrt(variance_residual) + # Set minimum size to avoid numerical issues + min_sigma = 1.0 * unit + component_sigma = max(component_sigma, min_sigma) + else: + variance_residual = max(variance_residual, 0.0) + component_sigma = np.sqrt(variance_residual) + min_sigma = 1.0 + component_sigma = max(component_sigma, min_sigma) + + # Center-of-gravity shift (keeps loop centered at x0, y0) + cg_shift = radius * moment_y + + # Relative positions (before rotation) + rel_x = x_unit * radius + rel_y = y_unit * radius - cg_shift + + return {"rel_x": rel_x, "rel_y": rel_y, "component_sigma": component_sigma, "radius": radius} + + +def _transform_to_image_coords(rel_x, rel_y, x0, y0, rotation_angle): + """ + Transform relative component positions to absolute image coordinates. + + Parameters + ---------- + rel_x : ndarray + Relative x positions in the loop-aligned frame + rel_y : ndarray + Relative y positions in the loop-aligned frame + x0 : + Loop center x coordinate + y0 : + Loop center y coordinate + rotation_angle : + Rotation angle in radians + + Returns + ------- + x_abs : ndarray + Absolute x component positions + y_abs : ndarray + Absolute y component positions + """ + cos_angle = np.cos(rotation_angle) + sin_angle = np.sin(rotation_angle) + + # Rotation matrix application + # Note: negative rel_x term because of coordinate convention + x_abs = x0 - rel_x * sin_angle + rel_y * cos_angle + y_abs = y0 + rel_x * cos_angle + rel_y * sin_angle + + return x_abs, y_abs + + +def _evaluate_gaussians_on_grid(x, y, x_centers, y_centers, sigma, fluxes, rotation): + """ + Evaluate the sum of circular Gaussians on an image grid. + + Parameters + ---------- + x : ndarray + x coordinate grid + y : ndarray + y coordinate grid + x_centers : ndarray + x positions of Gaussian centers + y_centers : ndarray + y positions of Gaussian centers + sigma : + Standard deviation of each Gaussian component + fluxes : ndarray + Flux of each component + rotation : + Rotation angle in radians + + Returns + ------- + image : ndarray + Sum of all Gaussian components + """ + # Pre-compute constants + cos_rot = np.cos(rotation) + sin_rot = np.sin(rotation) + + # Initialize output + image = np.zeros(x.shape, like=fluxes) + + # Sum Gaussians + for x_c, y_c, flux in zip(x_centers, y_centers, fluxes): + # Shift to component center + dx_grid = x - x_c + dy_grid = y - y_c + + # Rotate to component frame (optional - components are circular) + # This rotation isn't strictly needed for circular Gaussians but matches original + x_rot = dx_grid * cos_rot + dy_grid * sin_rot + y_rot = -dx_grid * sin_rot + dy_grid * cos_rot + + # Gaussian profile: exp(-r²/(2σ²)) + gaussian = np.exp(-0.5 * ((x_rot / sigma) ** 2 + (y_rot / sigma) ** 2)) + + # Normalize using flux density (continuous) normalization + # This makes it consistent with circular_gaussian_img() + gaussian_normalized = gaussian / (gaussian.sum()) + + # Add contribution + image += gaussian_normalized * flux + + return image + + +def _evaluate_visibility_analytical(u, v, x_centers, y_centers, sigma, fluxes): + r""" + Analytical Fourier transform of a multi-component circular Gaussian loop. + + .. math:: + + V(u,v) = \sum_i F_i \exp\left(-2\pi^2\sigma^2(u^2+v^2)\right) + \exp\left(2\pi i(u x_i + v y_i)\right) + + Parameters + ---------- + u : ndarray + u coordinates + v : ndarray + v coordinates + x_centers : ndarray + x positions of Gaussian components + y_centers : ndarray + y positions of Gaussian components + sigma : + Standard deviation of each Gaussian component + fluxes : ndarray + Flux of each component + + Returns + ------- + vis : ndarray (complex128) + Complex visibilities evaluated at (u, v) + """ + # Gaussian envelope in Fourier space + # FT{exp(-r²/(2σ²))} ∝ exp(-2π²σ²k²) + uv_squared = u**2 + v**2 + envelope = np.exp(-2 * np.pi**2 * sigma**2 * uv_squared) + + # Initialize visibility + vis = np.zeros_like(u, dtype=np.complex128) + + # Sum over components + for x_c, y_c, flux in zip(x_centers, y_centers, fluxes): + # Phase from component position + phase = 2.0 * np.pi * (x_c * u + y_c * v) + + # Add component contribution + vis += flux * envelope * np.exp(1j * phase) + + return vis + + +class GenericSource(ABC): + r""" + Abstract source class defining the properties and methods. + """ + + _registry: dict[str, Callable] = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + key = cls.__name__.lower() + GenericSource._registry[key] = cls + + @property + def n_params(self): + r"""The number of parameters""" + return len(self.__dict__.keys()) + + @property + @abstractmethod + def bounds(self) -> list[list[float]]: + r"""Return the lower and upper bounds of the source.""" + pass + + @property + @abstractmethod + def param_list(self) -> list[float]: + """Return list of parameters if fixed order""" + pass + + @abstractmethod + def estimate_bounds(self, *args, **kwargs) -> list[list[float]]: + """Return estimated bounds""" + pass + + +@dataclass() +class Circular(GenericSource): + amp: float + x0: float + y0: float + sigma: float + + def __init__(self, amp: float, x0: float, y0: float, sigma: float): + r""" + Circular gaussian source parameters. + + Parameters + ---------- + amp : + Amplitude + x0 : + Center x coordinate + y0 : + Center y coordinate + sigma : + Standard deviation + """ + self.amp = amp + self.x0 = x0 + self.y0 = y0 + self.sigma = sigma + + @property + def bounds(self) -> list[list[float]]: + raw_bounds = [ + [self.amp / 4, self.x0 - 5 * np.abs(self.sigma), self.y0 - 5 * np.abs(self.sigma), self.sigma / 4], + [self.amp * 4, self.x0 + 5 * np.abs(self.sigma), self.y0 + 5 * np.abs(self.sigma), self.sigma * 4], + ] + return [[q.value if hasattr(q, "value") else q for q in sublist] for sublist in raw_bounds] + + @property + def param_list(self) -> list[float]: + return [self.amp, self.x0, self.y0, self.sigma] + + def estimate_bounds(self, *args, **kwargs) -> list[list[float]]: + raise NotImplementedError() + + +@dataclass +class Elliptical(GenericSource): + amp: float + x0: float + y0: float + sigmax: float + sigmay: float + theta: float + + def __init__(self, amp, x0, y0, sigmax, sigmay, theta): + r""" + Elliptical gaussian source parameters. + + Parameters + ---------- + amp : + Amplitude + x0 : + Center x coordinate + y0 : + Center y coordinate + sigmax : + Standard deviation in x direction + sigmay : + Standard deviation in y direction + theta : + Rotation angle in anticlockwise + """ + self.amp = amp + self.x0 = x0 + self.y0 = y0 + self.sigmax = sigmax + self.sigmay = sigmay + self.theta = theta + + @property + def bounds(self) -> list[list[float]]: + raw_bounds = [ + [ + self.amp / 4, + self.x0 - (5 * np.abs(self.sigmax)), + self.y0 - (5 * np.abs(self.sigmay)), + self.sigmax / 4, + self.sigmay / 4, + self.theta - 22.5, + ], + [ + self.amp * 4, + self.x0 + (5 * np.abs(self.sigmax)), + self.y0 + (5 * np.abs(self.sigmay)), + self.sigmax * 4, + self.sigmay * 4, + self.theta + 22.5, + ], + ] + return [[q.value if hasattr(q, "value") else q for q in sublist] for sublist in raw_bounds] + + @property + def param_list(self) -> list[float]: + return [self.amp, self.x0, self.y0, self.sigmax, self.sigmay, self.theta] + + def estimate_bounds(self, *args, **kwargs) -> list[list[float]]: + return self.bounds + + +@dataclass +class Loop(GenericSource): + amp: float + x0: float + y0: float + sigma_min: float + sigma_max: float + alpha: float + beta: float + + def __init__(self, amp, x0, y0, sigma_min, sigma_max, alpha, beta): + self.amp = amp + self.x0 = x0 + self.y0 = y0 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.alpha = alpha + self.beta = beta + + @property + def bounds(self) -> list[list[float]]: + raw_bounds = [ + [ + self.amp / 2, + self.x0 - (2 * np.abs(self.sigma_max)), + self.y0 - (2 * np.abs(self.sigma_max)), + self.sigma_min / 2, + self.sigma_max / 2, + -np.pi / 2, + 0, + ], + [ + self.amp * 2, + self.x0 + (2 * np.abs(self.sigma_max)), + self.y0 + (2 * np.abs(self.sigma_max)), + self.sigma_min * 2, + self.sigma_max * 2, + np.pi / 2, + np.pi, + ], + ] + return [[q.value if hasattr(q, "value") else q for q in sublist] for sublist in raw_bounds] + + @property + def param_list(self) -> list[float]: + return [self.amp, self.x0, self.y0, self.sigma_min, self.sigma_max, self.alpha, self.beta] + + def estimate_bounds(self, *args, **kwargs) -> list[list[float]]: + return self.bounds + + +class SourceList(UserList[GenericSource]): + r""" + List of Sources + """ + + def __init__(self, sources: Optional[list[GenericSource]] = None): + r""" + List of Sources + + Parameters + ---------- + sources : + Sources + """ + super().__init__(sources) + + @property + def params(self) -> list[float]: + r"""Flat list of all parameters for all sources""" + return list(chain.from_iterable([source.param_list for source in self.data])) + + @property + def bounds(self) -> list[list[float]]: + r"""Flat list of upper and lower bounds for all sources""" + return np.hstack([s.bounds for s in self.data]).tolist() + + @classmethod + def from_params(cls, sources: "SourceList", params: list[float]) -> "SourceList": + r""" + Create a source list from given parameters and sources. + + Parameters + ---------- + sources : + List of sources + params + Flat list of all parameters for all sources. + """ + j = 0 + new_sources = cls() + for i, source in enumerate(sources): + name = source.__class__.__name__.lower() + n_params = source.n_params + new_sources.append(Source(name, *list(params[j : j + n_params]))) + j += n_params + + return new_sources + + +class SourceFactory: + r""" + Source Factory + """ + + def __init__(self, registry: dict[str, Callable]): + self._registry: dict[str, Callable] = registry + + def __call__(self, shape_type: str, *args, **kwargs) -> GenericSource: + shape_type = shape_type.lower() + cls = self._registry.get(shape_type) + if not cls: + raise ValueError(f"Unknown shape type: {shape_type}") + try: + return cls(*args, **kwargs) + except TypeError as e: + raise ValueError(f"Error creating '{shape_type}': {e}") + + +#: Instance of SourceFactory +Source = SourceFactory(registry=GenericSource._registry) diff --git a/xrayvision/vis_forward_fit/tests/__init__.py b/xrayvision/vis_forward_fit/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/xrayvision/vis_forward_fit/tests/test_forward_fit.py b/xrayvision/vis_forward_fit/tests/test_forward_fit.py new file mode 100644 index 0000000..9b71189 --- /dev/null +++ b/xrayvision/vis_forward_fit/tests/test_forward_fit.py @@ -0,0 +1,87 @@ +import astropy.units as apu +import numpy as np +from numpy.testing import assert_allclose +from scipy.optimize import minimize + +from xrayvision.transform import generate_uv +from xrayvision.vis_forward_fit.forward_fit import ( + SourceList, + _vis_forward_fit_minimise, + circular_gaussian_vis, + sources_to_image, + sources_to_vis, +) +from xrayvision.vis_forward_fit.sources import Source +from xrayvision.visibility import Visibilities + + +def test_simple_fit(): + uu = generate_uv(11 * apu.pixel)[::4] + u, v = np.meshgrid(uu, uu) + u = u.flatten().value + v = v.flatten().value + vis = circular_gaussian_vis(1, u, v, 0, 0, 2) + vis_ri = np.hstack([vis.real, vis.imag]) + + def objective(x, u, v, vis_ri): + params = [x[0], u, v, *x[1:]] + vispred = circular_gaussian_vis(*params) + vispred_ri = np.hstack([vispred.real, vispred.imag]) + return np.sum(np.abs(vis_ri - vispred_ri) ** 2) + + res = minimize(objective, [0.5, 0.5, 0.5, 1], (u, v, vis_ri), method="Nelder-Mead") + assert_allclose(res.x, [1, 0, 0, 2], atol=1e-5, rtol=1e-5) + + +def test_sources_to_map(): + sources = SourceList( + [ + Source("circular", 2 * apu.arcsec, -4 * apu.arcsec, -5 * apu.arcsec, 2 * apu.arcsec), + Source("circular", 4 * apu.arcsec, 5 * apu.arcsec, 4 * apu.arcsec, 3 * apu.arcsec), + ] + ) + image = sources_to_image(sources, [33, 33] * apu.pixel, pixel_size=[1, 1] * apu.arcsec / apu.pixel) + assert_allclose(image.sum().value, 6, rtol=5e-5) + y, x = 33 // 2 - 4, 33 // 2 - 5 + assert_allclose(image[x, y].value, 2 / (2 * np.pi * 2**2), atol=1e-5, rtol=5e-5) + y, x = 33 // 2 + 5, 33 // 2 + 4 + assert_allclose(image[x, y].value, 4 / (2 * np.pi * 3**2), atol=1e-5, rtol=5e-5) + + +def test_sources_to_vis(): + sources = SourceList([Source("circular", 2, -4, -5, 2), Source("circular", 4, 5, 4, 3)]) + uu = generate_uv(33 * apu.pixel).value + u, v = np.meshgrid(uu, uu) + vis = sources_to_vis(sources, u, v) + assert_allclose(vis.real.max(), 6, rtol=5e-5) + + +# Just testing the machinery not if the fitting is robust/good +def test_vis_forward_fit_minimise(): + rng = np.random.default_rng(42) + sources = SourceList([Source("circular", 2, -4, -5, 2), Source("elliptical", 4, 5, 4, 3, 8, 45)]) + uu = generate_uv(33 * apu.pixel) + u, v = np.meshgrid(uu, uu) + vis = sources_to_vis(sources, u.value, v.value) + visobs = Visibilities(vis.flatten() * apu.ph, u.flatten(), v.flatten()) + # Create non-optimal source parameters + init_souces = SourceList.from_params(sources, rng.standard_normal(len(sources.params)) * 0.1 + sources.params) + sources_fit, res = _vis_forward_fit_minimise(visobs, init_souces, method="Nelder-Mead") + assert_allclose(sources_fit.params, sources.params, atol=1e-4, rtol=1e-5) + + +# Just testing the machinery not if the fitting is robust/good +def test_vis_forward_fit_minimise_pso(): + rng = np.random.default_rng(42) + sources = SourceList([Source("circular", 2, -4, -5, 2), Source("circular", 4, 5, 4, 3)]) + uu = generate_uv(33 * apu.pixel) + u, v = np.meshgrid(uu, uu) + vis = sources_to_vis(sources, u.value, v.value) + # Create non-optimal source parameters + init_souces = SourceList.from_params(sources, rng.standard_normal(len(sources.params)) * 0.1 + sources.params) + visobs = Visibilities(vis.flatten() * apu.ph, u.flatten(), v.flatten()) + sources_fit, res = _vis_forward_fit_minimise(visobs, init_souces, method="PSO") + # the sources can swap so sort by x0, y0 before comparison + sources.sort(key=lambda x: (x.x0, x.y0)) + sources_fit.sort(key=lambda x: (x.x0, x.y0)) + assert_allclose(sources_fit.params, sources.params, atol=1e-4, rtol=5e-5) diff --git a/xrayvision/vis_forward_fit/tests/test_sources.py b/xrayvision/vis_forward_fit/tests/test_sources.py new file mode 100644 index 0000000..59e6d13 --- /dev/null +++ b/xrayvision/vis_forward_fit/tests/test_sources.py @@ -0,0 +1,208 @@ +import astropy.units as apu +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from xrayvision.imaging import image_to_vis, vis_to_image +from xrayvision.transform import generate_uv, generate_xy +from xrayvision.vis_forward_fit.sources import ( + Circular, + Elliptical, + Source, + SourceList, + circular_gaussian_img, + circular_gaussian_vis, + elliptical_gaussian_img, + elliptical_gaussian_vis, + loop_img, + loop_img_old, + loop_vis, + loop_vis_old, +) +from xrayvision.visibility import Visibilities + + +@pytest.mark.parametrize("x0", [0, -10]) +@pytest.mark.parametrize("y0", [0, 6]) +@pytest.mark.parametrize("sigma", [3]) +@pytest.mark.parametrize("size", [1, 2]) +@pytest.mark.parametrize("shape", [63]) +def test_circular_ft_equivalence_fft(x0, y0, sigma, size, shape): + xx = generate_xy(shape * apu.pix, pixel_size=size * apu.arcsec / apu.pixel, phase_center=0 * apu.arcsec) + yy = generate_xy(shape * apu.pix, pixel_size=size * apu.arcsec / apu.pixel, phase_center=0 * apu.arcsec) + x, y = np.meshgrid(xx, yy) + # by definition map center on phase 0,0 + uu = generate_uv(shape * apu.pix, pixel_size=size * apu.arcsec / apu.pixel, phase_center=0 * apu.arcsec) + vv = generate_uv(shape * apu.pix, pixel_size=size * apu.arcsec / apu.pixel, phase_center=0 * apu.arcsec) + u, v = np.meshgrid(uu, vv) + u = u.flatten() + v = v.flatten() + + image = circular_gaussian_img(1, x, y, x0 * apu.arcsec, y0 * apu.arcsec, sigma * size * apu.arcsec) + + vis_obs = image_to_vis(image, u=u, v=v, pixel_size=[size, size] * apu.arcsec / apu.pixel) + # by definition map center on phase 0,0 + vis_func = Visibilities( + circular_gaussian_vis(1, u, v, x0 * apu.arcsec, y0 * apu.arcsec, sigma * size * apu.arcsec).flatten(), u, v + ) + + image_func = vis_to_image(vis_func, [shape, shape] * apu.pixel, pixel_size=[size, size] * apu.arcsec / apu.pixel) + image_vis = vis_to_image(vis_obs, [shape, shape] * apu.pixel, pixel_size=[size, size] * apu.arcsec / apu.pixel) + + assert_allclose(vis_obs.visibilities.value, vis_func.visibilities, atol=1e-8) + assert_allclose(image.value, image_func.value, atol=1e-9) + assert_allclose(image.value, image_vis.value, atol=1e-9) + + +@pytest.mark.parametrize("x0", [0, 1, 2, 3]) +@pytest.mark.parametrize("y0", [0, -1, 2, -3]) +@pytest.mark.parametrize("sigma", [1, 2, 3]) +def test_equivalence_elliptical_to_circular(x0, y0, sigma): + amp = 1 + x, y = np.meshgrid(np.linspace(-20, 20, 101), np.linspace(-20, 20, 101)) + image_circular = circular_gaussian_img(amp, x, y, x0, y0, sigma) + image_elliptical = elliptical_gaussian_img(amp, x, y, x0, y0, sigma, sigma, 0) + assert_allclose(image_circular, image_elliptical, atol=1e-13) + + +@pytest.mark.parametrize("x0", [0, 1, 2, 3]) +@pytest.mark.parametrize("y0", [0, -1, 2, -3]) +@pytest.mark.parametrize("sigma", [1, 2, 3]) +def test_equivalence_elliptical_to_circular_vis(x0, y0, sigma): + amp = 1 + u, v = np.meshgrid(np.linspace(-20, 20, 101), np.linspace(-20, 20, 101)) + u = u * 1 / 2.5 + v = v * 1 / 2.5 + vis_circular = circular_gaussian_vis(amp, u, v, x0, y0, sigma) + vis_elliptical = elliptical_gaussian_vis(amp, u, v, x0, y0, sigma, sigma, 0) + assert_allclose(vis_circular, vis_elliptical, atol=1e-13) + + +@pytest.mark.parametrize("size", (65, 79)) +def test_loop_ft_equivalence_fft(size): + # So unless the array is sufficiently large this test fails + # I think has to do with the fact no taking into account the sampleing and implicit windowing + # TODO: How does this affect algo where the vis derived from map are compare to the observed? + # sigma = 4 * apu.arcsec + xx = generate_xy(size * apu.pix) + x, y = np.meshgrid(xx, xx) + uu = generate_uv(size * apu.pix) + u, v = np.meshgrid(uu, uu) + u = u.flatten() + v = v.flatten() + + image = loop_img_old( + 80, x, y, 0 * apu.arcsec, 0 * apu.arcsec, 9.0 * apu.arcsec, 22.5 * apu.arcsec, np.deg2rad(90), np.deg2rad(70) + ) + + vis_obs = image_to_vis(image * apu.ph, u=u, v=v) + vis_func = Visibilities( + loop_vis_old( + 80, + u, + v, + 0 * apu.arcsec, + 0 * apu.arcsec, + 9.0 * apu.arcsec, + 22.5 * apu.arcsec, + np.deg2rad(90), + np.deg2rad(70), + ).flatten() + * apu.ph, + u, + v, + ) + + image_func = vis_to_image(vis_func, [size, size] * apu.pixel) + image_vis = vis_to_image(vis_obs, [size, size] * apu.pixel) + + assert_allclose(vis_func.visibilities, vis_obs.visibilities, atol=1e-9) + assert_allclose(image_func, image_vis, atol=1e-9) + + +# def test_model_equivalence_fft(): +# uu = generate_uv(65 * apu.pix) +# u, v = np.meshgrid(uu, uu) +# u = u.flatten() +# v = v.flatten() +# +# image = model_img(np.atleast_2d([1, 2, 0, 0, 10, -90, 0.1]).T, 65, 65, 1) +# vis_obs = image_to_vis(image * apu.ph, u=u, v=v) +# +# vis_func = Visibilities( +# model_vis(np.atleast_2d([1, 2, 0, 0, 10, -90, 0.1]).T, 65, 65, 1).flatten() * apu.ph, u, v, +# ) +# +# image_func = vis_to_image(vis_func, [65, 65] * apu.pixel) +# image_vis = vis_to_image(vis_obs, [65, 65] * apu.pixel) +# +# assert_allclose(vis_func.visibilities, vis_obs.visibilities, atol=1e-9) +# assert_allclose(image_func, image_vis, atol=1e-9) + + +def test_source_factory(): + circular = Source("circular", 1, 2, 3, 4) + assert isinstance(circular, Circular) + assert circular.amp == 1 + assert circular.x0 == 2 + assert circular.y0 == 3 + assert circular.sigma == 4 + + elliptical = Source("elliptical", 1, 2, 3, 4, 5, 6) + assert isinstance(elliptical, Elliptical) + assert elliptical.amp == 1 + assert elliptical.x0 == 2 + assert elliptical.y0 == 3 + assert elliptical.sigmax == 4 + assert elliptical.sigmay == 5 + assert elliptical.theta == 6 + + +def test_source_list(): + orig_sources = SourceList([Source("circular", 1, 2, 3, 4), Source("elliptical", 1, 2, 3, 4, 5, 6)]) + params = orig_sources.params + assert params == [1, 2, 3, 4, 1, 2, 3, 4, 5, 6] + new_sources = SourceList.from_params(orig_sources, params) + assert orig_sources == new_sources + + +def test_loop_image_oldvnew(): + x = generate_xy(65 * apu.pixel).value + y = generate_xy(65 * apu.pixel).value + x, y = np.meshgrid(x, y) + + flux = 100 + x0 = 0 + y0 = 0 + sigmaj = 20 + sigmin = 10 + rotatiion = np.pi / 4 + loopw = np.deg2rad(110) + + sigma_to_fwhm = 2 * np.sqrt(2 * np.log(2)) + + image_old = loop_img_old(flux, x, y, x0, y0, sigmin * sigma_to_fwhm, sigmaj * sigma_to_fwhm, rotatiion, loopw) + image = loop_img(flux, x, y, x0, y0, sigmin, sigmaj, rotatiion, loopw) + + assert_allclose(image, image_old) + + +def test_loop_vis_oldvnew(): + u = generate_uv(65 * apu.pixel).value + v = generate_uv(65 * apu.pixel).value + u, v = np.meshgrid(u, v) + + flux = 100 + x0 = 0 + y0 = 0 + sigmaj = 20 + sigmin = 10 + rotatiion = np.pi / 4 + loopw = np.deg2rad(110) + + sigma_to_fwhm = 2 * np.sqrt(2 * np.log(2)) + + vis_old = loop_vis_old(flux, u, v, x0, y0, sigmin * sigma_to_fwhm, sigmaj * sigma_to_fwhm, rotatiion, loopw) + vis = loop_vis(flux, u, v, x0, y0, sigmin, sigmaj, rotatiion, loopw) + + assert_allclose(vis, vis_old) diff --git a/xrayvision/visibility.py b/xrayvision/visibility.py index ca9698f..ac50fad 100644 --- a/xrayvision/visibility.py +++ b/xrayvision/visibility.py @@ -193,7 +193,7 @@ def instrument(self): class Visibilities(VisibilitiesABC): @apu.quantity_input() - def __init__( + def __init__( # noqa: C901 self, visibilities: apu.Quantity, u: apu.Quantity[1 / apu.deg], @@ -431,8 +431,8 @@ def __getitem__(self, item): if len(item) != len(dims): item = list(item) + [slice(None)] * (len(dims) - len(item)) if all(isinstance(idx, numbers.Integral) for idx in item): - ValueError("Slicing out single visibility not supported.") - ds_item = dict((key, idx) for key, idx in zip(dims, item)) + raise ValueError("Slicing out single visibility not supported.") + ds_item = dict(zip(dims, item)) new_data = self._data.isel(ds_item) new_data.attrs[self._meta_key][_VIS_LABELS_KEY] = new_data.coords[_VIS_LABELS_KEY].values new_vis = copy.deepcopy(self) diff --git a/xrayvision/visualisation.py b/xrayvision/visualisation.py new file mode 100644 index 0000000..201b512 --- /dev/null +++ b/xrayvision/visualisation.py @@ -0,0 +1,110 @@ +from typing import Optional + +import numpy as np +from astropy.visualization import quantity_support +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from matplotlib.markers import MarkerStyle +from matplotlib.transforms import Affine2D + +from xrayvision.visibility import Visibilities + + +def plot_vis(vis: Visibilities, fig: Optional[Figure] = None, **mplkwargs: dict) -> Figure: + r""" + Plot visibilities amplitude and phase. + + Plot as a function of the resolution (r) :math:`\sqrt{u^2 + v^2` and angle (theta) :math:`\mathrm{arctan2}(u, v)`. + Theta is represented as a rotation of the plot symbol. + + Parameters + ---------- + vis : + Visibilities + fig : + Figure to use if given will use the first and second axes to plot the amplitude and phase. + mplkwargs : + Keyword arguments passed to matplotlib + + Returns + ------- + + """ + if fig is None: + fig, axes = plt.subplots(2, 1, sharex=True) + else: + axes = fig.get_axes() + + fig.subplots_adjust(hspace=0) + + angles = np.arctan2(vis.u, vis.v) + size = 1 / np.sqrt(vis.u**2 + vis.v**2) + + with quantity_support(): + for i, _ in enumerate(vis.visibilities): + transform = Affine2D().rotate(angles[i].to_value("deg")) + axes[0].scatter( + size[i], np.absolute(vis.visibilities[i]), marker=MarkerStyle("|", transform=transform), **mplkwargs + ) + axes[1].scatter( + size[i], + np.angle(vis.visibilities[i]).to("deg"), + marker=MarkerStyle("|", transform=transform), + **mplkwargs, + ) + + axes[0].set_ylabel(f"Amplitude [{vis.visibilities.unit}]") + axes[1].set_ylabel("Phase [deg]") + axes[1].set_xlabel(f"Resolution [{size.unit}]") + + return fig, axes + + +def plot_vis_diff(visa: Visibilities, visb: Visibilities, fig=None, **mplkwargs): + r""" + Plot the difference between amplitude and phase of the visibilities. + + Plot as a function of the resolution :math:`\sqrt{u^2 + v^2` and angle :math:`\mathrm{arctan2}(u, v)`. + The resolution is used as the x-axis while the angle is displayed as a rotation of the plot symbol. + + Parameters + ---------- + visa + Visibilities to plot + visb + Visibilities to plot + fig : + Figure to use + mplkwargs + Keyword arguments passed to matplotlib + + Returns + ------- + + """ + if not (np.all(visa.u == visb.u) and np.all(visb.v == visb.v)): + raise ValueError("The visibilities must have the same u, v coordinates.") + + if fig is None: + fig, axes = plt.subplots(2, 1, sharex=True) + else: + axes = fig.get_axes() + + fig.subplots_adjust(hspace=0) + + angles = np.arctan2(visa.u, visa.v) + size = 1 / np.sqrt(visa.u**2 + visa.v**2) + vis_diff = visa.visibilities - visb.visibilities + + for i, _ in enumerate(vis_diff): + transform = Affine2D().rotate(angles[i].to_value("deg")) + axes[0].scatter(size[i], np.absolute(vis_diff[i]), marker=MarkerStyle("|", transform=transform), **mplkwargs) + axes[1].scatter( + size[i], np.angle(vis_diff[i]).to("deg"), marker=MarkerStyle("|", transform=transform), **mplkwargs + ) + + axes[0].set_ylabel(f"Amplitude Diff [{vis_diff.unit}]") + axes[1].set_ylabel("Phase Diff [deg]") + axes[1].set_xlabel(f"Resolution [{size.unit}]") + + return fig, axes