Skip to content

Commit 6877ebf

Browse files
committed
Call user function only once in mean
Override the standard `mapreduce` machinery to promote accumulator type. This avoid calling the function twice, which can be confusing.
1 parent 81a90af commit 6877ebf

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

src/Statistics.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ if !isdefined(Base, :mean)
4444
"""
4545
mean(itr) = mean(identity, itr)
4646

47+
_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)
48+
4749
"""
4850
mean(f, itr)
4951
@@ -178,7 +180,33 @@ if !isdefined(Base, :mean)
178180
"""
179181
mean(A::AbstractArray; dims=:) = _mean(identity, A, dims)
180182

181-
_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)
183+
struct _InitType end
184+
185+
Base.add_sum(x::_InitType, y::Any) = y/1
186+
187+
Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, dims) =
188+
Base.mapreducedim!(f, op, Base.reducedim_init(f, op, A, dims), A)
189+
Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, ::Colon) =
190+
Base.mapfoldl_impl(f, op, _InitType(), A)
191+
promote_add(x::T, y::S) where {T,S} =
192+
Base.add_sum(convert(promote_type(T, S), x),
193+
convert(promote_type(T, S), y))
194+
195+
function Base.reducedim_init(f, op::typeof(promote_add), A::AbstractArray, region)
196+
Base._reducedim_init(f, op, zero, mean, A, region)
197+
end
198+
function Base._reducedim_init(f, op::typeof(promote_add), fv, fop, A, region)
199+
T = Base._realtype(f, Base.promote_union(eltype(A)))
200+
if T !== Any && applicable(zero, T)
201+
x = f(zero(T)/1)
202+
z = op(fv(x), fv(x))
203+
Tr = z isa T ? T : typeof(z)
204+
else
205+
z = fv(fop(f, A))
206+
Tr = typeof(z)
207+
end
208+
return Base.reducedim_initarray(A, region, z, Tr)
209+
end
182210

183211
# ::Dims is there to force specializing on Colon (as it is a Function)
184212
function _mean(f, A::AbstractArray, dims::Dims=:) where Dims
@@ -188,8 +216,7 @@ if !isdefined(Base, :mean)
188216
else
189217
n = mapreduce(i -> size(A, i), *, unique(dims); init=1)
190218
end
191-
x1 = f(first(A)) / 1
192-
result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims)
219+
result = mapreduce(f, promote_add, A, dims=dims, init=_InitType())
193220
if dims === (:)
194221
return result / n
195222
else

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ end
161161
float(typemax(Int)))
162162
end
163163
let x = rand(10000) # mean should use sum's accurate pairwise algorithm
164-
@test mean(x) == sum(x) / length(x)
164+
@test mean(x) == sum(x; init=0.0) / length(x)
165165
end
166166
@test mean(Number[1, 1.5, 2+3im]) === 1.5+1im # mixed-type array
167167
@test mean(v for v in Number[1, 1.5, 2+3im]) === 1.5+1im

0 commit comments

Comments
 (0)