-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
copy over implementation from LowRankApprox
- Loading branch information
Showing
5 changed files
with
292 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,11 +3,22 @@ uuid = "e65ccdef-c354-471a-8090-89bec1c20ec3" | |
authors = ["Jishnu Bhattacharya <[email protected]> and contributors"] | ||
version = "1.0.0-DEV" | ||
|
||
[deps] | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
||
[weakdeps] | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
|
||
[extensions] | ||
LowRankMatricesFillArraysExt = "FillArrays" | ||
|
||
[compat] | ||
julia = "1" | ||
|
||
[extras] | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Test"] | ||
test = ["Test", "FillArrays"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
module LowRankMatricesFillArraysExt | ||
|
||
using LowRankMatrices | ||
using FillArrays | ||
using FillArrays: AbstractFill | ||
|
||
LowRankMatrix{T}(Z::Zeros, r::Int=0) where {T<:Number} = | ||
LowRankMatrix(zeros(T,size(Z,1),r), zeros(T,size(Z,2),r)) | ||
LowRankMatrix{T}(Z::Zeros, r::Int=0) where {T} = | ||
LowRankMatrix(zeros(T,size(Z,1),r), zeros(T,size(Z,2),r)) | ||
|
||
LowRankMatrix(Z::Zeros, r::Int=0) = LowRankMatrix{eltype(Z)}(Z, r) | ||
function LowRankMatrix{T}(F::AbstractFill) where T | ||
v = T(FillArrays.getindex_value(F)) | ||
m,n = size(F) | ||
LowRankMatrix(fill(v,m,1), fill(one(T),n,1)) | ||
end | ||
LowRankMatrix(F::AbstractFill{T}) where T = LowRankMatrix{T}(F) | ||
|
||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,18 @@ | ||
module LowRankMatrices | ||
|
||
# Write your package code here. | ||
import Base: similar, convert, promote_rule, size, fill!, getindex, | ||
*, +, -, \, /, | ||
Matrix, copy, copyto! | ||
|
||
using LinearAlgebra | ||
import LinearAlgebra: rank, transpose, adjoint, mul! | ||
|
||
export LowRankMatrix | ||
|
||
include("lowrankmatrix.jl") | ||
|
||
if !isdefined(Base, :get_extension) | ||
include("../ext/LowRankMatricesFillArraysExt.jl") | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
## | ||
# Represent an m x n rank-r matrix | ||
# A = U*Vᵗ | ||
## | ||
function _LowRankMatrix end | ||
|
||
mutable struct LowRankMatrix{T} <: AbstractMatrix{T} | ||
U::Matrix{T} # m x r Matrix | ||
V::Matrix{T} # n x r Matrix | ||
|
||
global function _LowRankMatrix(U::AbstractMatrix{T}, V::AbstractMatrix{T}) where T | ||
m,r = size(U) | ||
n,rv = size(V) | ||
if r ≠ rv throw(ArgumentError("U and V must have same number of columns")) end | ||
new{T}(Matrix{T}(U), Matrix{T}(V)) | ||
end | ||
end | ||
|
||
LowRankMatrix(U::AbstractMatrix, V::AbstractMatrix) = _LowRankMatrix(promote(U,V)...) | ||
LowRankMatrix(U::AbstractVector, V::AbstractMatrix) = LowRankMatrix(reshape(U,length(U),1),V) | ||
LowRankMatrix(U::AbstractMatrix, V::AbstractVector) = LowRankMatrix(U,reshape(V,length(V),1)) | ||
LowRankMatrix(U::AbstractVector, V::AbstractVector) = | ||
_LowRankMatrix(reshape(U,length(U),1), reshape(V,length(V),1)) | ||
|
||
LowRankMatrix{T}(::UndefInitializer, mn::NTuple{2,Int}, r::Int) where {T} = | ||
LowRankMatrix(Matrix{T}(undef,mn[1],r),Matrix{T}(undef,mn[2],r)) | ||
|
||
similar(L::LowRankMatrix, ::Type{T}, dims::Dims{2}) where {T} = LowRankMatrix{T}(undef, dims, rank(L)) | ||
similar(L::LowRankMatrix{T}) where {T} = LowRankMatrix{T}(undef, size(L), rank(L)) | ||
similar(L::LowRankMatrix{T}, dims::Dims{2}) where {T} = LowRankMatrix(undef, dims, rank(L)) | ||
similar(L::LowRankMatrix{T}, m::Int) where {T} = Vector{T}(undef, m) | ||
similar(L::LowRankMatrix{T}, ::Type{S}) where {S,T} = LowRankMatrix{S}(undef, size(L), rank(L)) | ||
|
||
function LowRankMatrix{T}(A::AbstractMatrix{T}) where T | ||
U,Σ,V = svd(A) | ||
r = refactorsvd!(U,Σ,V) | ||
LowRankMatrix(U[:,1:r], V[:,1:r]) | ||
end | ||
|
||
LowRankMatrix{T}(A::AbstractMatrix) where T = LowRankMatrix{T}(AbstractMatrix{T}(A)) | ||
LowRankMatrix(A::AbstractMatrix{T}) where T = LowRankMatrix{T}(A) | ||
|
||
# Moves Σ into U and V | ||
function refactorsvd!(U::AbstractMatrix{S}, Σ::AbstractVector{T}, V::AbstractMatrix{S}) where {S,T} | ||
Base.require_one_based_indexing(U, Σ, V) | ||
conj!(V) | ||
σmax = Σ[1] | ||
r = count(s->s>10σmax*eps(T),Σ) | ||
m,n = size(U,1),size(V,1) | ||
for k=1:r | ||
σk = sqrt(Σ[k]) | ||
for i=1:m | ||
@inbounds U[i,k] *= σk | ||
end | ||
for j=1:n | ||
@inbounds V[j,k] *= σk | ||
end | ||
end | ||
r | ||
end | ||
|
||
for MAT in (:LowRankMatrix, :AbstractMatrix, :AbstractArray) | ||
@eval convert(::Type{$MAT{T}}, L::LowRankMatrix) where {T} = | ||
LowRankMatrix(convert(Matrix{T}, L.U), convert(Matrix{T}, L.V)) | ||
end | ||
convert(::Type{Matrix{T}}, L::LowRankMatrix) where {T} = convert(Matrix{T}, Matrix(L)) | ||
promote_rule(::Type{LowRankMatrix{T}}, ::Type{LowRankMatrix{V}}) where {T,V} = LowRankMatrix{promote_type(T,V)} | ||
promote_rule(::Type{LowRankMatrix{T}}, ::Type{Matrix{V}}) where {T,V} = Matrix{promote_type(T,V)} | ||
|
||
size(L::LowRankMatrix) = size(L.U,1),size(L.V,1) | ||
rank(L::LowRankMatrix) = size(L.U,2) | ||
transpose(L::LowRankMatrix) = LowRankMatrix(L.V,L.U) # TODO: change for 0.7 | ||
adjoint(L::LowRankMatrix{T}) where {T<:Real} = LowRankMatrix(L.V,L.U) | ||
adjoint(L::LowRankMatrix) = LowRankMatrix(conj(L.V),conj(L.U)) | ||
fill!(L::LowRankMatrix{T}, x::T) where {T} = (fill!(L.U, sqrt(abs(x)/rank(L))); fill!(L.V,sqrt(abs(x)/rank(L))/sign(x)); L) | ||
|
||
function unsafe_getindex(L::LowRankMatrix, i::Int, j::Int) | ||
ret = zero(eltype(L)) | ||
@inbounds for k=1:rank(L) | ||
ret = muladd(L.U[i,k],L.V[j,k],ret) | ||
end | ||
return ret | ||
end | ||
|
||
function getindex(L::LowRankMatrix, i::Int, j::Int) | ||
m,n = size(L) | ||
if 1 ≤ i ≤ m && 1 ≤ j ≤ n | ||
unsafe_getindex(L,i,j) | ||
else | ||
throw(BoundsError()) | ||
end | ||
end | ||
getindex(L::LowRankMatrix, i::Int, jr::AbstractRange) = transpose(eltype(L)[L[i,j] for j=jr]) | ||
getindex(L::LowRankMatrix, ir::AbstractRange, j::Int) = eltype(L)[L[i,j] for i=ir] | ||
getindex(L::LowRankMatrix, ir::AbstractRange, jr::AbstractRange) = eltype(L)[L[i,j] for i=ir,j=jr] | ||
Matrix(L::LowRankMatrix) = L[1:size(L,1),1:size(L,2)] | ||
|
||
# constructors | ||
|
||
copy(L::LowRankMatrix) = LowRankMatrix(copy(L.U),copy(L.V)) | ||
copyto!(L::LowRankMatrix, N::LowRankMatrix) = (copyto!(L.U,N.U); copyto!(L.V,N.V);L) | ||
|
||
|
||
# algebra | ||
|
||
for op in (:+,:-) | ||
@eval begin | ||
$op(L::LowRankMatrix) = LowRankMatrix($op(L.U),L.V) | ||
|
||
$op(a::Bool, L::LowRankMatrix{Bool}) = error("Not callable") | ||
$op(L::LowRankMatrix{Bool}, a::Bool) = error("Not callable") | ||
$op(a::Number,L::LowRankMatrix) = $op(LowRankMatrix(Fill(a,size(L))), L) | ||
$op(L::LowRankMatrix,a::Number) = $op(L, LowRankMatrix(Fill(a,size(L)))) | ||
|
||
function $op(L::LowRankMatrix, M::LowRankMatrix) | ||
size(L) == size(M) || throw(DimensionMismatch("A has dimensions $(size(L)) but B has dimensions $(size(M))")) | ||
LowRankMatrix(hcat(L.U,$op(M.U)), hcat(L.V,M.V)) | ||
end | ||
$op(L::LowRankMatrix,A::Matrix) = $op(promote(L,A)...) | ||
$op(A::Matrix,L::LowRankMatrix) = $op(promote(A,L)...) | ||
end | ||
end | ||
|
||
*(a::Number, L::LowRankMatrix) = LowRankMatrix(a*L.U,L.V) | ||
*(L::LowRankMatrix, a::Number) = LowRankMatrix(L.U,L.V*a) | ||
|
||
# override default: | ||
|
||
*(A::LowRankMatrix, B::Adjoint{T,LowRankMatrix{T}}) where T = A*adjoint(B) | ||
|
||
function mul!(b::AbstractVector, L::LowRankMatrix, x::AbstractVector) | ||
temp = zeros(promote_type(eltype(L),eltype(x)), rank(L)) | ||
mul!(temp, transpose(L.V), x) | ||
mul!(b, L.U, temp) | ||
b | ||
end | ||
function *(L::LowRankMatrix, M::LowRankMatrix) | ||
T = promote_type(eltype(L),eltype(M)) | ||
temp = zeros(T,rank(L),rank(M)) | ||
mul!(temp, transpose(L.V), M.U) | ||
V = zeros(T,size(M,2),rank(L)) | ||
mul!(V, M.V, transpose(temp)) | ||
LowRankMatrix(copy(L.U),V) | ||
end | ||
|
||
|
||
|
||
function *(L::LowRankMatrix, A::Matrix) | ||
V = zeros(promote_type(eltype(L),eltype(A)),size(A,2),rank(L)) | ||
mul!(V, transpose(A), L.V) | ||
LowRankMatrix(copy(L.U),V) | ||
end | ||
|
||
|
||
|
||
function *(A::Matrix, L::LowRankMatrix) | ||
U = zeros(promote_type(eltype(A),eltype(L)),size(A,1),rank(L)) | ||
mul!(U,A,L.U) | ||
LowRankMatrix(U,copy(L.V)) | ||
end | ||
|
||
\(L::LowRankMatrix, b::AbstractVecOrMat) = transpose(L.V) \ (L.U \ b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,87 @@ | ||
using LowRankMatrices | ||
using Test | ||
using LinearAlgebra | ||
using FillArrays | ||
|
||
@testset "Constructors" begin | ||
@test Matrix(LowRankMatrix(Zeros(10,5))) == zeros(10,5) | ||
|
||
@test LowRankMatrix{Float64}(Zeros(10,5)) == LowRankMatrix(Zeros(10,5)) == | ||
LowRankMatrix{Float64}(Zeros(10,5),1) == LowRankMatrix{Float64}(Zeros{Int}(10,5),1) == | ||
LowRankMatrix{Float64}(zeros(10,5)) == LowRankMatrix(zeros(10,5)) == | ||
LowRankMatrix{Float64}(zeros(Int,10,5)) | ||
|
||
|
||
@test isempty(LowRankMatrix{Float64}(zeros(10,5)).U) | ||
@test isempty(LowRankMatrix{Float64}(zeros(10,5)).V) | ||
@test isempty(LowRankMatrix(Zeros(10,5)).U) | ||
@test isempty(LowRankMatrix(Zeros(10,5)).V) | ||
|
||
@test rank(LowRankMatrix(Zeros(10,5))) == 0 | ||
|
||
|
||
@test Matrix(LowRankMatrix(Ones(10,5))) == fill(1.0,10,5) | ||
@test LowRankMatrix{Float64}(Ones(10,5)) == LowRankMatrix(Ones(10,5)) == | ||
LowRankMatrix{Float64}(Ones{Int}(10,5)) | ||
@test LowRankMatrix{Float64}(fill(1.0,10,5)) == LowRankMatrix(fill(1.0,10,5)) == | ||
LowRankMatrix{Float64}(fill(1,10,5)) | ||
@test rank(LowRankMatrix(Ones(10,5))) == 1 | ||
|
||
@test LowRankMatrix(Ones(10,5)) ≈ LowRankMatrix(fill(1.0,10,5)) | ||
|
||
|
||
@test Matrix(LowRankMatrix(Ones(10,5))) == fill(1.0,10,5) | ||
@test LowRankMatrix{Float64}(Ones(10,5)) == LowRankMatrix(Ones(10,5)) | ||
@test rank(LowRankMatrix(fill(1.0,10,5))) == 1 | ||
|
||
x = 2 | ||
@test Matrix(LowRankMatrix(Fill(x,10,5))) ≈ fill(x,10,5) | ||
@test LowRankMatrix{Float64}(Fill(x,10,5)) == LowRankMatrix(Fill(x,10,5)) == | ||
LowRankMatrix{Float64}(Fill{Float64}(x,10,5)) | ||
@test LowRankMatrix{Float64}(fill(x,10,5)) == LowRankMatrix(fill(x,10,5)) == | ||
LowRankMatrix{Float64}(fill(x,10,5)) | ||
@test rank(LowRankMatrix(Fill(x,10,5))) == 1 | ||
|
||
|
||
@testset "LowRankMatrices.jl" begin | ||
# Write your tests here. | ||
end | ||
|
||
|
||
@testset "LowRankMatrix algebra" begin | ||
A = LowRankMatrices._LowRankMatrix(randn(20,4), randn(12,4)) | ||
@test Matrix(2*A) ≈ Matrix(A*2) ≈ 2*Matrix(A) | ||
|
||
B = LowRankMatrices._LowRankMatrix(randn(20,2), randn(12,2)) | ||
@test Matrix(A+B) ≈ Matrix(A) + Matrix(B) ≈ Matrix(Matrix(A) + B) ≈ | ||
Matrix(A + Matrix(B)) | ||
@test rank(A+B) ≤ rank(A) + rank(B) | ||
|
||
@test Matrix(A-B) ≈ Matrix(A) - Matrix(B) ≈ Matrix(Matrix(A) - B) ≈ | ||
Matrix(A - Matrix(B)) | ||
@test rank(A-B) ≤ rank(A) + rank(B) | ||
|
||
B = LowRankMatrices._LowRankMatrix(randn(12,2), randn(14,2)) | ||
|
||
@test A*B isa LowRankMatrix | ||
@test rank(A*B) == size((A*B).U,2) == 4 | ||
|
||
@test Matrix(A)*Matrix(B) ≈ Matrix(A*Matrix(B)) ≈ Matrix(Matrix(A)*B) ≈ Matrix(A*B) | ||
|
||
B = LowRankMatrices._LowRankMatrix(randn(10,2), randn(14,2)) | ||
@test_throws DimensionMismatch A*B | ||
@test_throws DimensionMismatch B*A | ||
|
||
B = randn(12,14) | ||
@test A*B isa LowRankMatrix | ||
@test rank(A*B) == size((A*B).U,2) == 4 | ||
@test Matrix(A)*B ≈ Matrix(A*B) | ||
|
||
B = randn(20,20) | ||
@test B*A isa LowRankMatrix | ||
@test rank(B*A) == size((B*A).U,2) == 4 | ||
@test B*Matrix(A) ≈ Matrix(B*A) | ||
|
||
v = randn(12) | ||
@test all(mul!(randn(size(A,1)), A, v) .=== A*v ) | ||
@test A*v ≈ Matrix(A)*v | ||
end | ||
|