Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/151.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed mixed frequency unit plotting on shared axes to correctly convert units when multiple spectrograms with different frequency units are plotted together.
75 changes: 65 additions & 10 deletions radiospectra/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,40 @@
from astropy.visualization import time_support
from matplotlib import pyplot as plt
from matplotlib.image import NonUniformImage

from astropy.visualization import quantity_support, time_support


def _get_axis_converter(axis):
"""
Safe method to get axis converter for older and newer MPL versions.

``axis.get_converter()`` / ``axis.set_converter()`` were added in
Matplotlib 3.9. Once the minimum supported Matplotlib version is
>= 3.9 these helpers can be replaced by direct get/set calls.
"""
try:
return axis.get_converter()
except AttributeError:
try:
return axis.converter
except AttributeError:
return None


def _set_axis_converter(axis, converter):
"""
Safe method to set axis converter for older and newer MPL versions.

See `_get_axis_converter` for version notes.
"""
try:
axis.set_converter(converter)
except AttributeError:
try:
axis._set_converter(converter)
axis._converter_is_explicit = True
except AttributeError:
axis.converter = converter


class PcolormeshPlotMixin:
Expand All @@ -21,7 +57,6 @@ def plot(self, axes=None, **kwargs):
-------
`matplotlib.collections.QuadMesh`
"""
from matplotlib import pyplot as plt

if axes is None:
fig, axes = plt.subplots()
Expand All @@ -38,12 +73,22 @@ def plot(self, axes=None, **kwargs):
title = f"{title}, {self.detector}"

axes.set_title(title)
with time_support():

with time_support(), quantity_support():
# Pin existing converters to avoid warnings when re-plotting on shared axes.
converter_y = _get_axis_converter(axes.yaxis)
if converter_y is not None and not getattr(axes.yaxis, "_converter_is_explicit", False):
_set_axis_converter(axes.yaxis, converter_y)

converter_x = _get_axis_converter(axes.xaxis)
if converter_x is not None and not getattr(axes.xaxis, "_converter_is_explicit", False):
_set_axis_converter(axes.xaxis, converter_x)

axes.plot(self.times[[0, -1]], self.frequencies[[0, -1]], linestyle="None", marker="None")
if self.times.shape[0] == self.data.shape[0] and self.frequencies.shape[0] == self.data.shape[1]:
ret = axes.pcolormesh(self.times, self.frequencies.value, data, shading="auto", **kwargs)
ret = axes.pcolormesh(self.times, self.frequencies, data, shading="auto", **kwargs)
else:
ret = axes.pcolormesh(self.times, self.frequencies.value, data[:-1, :-1], shading="auto", **kwargs)
ret = axes.pcolormesh(self.times, self.frequencies, data[:-1, :-1], shading="auto", **kwargs)
axes.set_xlim(self.times[0], self.times[-1])
fig.autofmt_xdate()

Expand All @@ -61,16 +106,26 @@ class NonUniformImagePlotMixin:
"""

def plotim(self, fig=None, axes=None, **kwargs):
from matplotlib import pyplot as plt
from matplotlib.image import NonUniformImage

if axes is None:
fig, axes = plt.subplots()

with time_support():
with time_support(), quantity_support():
# Pin existing converters to avoid warnings when re-plotting on shared axes.
converter_y = _get_axis_converter(axes.yaxis)
if converter_y is not None and not getattr(axes.yaxis, "_converter_is_explicit", False):
_set_axis_converter(axes.yaxis, converter_y)

converter_x = _get_axis_converter(axes.xaxis)
if converter_x is not None and not getattr(axes.xaxis, "_converter_is_explicit", False):
_set_axis_converter(axes.xaxis, converter_x)

axes.yaxis.update_units(self.frequencies)
frequencies = axes.yaxis.convert_units(self.frequencies)

axes.plot(self.times[[0, -1]], self.frequencies[[0, -1]], linestyle="None", marker="None")
im = NonUniformImage(axes, interpolation="none", **kwargs)
im.set_data(axes.convert_xunits(self.times), self.frequencies.value, self.data)
im.set_data(axes.convert_xunits(self.times), frequencies, self.data)
axes.add_image(im)
axes.set_xlim(self.times[0], self.times[-1])
axes.set_ylim(self.frequencies.value[0], self.frequencies.value[-1])
axes.set_ylim(frequencies[0], frequencies[-1])
30 changes: 30 additions & 0 deletions radiospectra/spectrogram/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
import matplotlib

matplotlib.use("Agg")

import numpy as np
import pytest

import astropy.units as u
from astropy.time import Time

from radiospectra.spectrogram.spectrogrambase import GenericSpectrogram


@pytest.fixture
def make_spectrogram():
"""Factory fixture to create test spectrograms with given frequencies."""

def _make(frequencies, times=None, scale="utc"):
if times is None:
times = Time("2020-01-01T00:00:00", format="isot", scale=scale) + np.arange(4) * u.min
meta = {
"observatory": "Test",
"instrument": "TestInst",
"detector": "TestDet",
"start_time": times[0],
"end_time": times[-1],
"wavelength": np.array([1, 10]) * u.m,
"times": times,
"freqs": frequencies,
}
return GenericSpectrogram(np.arange(16).reshape(4, 4), meta)

