Skip to content

Sparse mapping refactor #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 17, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.11"
version = "0.6.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Dictionaries = "0.4.4"
Documenter = "1.8.1"
Literate = "2.20.1"
SparseArraysBase = "0.5.0"
SparseArraysBase = "0.6.0"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Dictionaries = "0.4.4"
SparseArraysBase = "0.5.0"
SparseArraysBase = "0.6.0"
Test = "<0.0.1, 1"
1 change: 1 addition & 0 deletions src/SparseArraysBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export SparseArrayDOK,
include("abstractsparsearrayinterface.jl")
include("sparsearrayinterface.jl")
include("indexing.jl")
include("map.jl")
include("wrappers.jl")
include("abstractsparsearray.jl")
include("sparsearraydok.jl")
Expand Down
66 changes: 3 additions & 63 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,39 +92,6 @@ end
# TODO: Define `default_similartype` or something like that?
return SparseArrayDOK{T}(undef, size)
end

# map over a specified subset of indices of the inputs.
function map_indices! end

@interface interface::AbstractArrayInterface function map_indices!(
indices, f, a_dest::AbstractArray, as::AbstractArray...
)
for I in indices
a_dest[I] = f(map(a -> a[I], as)...)
end
return a_dest
end

# Only map the stored values of the inputs.
function map_stored! end

@interface interface::AbstractArrayInterface function map_stored!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map_indices!(eachstoredindex(as...), f, a_dest, as...)
return a_dest
end

# Only map all values, not just the stored ones.
function map_all! end

@interface interface::AbstractArrayInterface function map_all!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map_indices!(eachindex(as...), f, a_dest, as...)
return a_dest
end

using DerivableInterfaces: DerivableInterfaces, zero!

# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
Expand All @@ -137,37 +104,10 @@ using DerivableInterfaces: DerivableInterfaces, zero!
# More generally, this codepath could be taking if `zero(eltype(a))`
# is defined and the elements are immutable.
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
return @interface interface map_stored!(f, a, a)
end

# Determines if a function preserves the stored values
# of the destination sparse array.
# The current code may be inefficient since it actually
# accesses an unstored element, which in the case of a
# sparse array of arrays can allocate an array.
# Sparse arrays could be expected to define a cheap
# unstored element allocator, for example
# `get_prototypical_unstored(a::AbstractArray)`.
function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...)
I = first(eachindex(as...))
return iszero(f(map(a -> getunstoredindex(a, I), as)...))
end

@interface interface::AbstractSparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
isempty(a_dest) && return a_dest # special case to avoid trying to access empty array
indices = if !preserves_unstored(f, a_dest, as...)
eachindex(a_dest)
elseif any(a -> a_dest !== a, as)
as = map(a -> Base.unalias(a_dest, a), as)
@interface interface zero!(a_dest)
eachstoredindex(as...)
else
eachstoredindex(a_dest)
@inbounds for I in eachstoredindex(a)
a[I] = f(a[I])
end
@interface interface map_indices!(indices, f, a_dest, as...)
return a_dest
return a
end

# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
Expand Down
23 changes: 20 additions & 3 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,27 @@ end
end
end

# required:
@interface ::AbstractSparseArrayInterface eachstoredindex(style::IndexStyle, A::AbstractArray) = throw(
MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)})
@noinline function error_if_canonical_eachstoredindex(style::IndexStyle, A::AbstractArray)
style === IndexStyle(A) && throw(Base.CanonicalIndexError("eachstoredindex", typeof(A)))
return nothing
end

# required: one implementation for canonical index style
@interface ::AbstractSparseArrayInterface function eachstoredindex(
style::IndexStyle, A::AbstractArray
)
error_if_canonical_eachstoredindex(style, A)
inds = eachstoredindex(A)
if style === IndexCartesian()
eltype(inds) === CartesianIndex{ndims(A)} && return inds
return map(Base.Fix1(Base.getindex, CartesianIndices(A)), inds)
elseif style === IndexLinear()
eltype(inds) === Int && return inds
return map(Base.Fix1(Base.getindex, LinearIndices(A)), inds)
else
error(lazy"unkown index style $style")
end
end

