From 640333d60eef404e025d74d5ca774c83360f3325 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 27 Jul 2023 18:29:03 +0530 Subject: [PATCH] copy over implementation from LowRankApprox --- Project.toml | 13 ++- ext/LowRankMatricesFillArraysExt.jl | 21 ++++ src/LowRankMatrices.jl | 15 ++- src/lowrankmatrix.jl | 162 ++++++++++++++++++++++++++++ test/runtests.jl | 85 ++++++++++++++- 5 files changed, 292 insertions(+), 4 deletions(-) create mode 100644 ext/LowRankMatricesFillArraysExt.jl create mode 100644 src/lowrankmatrix.jl diff --git a/Project.toml b/Project.toml index bb7001c..e882aea 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,22 @@ uuid = "e65ccdef-c354-471a-8090-89bec1c20ec3" authors = ["Jishnu Bhattacharya and contributors"] version = "1.0.0-DEV" +[deps] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[weakdeps] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + +[extensions] +LowRankMatricesFillArraysExt = "FillArrays" + [compat] julia = "1" [extras] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test", "FillArrays"] diff --git a/ext/LowRankMatricesFillArraysExt.jl b/ext/LowRankMatricesFillArraysExt.jl new file mode 100644 index 0000000..6ed64ed --- /dev/null +++ b/ext/LowRankMatricesFillArraysExt.jl @@ -0,0 +1,21 @@ +module LowRankMatricesFillArraysExt + +using LowRankMatrices +using FillArrays +using FillArrays: AbstractFill + +LowRankMatrix{T}(Z::Zeros, r::Int=0) where {T<:Number} = + LowRankMatrix(zeros(T,size(Z,1),r), zeros(T,size(Z,2),r)) +LowRankMatrix{T}(Z::Zeros, r::Int=0) where {T} = + LowRankMatrix(zeros(T,size(Z,1),r), zeros(T,size(Z,2),r)) + +LowRankMatrix(Z::Zeros, r::Int=0) = LowRankMatrix{eltype(Z)}(Z, r) +function LowRankMatrix{T}(F::AbstractFill) where T + v = T(FillArrays.getindex_value(F)) + m,n = size(F) + LowRankMatrix(fill(v,m,1), fill(one(T),n,1)) +end +LowRankMatrix(F::AbstractFill{T}) where T = LowRankMatrix{T}(F) + + +end diff --git a/src/LowRankMatrices.jl b/src/LowRankMatrices.jl index 252abad..c61add8 100644 --- a/src/LowRankMatrices.jl +++ b/src/LowRankMatrices.jl @@ -1,5 +1,18 @@ module LowRankMatrices -# Write your package code here. +import Base: similar, convert, promote_rule, size, fill!, getindex, + *, +, -, \, /, + Matrix, copy, copyto! + +using LinearAlgebra +import LinearAlgebra: rank, transpose, adjoint, mul! + +export LowRankMatrix + +include("lowrankmatrix.jl") + +if !isdefined(Base, :get_extension) + include("../ext/LowRankMatricesFillArraysExt.jl") +end end diff --git a/src/lowrankmatrix.jl b/src/lowrankmatrix.jl new file mode 100644 index 0000000..9bb75e8 --- /dev/null +++ b/src/lowrankmatrix.jl @@ -0,0 +1,162 @@ +## +# Represent an m x n rank-r matrix +# A = U*Vᵗ +## +function _LowRankMatrix end + +mutable struct LowRankMatrix{T} <: AbstractMatrix{T} + U::Matrix{T} # m x r Matrix + V::Matrix{T} # n x r Matrix + + global function _LowRankMatrix(U::AbstractMatrix{T}, V::AbstractMatrix{T}) where T + m,r = size(U) + n,rv = size(V) + if r ≠ rv throw(ArgumentError("U and V must have same number of columns")) end + new{T}(Matrix{T}(U), Matrix{T}(V)) + end +end + +LowRankMatrix(U::AbstractMatrix, V::AbstractMatrix) = _LowRankMatrix(promote(U,V)...) +LowRankMatrix(U::AbstractVector, V::AbstractMatrix) = LowRankMatrix(reshape(U,length(U),1),V) +LowRankMatrix(U::AbstractMatrix, V::AbstractVector) = LowRankMatrix(U,reshape(V,length(V),1)) +LowRankMatrix(U::AbstractVector, V::AbstractVector) = + _LowRankMatrix(reshape(U,length(U),1), reshape(V,length(V),1)) + +LowRankMatrix{T}(::UndefInitializer, mn::NTuple{2,Int}, r::Int) where {T} = + LowRankMatrix(Matrix{T}(undef,mn[1],r),Matrix{T}(undef,mn[2],r)) + +similar(L::LowRankMatrix, ::Type{T}, dims::Dims{2}) where {T} = LowRankMatrix{T}(undef, dims, rank(L)) +similar(L::LowRankMatrix{T}) where {T} = LowRankMatrix{T}(undef, size(L), rank(L)) +similar(L::LowRankMatrix{T}, dims::Dims{2}) where {T} = LowRankMatrix(undef, dims, rank(L)) +similar(L::LowRankMatrix{T}, m::Int) where {T} = Vector{T}(undef, m) +similar(L::LowRankMatrix{T}, ::Type{S}) where {S,T} = LowRankMatrix{S}(undef, size(L), rank(L)) + +function LowRankMatrix{T}(A::AbstractMatrix{T}) where T + U,Σ,V = svd(A) + r = refactorsvd!(U,Σ,V) + LowRankMatrix(U[:,1:r], V[:,1:r]) +end + +LowRankMatrix{T}(A::AbstractMatrix) where T = LowRankMatrix{T}(AbstractMatrix{T}(A)) +LowRankMatrix(A::AbstractMatrix{T}) where T = LowRankMatrix{T}(A) + +# Moves Σ into U and V +function refactorsvd!(U::AbstractMatrix{S}, Σ::AbstractVector{T}, V::AbstractMatrix{S}) where {S,T} + Base.require_one_based_indexing(U, Σ, V) + conj!(V) + σmax = Σ[1] + r = count(s->s>10σmax*eps(T),Σ) + m,n = size(U,1),size(V,1) + for k=1:r + σk = sqrt(Σ[k]) + for i=1:m + @inbounds U[i,k] *= σk + end + for j=1:n + @inbounds V[j,k] *= σk + end + end + r +end + +for MAT in (:LowRankMatrix, :AbstractMatrix, :AbstractArray) + @eval convert(::Type{$MAT{T}}, L::LowRankMatrix) where {T} = + LowRankMatrix(convert(Matrix{T}, L.U), convert(Matrix{T}, L.V)) +end +convert(::Type{Matrix{T}}, L::LowRankMatrix) where {T} = convert(Matrix{T}, Matrix(L)) +promote_rule(::Type{LowRankMatrix{T}}, ::Type{LowRankMatrix{V}}) where {T,V} = LowRankMatrix{promote_type(T,V)} +promote_rule(::Type{LowRankMatrix{T}}, ::Type{Matrix{V}}) where {T,V} = Matrix{promote_type(T,V)} + +size(L::LowRankMatrix) = size(L.U,1),size(L.V,1) +rank(L::LowRankMatrix) = size(L.U,2) +transpose(L::LowRankMatrix) = LowRankMatrix(L.V,L.U) # TODO: change for 0.7 +adjoint(L::LowRankMatrix{T}) where {T<:Real} = LowRankMatrix(L.V,L.U) +adjoint(L::LowRankMatrix) = LowRankMatrix(conj(L.V),conj(L.U)) +fill!(L::LowRankMatrix{T}, x::T) where {T} = (fill!(L.U, sqrt(abs(x)/rank(L))); fill!(L.V,sqrt(abs(x)/rank(L))/sign(x)); L) + +function unsafe_getindex(L::LowRankMatrix, i::Int, j::Int) + ret = zero(eltype(L)) + @inbounds for k=1:rank(L) + ret = muladd(L.U[i,k],L.V[j,k],ret) + end + return ret +end + +function getindex(L::LowRankMatrix, i::Int, j::Int) + m,n = size(L) + if 1 ≤ i ≤ m && 1 ≤ j ≤ n + unsafe_getindex(L,i,j) + else + throw(BoundsError()) + end +end +getindex(L::LowRankMatrix, i::Int, jr::AbstractRange) = transpose(eltype(L)[L[i,j] for j=jr]) +getindex(L::LowRankMatrix, ir::AbstractRange, j::Int) = eltype(L)[L[i,j] for i=ir] +getindex(L::LowRankMatrix, ir::AbstractRange, jr::AbstractRange) = eltype(L)[L[i,j] for i=ir,j=jr] +Matrix(L::LowRankMatrix) = L[1:size(L,1),1:size(L,2)] + +# constructors + +copy(L::LowRankMatrix) = LowRankMatrix(copy(L.U),copy(L.V)) +copyto!(L::LowRankMatrix, N::LowRankMatrix) = (copyto!(L.U,N.U); copyto!(L.V,N.V);L) + + +# algebra + +for op in (:+,:-) + @eval begin + $op(L::LowRankMatrix) = LowRankMatrix($op(L.U),L.V) + + $op(a::Bool, L::LowRankMatrix{Bool}) = error("Not callable") + $op(L::LowRankMatrix{Bool}, a::Bool) = error("Not callable") + $op(a::Number,L::LowRankMatrix) = $op(LowRankMatrix(Fill(a,size(L))), L) + $op(L::LowRankMatrix,a::Number) = $op(L, LowRankMatrix(Fill(a,size(L)))) + + function $op(L::LowRankMatrix, M::LowRankMatrix) + size(L) == size(M) || throw(DimensionMismatch("A has dimensions $(size(L)) but B has dimensions $(size(M))")) + LowRankMatrix(hcat(L.U,$op(M.U)), hcat(L.V,M.V)) + end + $op(L::LowRankMatrix,A::Matrix) = $op(promote(L,A)...) + $op(A::Matrix,L::LowRankMatrix) = $op(promote(A,L)...) + end +end + +*(a::Number, L::LowRankMatrix) = LowRankMatrix(a*L.U,L.V) +*(L::LowRankMatrix, a::Number) = LowRankMatrix(L.U,L.V*a) + +# override default: + +*(A::LowRankMatrix, B::Adjoint{T,LowRankMatrix{T}}) where T = A*adjoint(B) + +function mul!(b::AbstractVector, L::LowRankMatrix, x::AbstractVector) + temp = zeros(promote_type(eltype(L),eltype(x)), rank(L)) + mul!(temp, transpose(L.V), x) + mul!(b, L.U, temp) + b +end +function *(L::LowRankMatrix, M::LowRankMatrix) + T = promote_type(eltype(L),eltype(M)) + temp = zeros(T,rank(L),rank(M)) + mul!(temp, transpose(L.V), M.U) + V = zeros(T,size(M,2),rank(L)) + mul!(V, M.V, transpose(temp)) + LowRankMatrix(copy(L.U),V) +end + + + +function *(L::LowRankMatrix, A::Matrix) + V = zeros(promote_type(eltype(L),eltype(A)),size(A,2),rank(L)) + mul!(V, transpose(A), L.V) + LowRankMatrix(copy(L.U),V) +end + + + +function *(A::Matrix, L::LowRankMatrix) + U = zeros(promote_type(eltype(A),eltype(L)),size(A,1),rank(L)) + mul!(U,A,L.U) + LowRankMatrix(U,copy(L.V)) +end + +\(L::LowRankMatrix, b::AbstractVecOrMat) = transpose(L.V) \ (L.U \ b) diff --git a/test/runtests.jl b/test/runtests.jl index 708bd08..205258d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,87 @@ using LowRankMatrices using Test +using LinearAlgebra +using FillArrays + +@testset "Constructors" begin + @test Matrix(LowRankMatrix(Zeros(10,5))) == zeros(10,5) + + @test LowRankMatrix{Float64}(Zeros(10,5)) == LowRankMatrix(Zeros(10,5)) == + LowRankMatrix{Float64}(Zeros(10,5),1) == LowRankMatrix{Float64}(Zeros{Int}(10,5),1) == + LowRankMatrix{Float64}(zeros(10,5)) == LowRankMatrix(zeros(10,5)) == + LowRankMatrix{Float64}(zeros(Int,10,5)) + + + @test isempty(LowRankMatrix{Float64}(zeros(10,5)).U) + @test isempty(LowRankMatrix{Float64}(zeros(10,5)).V) + @test isempty(LowRankMatrix(Zeros(10,5)).U) + @test isempty(LowRankMatrix(Zeros(10,5)).V) + + @test rank(LowRankMatrix(Zeros(10,5))) == 0 + + + @test Matrix(LowRankMatrix(Ones(10,5))) == fill(1.0,10,5) + @test LowRankMatrix{Float64}(Ones(10,5)) == LowRankMatrix(Ones(10,5)) == + LowRankMatrix{Float64}(Ones{Int}(10,5)) + @test LowRankMatrix{Float64}(fill(1.0,10,5)) == LowRankMatrix(fill(1.0,10,5)) == + LowRankMatrix{Float64}(fill(1,10,5)) + @test rank(LowRankMatrix(Ones(10,5))) == 1 + + @test LowRankMatrix(Ones(10,5)) ≈ LowRankMatrix(fill(1.0,10,5)) + + + @test Matrix(LowRankMatrix(Ones(10,5))) == fill(1.0,10,5) + @test LowRankMatrix{Float64}(Ones(10,5)) == LowRankMatrix(Ones(10,5)) + @test rank(LowRankMatrix(fill(1.0,10,5))) == 1 + + x = 2 + @test Matrix(LowRankMatrix(Fill(x,10,5))) ≈ fill(x,10,5) + @test LowRankMatrix{Float64}(Fill(x,10,5)) == LowRankMatrix(Fill(x,10,5)) == + LowRankMatrix{Float64}(Fill{Float64}(x,10,5)) + @test LowRankMatrix{Float64}(fill(x,10,5)) == LowRankMatrix(fill(x,10,5)) == + LowRankMatrix{Float64}(fill(x,10,5)) + @test rank(LowRankMatrix(Fill(x,10,5))) == 1 + -@testset "LowRankMatrices.jl" begin - # Write your tests here. end + + +@testset "LowRankMatrix algebra" begin + A = LowRankMatrices._LowRankMatrix(randn(20,4), randn(12,4)) + @test Matrix(2*A) ≈ Matrix(A*2) ≈ 2*Matrix(A) + + B = LowRankMatrices._LowRankMatrix(randn(20,2), randn(12,2)) + @test Matrix(A+B) ≈ Matrix(A) + Matrix(B) ≈ Matrix(Matrix(A) + B) ≈ + Matrix(A + Matrix(B)) + @test rank(A+B) ≤ rank(A) + rank(B) + + @test Matrix(A-B) ≈ Matrix(A) - Matrix(B) ≈ Matrix(Matrix(A) - B) ≈ + Matrix(A - Matrix(B)) + @test rank(A-B) ≤ rank(A) + rank(B) + + B = LowRankMatrices._LowRankMatrix(randn(12,2), randn(14,2)) + + @test A*B isa LowRankMatrix + @test rank(A*B) == size((A*B).U,2) == 4 + + @test Matrix(A)*Matrix(B) ≈ Matrix(A*Matrix(B)) ≈ Matrix(Matrix(A)*B) ≈ Matrix(A*B) + + B = LowRankMatrices._LowRankMatrix(randn(10,2), randn(14,2)) + @test_throws DimensionMismatch A*B + @test_throws DimensionMismatch B*A + + B = randn(12,14) + @test A*B isa LowRankMatrix + @test rank(A*B) == size((A*B).U,2) == 4 + @test Matrix(A)*B ≈ Matrix(A*B) + + B = randn(20,20) + @test B*A isa LowRankMatrix + @test rank(B*A) == size((B*A).U,2) == 4 + @test B*Matrix(A) ≈ Matrix(B*A) + + v = randn(12) + @test all(mul!(randn(size(A,1)), A, v) .=== A*v ) + @test A*v ≈ Matrix(A)*v +end +