Skip to content

Commit 62d3f3a

Browse files
authored
Merge pull request #2374 from tdegeus/average2
average: fixing overload issue for axis argument
2 parents 6504ecd + 4774c4d commit 62d3f3a

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

docs/source/api/xmath.rst

+2
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ Mathematical functions
276276
+-----------------------------------------------+---------------------------------------------------------------------+
277277
| :ref:`mean <mean-function-reference>` | mean of elements over given axes |
278278
+-----------------------------------------------+---------------------------------------------------------------------+
279+
| :ref:`average <average-function-reference>` | weighted average along the specified axis |
280+
+-----------------------------------------------+---------------------------------------------------------------------+
279281
| :ref:`variance <variance-function-reference>` | variance of elements over given axes |
280282
+-----------------------------------------------+---------------------------------------------------------------------+
281283
| :ref:`stddev <stddev-function-reference>` | standard deviation of elements over given axes |

include/xtensor/xmath.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,13 @@ namespace detail {
19991999
return sum<T>(std::forward<E>(e) * std::move(weights_view), std::move(ax), ev) / std::move(scl);
20002000
}
20012001

2002+
template <class T = void, class E, class W, class X, class EVS = DEFAULT_STRATEGY_REDUCERS,
2003+
XTL_REQUIRES(is_reducer_options<EVS>, xtl::is_integral<X>)>
2004+
inline auto average(E&& e, W&& weights, X axis, EVS ev = EVS())
2005+
{
2006+
return average(std::forward<E>(e), std::forward<W>(weights), {axis}, std::forward<EVS>(ev));
2007+
}
2008+
20022009
template <class T = void, class E, class W, class X, std::size_t N, class EVS = DEFAULT_STRATEGY_REDUCERS>
20032010
inline auto average(E&& e, W&& weights, const X(&axes)[N], EVS ev = EVS())
20042011
{

test/test_xmath.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,41 @@ namespace xt
854854
EXPECT_EQ(res5[0], 8.0);
855855
}
856856

857+
/********************
858+
* Mean and average *
859+
********************/
860+
861+
TEST(xmath, mean)
862+
{
863+
xt::xtensor<double,2> v = {{1.0, 1.0, 1.0}, {2.0, 2.0, 2.0}};
864+
xt::xtensor<double,1> m0 = {1.5, 1.5, 1.5};
865+
xt::xtensor<double,1> m1 = {1.0, 2.0};
866+
double m = 9.0 / 6.0;
867+
868+
EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, 0), m0)));
869+
EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, {0}), m0)));
870+
EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, 1), m1)));
871+
EXPECT_TRUE(xt::all(xt::equal(xt::mean(v, {1}), m1)));
872+
EXPECT_EQ(xt::mean(v)(), m);
873+
EXPECT_EQ(xt::mean(v, {0, 1})(), m);
874+
}
875+
876+
TEST(xmath, average)
877+
{
878+
xt::xtensor<double,2> v = {{1.0, 1.0, 1.0}, {2.0, 2.0, 2.0}};
879+
xt::xtensor<double,2> w = {{2.0, 2.0, 2.0}, {2.0, 2.0, 2.0}};
880+
xt::xtensor<double,1> m0 = {1.5, 1.5, 1.5};
881+
xt::xtensor<double,1> m1 = {1.0, 2.0};
882+
double m = 9.0 / 6.0;
883+
884+
EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, 0), m0)));
885+
EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, {0}), m0)));
886+
EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, 1), m1)));
887+
EXPECT_TRUE(xt::all(xt::equal(xt::average(v, w, {1}), m1)));
888+
EXPECT_EQ(xt::average(v, w)(), m);
889+
EXPECT_EQ(xt::average(v, w, {0, 1})(), m);
890+
}
891+
857892
/************************
858893
* Linear interpolation *
859894
************************/

0 commit comments

Comments
 (0)