Skip to content

Commit

Permalink
copy over implementation from LowRankApprox
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Jul 27, 2023
1 parent 154c342 commit 640333d
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 4 deletions.
13 changes: 12 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,22 @@ uuid = "e65ccdef-c354-471a-8090-89bec1c20ec3"
authors = ["Jishnu Bhattacharya <[email protected]> and contributors"]
version = "1.0.0-DEV"

[deps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

[extensions]
LowRankMatricesFillArraysExt = "FillArrays"

[compat]
julia = "1"

[extras]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "FillArrays"]
21 changes: 21 additions & 0 deletions ext/LowRankMatricesFillArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module LowRankMatricesFillArraysExt

using LowRankMatrices
using FillArrays
using FillArrays: AbstractFill

LowRankMatrix{T}(Z::Zeros, r::Int=0) where {T<:Number} =
LowRankMatrix(zeros(T,size(Z,1),r), zeros(T,size(Z,2),r))
LowRankMatrix{T}(Z::Zeros, r::Int=0) where {T} =
LowRankMatrix(zeros(T,size(Z,1),r), zeros(T,size(Z,2),r))

LowRankMatrix(Z::Zeros, r::Int=0) = LowRankMatrix{eltype(Z)}(Z, r)
function LowRankMatrix{T}(F::AbstractFill) where T
v = T(FillArrays.getindex_value(F))
m,n = size(F)
LowRankMatrix(fill(v,m,1), fill(one(T),n,1))
end
LowRankMatrix(F::AbstractFill{T}) where T = LowRankMatrix{T}(F)


end
15 changes: 14 additions & 1 deletion src/LowRankMatrices.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
module LowRankMatrices

# Write your package code here.
import Base: similar, convert, promote_rule, size, fill!, getindex,
*, +, -, \, /,
Matrix, copy, copyto!

using LinearAlgebra
import LinearAlgebra: rank, transpose, adjoint, mul!

export LowRankMatrix

include("lowrankmatrix.jl")

if !isdefined(Base, :get_extension)
include("../ext/LowRankMatricesFillArraysExt.jl")
end

end
162 changes: 162 additions & 0 deletions src/lowrankmatrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
##
# Represent an m x n rank-r matrix
# A = U*Vᵗ
##
function _LowRankMatrix end

mutable struct LowRankMatrix{T} <: AbstractMatrix{T}
U::Matrix{T} # m x r Matrix
V::Matrix{T} # n x r Matrix

global function _LowRankMatrix(U::AbstractMatrix{T}, V::AbstractMatrix{T}) where T
m,r = size(U)
n,rv = size(V)
if r rv throw(ArgumentError("U and V must have same number of columns")) end
new{T}(Matrix{T}(U), Matrix{T}(V))
end
end

LowRankMatrix(U::AbstractMatrix, V::AbstractMatrix) = _LowRankMatrix(promote(U,V)...)
LowRankMatrix(U::AbstractVector, V::AbstractMatrix) = LowRankMatrix(reshape(U,length(U),1),V)
LowRankMatrix(U::AbstractMatrix, V::AbstractVector) = LowRankMatrix(U,reshape(V,length(V),1))
LowRankMatrix(U::AbstractVector, V::AbstractVector) =
_LowRankMatrix(reshape(U,length(U),1), reshape(V,length(V),1))

LowRankMatrix{T}(::UndefInitializer, mn::NTuple{2,Int}, r::Int) where {T} =
LowRankMatrix(Matrix{T}(undef,mn[1],r),Matrix{T}(undef,mn[2],r))

similar(L::LowRankMatrix, ::Type{T}, dims::Dims{2}) where {T} = LowRankMatrix{T}(undef, dims, rank(L))
similar(L::LowRankMatrix{T}) where {T} = LowRankMatrix{T}(undef, size(L), rank(L))
similar(L::LowRankMatrix{T}, dims::Dims{2}) where {T} = LowRankMatrix(undef, dims, rank(L))
similar(L::LowRankMatrix{T}, m::Int) where {T} = Vector{T}(undef, m)
similar(L::LowRankMatrix{T}, ::Type{S}) where {S,T} = LowRankMatrix{S}(undef, size(L), rank(L))

function LowRankMatrix{T}(A::AbstractMatrix{T}) where T
U,Σ,V = svd(A)
r = refactorsvd!(U,Σ,V)
LowRankMatrix(U[:,1:r], V[:,1:r])
end

LowRankMatrix{T}(A::AbstractMatrix) where T = LowRankMatrix{T}(AbstractMatrix{T}(A))
LowRankMatrix(A::AbstractMatrix{T}) where T = LowRankMatrix{T}(A)

