Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/151.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed mixed frequency unit plotting on shared axes to correctly convert units when multiple spectrograms with different frequency units are plotted together.
90 changes: 81 additions & 9 deletions radiospectra/mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,61 @@
from matplotlib import pyplot as plt
from matplotlib.image import NonUniformImage

from astropy import units as u
from astropy.visualization import time_support


def _get_axis_converter(axis):
"""Safe method to get axis converter for older and newer MPL versions."""
try:
return axis.get_converter()
except AttributeError:
try:
return axis.converter
except AttributeError:
return None


def _set_axis_converter(axis, converter):
"""Safe method to set axis converter for older and newer MPL versions."""
try:
axis.set_converter(converter)
except AttributeError:
try:
axis._set_converter(converter)
axis._converter_is_explicit = True
except AttributeError:
axis.converter = converter


def _frequency_values_for_axes(axes, frequencies):
"""
Convert frequencies to the unit already used on the axes when available.
"""
if not hasattr(frequencies, "unit"):
return frequencies, None

target_unit = None
if axes.has_data():
target_unit = getattr(axes, "_radiospectra_frequency_unit", None)
if target_unit is None:
target_unit = axes.yaxis.get_units()
if target_unit is not None:
try:
target_unit = u.Unit(target_unit)
except (TypeError, ValueError):
target_unit = None

if target_unit is None:
target_unit = frequencies.unit
try:
frequency_values = frequencies.to_value(target_unit)
except u.UnitConversionError:
target_unit = frequencies.unit
frequency_values = frequencies.value
return frequency_values, target_unit


class PcolormeshPlotMixin:
"""
Class provides plotting functions using `~pcolormesh`.
Expand All @@ -21,7 +76,6 @@ def plot(self, axes=None, **kwargs):
-------
`matplotlib.collections.QuadMesh`
"""
from matplotlib import pyplot as plt

if axes is None:
fig, axes = plt.subplots()
Expand All @@ -37,13 +91,23 @@ def plot(self, axes=None, **kwargs):
if self.instrument != self.detector:
title = f"{title}, {self.detector}"

plot_frequencies, plot_frequency_unit = _frequency_values_for_axes(axes, self.frequencies)
if plot_frequency_unit is not None:
axes._radiospectra_frequency_unit = plot_frequency_unit
axes.set_ylabel(f"Frequency [{plot_frequency_unit.to_string()}]")

axes.set_title(title)
with time_support():
axes.plot(self.times[[0, -1]], self.frequencies[[0, -1]], linestyle="None", marker="None")
# Pin existing converter to avoid warnings when re-plotting with different units
converter = _get_axis_converter(axes.xaxis)
if converter is not None:
_set_axis_converter(axes.xaxis, converter)

axes.plot(self.times[[0, -1]], [plot_frequencies[0], plot_frequencies[-1]], linestyle="None", marker="None")
if self.times.shape[0] == self.data.shape[0] and self.frequencies.shape[0] == self.data.shape[1]:
ret = axes.pcolormesh(self.times, self.frequencies.value, data, shading="auto", **kwargs)
ret = axes.pcolormesh(self.times, plot_frequencies, data, shading="auto", **kwargs)
else:
ret = axes.pcolormesh(self.times, self.frequencies.value, data[:-1, :-1], shading="auto", **kwargs)
ret = axes.pcolormesh(self.times, plot_frequencies, data[:-1, :-1], shading="auto", **kwargs)
axes.set_xlim(self.times[0], self.times[-1])
fig.autofmt_xdate()

