Skip to content

Commit a706ae3

Browse files
committed
optimize +
1 parent 68d01c9 commit a706ae3

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

src/rulesets/Base/indexing.jl

+11
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,17 @@ Base.size(A::OneElement) = map(length, A.axes)
111111
Base.axes(A::OneElement) = A.axes
112112
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
113113

114+
function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N}
115+
if !ChainRulesCore.is_inplaceable_destination(xs)
116+
xs = collect(xs)
117+
end
118+
xs[oe.ind...] += oe.val
119+
return xs
120+
end
121+
122+
Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe)
123+
Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe)
124+
Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2)
114125

115126
"""
116127
_setindex_zero(x, dy, inds...)

0 commit comments

Comments
 (0)