# derived but may be specialized:
@interface ::AbstractSparseArrayInterface function eachstoredindex(
Expand Down
162 changes: 162 additions & 0 deletions src/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# zero-preserving Traits
# ----------------------
"""
abstract type ZeroPreserving <: Function end

Holy Trait to indicate how a function interacts with abstract zero values:

- `StrongPreserving` : output is guaranteed to be zero if **any** input is.
- `WeakPreserving` : output is guaranteed to be zero if **all** inputs are.
- `NonPreserving` : no guarantees on output.

To attempt to automatically determine this, either `ZeroPreserving(f, A::AbstractArray...)` or
`ZeroPreserving(f, T::Type...)` can be used/overloaded.

!!! warning
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
"""
abstract type ZeroPreserving <: Function end

struct StrongPreserving{F} <: ZeroPreserving
f::F
end
struct WeakPreserving{F} <: ZeroPreserving
f::F
end
struct NonPreserving{F} <: ZeroPreserving
f::F
end

# Backport: remove in 1.12
@static if !isdefined(Base, :haszero)
_haszero(T::Type) = false
_haszero(::Type{<:Number}) = true
else
_haszero = Base.haszero
end

# warning: cannot automatically detect WeakPreserving since this would mean checking all values
function ZeroPreserving(f, A::AbstractArray, Bs::AbstractArray...)
return ZeroPreserving(f, eltype(A), eltype.(Bs)...)
end
# TODO: the following might not properly specialize on the types
# TODO: non-concrete element types
function ZeroPreserving(f, T::Type, Ts::Type...)
if all(_haszero, (T, Ts...))
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving(f) : NonPreserving(f)
else
return NonPreserving(f)
end
end

const _WEAK_FUNCTIONS = (:+, :-)
for f in _WEAK_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
end
end

const _STRONG_FUNCTIONS = (:*,)
for f in _STRONG_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
$f
)
end
end

# map(!)
# ------
@interface I::AbstractSparseArrayInterface function Base.map(
f, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map(f_pres, A, Bs...)
end
@interface I::AbstractSparseArrayInterface function Base.map(
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
)
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
C = similar(I, T, size(A))
return @interface I map!(f, C, A, Bs...)
end

@interface I::AbstractSparseArrayInterface function Base.map!(
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map!(f_pres, C, A, Bs...)
end

@interface ::AbstractSparseArrayInterface function Base.map!(
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
checkshape(C, A, Bs...)
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))

if f isa StrongPreserving
style = IndexStyle(C, unaliased...)
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa WeakPreserving
style = IndexStyle(C, unaliased...)
inds = union(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa NonPreserving
inds = eachindex(C, unaliased...)
else
error(lazy"unknown zero-preserving type $(typeof(f))")
end

@inbounds for I in inds
C[I] = f.f(ith_all(I, unaliased)...)
end

return C
end

# Derived functions
# -----------------
@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!(
identity, C, A
)

# Only map the stored values of the inputs.
function map_stored! end

@interface interface::AbstractArrayInterface function map_stored!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map!(WeakPreserving(f), a_dest, as...)
return a_dest
end

# Only map all values, not just the stored ones.
function map_all! end

@interface interface::AbstractArrayInterface function map_all!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map!(NonPreserving(f), a_dest, as...)
return a_dest
end

# Utility functions
# -----------------
# shape check similar to checkbounds
checkshape(::Type{Bool}, A::AbstractArray) = true
checkshape(::Type{Bool}, A::AbstractArray, B::AbstractArray) = size(A) == size(B)
function checkshape(::Type{Bool}, A::AbstractArray, Bs::AbstractArray...)
return allequal(size, (A, Bs...))
end

function checkshape(A::AbstractArray, Bs::AbstractArray...)
return checkshape(Bool, A, Bs...) ||
throw(DimensionMismatch("argument shapes must match"))
end

@inline ith_all(i, ::Tuple{}) = ()
function ith_all(i, as)
@_propagate_inbounds_meta
return (as[1][i], ith_all(i, Base.tail(as))...)
end
2 changes: 1 addition & 1 deletion src/oneelementarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ storedindex(a::OneElementArray) = getfield(a, :index)
function isstored(a::OneElementArray, I::Int...)
return I == storedindex(a)
end
function eachstoredindex(a::OneElementArray)
function eachstoredindex(::IndexCartesian, a::OneElementArray)
return Fill(CartesianIndex(storedindex(a)), 1)
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ JLArrays = "0.2.0"
LinearAlgebra = "<0.0.1, 1"
Random = "<0.0.1, 1"
SafeTestsets = "0.1.0"
SparseArraysBase = "0.5.0"
SparseArraysBase = "0.6.0"
StableRNGs = "1.0.2"
Suppressor = "0.2.8"
Test = "<0.0.1, 1"
Loading