Skip to content

Commit 69b4e4d

Browse files
strict types
1 parent 262293b commit 69b4e4d

File tree

1 file changed

+30
-27
lines changed

1 file changed

+30
-27
lines changed

src/fillalgebra.jl

+30-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
const FillVector{F,A} = Fill{F,1,A}
22
const FillMatrix{F,A} = Fill{F,2,A}
3+
const OnesVector{F,A} = Ones{F,1,A}
4+
const OnesMatrix{F,A} = Ones{F,2,A}
5+
const ZerosVector{F,A} = Zeros{F,1,A}
6+
const ZerosMatrix{F,A} = Zeros{F,2,A}
7+
38

49
## vec
510
vec(a::Ones{T}) where T = Ones{T}(length(a))
@@ -17,6 +22,13 @@ adjoint(a::Zeros{T,2}) where T = Zeros{T}(reverse(a.axes))
1722
transpose(a::Fill{T,2}) where T = Fill{T}(transpose(a.value), reverse(a.axes))
1823
adjoint(a::Fill{T,2}) where T = Fill{T}(adjoint(a.value), reverse(a.axes))
1924

25+
fillsimilar(a::Ones{T}, axes) where T = Ones{T}(axes)
26+
fillsimilar(a::Zeros{T}, axes) where T = Zeros{T}(axes)
27+
fillsimilar(a::AbstractFill, axes) = Fill(getindex_value(a), axes)
28+
fillsimilar(a::Ones, ::Type{T}, axes) where T = Ones{T}(axes)
29+
fillsimilar(a::Zeros, ::Type{T}, axes) where T = Zeros{T}(axes)
30+
fillsimilar(a::AbstractFill, ::Type{T}, axes) where T = Fill{T}(getindex_value(a), axes)
31+
2032
permutedims(a::AbstractFill{<:Any,1}) = fillsimilar(a, (1, length(a)))
2133
permutedims(a::AbstractFill{<:Any,2}) = fillsimilar(a, reverse(a.axes))
2234

@@ -34,20 +46,25 @@ end
3446
function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,2})
3547
axes(a, 2) axes(b, 1) &&
3648
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
37-
return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1), axes(b, 2)))
49+
return Fill(getindex_value(a)*getindex_value(b), (axes(a, 1), axes(b, 2)))
3850
end
3951

4052
function mult_fill(a::AbstractFill, b::AbstractFill{<:Any,1})
4153
axes(a, 2) axes(b, 1) &&
4254
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
43-
return Fill(getindex_value(a)*getindex_value(b)*size(a,2), (axes(a, 1),))
55+
return Fill(getindex_value(a)*getindex_value(b), (axes(a, 1),))
4456
end
4557

46-
function mult_ones(a::AbstractVector, b::AbstractMatrix)
58+
function mult_ones(a, b::AbstractMatrix)
4759
axes(a, 2) axes(b, 1) &&
4860
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
4961
return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1), axes(b, 2)))
5062
end
63+
function mult_ones(a, b::AbstractVector)
64+
axes(a, 2) axes(b, 1) &&
65+
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
66+
return Ones{promote_type(eltype(a), eltype(b))}((axes(a, 1),))
67+
end
5168

5269
function mult_zeros(a, b::AbstractMatrix)
5370
axes(a, 2) axes(b, 1) &&
@@ -65,6 +82,8 @@ end
6582
*(a::AbstractFill{<:Any,2}, b::AbstractFill{<:Any,1}) = mult_fill(a,b)
6683

6784
*(a::Ones{<:Any,1}, b::Ones{<:Any,2}) = mult_ones(a, b)
85+
*(a::Ones{<:Any,2}, b::Ones{<:Any,2}) = mult_ones(a, b)
86+
*(a::Ones{<:Any,2}, b::Ones{<:Any,1}) = mult_ones(a, b)
6887

