diff --git a/src/blockbroadcast.jl b/src/blockbroadcast.jl index f8e7e89d..8d3307c7 100644 --- a/src/blockbroadcast.jl +++ b/src/blockbroadcast.jl @@ -34,9 +34,9 @@ sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b))) sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b)) combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b))) -Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b) -Base.Broadcast.axistype(a::BlockedUnitRange, b) = length(b) == 1 ? a : combine_blockaxes(a, b) -Base.Broadcast.axistype(a, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b) +Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = combine_blockaxes(a, b) +Base.Broadcast.axistype(a::BlockedUnitRange, b) = combine_blockaxes(a, b) +Base.Broadcast.axistype(a, b::BlockedUnitRange) = combine_blockaxes(a, b) similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} = diff --git a/test/test_blockbroadcast.jl b/test/test_blockbroadcast.jl index 0d58aa3d..149e4020 100644 --- a/test/test_blockbroadcast.jl +++ b/test/test_blockbroadcast.jl @@ -182,6 +182,12 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal u = BlockArray(randn(5), [2,3]); @inferred(copyto!(similar(u), Base.broadcasted(exp, u))) @test exp.(u) == exp.(Vector(u)) + + shape1 = (BlockArrays._BlockedUnitRange((2,)),); + shape2 = (BlockArrays._BlockedUnitRange((2,)),); + @inferred Base.Broadcast.axistype(shape1[1], shape2[1]) + @inferred BlockArrays.combine_blockaxes(shape1[1], shape2[1]) + @inferred Base.Broadcast.broadcast_shape(shape1, shape2) end @testset "adjtrans" begin