Skip to content

Commit 4c70aff

Browse files
authored
[BlockSparseArrays] Fix adjoint and transpose (#1470)
* [BlockSparseArrays] Fix adjoint and transpose * [NDTensors] Bump to v0.3.11
1 parent cd766a2 commit 4c70aff

File tree

5 files changed

+184
-22
lines changed

5 files changed

+184
-22
lines changed

NDTensors/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.10"
4+
version = "0.3.11"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl

+15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ArrayLayouts: LayoutArray
22
using BlockArrays: blockisequal
3+
using LinearAlgebra: Adjoint, Transpose
34
using ..SparseArrayInterface:
45
SparseArrayInterface,
56
SparseArrayStyle,
@@ -73,6 +74,20 @@ function Base.copyto!(a_dest::LayoutArray, a_src::BlockSparseArrayLike)
7374
return a_dest
7475
end
7576

77+
function Base.copyto!(
78+
a_dest::AbstractMatrix, a_src::Transpose{T,<:AbstractBlockSparseMatrix{T}}
79+
) where {T}
80+
sparse_copyto!(a_dest, a_src)
81+
return a_dest
82+
end
83+
84+
function Base.copyto!(
85+
a_dest::AbstractMatrix, a_src::Adjoint{T,<:AbstractBlockSparseMatrix{T}}
86+
) where {T}
87+
sparse_copyto!(a_dest, a_src)
88+
return a_dest
89+
end
90+
7691
function Base.permutedims!(a_dest, a_src::BlockSparseArrayLike, perm)
7792
sparse_permutedims!(a_dest, a_src, perm)
7893
return a_dest

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ using SplitApplyCombine: groupcount
33

44
using Adapt: Adapt, WrappedArray
55

6-
const WrappedAbstractBlockSparseArray{T,N,A} = WrappedArray{
7-
T,N,<:AbstractBlockSparseArray,<:AbstractBlockSparseArray{T,N}
6+
const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
7+
T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}
88
}
99

1010
# TODO: Rename `AnyBlockSparseArray`.

NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl

+79-13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using BlockArrays:
88
blocks,
99
blocklengths,
1010
findblockindex
11+
using LinearAlgebra: Adjoint, Transpose
1112
using ..SparseArrayInterface: perm, iperm, nstored
1213
## using MappedArrays: mappedarray
1314

@@ -86,35 +87,96 @@ end
8687

8788
# BlockArrays
8889

89-
using ..SparseArrayInterface: SparseArrayInterface, AbstractSparseArray
90+
using ..SparseArrayInterface:
91+
SparseArrayInterface, AbstractSparseArray, AbstractSparseMatrix
9092

91-
# Represents the array of arrays of a `SubArray`
92-
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
93+
_perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
94+
_getindices(t::Tuple, indices) = map(i -> t[i], indices)
95+
_getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), indices))
96+
97+
# Represents the array of arrays of a `PermutedDimsArray`
98+
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
9399
struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <:
94100
AbstractSparseArray{T,N}
95101
array::Array
96102
end
97103
function blocksparse_blocks(a::PermutedDimsArray)
98104
return SparsePermutedDimsArrayBlocks(a)
99105
end
100-
_perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
101-
_getindices(t::Tuple, indices) = map(i -> t[i], indices)
102-
_getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), indices))
103-
function SparseArrayInterface.stored_indices(a::SparsePermutedDimsArrayBlocks)
104-
return map(I -> _getindices(I, _perm(a.array)), stored_indices(blocks(parent(a.array))))
105-
end
106106
function Base.size(a::SparsePermutedDimsArrayBlocks)
107107
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
108108
end
109-
function Base.getindex(a::SparsePermutedDimsArrayBlocks, index::Vararg{Int})
109+
function Base.getindex(
110+
a::SparsePermutedDimsArrayBlocks{<:Any,N}, index::Vararg{Int,N}
111+
) where {N}
110112
return PermutedDimsArray(
111113
blocks(parent(a.array))[_getindices(index, _perm(a.array))...], _perm(a.array)
112114
)
113115
end
116+
function SparseArrayInterface.stored_indices(a::SparsePermutedDimsArrayBlocks)
117+
return map(I -> _getindices(I, _perm(a.array)), stored_indices(blocks(parent(a.array))))
118+
end
119+
# TODO: Either make this the generic interface or define
120+
# `SparseArrayInterface.sparse_storage`, which is used
121+
# to defined this.
122+
SparseArrayInterface.nstored(a::SparsePermutedDimsArrayBlocks) = length(stored_indices(a))
114123
function SparseArrayInterface.sparse_storage(a::SparsePermutedDimsArrayBlocks)
115124
return error("Not implemented")
116125
end
117126

