diff --git a/src/skyborn/plot/curved_quiver_plot.py b/src/skyborn/plot/curved_quiver_plot.py index cf28073..f48116d 100644 --- a/src/skyborn/plot/curved_quiver_plot.py +++ b/src/skyborn/plot/curved_quiver_plot.py @@ -8,9 +8,9 @@ from typing import TYPE_CHECKING, Literal import matplotlib.pyplot as plt +import numpy as np import xarray as xr -# import numpy as np # import matplotlib from matplotlib.artist import Artist, allow_rasterization from matplotlib.backend_bases import RendererBase @@ -24,7 +24,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes -__all__ = ["curved_quiver", "add_curved_quiverkey"] +__all__ = ["curved_quiver", "add_curved_quiverkey", "curved_quiver_numpy", "curved_quiver_profile"] def curved_quiver( @@ -690,3 +690,215 @@ def add_curved_quiverkey( return CurvedQuiverLegend( ax, curved_quiver_set, U, units=units, loc=loc, labelpos=labelpos, **kwargs ) + + +def curved_quiver_numpy( + x: np.ndarray, + y: np.ndarray, + u: np.ndarray, + v: np.ndarray, + ax: Axes | None = None, + density=1, + linewidth=None, + color=None, + cmap=None, + norm=None, + arrowsize=1, + arrowstyle="-|>", + transform=None, + zorder=None, + start_points=None, + integration_direction="both", + grains=15, + broken_streamlines=True, + interpolate_irregular=True, + nx=None, + ny=None, + interpolation_method="linear", +) -> CurvedQuiverplotSet: + """ + Plot curved arrows using numpy arrays directly (xarray not required). + + .. warning:: + + This function is experimental and the API is subject to change. + + Parameters + ---------- + x, y : 1D arrays + Coordinate arrays for the grid. Can be irregularly spaced. + u, v : 2D arrays + Velocity components. Shape must match (len(y), len(x)). + ax : matplotlib.axes.Axes, optional + Axes on which to plot. By default, use the current axes. + density : float or (float, float) + Controls the closeness of streamlines. When ``density = 1``, the domain + is divided into a 30x30 grid. *density* linearly scales this grid. + linewidth : float or 2D array + The width of the streamlines. + color : color or 2D array + The streamline color. + cmap, norm + Data normalization and colormapping parameters for *color*. + arrowsize : float + Scaling factor for the arrow size. + arrowstyle : str + Arrow style specification. See `~matplotlib.patches.FancyArrowPatch`. + transform : Transform, optional + Coordinate transformation for the plot. + zorder : float + The zorder of the streamlines and arrows. + start_points : (N, 2) array + Coordinates of starting points for the streamlines. + integration_direction : {'forward', 'backward', 'both'}, default: 'both' + Integrate the streamline in forward, backward or both directions. + grains : int, default: 15 + Number of grains used in streamline integration. + broken_streamlines : boolean, default: True + If False, forces streamlines to continue until they leave the plot domain. + interpolate_irregular : boolean, default: True + If True, interpolate irregular grids to regular grid. If False, raise error. + nx, ny : int, optional + Target grid resolution for interpolation. If None, use original resolution. + interpolation_method : str, default: 'linear' + Interpolation method: 'linear', 'nearest', or 'cubic'. + + Returns + ------- + CurvedQuiverplotSet + Container object with lines and arrows. + + Example + ------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> x = np.linspace(0, 10, 50) + >>> y = np.linspace(0, 5, 25) + >>> X, Y = np.meshgrid(x, y) + >>> U = np.sin(X) * np.cos(Y) + >>> V = np.cos(X) * np.sin(Y) + >>> fig, ax = plt.subplots() + >>> curved_quiver_numpy(x, y, U, V, ax=ax) + """ + from .modplot import velovect, _is_regular_grid, _interpolate_to_regular_grid + + if ax is None: + ax = plt.gca() + + if type(transform).__name__ == "PlateCarree": + transform = transform._as_mpl_transform(ax) + + x = np.asarray(x) + y = np.asarray(y) + u = np.asarray(u) + v = np.asarray(v) + + if x.ndim != 1 or y.ndim != 1: + raise ValueError("'x' and 'y' must be 1D arrays") + + if u.shape != (len(y), len(x)) or v.shape != (len(y), len(x)): + raise ValueError( + f"'u' and 'v' must have shape ({len(y)}, {len(x)}), " + f"got {u.shape} and {v.shape}" + ) + + if _is_regular_grid(x, y): + x_plot, y_plot, u_plot, v_plot = x, y, u, v + else: + if not interpolate_irregular: + raise ValueError( + "Grid is irregularly spaced. Set interpolate_irregular=True " + "to interpolate to a regular grid." + ) + x_plot, y_plot, u_plot, v_plot = _interpolate_to_regular_grid( + x, y, u, v, nx=nx, ny=ny, method=interpolation_method + ) + + obj = velovect( + ax, + x_plot, + y_plot, + u_plot, + v_plot, + density=density, + linewidth=linewidth, + color=color, + cmap=cmap, + norm=norm, + arrowsize=arrowsize, + arrowstyle=arrowstyle, + transform=transform, + zorder=zorder, + start_points=start_points, + integration_direction=integration_direction, + grains=grains, + broken_streamlines=broken_streamlines, + ) + return obj + + +def curved_quiver_profile( + distance: np.ndarray, + height: np.ndarray, + u: np.ndarray, + v: np.ndarray, + ax: Axes | None = None, + distance_units: str = "km", + height_units: str = "m", + **kwargs, +) -> CurvedQuiverplotSet: + """ + Plot curved arrows on a vertical profile/cross-section. + + This function is designed for vertical cross-section data where the + coordinates represent distance along a profile and vertical height/depth. + + .. warning:: + + This function is experimental and the API is subject to change. + + Parameters + ---------- + distance : 1D array + Distance along the profile (e.g., longitude, or actual distance in km). + height : 1D array + Vertical coordinate (e.g., altitude, pressure level, depth). + u, v : 2D arrays + Velocity components. Shape must match (len(height), len(distance)). + ax : matplotlib.axes.Axes, optional + Axes on which to plot. By default, use the current axes. + distance_units : str, default: 'km' + Unit string for distance axis (for labeling purposes). + height_units : str, default: 'm' + Unit string for height axis (for labeling purposes). + **kwargs + Additional keyword arguments passed to `curved_quiver_numpy`. + + Returns + ------- + CurvedQuiverplotSet + Container object with lines and arrows. + + Example + ------- + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> distance = np.linspace(0, 100, 50) # km + >>> height = np.linspace(0, 20, 30) # km + >>> H, D = np.meshgrid(height, distance, indexing='ij') + >>> U = np.sin(H / 5) * np.cos(D / 20) + >>> V = np.cos(H / 5) * np.sin(D / 20) + >>> fig, ax = plt.subplots() + >>> curved_quiver_profile(distance, height, U, V, ax=ax) + >>> ax.set_xlabel(f'Distance ({distance_units})') + >>> ax.set_ylabel(f'Height ({height_units})') + """ + obj = curved_quiver_numpy( + x=distance, + y=height, + u=u, + v=v, + ax=ax, + **kwargs, + ) + return obj diff --git a/src/skyborn/plot/modplot.py b/src/skyborn/plot/modplot.py index 64bbb98..2123f4b 100644 --- a/src/skyborn/plot/modplot.py +++ b/src/skyborn/plot/modplot.py @@ -16,7 +16,7 @@ import numpy as np from matplotlib import cm, patches -__all__ = ["velovect"] +__all__ = ["velovect", "_is_regular_grid", "_interpolate_to_regular_grid"] def velovect( @@ -789,3 +789,73 @@ def _gen_starting_points(x, y, grains): seed_points = np.array([xs, ys]) return seed_points.T + + +def _is_regular_grid(x, y, rtol=1e-5, atol=1e-8): + """ + Check if grid is regular (equally spaced). + + Parameters + ---------- + x, y : 1D arrays + Coordinate arrays + rtol, atol : float + Relative and absolute tolerances for np.allclose + + Returns + ------- + bool + True if grid is regular, False otherwise + """ + x_regular = np.allclose(np.diff(x), (x[-1] - x[0]) / (len(x) - 1), rtol=rtol, atol=atol) + y_regular = np.allclose(np.diff(y), (y[-1] - y[0]) / (len(y) - 1), rtol=rtol, atol=atol) + return x_regular and y_regular + + +def _interpolate_to_regular_grid(x, y, u, v, nx=None, ny=None, method='linear'): + """ + Interpolate irregular grid data to a regular grid. + + Parameters + ---------- + x, y : 1D arrays + Original coordinate arrays (can be irregular) + u, v : 2D arrays + Velocity fields on original grid + nx, ny : int, optional + Target grid resolution. If None, use original resolution. + method : str + Interpolation method: 'linear' (default), 'nearest', 'cubic' + + Returns + ------- + x_reg, y_reg : 1D arrays + Regular coordinate arrays + u_reg, v_reg : 2D arrays + Interpolated velocity fields on regular grid + """ + try: + from scipy.interpolate import griddata + except ImportError: + raise ImportError( + "scipy is required for interpolation on irregular grids. " + "Install with: pip install scipy" + ) + + if nx is None: + nx = len(x) + if ny is None: + ny = len(y) + + X, Y = np.meshgrid(x, y, indexing='xy') + + points = np.column_stack([X.ravel(), Y.ravel()]) + + xi = np.linspace(x.min(), x.max(), nx) + yi = np.linspace(y.min(), y.max(), ny) + Xi, Yi = np.meshgrid(xi, yi, indexing='xy') + + u_reg = griddata(points, u.ravel(), (Xi, Yi), method=method) + v_reg = griddata(points, v.ravel(), (Xi, Yi), method=method) + + return xi, yi, u_reg, v_reg