diff --git a/src/generic/broadcast.jl b/src/generic/broadcast.jl index 20898ff5..d702a1cf 100644 --- a/src/generic/broadcast.jl +++ b/src/generic/broadcast.jl @@ -45,26 +45,33 @@ BroadcastStyle(::BandedStyle, ::DefaultArrayStyle{2}) = BandedStyle() size(bc::Broadcasted{BandedStyle}) = length.(axes(bc)) isbanded(bc::Broadcasted{BandedStyle}) = true -#### -# Default to standard Array broadcast -# -# This is because, for example, exp.(B) is not banded -#### - - -copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle}) = - copyto!(dest, Broadcasted{DefaultArrayStyle{2}}(bc.f, bc.args, bc.axes)) ## # copyto! ## +# Default to standard Array broadcast, because, for example, exp.(B) is not banded +copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle}) = + copyto!(dest, Broadcasted{DefaultArrayStyle{2}}(bc.f, bc.args, bc.axes)) + +copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{AbstractMatrix,AbstractMatrix}}) = + _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(dest), map(MemoryLayout,bc.args)) +copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{AbstractVector,AbstractMatrix}}) = + _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(dest), map(MemoryLayout,bc.args)) +copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{AbstractMatrix,AbstractVector}}) = + _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(dest), map(MemoryLayout,bc.args)) +copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{AbstractMatrix,Number}}) = + _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(dest), map(MemoryLayout,bc.args)) +copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{Number,AbstractMatrix}}) = + _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(dest), map(MemoryLayout,bc.args)) + + _copyto!(dest_L::AbstractBandedLayout, src_L::AbstractBandedLayout, dest::AbstractMatrix, src::AbstractMatrix) = _banded_broadcast!(dest, identity, src, dest_L, src_L) function checkbroadcastband(dest, sizesrc, bndssrc) size(dest) == sizesrc || throw(DimensionMismatch()) - min(sizesrc[1],bndssrc[1]+1) ≤ min(bandwidth(dest,1)+1,size(dest,1)) && + min(sizesrc[1],bndssrc[1]+1) ≤ min(bandwidth(dest,1)+1,size(dest,1)) && min(sizesrc[2],bndssrc[2]+1) ≤ min(bandwidth(dest,2)+2,size(dest,2)) || throw(BandError(dest,size(dest,2)-1)) end @@ -143,7 +150,7 @@ end @inline _colshift(A::AbstractMatrix, j) = _colshift(bandwidths(A), j) @inline _bulkshift(A::AbstractMatrix, j) = _bulkshift(bandwidths(A), j) -@inline function _colrange((m,n), (l,u), j) +@inline function _colrange((m,n), (l,u), j) j ≤ u && return _startcolrange((m,n), (l,u), j) j ≤ m-l && return _bulkcolrange((m,n), (l,u), j) return _stopcolrange((m,n), (l,u), j) @@ -359,16 +366,6 @@ function _banded_broadcast!(dest::AbstractMatrix, f, (x, src)::Tuple{Number,Abst dest end -function copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{<:AbstractMatrix,<:Number}}) - (A,x) = bc.args - _banded_broadcast!(dest, bc.f, (A, x), MemoryLayout(typeof(dest)), MemoryLayout(typeof(A))) -end - -function copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{<:Number,<:AbstractMatrix}}) - (x,A) = bc.args - _banded_broadcast!(dest, bc.f, (x,A), MemoryLayout(typeof(dest)), MemoryLayout(typeof(A))) -end - ############### # matrix-vector broadcast ############### @@ -414,7 +411,7 @@ function _left_colvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{Ab dest end -_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix{T},AbstractVector{V}}, _1, _2) where {T,V} = +_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix{T},AbstractVector{V}}, _1, _2) where {T,V} = _right_colvec_banded_broadcast!(dest, f, (A,B), _1, _2) function _right_colvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix{T},AbstractVecOrMat{V}}, _1, _2) where {T,V} @@ -533,15 +530,6 @@ function _right_rowvec_banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{A dest end - -copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{<:AbstractVector,<:AbstractMatrix}}) = - _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(dest), MemoryLayout.(bc.args)) - - -copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{<:AbstractMatrix,<:AbstractVector}}) = - _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(dest), MemoryLayout.(bc.args)) - - ################ # matrix-matrix broadcast ################ @@ -611,7 +599,7 @@ function checkzerobands(dest, f, (A,B)::Tuple{AbstractMatrix,AbstractMatrix}) end -function _banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix,AbstractMatrix}, ::BandedColumns, ::Tuple{<:BandedColumns,<:BandedColumns}) +function _banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix,AbstractMatrix}, ::BandedColumns, ::Tuple{BandedColumns,BandedColumns}) z = f(zero(eltype(A)), zero(eltype(B))) bc = broadcasted(f, A, B) l, u = bandwidths(bc) @@ -653,7 +641,7 @@ function _banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix data_d_u_A = view(data_d, bs .+ d_u .+ 1, :) data_A_u_A = view(data_A, bs .+ A_u .+ 1, :) data_d_u_A .= f.(data_A_u_A, zero(eltype(B))) - + # construct where A upper is zero # this is from band B_u:min(A_u+1,-B_l) bs = max(-d_u,-B_u):min(-1-A_u,B_l,d_l) @@ -683,10 +671,6 @@ function _banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix dest end - -copyto!(dest::AbstractArray, bc::Broadcasted{BandedStyle, <:Any, <:Any, <:Tuple{<:AbstractMatrix,<:AbstractMatrix}}) = - _banded_broadcast!(dest, bc.f, bc.args, MemoryLayout(typeof(dest)), map(MemoryLayout,bc.args)) - # override copy in case data has special broadcast _default_banded_broadcast(bc::Broadcasted{Style}, _) where Style = Base.invoke(copy, Tuple{Broadcasted{Style}}, bc) _default_banded_broadcast(bc::Broadcasted) = _default_banded_broadcast(bc, axes(bc)) @@ -742,7 +726,7 @@ _broadcast_bandwidths(bnds, _) = bnds _broadcast_bandwidths((l,u), a::AbstractVector) = (bandwidth(a,1),u) function __broadcast_bandwidths((l,u), A) sz = _bcsize(A) - (length(sz) == 1 || sz[2] == 1) && return (bandwidth(A,1),u) + (length(sz) == 1 || sz[2] == 1) && return (bandwidth(A,1),u) sz[1] == 1 && return (l, bandwidth(A,2)) bandwidths(A) # need to special case vector broadcasting end @@ -763,7 +747,7 @@ _band_eval_args(a::Broadcasted, b...) = (zero(_broadcast_eltype(a)), _band_eval_ __bnds(m, n) = (m-1, n-1) __bnds(m) = (m-1, 0) _bnds(bc) = __bnds(size(bc)...) - + bandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(*)}) = min.(_broadcast_bandwidths.(Ref(_bnds(bc)), bc.args)...)