Skip to content

Commit 2c12cfc

Browse files
Illviljandcherian
andauthored
Handle None in assert_valid_xy (#6871)
* Handle none in assert_valid_xy * add test * Update xarray/tests/test_plot.py * Update test_plot.py Co-authored-by: Deepak Cherian <[email protected]>
1 parent 0c8a78b commit 2c12cfc

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

xarray/plot/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from datetime import datetime
77
from inspect import getfullargspec
8-
from typing import Any, Iterable, Mapping, Sequence
8+
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Mapping, Sequence
99

1010
import numpy as np
1111
import pandas as pd
@@ -28,6 +28,11 @@
2828
except ImportError:
2929
cftime = None
3030

31+
32+
if TYPE_CHECKING:
33+
from ..core.dataarray import DataArray
34+
35+
3136
ROBUST_PERCENTILE = 2.0
3237

3338
# copied from seaborn
@@ -396,7 +401,8 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):
396401
return x, y
397402

398403

399-
def _assert_valid_xy(darray, xy, name):
404+
# TODO: Can by used to more than x or y, rename?
405+
def _assert_valid_xy(darray: DataArray, xy: None | Hashable, name: str) -> None:
400406
"""
401407
make sure x and y passed to plotting functions are valid
402408
"""
@@ -410,9 +416,11 @@ def _assert_valid_xy(darray, xy, name):
410416

411417
valid_xy = (set(darray.dims) | set(darray.coords)) - multiindex_dims
412418

413-
if xy not in valid_xy:
414-
valid_xy_str = "', '".join(sorted(valid_xy))
415-
raise ValueError(f"{name} must be one of None, '{valid_xy_str}'")
419+
if (xy is not None) and (xy not in valid_xy):
420+
valid_xy_str = "', '".join(sorted(tuple(str(v) for v in valid_xy)))
421+
raise ValueError(
422+
f"{name} must be one of None, '{valid_xy_str}'. Received '{xy}' instead."
423+
)
416424

417425

418426
def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs):
@@ -1152,7 +1160,7 @@ def _adjust_legend_subtitles(legend):
11521160

11531161
def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
11541162
dvars = set(ds.variables.keys())
1155-
error_msg = f" must be one of ({', '.join(dvars)})"
1163+
error_msg = f" must be one of ({', '.join(sorted(tuple(str(v) for v in dvars)))})"
11561164

11571165
if x not in dvars:
11581166
raise ValueError(f"Expected 'x' {error_msg}. Received {x} instead.")

xarray/tests/test_plot.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from xarray.plot.dataset_plot import _infer_meta_data
1818
from xarray.plot.plot import _infer_interval_breaks
1919
from xarray.plot.utils import (
20+
_assert_valid_xy,
2021
_build_discrete_cmap,
2122
_color_palette,
2223
_determine_cmap_params,
@@ -3025,3 +3026,19 @@ def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_co
30253026
add_legend=add_legend,
30263027
add_colorbar=add_colorbar,
30273028
)
3029+
3030+
3031+
@requires_matplotlib
3032+
def test_assert_valid_xy() -> None:
3033+
ds = xr.tutorial.scatter_example_dataset()
3034+
darray = ds.A
3035+
3036+
# x is valid and should not error:
3037+
_assert_valid_xy(darray=darray, xy="x", name="x")
3038+
3039+
# None should be valid as well even though it isn't in the valid list:
3040+
_assert_valid_xy(darray=darray, xy=None, name="x")
3041+
3042+
# A hashable that is not valid should error:
3043+
with pytest.raises(ValueError, match="x must be one of"):
3044+
_assert_valid_xy(darray=darray, xy="error_now", name="x")

0 commit comments

Comments
 (0)