127+
reverse_index(index) = reverse(index)
128+
reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))
129+
130+
# Represents the array of arrays of a `Transpose`
131+
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
132+
struct SparseTransposeBlocks{T,Array<:Transpose{T}} <: AbstractSparseMatrix{T}
133+
array::Array
134+
end
135+
function blocksparse_blocks(a::Transpose)
136+
return SparseTransposeBlocks(a)
137+
end
138+
function Base.size(a::SparseTransposeBlocks)
139+
return reverse(size(blocks(parent(a.array))))
140+
end
141+
function Base.getindex(a::SparseTransposeBlocks, index::Vararg{Int,2})
142+
return transpose(blocks(parent(a.array))[reverse(index)...])
143+
end
144+
function SparseArrayInterface.stored_indices(a::SparseTransposeBlocks)
145+
return map(reverse_index, stored_indices(blocks(parent(a.array))))
146+
end
147+
# TODO: Either make this the generic interface or define
148+
# `SparseArrayInterface.sparse_storage`, which is used
149+
# to defined this.
150+
SparseArrayInterface.nstored(a::SparseTransposeBlocks) = length(stored_indices(a))
151+
function SparseArrayInterface.sparse_storage(a::SparseTransposeBlocks)
152+
return error("Not implemented")
153+
end
154+
155+
# Represents the array of arrays of a `Adjoint`
156+
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
157+
struct SparseAdjointBlocks{T,Array<:Adjoint{T}} <: AbstractSparseMatrix{T}
158+
array::Array
159+
end
160+
function blocksparse_blocks(a::Adjoint)
161+
return SparseAdjointBlocks(a)
162+
end
163+
function Base.size(a::SparseAdjointBlocks)
164+
return reverse(size(blocks(parent(a.array))))
165+
end
166+
function Base.getindex(a::SparseAdjointBlocks, index::Vararg{Int,2})
167+
return blocks(parent(a.array))[reverse(index)...]'
168+
end
169+
function SparseArrayInterface.stored_indices(a::SparseAdjointBlocks)
170+
return map(reverse_index, stored_indices(blocks(parent(a.array))))
171+
end
172+
# TODO: Either make this the generic interface or define
173+
# `SparseArrayInterface.sparse_storage`, which is used
174+
# to defined this.
175+
SparseArrayInterface.nstored(a::SparseAdjointBlocks) = length(stored_indices(a))
176+
function SparseArrayInterface.sparse_storage(a::SparseAdjointBlocks)
177+
return error("Not implemented")
178+
end
179+
118180
# TODO: Move to `BlockArraysExtensions`.
119181
# This takes a range of indices `indices` of array `a`
120182
# and maps it to the range of indices within block `block`.
@@ -167,9 +229,6 @@ end
167229
function Base.size(a::SparseSubArrayBlocks)
168230
return length.(axes(a))
169231
end
170-
function SparseArrayInterface.stored_indices(a::SparseSubArrayBlocks)
171-
return stored_indices(view(blocks(parent(a.array)), axes(a)...))
172-
end
173232
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N}
174233
return a[Tuple(I)...]
175234
end
@@ -192,6 +251,13 @@ function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) whe
192251
# TODO: Implement this properly.
193252
return true
194253
end
254+
function SparseArrayInterface.stored_indices(a::SparseSubArrayBlocks)
255+
return stored_indices(view(blocks(parent(a.array)), axes(a)...))
256+
end
257+
# TODO: Either make this the generic interface or define
258+
# `SparseArrayInterface.sparse_storage`, which is used
259+
# to defined this.
260+
SparseArrayInterface.nstored(a::SparseSubArrayBlocks) = length(stored_indices(a))
195261
function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks)
196262
return error("Not implemented")
197263
end

NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl

+87-6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ include("TestBlockSparseArraysUtils.jl")
4848
@test block_nstored(a) == 2
4949
@test nstored(a) == 2 * 4 + 3 * 3
5050

51+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
52+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
53+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
5154
b = similar(a, complex(elt))
5255
@test eltype(b) == complex(eltype(a))
5356
@test iszero(b)
@@ -56,37 +59,58 @@ include("TestBlockSparseArraysUtils.jl")
5659
@test size(b) == size(a)
5760
@test blocksize(b) == blocksize(a)
5861

62+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
63+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
64+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
5965
b = copy(a)
6066
b[1, 1] = 11
6167
@test b[1, 1] == 11
6268
@test a[1, 1] 11
6369

70+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
71+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
72+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
6473
b = copy(a)
6574
b .*= 2
6675
@test b 2a
6776

77+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
78+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
79+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
6880
b = copy(a)
6981
b ./= 2
7082
@test b a / 2
7183

84+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
85+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
86+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
7287
b = 2 * a
7388
@test Array(b) 2 * Array(a)
7489
@test eltype(b) == elt
7590
@test block_nstored(b) == 2
7691
@test nstored(b) == 2 * 4 + 3 * 3
7792

93+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
94+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
95+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
7896
b = (2 + 3im) * a
7997
@test Array(b) (2 + 3im) * Array(a)
8098
@test eltype(b) == complex(elt)
8199
@test block_nstored(b) == 2
82100
@test nstored(b) == 2 * 4 + 3 * 3
83101

102+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
103+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
104+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
84105
b = a + a
85106
@test Array(b) 2 * Array(a)
86107
@test eltype(b) == elt
87108
@test block_nstored(b) == 2
88109
@test nstored(b) == 2 * 4 + 3 * 3
89110

