1
1
# zero-preserving Traits
2
2
# ----------------------
3
3
"""
4
- abstract type ZeroPreserving end
4
+ abstract type ZeroPreserving <: Function end
5
5
6
6
Holy Trait to indicate how a function interacts with abstract zero values:
7
7
@@ -15,10 +15,17 @@ To attempt to automatically determine this, either `ZeroPreserving(f, A::Abstrac
15
15
!!! warning
16
16
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
17
17
"""
18
- abstract type ZeroPreserving end
19
- struct StrongPreserving <: ZeroPreserving end
20
- struct WeakPreserving <: ZeroPreserving end
21
- struct NonPreserving <: ZeroPreserving end
18
+ abstract type ZeroPreserving <: Function end
19
+
20
+ struct StrongPreserving{F} <: ZeroPreserving
21
+ f:: F
22
+ end
23
+ struct WeakPreserving{F} <: ZeroPreserving
24
+ f:: F
25
+ end
26
+ struct NonPreserving{F} <: ZeroPreserving
27
+ f:: F
28
+ end
22
29
23
30
# Backport: remove in 1.12
24
31
@static if ! isdefined (Base, :haszero )
36
43
# TODO : non-concrete element types
37
44
function ZeroPreserving (f, T:: Type , Ts:: Type... )
38
45
if all (_haszero, (T, Ts... ))
39
- return iszero (f (zero (T), zero .(Ts)... )) ? WeakPreserving () : NonPreserving ()
46
+ return iszero (f (zero (T), zero .(Ts)... )) ? WeakPreserving (f ) : NonPreserving (f )
40
47
else
41
- return NonPreserving ()
48
+ return NonPreserving (f )
42
49
end
43
50
end
44
51
45
52
const _WEAK_FUNCTIONS = (:+ , :- )
46
53
for f in _WEAK_FUNCTIONS
47
54
@eval begin
48
- ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = WeakPreserving ()
55
+ ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = WeakPreserving ($ f )
49
56
end
50
57
end
51
58
52
59
const _STRONG_FUNCTIONS = (:* ,)
53
60
for f in _STRONG_FUNCTIONS
54
61
@eval begin
55
- ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = StrongPreserving ()
62
+ ZeroPreserving (:: typeof ($ f), :: Type{<:Number} , :: Type{<:Number} ...) = StrongPreserving (
63
+ $ f
64
+ )
56
65
end
57
66
end
58
67
61
70
@interface I:: AbstractSparseArrayInterface function Base. map (
62
71
f, A:: AbstractArray , Bs:: AbstractArray...
63
72
)
64
- T = Base. Broadcast. combine_eltypes (f, (A, Bs... ))
73
+ f_pres = ZeroPreserving (f, A, Bs... )
74
+ return @interface I map (f_pres, A, Bs... )
75
+ end
76
+ @interface I:: AbstractSparseArrayInterface function Base. map (
77
+ f:: ZeroPreserving , A:: AbstractArray , Bs:: AbstractArray...
78
+ )
79
+ T = Base. Broadcast. combine_eltypes (f. f, (A, Bs... ))
65
80
C = similar (I, T, size (A))
66
81
return @interface I map! (f, C, A, Bs... )
67
82
end
68
83
69
- @interface :: AbstractSparseArrayInterface function Base. map! (
84
+ @interface I :: AbstractSparseArrayInterface function Base. map! (
70
85
f, C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
71
86
)
72
- return _map! (f, ZeroPreserving (f, A, Bs... ), C, A, Bs... )
87
+ f_pres = ZeroPreserving (f, A, Bs... )
88
+ return @interface I map! (f_pres, C, A, Bs... )
73
89
end
74
90
75
- function _map! (
76
- f, :: StrongPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
77
- )
78
- checkshape (C, A, Bs... )
79
- style = IndexStyle (C, A, Bs... )
80
- unaliased = map (Base. Fix1 (Base. unalias, C), (A, Bs... ))
81
- zero! (C)
82
- for I in intersect (eachstoredindex .(Ref (style), unaliased)... )
83
- @inbounds C[I] = f (ith_all (I, unaliased)... )
84
- end
85
- return C
86
- end
87
- function _map! (
88
- f, :: WeakPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
91
+ @interface :: AbstractSparseArrayInterface function Base. map! (
92
+ f:: ZeroPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray...
89
93
)
90
94
checkshape (C, A, Bs... )
91
- style = IndexStyle (C, A, Bs... )
92
95
unaliased = map (Base. Fix1 (Base. unalias, C), (A, Bs... ))
93
- zero! (C)
94
- for I in union (eachstoredindex .(Ref (style), unaliased)... )
95
- @inbounds C[I] = f (ith_all (I, unaliased)... )
96
+
97
+ if f isa StrongPreserving
98
+ style = IndexStyle (C, unaliased... )
99
+ inds = intersect (eachstoredindex .(Ref (style), unaliased)... )
100
+ zero! (C)
101
+ elseif f isa WeakPreserving
102
+ style = IndexStyle (C, unaliased... )
103
+ inds = union (eachstoredindex .(Ref (style), unaliased)... )
104
+ zero! (C)
105
+ elseif f isa NonPreserving
106
+ inds = eachindex (C, unaliased... )
107
+ else
108
+ error (lazy " unknown zero-preserving type $(typeof(f))" )
96
109
end
97
- return C
98
- end
99
- function _map! (f, :: NonPreserving , C:: AbstractArray , A:: AbstractArray , Bs:: AbstractArray... )
100
- checkshape (C, A, Bs... )
101
- unaliased = map (Base. Fix1 (Base. unalias, C), (A, Bs... ))
102
- for I in eachindex (C, A, Bs... )
103
- @inbounds C[I] = f (ith_all (I, unaliased)... )
110
+
111
+ @inbounds for I in inds
112
+ C[I] = f. f (ith_all (I, unaliased)... )
104
113
end
114
+
105
115
return C
106
116
end
107
117
0 commit comments