Expand All @@ -61,16 +125,24 @@ class NonUniformImagePlotMixin:
"""

def plotim(self, fig=None, axes=None, **kwargs):
from matplotlib import pyplot as plt
from matplotlib.image import NonUniformImage

if axes is None:
fig, axes = plt.subplots()

plot_frequencies, plot_frequency_unit = _frequency_values_for_axes(axes, self.frequencies)
if plot_frequency_unit is not None:
axes._radiospectra_frequency_unit = plot_frequency_unit
axes.set_ylabel(f"Frequency [{plot_frequency_unit.to_string()}]")

with time_support():
axes.plot(self.times[[0, -1]], self.frequencies[[0, -1]], linestyle="None", marker="None")
# Pin existing converter to avoid warnings when re-plotting with different units
converter = _get_axis_converter(axes.xaxis)
if converter is not None:
_set_axis_converter(axes.xaxis, converter)

axes.plot(self.times[[0, -1]], [plot_frequencies[0], plot_frequencies[-1]], linestyle="None", marker="None")
im = NonUniformImage(axes, interpolation="none", **kwargs)
im.set_data(axes.convert_xunits(self.times), self.frequencies.value, self.data)
im.set_data(axes.convert_xunits(self.times), plot_frequencies, self.data)
axes.add_image(im)
axes.set_xlim(self.times[0], self.times[-1])
axes.set_ylim(self.frequencies.value[0], self.frequencies.value[-1])
axes.set_ylim(plot_frequencies[0], plot_frequencies[-1])
30 changes: 30 additions & 0 deletions radiospectra/spectrogram/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
import matplotlib

matplotlib.use("Agg")

import numpy as np
import pytest

import astropy.units as u
from astropy.time import Time

from radiospectra.spectrogram.spectrogrambase import GenericSpectrogram


@pytest.fixture
def make_spectrogram():
"""Factory fixture to create test spectrograms with given frequencies."""

def _make(frequencies, times=None, scale="utc"):
if times is None:
times = Time("2020-01-01T00:00:00", format="isot", scale=scale) + np.arange(4) * u.min
meta = {
"observatory": "Test",
"instrument": "TestInst",
"detector": "TestDet",
"start_time": times[0],
"end_time": times[-1],
"wavelength": np.array([1, 10]) * u.m,
"times": times,
"freqs": frequencies,
}
return GenericSpectrogram(np.arange(16).reshape(4, 4), meta)

return _make
114 changes: 94 additions & 20 deletions radiospectra/spectrogram/tests/test_spectrogrambase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,93 @@
import numpy as np

import astropy.units as u
from astropy.time import Time

from radiospectra.spectrogram.spectrogrambase import GenericSpectrogram
from radiospectra.mixins import _get_axis_converter


def _get_spectrogram_with_time_scale(scale):
times = Time("2020-01-01T00:00:00", format="isot", scale=scale) + np.arange(4) * u.min
frequencies = np.linspace(10, 40, 4) * u.MHz
meta = {
"observatory": "Test",
"instrument": "TestInst",
"detector": "TestDet",
"start_time": times[0],
"end_time": times[-1],
"wavelength": np.array([1, 10]) * u.m,
"times": times,
"freqs": frequencies,
}
return GenericSpectrogram(np.arange(16).reshape(4, 4), meta)
def test_plot_mixed_frequency_units_on_same_axes(make_spectrogram):
"""Two spectrograms with different frequency units should plot on the same axes."""
rad1 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
rad2 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz)
fig, axes = plt.subplots()

rad1.plot(axes=axes)
rad2.plot(axes=axes)

# The y-axis unit should be set by the first spectrogram (kHz)
y_label = axes.get_ylabel()
# The y-axis range should cover the converted MHz values (up to 4000 kHz)
y_min, y_max = axes.get_ylim()
plt.close(fig)

assert "kHz" in y_label
assert y_max > 1000, "MHz values should be converted to kHz on the y-axis"


def test_plotim(make_spectrogram):
"""Test NonUniformImagePlotMixin.plotim() executes without error."""
rad_im = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
with (
mock.patch("matplotlib.image.NonUniformImage.set_interpolation", autospec=True),
mock.patch("matplotlib.image.NonUniformImage.set_data", autospec=True) as set_data,
):
rad_im.plotim()
plt.close("all")

_, x_values, y_values, image = set_data.call_args.args
assert len(x_values) == len(rad_im.times)
np.testing.assert_allclose(y_values, rad_im.frequencies.value)
np.testing.assert_allclose(image, rad_im.data)


def test_plotim_mixed_frequency_units_on_same_axes(make_spectrogram):
"""Two NonUniformImage plots with different units should share conversion."""
rad1 = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
rad2 = make_spectrogram(np.array([1, 2, 3, 4]) * u.MHz)
fig, axes = plt.subplots()
with (
mock.patch("matplotlib.image.NonUniformImage.set_interpolation", autospec=True),
mock.patch("matplotlib.image.NonUniformImage.set_data", autospec=True) as set_data,
):
rad1.plotim(axes=axes)
rad2.plotim(axes=axes)
plt.close(fig)

_, _, y_values, _ = set_data.call_args.args
np.testing.assert_allclose(y_values, np.array([1000, 2000, 3000, 4000]))


def test_plot_with_quantity_data(make_spectrogram):
"""Test plotting when data is an astropy Quantity."""
rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
rad.data = rad.data * u.ct
rad.plot()
plt.close("all")


def test_plot_uses_time_support_for_datetime_conversion():
spec = _get_spectrogram_with_time_scale("tt")
def test_plot_with_shape_mismatch(make_spectrogram):
"""Test plotting branch when data shape doesn't exactly match time/freq arrays."""
rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
# times/freqs are length 4, data shape (4, 4) matches.
# make data (5, 5) to trigger the `else` branch (data[:-1, :-1])
rad.data = np.zeros((5, 5))
rad.plot()
plt.close("all")


def test_plot_instrument_detector_differ(make_spectrogram):
"""Test title generation when instrument and detector differ."""
rad = make_spectrogram(np.array([10, 20, 30, 40]) * u.kHz)
# GenericSpectrogram gets instrument/detector from meta dictionary
rad.meta["detector"] = "DifferentDetector"
mesh = rad.plot()
assert "DifferentDetector".upper() in mesh.axes.get_title().upper()
plt.close("all")


def test_plot_uses_time_support_for_datetime_conversion(make_spectrogram):
"""Plotting with non-UTC time scale should use time_support."""
spec = make_spectrogram(np.linspace(10, 40, 4) * u.MHz, scale="tt")

mesh = spec.plot()
x_limits = np.array(mesh.axes.get_xlim())
Expand All @@ -37,8 +101,9 @@ def test_plot_uses_time_support_for_datetime_conversion():
np.testing.assert_allclose(x_limits, expected_tt_limits)


def test_plotim_uses_time_support_for_datetime_conversion():
spec = _get_spectrogram_with_time_scale("tt")
def test_plotim_uses_time_support_for_datetime_conversion(make_spectrogram):
"""plotim with non-UTC time scale should use time_support."""
spec = make_spectrogram(np.linspace(10, 40, 4) * u.MHz, scale="tt")
fig, axes = plt.subplots()

with (
Expand All @@ -55,3 +120,12 @@ def test_plotim_uses_time_support_for_datetime_conversion():
np.testing.assert_allclose(x_values, expected_tt)
np.testing.assert_allclose(y_values, spec.frequencies.value)
np.testing.assert_allclose(image, spec.data)


def test_get_axis_converter_without_attribute():
"""_get_axis_converter should return None when no converter exists."""

class DummyAxis:
pass

assert _get_axis_converter(DummyAxis()) is None