6988
*(a::Zeros{<:Any,1}, b::Zeros{<:Any,2}) = mult_zeros(a, b)
7089
*(a::Zeros{<:Any,2}, b::Zeros{<:Any,2}) = mult_zeros(a, b)
@@ -89,14 +108,6 @@ end
89108
*(a::Zeros{<:Any,2}, b::Diagonal) = mult_zeros(a, b)
90109
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
91110
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
92-
function *(a::Diagonal, b::AbstractFill{<:Any,2})
93-
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
94-
a.diag .* b # use special broadcast
95-
end
96-
function *(a::AbstractFill{<:Any,2}, b::Diagonal)
97-
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
98-
a .* permutedims(b.diag) # use special broadcast
99-
end
100111

101112
*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
102113
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
@@ -116,30 +127,29 @@ function *(f::FillMatrix, x::AbstractMatrix)
116127
repeat(sum(x, dims=1) * f.value, m, 1)
117128
end
118129

119-
function *(x::AbstractMatrix, f::Ones)
130+
function *(x::AbstractMatrix, f::OnesMatrix)
120131
axes(x, 2) axes(f, 1) &&
121132
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
122133
m = size(f, 2)
123134
repeat(sum(x, dims=2) * one(eltype(f)), 1, m)
124135
end
125136

126-
function *(f::Ones, x::AbstractMatrix)
137+
function *(f::OnesMatrix, x::AbstractMatrix)
127138
axes(f, 2) axes(x, 1) &&
128139
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
129140
m = size(f, 1)
130141
repeat(sum(x, dims=1) * one(eltype(f)), m, 1)
131142
end
132143

133-
function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
144+
145+
146+
function *(a::Adjoint{T, <:AbstractVector{T}}, b::Zeros{S, 1}) where {T, S}
134147
la, lb = length(a), length(b)
135148
if la lb
136149
throw(DimensionMismatch("dot product arguments have lengths $la and $lb"))
137150
end
138-
return zero(Base.promote_op(*, T, S))
151+
return zero(promote_type(T, S))
139152
end
140-
141-
*(a::AdjointAbsVec, b::Zeros{<:Any, 1}) = _adjvec_mul_zeros(a, b)
142-
*(a::AdjointAbsVec{<:Number}, b::Zeros{<:Number, 1}) = _adjvec_mul_zeros(a, b)
143153
*(a::Adjoint{T, <:AbstractMatrix{T}} where T, b::Zeros{<:Any, 1}) = mult_zeros(a, b)
144154

145155
function *(a::Transpose{T, <:AbstractVector{T}}, b::Zeros{T, 1}) where T<:Real
@@ -161,7 +171,6 @@ function +(a::Zeros{T}, b::Zeros{V}) where {T, V}
161171
return Zeros{promote_type(T,V)}(size(a)...)
162172
end
163173
-(a::Zeros, b::Zeros) = -(a + b)
164-
-(a::Ones, b::Ones) = Zeros(a)+Zeros(b)
165174

166175
# Zeros +/- Fill and Fill +/- Zeros
167176
function +(a::AbstractFill{T}, b::Zeros{V}) where {T, V}
@@ -241,14 +250,8 @@ end
241250

242251
function rmul!(z::AbstractFill, x::Number)
243252
λ = getindex_value(z)
244-
# Following check ensures consistency w/ lmul!(x, Array(z))
253+
# Following check ensures consistency w/ lmul!(x, Array(z))
245254
# for, e.g., lmul!(NaN, z)
246255
λ*x == λ || throw(ArgumentError("Cannot scale by $x"))
247256
z
248-
end
249-
250-
fillzero(::Type{Fill{T,N,AXIS}}, n, m) where {T,N,AXIS} = Fill{T,N,AXIS}(zero(T), (n, m))
251-
fillzero(::Type{Zeros{T,N,AXIS}}, n, m) where {T,N,AXIS} = Zeros{T,N,AXIS}((n, m))
252-
fillzero(::Type{F}, n, m) where F = throw(ArgumentError("Cannot create a zero array of type $F"))
253-
254-
diagzero(D::Diagonal{F}, i, j) where F<:AbstractFill = fillzero(F, axes(D.diag[i], 1), axes(D.diag[j], 2))
257+
end

0 commit comments

Comments
 (0)