111+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
112+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
113+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
90114
x = BlockSparseArray{elt}(undef, ([3, 4], [2, 3]))
91115
x[Block(1, 2)] = randn(elt, size(@view(x[Block(1, 2)])))
92116
x[Block(2, 1)] = randn(elt, size(@view(x[Block(2, 1)])))
@@ -96,12 +120,18 @@ include("TestBlockSparseArraysUtils.jl")
96120
@test block_nstored(b) == 2
97121
@test nstored(b) == 2 * 4 + 3 * 3
98122

123+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
124+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
125+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
99126
b = permutedims(a, (2, 1))
100127
@test Array(b) permutedims(Array(a), (2, 1))
101128
@test eltype(b) == elt
102129
@test block_nstored(b) == 2
103130
@test nstored(b) == 2 * 4 + 3 * 3
104131

132+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
133+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
134+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
105135
b = map(x -> 2x, a)
106136
@test Array(b) 2 * Array(a)
107137
@test eltype(b) == elt
@@ -110,6 +140,9 @@ include("TestBlockSparseArraysUtils.jl")
110140
@test block_nstored(b) == 2
111141
@test nstored(b) == 2 * 4 + 3 * 3
112142

143+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
144+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
145+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
113146
b = a[[Block(2), Block(1)], [Block(2), Block(1)]]
114147
@test b[Block(1, 1)] == a[Block(2, 2)]
115148
@test b[Block(1, 2)] == a[Block(2, 1)]
@@ -120,13 +153,19 @@ include("TestBlockSparseArraysUtils.jl")
120153
@test nstored(b) == nstored(a)
121154
@test block_nstored(b) == 2
122155

156+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
157+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
158+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
123159
b = a[Block(1):Block(2), Block(1):Block(2)]
124160
@test b == a
125161
@test size(b) == size(a)
126162
@test blocksize(b) == (2, 2)
127163
@test nstored(b) == nstored(a)
128164
@test block_nstored(b) == 2
129165

166+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
167+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
168+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
130169
b = a[Block(1):Block(1), Block(1):Block(2)]
131170
@test b == Array(a)[1:2, 1:end]
132171
@test b[Block(1, 1)] == a[Block(1, 1)]
@@ -136,41 +175,83 @@ include("TestBlockSparseArraysUtils.jl")
136175
@test nstored(b) == nstored(a[Block(1, 2)])
137176
@test block_nstored(b) == 1
138177

178+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
179+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
180+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
139181
b = a[2:4, 2:4]
140182
@test b == Array(a)[2:4, 2:4]
141183
@test size(b) == (3, 3)
142184
@test blocksize(b) == (2, 2)
143185
@test nstored(b) == 1 * 1 + 2 * 2
144186
@test block_nstored(b) == 2
145187

188+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
189+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
190+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
146191
b = a[Block(2, 1)[1:2, 2:3]]
147192
@test b == Array(a)[3:4, 2:3]
148193
@test size(b) == (2, 2)
149194
@test blocksize(b) == (1, 1)
150195
@test nstored(b) == 2 * 2
151196
@test block_nstored(b) == 1
152197

198+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
199+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
200+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
201+
b = PermutedDimsArray(a, (2, 1))
202+
@test block_nstored(b) == 2
203+
@test Array(b) == permutedims(Array(a), (2, 1))
204+
c = 2 * b
205+
@test block_nstored(c) == 2
206+
@test Array(c) == 2 * permutedims(Array(a), (2, 1))
207+
208+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
209+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
210+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
211+
b = a'
212+
@test block_nstored(b) == 2
213+
@test Array(b) == Array(a)'
214+
c = 2 * b
215+
@test block_nstored(c) == 2
216+
@test Array(c) == 2 * Array(a)'
217+
218+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
219+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
220+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
221+
b = transpose(a)
222+
@test block_nstored(b) == 2
223+
@test Array(b) == transpose(Array(a))
224+
c = 2 * b
225+
@test block_nstored(c) == 2
226+
@test Array(c) == 2 * transpose(Array(a))
227+
153228
## Broken, need to fix.
154229

230+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
231+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
232+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
155233
@test_broken a[Block(1), Block(1):Block(2)]
156234

157235
# This is outputting only zero blocks.
236+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
237+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
238+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
158239
b = a[Block(2):Block(2), Block(1):Block(2)]
159240
@test_broken block_nstored(b) == 1
160241
@test_broken b == Array(a)[3:5, 1:end]
161242

162-
b = a'
163-
@test_broken block_nstored(b) == 2
164-
165-
b = transpose(a)
166-
@test_broken block_nstored(b) == 2
167-
243+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
244+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
245+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
168246
b = copy(a)
169247
x = randn(size(@view(a[Block(2, 2)])))
170248
b[Block(2), Block(2)] = x
171249
@test_broken b[Block(2, 2)] == x
172250

173251
# Doesnt' set the block
252+
a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
253+
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
254+
a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)])))
174255
b = copy(a)
175256
b[Block(1, 1)] .= 1
176257
@test_broken b[1, 1] == trues(size(@view(b[1, 1])))

0 commit comments

Comments
 (0)