Skip to content

Commit 14b4f1b

Browse files
committed
Propagate attrs with DataArray unary ops
1 parent a122e87 commit 14b4f1b

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

xarray/core/dataarray.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from .indexes import Indexes, default_indexes, propagate_indexes
5656
from .indexing import is_fancy_indexer
5757
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
58-
from .options import OPTIONS
58+
from .options import OPTIONS, _get_keep_attrs
5959
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
6060
from .variable import (
6161
IndexVariable,
@@ -2730,13 +2730,19 @@ def __rmatmul__(self, other):
27302730
def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]:
27312731
@functools.wraps(f)
27322732
def func(self, *args, **kwargs):
2733+
keep_attrs = kwargs.pop("keep_attrs", None)
2734+
if keep_attrs is None:
2735+
keep_attrs = _get_keep_attrs(default=True)
27332736
with warnings.catch_warnings():
27342737
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
27352738
warnings.filterwarnings(
27362739
"ignore", r"Mean of empty slice", category=RuntimeWarning
27372740
)
27382741
with np.errstate(all="ignore"):
2739-
return self.__array_wrap__(f(self.variable.data, *args, **kwargs))
2742+
da = self.__array_wrap__(f(self.variable.data, *args, **kwargs))
2743+
if keep_attrs:
2744+
da.attrs = self.attrs
2745+
return da
27402746

27412747
return func
27422748

xarray/tests/test_dataarray.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,21 +2488,20 @@ def test_assign_attrs(self):
24882488
assert_identical(new_actual, expected)
24892489
assert actual.attrs == {"a": 1, "b": 2}
24902490

2491-
def test_propagate_attrs(self):
2491+
@pytest.mark.parametrize(
2492+
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
2493+
)
2494+
def test_propagate_attrs(self, func):
24922495
da = DataArray(self.va)
24932496

24942497
# test defaults
2495-
assert (np.float64(1.0) * da).attrs == da.attrs
2496-
assert np.abs(da).attrs == da.attrs
2497-
# not sure about the next two
2498-
assert da.clip(0, 1).attrs != da.attrs
2499-
assert abs(da).attrs != da.attrs
2498+
assert func(da).attrs == da.attrs
25002499

25012500
with set_options(keep_attrs=False):
2502-
assert da.clip(0, 1).attrs != da.attrs
2503-
assert (np.float64(1.0) * da).attrs != da.attrs
2504-
assert np.abs(da).attrs != da.attrs
2505-
assert abs(da).attrs != da.attrs
2501+
assert func(da).attrs == {}
2502+
2503+
with set_options(keep_attrs=True):
2504+
assert func(da).attrs == da.attrs
25062505

25072506
def test_fillna(self):
25082507
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")

0 commit comments

Comments
 (0)