Skip to content

Commit 892d3c5

Browse files
committed
Refactor map to allow specifying kind of preserving
1 parent 80f6805 commit 892d3c5

File tree

1 file changed

+47
-37
lines changed

1 file changed

+47
-37
lines changed

src/map.jl

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# zero-preserving Traits
22
# ----------------------
33
"""
4-
abstract type ZeroPreserving end
4+
abstract type ZeroPreserving <: Function end
55
66
Holy Trait to indicate how a function interacts with abstract zero values:
77
@@ -15,10 +15,17 @@ To attempt to automatically determine this, either `ZeroPreserving(f, A::Abstrac
1515
!!! warning
1616
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
1717
"""
18-
abstract type ZeroPreserving end
19-
struct StrongPreserving <: ZeroPreserving end
20-
struct WeakPreserving <: ZeroPreserving end
21-
struct NonPreserving <: ZeroPreserving end
18+
abstract type ZeroPreserving <: Function end
19+
20+
struct StrongPreserving{F} <: ZeroPreserving
21+
f::F
22+
end
23+
struct WeakPreserving{F} <: ZeroPreserving
24+
f::F
25+
end
26+
struct NonPreserving{F} <: ZeroPreserving
27+
f::F
28+
end
2229

2330
# Backport: remove in 1.12
2431
@static if !isdefined(Base, :haszero)
@@ -36,23 +43,25 @@ end
3643
# TODO: non-concrete element types
3744
function ZeroPreserving(f, T::Type, Ts::Type...)
3845
if all(_haszero, (T, Ts...))
39-
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving() : NonPreserving()
46+
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving(f) : NonPreserving(f)
4047
else
41-
return NonPreserving()
48+
return NonPreserving(f)
4249
end
4350
end
4451

4552
const _WEAK_FUNCTIONS = (:+, :-)
4653
for f in _WEAK_FUNCTIONS
4754
@eval begin
48-
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving()
55+
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
4956
end
5057
end
5158

5259
const _STRONG_FUNCTIONS = (:*,)
5360
for f in _STRONG_FUNCTIONS
5461
@eval begin
55-
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving()
62+
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
63+
$f
64+
)
5665
end
5766
end
5867

@@ -61,47 +70,48 @@ end
6170
@interface I::AbstractSparseArrayInterface function Base.map(
6271
f, A::AbstractArray, Bs::AbstractArray...
6372
)
64-
T = Base.Broadcast.combine_eltypes(f, (A, Bs...))
73+
f_pres = ZeroPreserving(f, A, Bs...)
74+
return @interface I map(f_pres, A, Bs...)
75+
end
76+
@interface I::AbstractSparseArrayInterface function Base.map(
77+
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
78+
)
79+
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
6580
C = similar(I, T, size(A))
6681
return @interface I map!(f, C, A, Bs...)
6782
end
6883

69-
@interface ::AbstractSparseArrayInterface function Base.map!(
84+
@interface I::AbstractSparseArrayInterface function Base.map!(
7085
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
7186
)
72-
return _map!(f, ZeroPreserving(f, A, Bs...), C, A, Bs...)
87+
f_pres = ZeroPreserving(f, A, Bs...)
88+
return @interface I map!(f_pres, C, A, Bs...)
7389
end
7490

75-
function _map!(
76-
f, ::StrongPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
77-
)
78-
checkshape(C, A, Bs...)
79-
style = IndexStyle(C, A, Bs...)
80-
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
81-
zero!(C)
82-
for I in intersect(eachstoredindex.(Ref(style), unaliased)...)
83-
@inbounds C[I] = f(ith_all(I, unaliased)...)
84-
end
85-
return C
86-
end
87-
function _map!(
88-
f, ::WeakPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
91+
@interface ::AbstractSparseArrayInterface function Base.map!(
92+
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
8993
)
9094
checkshape(C, A, Bs...)
91-
style = IndexStyle(C, A, Bs...)
9295
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
93-
zero!(C)
94-
for I in union(eachstoredindex.(Ref(style), unaliased)...)
95-
@inbounds C[I] = f(ith_all(I, unaliased)...)
96+
97+
if f isa StrongPreserving
98+
style = IndexStyle(C, unaliased...)
99+
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
100+
zero!(C)
101+
elseif f isa WeakPreserving
102+
style = IndexStyle(C, unaliased...)
103+
inds = union(eachstoredindex.(Ref(style), unaliased)...)
104+
zero!(C)
105+
elseif f isa NonPreserving
106+
inds = eachindex(C, unaliased...)
107+
else
108+
error(lazy"unknown zero-preserving type $(typeof(f))")
96109
end
97-
return C
98-
end
99-
function _map!(f, ::NonPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...)
100-
checkshape(C, A, Bs...)
101-
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
102-
for I in eachindex(C, A, Bs...)
103-
@inbounds C[I] = f(ith_all(I, unaliased)...)
110+
111+
@inbounds for I in inds
112+
C[I] = f.f(ith_all(I, unaliased)...)
104113
end
114+
105115
return C
106116
end
107117

0 commit comments

Comments
 (0)