diff --git a/Project.toml b/Project.toml index 7f585fd..ef6906d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" authors = ["ITensor developers and contributors"] -version = "0.5.11" +version = "0.6.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/docs/Project.toml b/docs/Project.toml index 898915b..ad4f567 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,4 +8,4 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" Dictionaries = "0.4.4" Documenter = "1.8.1" Literate = "2.20.1" -SparseArraysBase = "0.5.0" +SparseArraysBase = "0.6.0" diff --git a/examples/Project.toml b/examples/Project.toml index 01ec629..5c0e154 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Dictionaries = "0.4.4" -SparseArraysBase = "0.5.0" +SparseArraysBase = "0.6.0" Test = "<0.0.1, 1" diff --git a/src/SparseArraysBase.jl b/src/SparseArraysBase.jl index 0e0db90..c51b99b 100644 --- a/src/SparseArraysBase.jl +++ b/src/SparseArraysBase.jl @@ -20,6 +20,7 @@ export SparseArrayDOK, include("abstractsparsearrayinterface.jl") include("sparsearrayinterface.jl") include("indexing.jl") +include("map.jl") include("wrappers.jl") include("abstractsparsearray.jl") include("sparsearraydok.jl") diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl index 6af80e6..719c94e 100644 --- a/src/abstractsparsearrayinterface.jl +++ b/src/abstractsparsearrayinterface.jl @@ -92,39 +92,6 @@ end # TODO: Define `default_similartype` or something like that? return SparseArrayDOK{T}(undef, size) end - -# map over a specified subset of indices of the inputs. -function map_indices! end - -@interface interface::AbstractArrayInterface function map_indices!( - indices, f, a_dest::AbstractArray, as::AbstractArray... -) - for I in indices - a_dest[I] = f(map(a -> a[I], as)...) - end - return a_dest -end - -# Only map the stored values of the inputs. -function map_stored! end - -@interface interface::AbstractArrayInterface function map_stored!( - f, a_dest::AbstractArray, as::AbstractArray... -) - @interface interface map_indices!(eachstoredindex(as...), f, a_dest, as...) - return a_dest -end - -# Only map all values, not just the stored ones. -function map_all! end - -@interface interface::AbstractArrayInterface function map_all!( - f, a_dest::AbstractArray, as::AbstractArray... -) - @interface interface map_indices!(eachindex(as...), f, a_dest, as...) - return a_dest -end - using DerivableInterfaces: DerivableInterfaces, zero! # `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` @@ -137,37 +104,10 @@ using DerivableInterfaces: DerivableInterfaces, zero! # More generally, this codepath could be taking if `zero(eltype(a))` # is defined and the elements are immutable. f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! - return @interface interface map_stored!(f, a, a) -end - -# Determines if a function preserves the stored values -# of the destination sparse array. -# The current code may be inefficient since it actually -# accesses an unstored element, which in the case of a -# sparse array of arrays can allocate an array. -# Sparse arrays could be expected to define a cheap -# unstored element allocator, for example -# `get_prototypical_unstored(a::AbstractArray)`. -function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...) - I = first(eachindex(as...)) - return iszero(f(map(a -> getunstoredindex(a, I), as)...)) -end - -@interface interface::AbstractSparseArrayInterface function Base.map!( - f, a_dest::AbstractArray, as::AbstractArray... -) - isempty(a_dest) && return a_dest # special case to avoid trying to access empty array - indices = if !preserves_unstored(f, a_dest, as...) - eachindex(a_dest) - elseif any(a -> a_dest !== a, as) - as = map(a -> Base.unalias(a_dest, a), as) - @interface interface zero!(a_dest) - eachstoredindex(as...) - else - eachstoredindex(a_dest) + @inbounds for I in eachstoredindex(a) + a[I] = f(a[I]) end - @interface interface map_indices!(indices, f, a_dest, as...) - return a_dest + return a end # `f::typeof(norm)`, `op::typeof(max)` used by `norm`. diff --git a/src/indexing.jl b/src/indexing.jl index 3cd089d..2c2d6ca 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -308,10 +308,27 @@ end end end -# required: -@interface ::AbstractSparseArrayInterface eachstoredindex(style::IndexStyle, A::AbstractArray) = throw( - MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)}) +@noinline function error_if_canonical_eachstoredindex(style::IndexStyle, A::AbstractArray) + style === IndexStyle(A) && throw(Base.CanonicalIndexError("eachstoredindex", typeof(A))) + return nothing +end + +# required: one implementation for canonical index style +@interface ::AbstractSparseArrayInterface function eachstoredindex( + style::IndexStyle, A::AbstractArray ) + error_if_canonical_eachstoredindex(style, A) + inds = eachstoredindex(A) + if style === IndexCartesian() + eltype(inds) === CartesianIndex{ndims(A)} && return inds + return map(Base.Fix1(Base.getindex, CartesianIndices(A)), inds) + elseif style === IndexLinear() + eltype(inds) === Int && return inds + return map(Base.Fix1(Base.getindex, LinearIndices(A)), inds) + else + error(lazy"unkown index style $style") + end +end # derived but may be specialized: @interface ::AbstractSparseArrayInterface function eachstoredindex( diff --git a/src/map.jl b/src/map.jl new file mode 100644 index 0000000..f35426e --- /dev/null +++ b/src/map.jl @@ -0,0 +1,162 @@ +# zero-preserving Traits +# ---------------------- +""" + abstract type ZeroPreserving <: Function end + +Holy Trait to indicate how a function interacts with abstract zero values: + +- `StrongPreserving` : output is guaranteed to be zero if **any** input is. +- `WeakPreserving` : output is guaranteed to be zero if **all** inputs are. +- `NonPreserving` : no guarantees on output. + +To attempt to automatically determine this, either `ZeroPreserving(f, A::AbstractArray...)` or +`ZeroPreserving(f, T::Type...)` can be used/overloaded. + +!!! warning + incorrectly registering a function to be zero-preserving will lead to silently wrong results. +""" +abstract type ZeroPreserving <: Function end + +struct StrongPreserving{F} <: ZeroPreserving + f::F +end +struct WeakPreserving{F} <: ZeroPreserving + f::F +end +struct NonPreserving{F} <: ZeroPreserving + f::F +end + +# Backport: remove in 1.12 +@static if !isdefined(Base, :haszero) + _haszero(T::Type) = false + _haszero(::Type{<:Number}) = true +else + _haszero = Base.haszero +end + +# warning: cannot automatically detect WeakPreserving since this would mean checking all values +function ZeroPreserving(f, A::AbstractArray, Bs::AbstractArray...) + return ZeroPreserving(f, eltype(A), eltype.(Bs)...) +end +# TODO: the following might not properly specialize on the types +# TODO: non-concrete element types +function ZeroPreserving(f, T::Type, Ts::Type...) + if all(_haszero, (T, Ts...)) + return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving(f) : NonPreserving(f) + else + return NonPreserving(f) + end +end + +const _WEAK_FUNCTIONS = (:+, :-) +for f in _WEAK_FUNCTIONS + @eval begin + ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f) + end +end + +const _STRONG_FUNCTIONS = (:*,) +for f in _STRONG_FUNCTIONS + @eval begin + ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving( + $f + ) + end +end + +# map(!) +# ------ +@interface I::AbstractSparseArrayInterface function Base.map( + f, A::AbstractArray, Bs::AbstractArray... +) + f_pres = ZeroPreserving(f, A, Bs...) + return @interface I map(f_pres, A, Bs...) +end +@interface I::AbstractSparseArrayInterface function Base.map( + f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray... +) + T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...)) + C = similar(I, T, size(A)) + return @interface I map!(f, C, A, Bs...) +end + +@interface I::AbstractSparseArrayInterface function Base.map!( + f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray... +) + f_pres = ZeroPreserving(f, A, Bs...) + return @interface I map!(f_pres, C, A, Bs...) +end + +@interface ::AbstractSparseArrayInterface function Base.map!( + f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray... +) + checkshape(C, A, Bs...) + unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...)) + + if f isa StrongPreserving + style = IndexStyle(C, unaliased...) + inds = intersect(eachstoredindex.(Ref(style), unaliased)...) + zero!(C) + elseif f isa WeakPreserving + style = IndexStyle(C, unaliased...) + inds = union(eachstoredindex.(Ref(style), unaliased)...) + zero!(C) + elseif f isa NonPreserving + inds = eachindex(C, unaliased...) + else + error(lazy"unknown zero-preserving type $(typeof(f))") + end + + @inbounds for I in inds + C[I] = f.f(ith_all(I, unaliased)...) + end + + return C +end + +# Derived functions +# ----------------- +@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!( + identity, C, A +) + +# Only map the stored values of the inputs. +function map_stored! end + +@interface interface::AbstractArrayInterface function map_stored!( + f, a_dest::AbstractArray, as::AbstractArray... +) + @interface interface map!(WeakPreserving(f), a_dest, as...) + return a_dest +end + +# Only map all values, not just the stored ones. +function map_all! end + +@interface interface::AbstractArrayInterface function map_all!( + f, a_dest::AbstractArray, as::AbstractArray... +) + @interface interface map!(NonPreserving(f), a_dest, as...) + return a_dest +end + +# Utility functions +# ----------------- +# shape check similar to checkbounds +checkshape(::Type{Bool}, A::AbstractArray) = true +checkshape(::Type{Bool}, A::AbstractArray, B::AbstractArray) = size(A) == size(B) +function checkshape(::Type{Bool}, A::AbstractArray, Bs::AbstractArray...) + return allequal(size, (A, Bs...)) +end + +function checkshape(A::AbstractArray, Bs::AbstractArray...) + return checkshape(Bool, A, Bs...) || + throw(DimensionMismatch("argument shapes must match")) +end + +@inline ith_all(i, ::Tuple{}) = () +function ith_all(i, as) + @_propagate_inbounds_meta + return (as[1][i], ith_all(i, Base.tail(as))...) +end diff --git a/src/oneelementarray.jl b/src/oneelementarray.jl index 7936c6c..251acfa 100644 --- a/src/oneelementarray.jl +++ b/src/oneelementarray.jl @@ -287,7 +287,7 @@ storedindex(a::OneElementArray) = getfield(a, :index) function isstored(a::OneElementArray, I::Int...) return I == storedindex(a) end -function eachstoredindex(a::OneElementArray) +function eachstoredindex(::IndexCartesian, a::OneElementArray) return Fill(CartesianIndex(storedindex(a)), 1) end diff --git a/test/Project.toml b/test/Project.toml index 9176f07..d7c7bb6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,7 +21,7 @@ JLArrays = "0.2.0" LinearAlgebra = "<0.0.1, 1" Random = "<0.0.1, 1" SafeTestsets = "0.1.0" -SparseArraysBase = "0.5.0" +SparseArraysBase = "0.6.0" StableRNGs = "1.0.2" Suppressor = "0.2.8" Test = "<0.0.1, 1"