# Moves Σ into U and V
function refactorsvd!(U::AbstractMatrix{S}, Σ::AbstractVector{T}, V::AbstractMatrix{S}) where {S,T}
Base.require_one_based_indexing(U, Σ, V)
conj!(V)
σmax = Σ[1]
r = count(s->s>10σmax*eps(T),Σ)
m,n = size(U,1),size(V,1)
for k=1:r
σk = sqrt(Σ[k])
for i=1:m
@inbounds U[i,k] *= σk
end
for j=1:n
@inbounds V[j,k] *= σk
end
end
r
end

for MAT in (:LowRankMatrix, :AbstractMatrix, :AbstractArray)
@eval convert(::Type{$MAT{T}}, L::LowRankMatrix) where {T} =
LowRankMatrix(convert(Matrix{T}, L.U), convert(Matrix{T}, L.V))
end
convert(::Type{Matrix{T}}, L::LowRankMatrix) where {T} = convert(Matrix{T}, Matrix(L))
promote_rule(::Type{LowRankMatrix{T}}, ::Type{LowRankMatrix{V}}) where {T,V} = LowRankMatrix{promote_type(T,V)}
promote_rule(::Type{LowRankMatrix{T}}, ::Type{Matrix{V}}) where {T,V} = Matrix{promote_type(T,V)}

size(L::LowRankMatrix) = size(L.U,1),size(L.V,1)
rank(L::LowRankMatrix) = size(L.U,2)
transpose(L::LowRankMatrix) = LowRankMatrix(L.V,L.U) # TODO: change for 0.7
adjoint(L::LowRankMatrix{T}) where {T<:Real} = LowRankMatrix(L.V,L.U)
adjoint(L::LowRankMatrix) = LowRankMatrix(conj(L.V),conj(L.U))
fill!(L::LowRankMatrix{T}, x::T) where {T} = (fill!(L.U, sqrt(abs(x)/rank(L))); fill!(L.V,sqrt(abs(x)/rank(L))/sign(x)); L)

function unsafe_getindex(L::LowRankMatrix, i::Int, j::Int)
ret = zero(eltype(L))
@inbounds for k=1:rank(L)
ret = muladd(L.U[i,k],L.V[j,k],ret)
end
return ret
end

function getindex(L::LowRankMatrix, i::Int, j::Int)
m,n = size(L)
if 1 i m && 1 j n
unsafe_getindex(L,i,j)
else
throw(BoundsError())
end
end
getindex(L::LowRankMatrix, i::Int, jr::AbstractRange) = transpose(eltype(L)[L[i,j] for j=jr])
getindex(L::LowRankMatrix, ir::AbstractRange, j::Int) = eltype(L)[L[i,j] for i=ir]
getindex(L::LowRankMatrix, ir::AbstractRange, jr::AbstractRange) = eltype(L)[L[i,j] for i=ir,j=jr]
Matrix(L::LowRankMatrix) = L[1:size(L,1),1:size(L,2)]

# constructors

copy(L::LowRankMatrix) = LowRankMatrix(copy(L.U),copy(L.V))
copyto!(L::LowRankMatrix, N::LowRankMatrix) = (copyto!(L.U,N.U); copyto!(L.V,N.V);L)


# algebra

for op in (:+,:-)
@eval begin
$op(L::LowRankMatrix) = LowRankMatrix($op(L.U),L.V)

$op(a::Bool, L::LowRankMatrix{Bool}) = error("Not callable")
$op(L::LowRankMatrix{Bool}, a::Bool) = error("Not callable")
$op(a::Number,L::LowRankMatrix) = $op(LowRankMatrix(Fill(a,size(L))), L)
$op(L::LowRankMatrix,a::Number) = $op(L, LowRankMatrix(Fill(a,size(L))))

function $op(L::LowRankMatrix, M::LowRankMatrix)
size(L) == size(M) || throw(DimensionMismatch("A has dimensions $(size(L)) but B has dimensions $(size(M))"))
LowRankMatrix(hcat(L.U,$op(M.U)), hcat(L.V,M.V))
end
$op(L::LowRankMatrix,A::Matrix) = $op(promote(L,A)...)
$op(A::Matrix,L::LowRankMatrix) = $op(promote(A,L)...)
end
end

*(a::Number, L::LowRankMatrix) = LowRankMatrix(a*L.U,L.V)
*(L::LowRankMatrix, a::Number) = LowRankMatrix(L.U,L.V*a)

# override default:

*(A::LowRankMatrix, B::Adjoint{T,LowRankMatrix{T}}) where T = A*adjoint(B)

function mul!(b::AbstractVector, L::LowRankMatrix, x::AbstractVector)
temp = zeros(promote_type(eltype(L),eltype(x)), rank(L))
mul!(temp, transpose(L.V), x)
mul!(b, L.U, temp)
b
end
function *(L::LowRankMatrix, M::LowRankMatrix)
T = promote_type(eltype(L),eltype(M))
temp = zeros(T,rank(L),rank(M))
mul!(temp, transpose(L.V), M.U)
V = zeros(T,size(M,2),rank(L))
mul!(V, M.V, transpose(temp))
LowRankMatrix(copy(L.U),V)
end



