@@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype
866
866
@test a1' * a2 ≈ Array (a1)' * Array (a2)
867
867
@test dot (a1, a2) ≈ a1' * a2
868
868
end
869
+ @testset " cat" begin
870
+ a1 = dev (BlockSparseArray {elt} ([2 , 3 ], [2 , 3 ]))
871
+ a1[Block (2 , 1 )] = dev (randn (elt, size (@view (a1[Block (2 , 1 )]))))
872
+ a2 = dev (BlockSparseArray {elt} ([2 , 3 ], [2 , 3 ]))
873
+ a2[Block (1 , 2 )] = dev (randn (elt, size (@view (a2[Block (1 , 2 )]))))
874
+
875
+ a_dest = cat (a1, a2; dims= 1 )
876
+ @test block_nstored (a_dest) == 2
877
+ @test blocklengths .(axes (a_dest)) == ([2 , 3 , 2 , 3 ], [2 , 3 ])
878
+ @test issetequal (block_stored_indices (a_dest), [Block (2 , 1 ), Block (3 , 2 )])
879
+ @test a_dest[Block (2 , 1 )] == a1[Block (2 , 1 )]
880
+ @test a_dest[Block (3 , 2 )] == a2[Block (1 , 2 )]
881
+
882
+ a_dest = cat (a1, a2; dims= 2 )
883
+ @test block_nstored (a_dest) == 2
884
+ @test blocklengths .(axes (a_dest)) == ([2 , 3 ], [2 , 3 , 2 , 3 ])
885
+ @test issetequal (block_stored_indices (a_dest), [Block (2 , 1 ), Block (1 , 4 )])
886
+ @test a_dest[Block (2 , 1 )] == a1[Block (2 , 1 )]
887
+ @test a_dest[Block (1 , 4 )] == a2[Block (1 , 2 )]
888
+
889
+ a_dest = cat (a1, a2; dims= (1 , 2 ))
890
+ @test block_nstored (a_dest) == 2
891
+ @test blocklengths .(axes (a_dest)) == ([2 , 3 , 2 , 3 ], [2 , 3 , 2 , 3 ])
892
+ @test issetequal (block_stored_indices (a_dest), [Block (2 , 1 ), Block (3 , 4 )])
893
+ @test a_dest[Block (2 , 1 )] == a1[Block (2 , 1 )]
894
+ @test a_dest[Block (3 , 4 )] == a2[Block (1 , 2 )]
895
+ end
869
896
@testset " TensorAlgebra" begin
870
897
a1 = dev (BlockSparseArray {elt} ([2 , 3 ], [2 , 3 ]))
871
898
a1[Block (1 , 1 )] = dev (randn (elt, size (@view (a1[Block (1 , 1 )]))))
0 commit comments