Skip to content

Commit 57994ff

Browse files
authored
[BlockSparseArrays] Direct sum/cat (#1579)
* [BlockSparseArrays] Direct sum/`cat` * [NDTensors] Bump to v0.3.64
1 parent cf050da commit 57994ff

File tree

10 files changed

+184
-1
lines changed

10 files changed

+184
-1
lines changed

NDTensors/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.63"
4+
version = "0.3.64"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ include("blocksparsearrayinterface/broadcast.jl")
77
include("blocksparsearrayinterface/map.jl")
88
include("blocksparsearrayinterface/arraylayouts.jl")
99
include("blocksparsearrayinterface/views.jl")
10+
include("blocksparsearrayinterface/cat.jl")
1011
include("abstractblocksparsearray/abstractblocksparsearray.jl")
1112
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
1213
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
@@ -17,6 +18,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl")
1718
include("abstractblocksparsearray/broadcast.jl")
1819
include("abstractblocksparsearray/map.jl")
1920
include("abstractblocksparsearray/linearalgebra.jl")
21+
include("abstractblocksparsearray/cat.jl")
2022
include("blocksparsearray/defaults.jl")
2123
include("blocksparsearray/blocksparsearray.jl")
2224
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# TODO: Change to `AnyAbstractBlockSparseArray`.
2+
function Base.cat(as::BlockSparseArrayLike...; dims)
3+
# TODO: Use `sparse_cat` instead, currently
4+
# that erroneously allocates too many blocks that are
5+
# zero and shouldn't be stored.
6+
return blocksparse_cat(as...; dims)
7+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
2+
using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat!
3+
4+
# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`.
5+
# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`.
6+
function SparseArrayInterface.axis_cat(
7+
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
8+
)
9+
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))
10+
end
11+
12+
# that erroneously allocates too many blocks that are
13+
# zero and shouldn't be stored.
14+
function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
15+
sparse_cat!(blocks(a_dest), blocks.(as)...; dims)
16+
return a_dest
17+
end
18+
19+
# TODO: Delete this in favor of `sparse_cat`, currently
20+
# that erroneously allocates too many blocks that are
21+
# zero and shouldn't be stored.
22+
function blocksparse_cat(as::AbstractArray...; dims)
23+
a_dest = allocate_cat_output(as...; dims)
24+
blocksparse_cat!(a_dest, as...; dims)
25+
return a_dest
26+
end

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

+27
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
866866
@test a1' * a2 Array(a1)' * Array(a2)
867867
@test dot(a1, a2) a1' * a2
868868
end
869+
@testset "cat" begin
870+
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
871+
a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)]))))
872+
a2 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
873+
a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)]))))
874+
875+
a_dest = cat(a1, a2; dims=1)
876+
@test block_nstored(a_dest) == 2
877+
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3])
878+
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 2)])
879+
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
880+
@test a_dest[Block(3, 2)] == a2[Block(1, 2)]
881+
882+
a_dest = cat(a1, a2; dims=2)
883+
@test block_nstored(a_dest) == 2
884+
@test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3])
885+
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(1, 4)])
886+
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
887+
@test a_dest[Block(1, 4)] == a2[Block(1, 2)]
888+
889+
a_dest = cat(a1, a2; dims=(1, 2))
890+
@test block_nstored(a_dest) == 2
891+
@test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3])
892+
@test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 4)])
893+
@test a_dest[Block(2, 1)] == a1[Block(2, 1)]
894+
@test a_dest[Block(3, 4)] == a2[Block(1, 2)]
895+
end
869896
@testset "TensorAlgebra" begin
870897
a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3]))
871898
a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)]))))

NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include("sparsearrayinterface/broadcast.jl")
1212
include("sparsearrayinterface/conversion.jl")
1313
include("sparsearrayinterface/wrappers.jl")
1414
include("sparsearrayinterface/zero.jl")
15+
include("sparsearrayinterface/cat.jl")
1516
include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl")
1617
include("abstractsparsearray/abstractsparsearray.jl")
1718
include("abstractsparsearray/abstractsparsematrix.jl")
@@ -24,6 +25,7 @@ include("abstractsparsearray/broadcast.jl")
2425
include("abstractsparsearray/map.jl")
2526
include("abstractsparsearray/baseinterface.jl")
2627
include("abstractsparsearray/convert.jl")
28+
include("abstractsparsearray/cat.jl")
2729
include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl")
2830
include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl")
2931
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# TODO: Change to `AnyAbstractSparseArray`.
2+
function Base.cat(as::SparseArrayLike...; dims)
3+
return sparse_cat(as...; dims)
4+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
unval(x) = x
2+
unval(::Val{x}) where {x} = x
3+
4+
# TODO: Assert that `a1` and `a2` start at one.
5+
axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
6+
function axis_cat(
7+
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
8+
)
9+
return axis_cat(axis_cat(a1, a2), a_rest...)
10+
end
11+
function cat_axes(as::AbstractArray...; dims)
12+
return ntuple(length(first(axes.(as)))) do dim
13+
return if dim in unval(dims)
14+
axis_cat(map(axes -> axes[dim], axes.(as))...)
15+
else
16+
axes(first(as))[dim]
17+
end
18+
end
19+
end
20+
21+
function allocate_cat_output(as::AbstractArray...; dims)
22+
eltype_dest = promote_type(eltype.(as)...)
23+
axes_dest = cat_axes(as...; dims)
24+
# TODO: Promote the block types of the inputs rather than using
25+
# just the first input.
26+
# TODO: Make this customizable with `cat_similar`.
27+
# TODO: Base the zero element constructor on those of the inputs,
28+
# for example block sparse arrays.
29+
return similar(first(as), eltype_dest, axes_dest...)
30+
end
31+
32+
# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
33+
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
34+
# This is very similar to the `Base.cat` implementation but handles zero values better.
35+
function cat_offset!(
36+
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
37+
)
38+
inds = ntuple(ndims(a_dest)) do dim
39+
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
40+
end
41+
a_dest[inds...] = a1
42+
new_offsets = ntuple(ndims(a_dest)) do dim
43+
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
44+
end
45+
cat_offset!(a_dest, new_offsets, a_rest...; dims)
46+
return a_dest
47+
end
48+
function cat_offset!(a_dest::AbstractArray, offsets; dims)
49+
return a_dest
50+
end
51+
52+
# TODO: Define a generic `cat!` function.
53+
function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
54+
offsets = ntuple(zero, ndims(a_dest))
55+
# TODO: Fill `a_dest` with zeros if needed.
56+
cat_offset!(a_dest, offsets, as...; dims)
57+
return a_dest
58+
end
59+
60+
function sparse_cat(as::AbstractArray...; dims)
61+
a_dest = allocate_cat_output(as...; dims)
62+
sparse_cat!(a_dest, as...; dims)
63+
return a_dest
64+
end

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl

+19
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,31 @@ function sparse_setindex!(a::AbstractArray, value, I::Vararg{Int})
137137
return a
138138
end
139139

140+
# Fix ambiguity error
141+
function sparse_setindex!(a::AbstractArray, value)
142+
sparse_setindex!(a, value, CartesianIndex())
143+
return a
144+
end
145+
140146
# Linear indexing
141147
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex{1})
142148
sparse_setindex!(a, value, CartesianIndices(a)[I])
143149
return a
144150
end
145151

152+
# Slicing
153+
# TODO: Make this handle more general slicing operations,
154+
# base it off of `ArrayLayouts.sub_materialize`.
155+
function sparse_setindex!(a::AbstractArray, value, I::AbstractUnitRange...)
156+
inds = CartesianIndices(I)
157+
for i in stored_indices(value)
158+
if i in CartesianIndices(inds)
159+
a[inds[i]] = value[i]
160+
end
161+
end
162+
return a
163+
end
164+
146165
# Handle trailing indices
147166
function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex)
148167
t = Tuple(I)

NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl

+32
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,38 @@ using Test: @test, @testset
342342
@test a_dest isa SparseArray{elt}
343343
@test SparseArrayInterface.nstored(a_dest) == 2
344344

345+
# cat
346+
a1 = SparseArray{elt}(2, 3)
347+
a1[1, 2] = 12
348+
a1[2, 1] = 21
349+
a2 = SparseArray{elt}(2, 3)
350+
a2[1, 1] = 11
351+
a2[2, 2] = 22
352+
353+
a_dest = cat(a1, a2; dims=1)
354+
@test size(a_dest) == (4, 3)
355+
@test SparseArrayInterface.nstored(a_dest) == 4
356+
@test a_dest[1, 2] == a1[1, 2]
357+
@test a_dest[2, 1] == a1[2, 1]
358+
@test a_dest[3, 1] == a2[1, 1]
359+
@test a_dest[4, 2] == a2[2, 2]
360+
361+
a_dest = cat(a1, a2; dims=2)
362+
@test size(a_dest) == (2, 6)
363+
@test SparseArrayInterface.nstored(a_dest) == 4
364+
@test a_dest[1, 2] == a1[1, 2]
365+
@test a_dest[2, 1] == a1[2, 1]
366+
@test a_dest[1, 4] == a2[1, 1]
367+
@test a_dest[2, 5] == a2[2, 2]
368+
369+
a_dest = cat(a1, a2; dims=(1, 2))
370+
@test size(a_dest) == (4, 6)
371+
@test SparseArrayInterface.nstored(a_dest) == 4
372+
@test a_dest[1, 2] == a1[1, 2]
373+
@test a_dest[2, 1] == a1[2, 1]
374+
@test a_dest[3, 4] == a2[1, 1]
375+
@test a_dest[4, 5] == a2[2, 2]
376+
345377
## # Sparse matrix of matrix multiplication
346378
## TODO: Make this work, seems to require
347379
## a custom zero constructor.

0 commit comments

Comments
 (0)