function *(L::LowRankMatrix, A::Matrix)
V = zeros(promote_type(eltype(L),eltype(A)),size(A,2),rank(L))
mul!(V, transpose(A), L.V)
LowRankMatrix(copy(L.U),V)
end



function *(A::Matrix, L::LowRankMatrix)
U = zeros(promote_type(eltype(A),eltype(L)),size(A,1),rank(L))
mul!(U,A,L.U)
LowRankMatrix(U,copy(L.V))
end

\(L::LowRankMatrix, b::AbstractVecOrMat) = transpose(L.V) \ (L.U \ b)
85 changes: 83 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,87 @@
using LowRankMatrices
using Test
using LinearAlgebra
using FillArrays

@testset "Constructors" begin
@test Matrix(LowRankMatrix(Zeros(10,5))) == zeros(10,5)

@test LowRankMatrix{Float64}(Zeros(10,5)) == LowRankMatrix(Zeros(10,5)) ==
LowRankMatrix{Float64}(Zeros(10,5),1) == LowRankMatrix{Float64}(Zeros{Int}(10,5),1) ==
LowRankMatrix{Float64}(zeros(10,5)) == LowRankMatrix(zeros(10,5)) ==
LowRankMatrix{Float64}(zeros(Int,10,5))


@test isempty(LowRankMatrix{Float64}(zeros(10,5)).U)
@test isempty(LowRankMatrix{Float64}(zeros(10,5)).V)
@test isempty(LowRankMatrix(Zeros(10,5)).U)
@test isempty(LowRankMatrix(Zeros(10,5)).V)

@test rank(LowRankMatrix(Zeros(10,5))) == 0


@test Matrix(LowRankMatrix(Ones(10,5))) == fill(1.0,10,5)
@test LowRankMatrix{Float64}(Ones(10,5)) == LowRankMatrix(Ones(10,5)) ==
LowRankMatrix{Float64}(Ones{Int}(10,5))
@test LowRankMatrix{Float64}(fill(1.0,10,5)) == LowRankMatrix(fill(1.0,10,5)) ==
LowRankMatrix{Float64}(fill(1,10,5))
@test rank(LowRankMatrix(Ones(10,5))) == 1

@test LowRankMatrix(Ones(10,5)) LowRankMatrix(fill(1.0,10,5))


@test Matrix(LowRankMatrix(Ones(10,5))) == fill(1.0,10,5)
@test LowRankMatrix{Float64}(Ones(10,5)) == LowRankMatrix(Ones(10,5))
@test rank(LowRankMatrix(fill(1.0,10,5))) == 1

x = 2
@test Matrix(LowRankMatrix(Fill(x,10,5))) fill(x,10,5)
@test LowRankMatrix{Float64}(Fill(x,10,5)) == LowRankMatrix(Fill(x,10,5)) ==
LowRankMatrix{Float64}(Fill{Float64}(x,10,5))
@test LowRankMatrix{Float64}(fill(x,10,5)) == LowRankMatrix(fill(x,10,5)) ==
LowRankMatrix{Float64}(fill(x,10,5))
@test rank(LowRankMatrix(Fill(x,10,5))) == 1


@testset "LowRankMatrices.jl" begin
# Write your tests here.
end


@testset "LowRankMatrix algebra" begin
A = LowRankMatrices._LowRankMatrix(randn(20,4), randn(12,4))
@test Matrix(2*A) Matrix(A*2) 2*Matrix(A)

B = LowRankMatrices._LowRankMatrix(randn(20,2), randn(12,2))
@test Matrix(A+B) Matrix(A) + Matrix(B) Matrix(Matrix(A) + B)
Matrix(A + Matrix(B))
@test rank(A+B) rank(A) + rank(B)

@test Matrix(A-B) Matrix(A) - Matrix(B) Matrix(Matrix(A) - B)
Matrix(A - Matrix(B))
@test rank(A-B) rank(A) + rank(B)

B = LowRankMatrices._LowRankMatrix(randn(12,2), randn(14,2))

@test A*B isa LowRankMatrix
@test rank(A*B) == size((A*B).U,2) == 4

@test Matrix(A)*Matrix(B) Matrix(A*Matrix(B)) Matrix(Matrix(A)*B) Matrix(A*B)

B = LowRankMatrices._LowRankMatrix(randn(10,2), randn(14,2))
@test_throws DimensionMismatch A*B
@test_throws DimensionMismatch B*A

B = randn(12,14)
@test A*B isa LowRankMatrix
@test rank(A*B) == size((A*B).U,2) == 4
@test Matrix(A)*B Matrix(A*B)

B = randn(20,20)
@test B*A isa LowRankMatrix
@test rank(B*A) == size((B*A).U,2) == 4
@test B*Matrix(A) Matrix(B*A)

v = randn(12)
@test all(mul!(randn(size(A,1)), A, v) .=== A*v )
@test A*v Matrix(A)*v
end

0 comments on commit 640333d

Please sign in to comment.