diff --git a/changelog/151.bugfix.rst b/changelog/151.bugfix.rst new file mode 100644 index 0000000..6b6943a --- /dev/null +++ b/changelog/151.bugfix.rst @@ -0,0 +1 @@ +Fixed mixed frequency unit plotting on shared axes to correctly convert units when multiple spectrograms with different frequency units are plotted together. diff --git a/radiospectra/mixins.py b/radiospectra/mixins.py index f8b4fc3..488c7cf 100644 --- a/radiospectra/mixins.py +++ b/radiospectra/mixins.py @@ -1,4 +1,40 @@ -from astropy.visualization import time_support +from matplotlib import pyplot as plt +from matplotlib.image import NonUniformImage + +from astropy.visualization import quantity_support, time_support + + +def _get_axis_converter(axis): + """ + Safe method to get axis converter for older and newer MPL versions. + + ``axis.get_converter()`` / ``axis.set_converter()`` were added in + Matplotlib 3.9. Once the minimum supported Matplotlib version is + >= 3.9 these helpers can be replaced by direct get/set calls. + """ + try: + return axis.get_converter() + except AttributeError: + try: + return axis.converter + except AttributeError: + return None + + +def _set_axis_converter(axis, converter): + """ + Safe method to set axis converter for older and newer MPL versions. + + See `_get_axis_converter` for version notes. + """ + try: + axis.set_converter(converter) + except AttributeError: + try: + axis._set_converter(converter) + axis._converter_is_explicit = True + except AttributeError: + axis.converter = converter class PcolormeshPlotMixin: @@ -21,7 +57,6 @@ def plot(self, axes=None, **kwargs): ------- `matplotlib.collections.QuadMesh` """ - from matplotlib import pyplot as plt if axes is None: fig, axes = plt.subplots() @@ -38,12 +73,22 @@ def plot(self, axes=None, **kwargs): title = f"{title}, {self.detector}" axes.set_title(title) - with time_support(): + + with time_support(), quantity_support(): + # Pin existing converters to avoid warnings when re-plotting on shared axes. + converter_y = _get_axis_converter(axes.yaxis) + if converter_y is not None and not getattr(axes.yaxis, "_converter_is_explicit", False): + _set_axis_converter(axes.yaxis, converter_y) + + converter_x = _get_axis_converter(axes.xaxis) + if converter_x is not None and not getattr(axes.xaxis, "_converter_is_explicit", False): + _set_axis_converter(axes.xaxis, converter_x) + axes.plot(self.times[[0, -1]], self.frequencies[[0, -1]], linestyle="None", marker="None") if self.times.shape[0] == self.data.shape[0] and self.frequencies.shape[0] == self.data.shape[1]: - ret = axes.pcolormesh(self.times, self.frequencies.value, data, shading="auto", **kwargs) + ret = axes.pcolormesh(self.times, self.frequencies, data, shading="auto", **kwargs) else: - ret = axes.pcolormesh(self.times, self.frequencies.value, data[:-1, :-1], shading="auto", **kwargs) + ret = axes.pcolormesh(self.times, self.frequencies, data[:-1, :-1], shading="auto", **kwargs) axes.set_xlim(self.times[0], self.times[-1]) fig.autofmt_xdate() @@ -61,16 +106,26 @@ class NonUniformImagePlotMixin: """ def plotim(self, fig=None, axes=None, **kwargs): - from matplotlib import pyplot as plt - from matplotlib.image import NonUniformImage if axes is None: fig, axes = plt.subplots() - with time_support(): + with time_support(), quantity_support(): + # Pin existing converters to avoid warnings when re-plotting on shared axes. + converter_y = _get_axis_converter(axes.yaxis) + if converter_y is not None and not getattr(axes.yaxis, "_converter_is_explicit", False): + _set_axis_converter(axes.yaxis, converter_y) + + converter_x = _get_axis_converter(axes.xaxis) + if converter_x is not None and not getattr(axes.xaxis, "_converter_is_explicit", False): + _set_axis_converter(axes.xaxis, converter_x) + + axes.yaxis.update_units(self.frequencies) + frequencies = axes.yaxis.convert_units(self.frequencies) + axes.plot(self.times[[0, -1]], self.frequencies[[0, -1]], linestyle="None", marker="None") im = NonUniformImage(axes, interpolation="none", **kwargs) - im.set_data(axes.convert_xunits(self.times), self.frequencies.value, self.data) + im.set_data(axes.convert_xunits(self.times), frequencies, self.data) axes.add_image(im) axes.set_xlim(self.times[0], self.times[-1]) - axes.set_ylim(self.frequencies.value[0], self.frequencies.value[-1]) + axes.set_ylim(frequencies[0], frequencies[-1]) diff --git a/radiospectra/spectrogram/tests/conftest.py b/radiospectra/spectrogram/tests/conftest.py index 290cc21..c8f025c 100644 --- a/radiospectra/spectrogram/tests/conftest.py +++ b/radiospectra/spectrogram/tests/conftest.py @@ -1,3 +1,33 @@ import matplotlib matplotlib.use("Agg") + +import numpy as np +import pytest + +import astropy.units as u +from astropy.time import Time + +from radiospectra.spectrogram.spectrogrambase import GenericSpectrogram + + +@pytest.fixture +def make_spectrogram(): + """Factory fixture to create test spectrograms with given frequencies.""" + + def _make(frequencies, times=None, scale="utc"): + if times is None: + times = Time("2020-01-01T00:00:00", format="isot", scale=scale) + np.arange(4) * u.min + meta = { + "observatory": "Test", + "instrument": "TestInst", + "detector": "TestDet", + "start_time": times[0], + "end_time": times[-1], + "wavelength": np.array([1, 10]) * u.m, + "times": times, + "freqs": frequencies, + } + return GenericSpectrogram(np.arange(16).reshape(4, 4), meta) + + return _make diff --git a/radiospectra/spectrogram/tests/test_spectrogrambase.py b/radiospectra/spectrogram/tests/test_spectrogrambase.py index 6f31d3c..ddd3a36 100644 --- a/radiospectra/spectrogram/tests/test_spectrogrambase.py +++ b/radiospectra/spectrogram/tests/test_spectrogrambase.py @@ -4,29 +4,108 @@ import numpy as np import astropy.units as u -from astropy.time import Time -from radiospectra.spectrogram.spectrogrambase import GenericSpectrogram +def test_plot_mixed_frequency_units_on_same_axes(make_spectrogram): + """Two spectrograms with different frequency units should plot on the same axes.""" + rad1 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + rad2 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz) -def _get_spectrogram_with_time_scale(scale): - times = Time("2020-01-01T00:00:00", format="isot", scale=scale) + np.arange(4) * u.min - frequencies = np.linspace(10, 40, 4) * u.MHz - meta = { - "observatory": "Test", - "instrument": "TestInst", - "detector": "TestDet", - "start_time": times[0], - "end_time": times[-1], - "wavelength": np.array([1, 10]) * u.m, - "times": times, - "freqs": frequencies, - } - return GenericSpectrogram(np.arange(16).reshape(4, 4), meta) + rad1.plot() + axes = plt.gca() + rad2.plot(axes=axes) + # The y-axis unit should be set by the first spectrogram (kHz) + assert axes.yaxis.units == u.kHz + # The y-axis range should cover the converted MHz values (up to 4000 kHz) + y_min, y_max = axes.get_ylim() + plt.close("all") -def test_plot_uses_time_support_for_datetime_conversion(): - spec = _get_spectrogram_with_time_scale("tt") + assert y_max > 1000, "MHz values should be converted to kHz on the y-axis" + + +def test_plot_mixed_frequency_units_mhz_first(make_spectrogram): + """Plot MHz spectrogram first, then kHz — units should stay as MHz.""" + rad1 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz) + rad2 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + + rad1.plot() + axes = plt.gca() + rad2.plot(axes=axes) + + # The y-axis unit should be set by the first spectrogram (MHz) + assert axes.yaxis.units == u.MHz + # kHz values should be converted to MHz; 40 kHz = 0.04 MHz + y_min, y_max = axes.get_ylim() + plt.close("all") + + assert y_max >= 4, "y-axis should cover up to 4 MHz" + + +def test_plotim(make_spectrogram): + """Test NonUniformImagePlotMixin.plotim() executes without error.""" + rad_im = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + with ( + mock.patch("matplotlib.image.NonUniformImage.set_interpolation", autospec=True), + mock.patch("matplotlib.image.NonUniformImage.set_data", autospec=True) as set_data, + ): + rad_im.plotim() + plt.close("all") + + _, x_values, y_values, image = set_data.call_args.args + assert len(x_values) == len(rad_im.times) + np.testing.assert_allclose(y_values, rad_im.frequencies.value) + np.testing.assert_allclose(image, rad_im.data) + + +def test_plotim_mixed_frequency_units_on_same_axes(make_spectrogram): + """Two NonUniformImage plots with different units should share conversion.""" + rad1 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + rad2 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz) + fig, axes = plt.subplots() + with ( + mock.patch("matplotlib.image.NonUniformImage.set_interpolation", autospec=True), + mock.patch("matplotlib.image.NonUniformImage.set_data", autospec=True) as set_data, + ): + rad1.plotim(axes=axes) + rad2.plotim(axes=axes) + plt.close(fig) + + _, _, y_values, _ = set_data.call_args.args + np.testing.assert_allclose(y_values, np.array([1000, 2000, 3000, 4000])) + + +def test_plot_with_quantity_data(make_spectrogram): + """Test plotting when data is an astropy Quantity.""" + rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + rad.data = rad.data * u.ct + rad.plot() + plt.close("all") + + +def test_plot_with_shape_mismatch(make_spectrogram): + """Test plotting branch when data shape doesn't exactly match time/freq arrays.""" + rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + # times/freqs are length 4, data shape (4, 4) matches. + # make data (5, 5) to trigger the `else` branch (data[:-1, :-1]) + rad.data = np.zeros((5, 5)) + rad.plot() + plt.close("all") + + +def test_plot_instrument_detector_differ(make_spectrogram): + """Test title generation when instrument and detector differ.""" + rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz) + # GenericSpectrogram gets instrument/detector from meta dictionary + rad.meta["detector"] = "DifferentDetector" + mesh = rad.plot() + assert "DifferentDetector".upper() in mesh.axes.get_title().upper() + plt.close("all") + + +def test_plot_uses_time_support_for_datetime_conversion(make_spectrogram): + """Plotting with non-UTC time scale should use time_support.""" + spec = make_spectrogram(np.linspace(10, 40, 4) * u.MHz, scale="tt") mesh = spec.plot() x_limits = np.array(mesh.axes.get_xlim()) @@ -37,8 +116,9 @@ def test_plot_uses_time_support_for_datetime_conversion(): np.testing.assert_allclose(x_limits, expected_tt_limits) -def test_plotim_uses_time_support_for_datetime_conversion(): - spec = _get_spectrogram_with_time_scale("tt") +def test_plotim_uses_time_support_for_datetime_conversion(make_spectrogram): + """plotim with non-UTC time scale should use time_support.""" + spec = make_spectrogram(np.linspace(10, 40, 4) * u.MHz, scale="tt") fig, axes = plt.subplots() with (