return _make
120 changes: 100 additions & 20 deletions radiospectra/spectrogram/tests/test_spectrogrambase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,108 @@
import numpy as np

import astropy.units as u
from astropy.time import Time

from radiospectra.spectrogram.spectrogrambase import GenericSpectrogram

def test_plot_mixed_frequency_units_on_same_axes(make_spectrogram):
"""Two spectrograms with different frequency units should plot on the same axes."""
rad1 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
rad2 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz)

def _get_spectrogram_with_time_scale(scale):
times = Time("2020-01-01T00:00:00", format="isot", scale=scale) + np.arange(4) * u.min
frequencies = np.linspace(10, 40, 4) * u.MHz
meta = {
"observatory": "Test",
"instrument": "TestInst",
"detector": "TestDet",
"start_time": times[0],
"end_time": times[-1],
"wavelength": np.array([1, 10]) * u.m,
"times": times,
"freqs": frequencies,
}
return GenericSpectrogram(np.arange(16).reshape(4, 4), meta)
rad1.plot()
axes = plt.gca()
rad2.plot(axes=axes)

# The y-axis unit should be set by the first spectrogram (kHz)
assert axes.yaxis.units == u.kHz
# The y-axis range should cover the converted MHz values (up to 4000 kHz)
y_min, y_max = axes.get_ylim()
plt.close("all")

def test_plot_uses_time_support_for_datetime_conversion():
spec = _get_spectrogram_with_time_scale("tt")
assert y_max > 1000, "MHz values should be converted to kHz on the y-axis"


def test_plot_mixed_frequency_units_mhz_first(make_spectrogram):
"""Plot MHz spectrogram first, then kHz — units should stay as MHz."""
rad1 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz)
rad2 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)

rad1.plot()
axes = plt.gca()
rad2.plot(axes=axes)

# The y-axis unit should be set by the first spectrogram (MHz)
assert axes.yaxis.units == u.MHz
# kHz values should be converted to MHz; 40 kHz = 0.04 MHz
y_min, y_max = axes.get_ylim()
plt.close("all")

assert y_max >= 4, "y-axis should cover up to 4 MHz"


def test_plotim(make_spectrogram):
"""Test NonUniformImagePlotMixin.plotim() executes without error."""
rad_im = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
with (
mock.patch("matplotlib.image.NonUniformImage.set_interpolation", autospec=True),
mock.patch("matplotlib.image.NonUniformImage.set_data", autospec=True) as set_data,
):
rad_im.plotim()
plt.close("all")

_, x_values, y_values, image = set_data.call_args.args
assert len(x_values) == len(rad_im.times)
np.testing.assert_allclose(y_values, rad_im.frequencies.value)
np.testing.assert_allclose(image, rad_im.data)


def test_plotim_mixed_frequency_units_on_same_axes(make_spectrogram):
"""Two NonUniformImage plots with different units should share conversion."""
rad1 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
rad2 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz)
fig, axes = plt.subplots()
with (
mock.patch("matplotlib.image.NonUniformImage.set_interpolation", autospec=True),
mock.patch("matplotlib.image.NonUniformImage.set_data", autospec=True) as set_data,
):
rad1.plotim(axes=axes)
rad2.plotim(axes=axes)
plt.close(fig)

_, _, y_values, _ = set_data.call_args.args
np.testing.assert_allclose(y_values, np.array([1000, 2000, 3000, 4000]))


def test_plot_with_quantity_data(make_spectrogram):
"""Test plotting when data is an astropy Quantity."""
rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
rad.data = rad.data * u.ct
rad.plot()
plt.close("all")


def test_plot_with_shape_mismatch(make_spectrogram):
"""Test plotting branch when data shape doesn't exactly match time/freq arrays."""
rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
# times/freqs are length 4, data shape (4, 4) matches.
# make data (5, 5) to trigger the `else` branch (data[:-1, :-1])
rad.data = np.zeros((5, 5))
rad.plot()
plt.close("all")


def test_plot_instrument_detector_differ(make_spectrogram):
"""Test title generation when instrument and detector differ."""
rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
# GenericSpectrogram gets instrument/detector from meta dictionary
rad.meta["detector"] = "DifferentDetector"
mesh = rad.plot()
assert "DifferentDetector".upper() in mesh.axes.get_title().upper()
plt.close("all")


def test_plot_uses_time_support_for_datetime_conversion(make_spectrogram):
"""Plotting with non-UTC time scale should use time_support."""
spec = make_spectrogram(np.linspace(10, 40, 4) * u.MHz, scale="tt")

mesh = spec.plot()
x_limits = np.array(mesh.axes.get_xlim())
Expand All @@ -37,8 +116,9 @@ def test_plot_uses_time_support_for_datetime_conversion():
np.testing.assert_allclose(x_limits, expected_tt_limits)


def test_plotim_uses_time_support_for_datetime_conversion():
spec = _get_spectrogram_with_time_scale("tt")
def test_plotim_uses_time_support_for_datetime_conversion(make_spectrogram):
"""plotim with non-UTC time scale should use time_support."""
spec = make_spectrogram(np.linspace(10, 40, 4) * u.MHz, scale="tt")
fig, axes = plt.subplots()

with (
Expand Down