From ddc8994a15f1e92118ac95a6c0a30b6ed918a6d4 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Fri, 2 Jun 2023 14:54:30 +0200
Subject: [PATCH 1/4] Add LazyKernelMatrix and lazykernelmatrix

---
 src/KernelFunctions.jl         |   2 +
 src/matrix/lazykernelmatrix.jl | 109 +++++++++++++++++++++++++++++++++
 2 files changed, 111 insertions(+)
 create mode 100644 src/matrix/lazykernelmatrix.jl

diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl
index 63205b5bf..56725313a 100644
--- a/src/KernelFunctions.jl
+++ b/src/KernelFunctions.jl
@@ -1,6 +1,7 @@
 module KernelFunctions
 
 export kernelmatrix, kernelmatrix!, kernelmatrix_diag, kernelmatrix_diag!
+export LazyKernelMatrix, lazykernelmatrix
 export duplicate, set! # Helpers
 
 export Kernel, MOKernel
@@ -106,6 +107,7 @@ include("kernels/gibbskernel.jl")
 include("kernels/scaledkernel.jl")
 include("kernels/normalizedkernel.jl")
 include("matrix/kernelmatrix.jl")
+include("matrix/lazykernelmatrix.jl")
 include("kernels/kernelsum.jl")
 include("kernels/kernelproduct.jl")
 include("kernels/kerneltensorproduct.jl")
diff --git a/src/matrix/lazykernelmatrix.jl b/src/matrix/lazykernelmatrix.jl
new file mode 100644
index 000000000..f20c96aaf
--- /dev/null
+++ b/src/matrix/lazykernelmatrix.jl
@@ -0,0 +1,109 @@
+"""
+    lazykernelmatrix(κ::Kernel, x::AbstractVector) -> AbstractMatrix
+
+Construct a lazy representation of the kernel `κ` for each pair of inputs in `x`.
+
+The result is a matrix with the same entries as [`kernelmatrix(κ, x)`](@ref) but where the
+entries are not computed until they are needed.
+"""
+lazykernelmatrix(κ::Kernel, x) = lazykernelmatrix(κ, x, x)
+
+"""
+    lazykernelmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector) -> AbstractMatrix
+
+Construct a lazy representation of the kernel `κ` for each pair of inputs in `x`.
+
+The result is a matrix with the same entries as [`kernelmatrix(κ, x, y)`](@ref) but where
+the entries are not computed until they are needed.
+"""
+lazykernelmatrix(κ::Kernel, x, y) = LazyKernelMatrix(κ, x, y)
+
+"""
+    LazyKernelMatrix(κ::Kernel, x[, y])
+    LazyKernelMatrix{T<:Real}(κ::Kernel, x, y)
+
+Construct a lazy representation of the kernel `κ` for each pair of inputs in `x` and `y`.
+
+Instead of constructing this directly, it is better to call
+[`lazykernelmatrix(κ, x[, y])`](@ref lazykernelmatrix).
+"""
+struct LazyKernelMatrix{T<:Real,Tk<:Kernel,Tx<:AbstractVector,Ty<:AbstractVector} <:
+       AbstractMatrix{T}
+    kernel::Tk
+    x::Tx
+    y::Ty
+    function LazyKernelMatrix{T}(κ::Tk, x::Tx, y::Ty) where {T<:Real,Tk<:Kernel,Tx,Ty}
+        Base.require_one_based_indexing(x)
+        Base.require_one_based_indexing(y)
+        return new{T,Tk,Tx,Ty}(κ, x, y)
+    end
+    function LazyKernelMatrix{T}(κ::Tk, x::Tx) where {T<:Real,Tk<:Kernel,Tx}
+        Base.require_one_based_indexing(x)
+        return new{T,Tk,Tx,Tx}(κ, x, x)
+    end
+end
+function LazyKernelMatrix(κ::Kernel, x::AbstractVector, y::AbstractVector)
+    # evaluate once to get eltype
+    T = typeof(κ(first(x), first(y)))
+    return LazyKernelMatrix{T}(κ, x, y)
+end
+LazyKernelMatrix(κ::Kernel, x::AbstractVector) = LazyKernelMatrix(κ, x, x)
+
+Base.Matrix(K::LazyKernelMatrix) = kernelmatrix(K.kernel, K.x, K.y)
+function Base.AbstractMatrix{T}(K::LazyKernelMatrix) where {T}
+    return LazyKernelMatrix{T}(K.kernel, K.x, K.y)
+end
+
+Base.size(K::LazyKernelMatrix) = (length(K.x), length(K.y))
+
+Base.axes(K::LazyKernelMatrix) = (axes(K.x, 1), axes(K.y, 1))
+
+function Base.getindex(K::LazyKernelMatrix{T}, i::Int, j::Int) where {T}
+    return T(K.kernel(K.x[i], K.y[j]))
+end
+for f in (:getindex, :view)
+    @eval begin
+        function Base.$f(
+            K::LazyKernelMatrix{T},
+            I::Union{Colon,AbstractVector},
+            J::Union{Colon,AbstractVector},
+        ) where {T}
+            return LazyKernelMatrix{T}(K.kernel, $f(K.x, I), $f(K.y, J))
+        end
+    end
+end
+
+Base.zero(K::LazyKernelMatrix{T}) where {T} = LazyKernelMatrix{T}(ZeroKernel(), K.x, K.y)
+Base.one(K::LazyKernelMatrix{T}) where {T} = LazyKernelMatrix{T}(WhiteKernel(), K.x, K.y)
+
+function Base.:*(c::S, K::LazyKernelMatrix{T}) where {T,S<:Real}
+    R = typeof(oneunit(S) * oneunit(T))
+    return LazyKernelMatrix{R}(c * K.kernel, K.x, K.y)
+end
+Base.:*(K::LazyKernelMatrix, c::Real) = c * K
+Base.:/(K::LazyKernelMatrix, c::Real) = K * inv(c)
+Base.:\(c::Real, K::LazyKernelMatrix) = inv(c) * K
+
+function Base.:+(K::LazyKernelMatrix{T}, C::UniformScaling{S}) where {T,S<:Real}
+    if isequal(K.x, K.y)
+        R = typeof(zero(T) + zero(S))
+        return LazyKernelMatrix{R}(K.kernel + C.λ * WhiteKernel(), K.x, K.y)
+    else
+        return Matrix(K) + C
+    end
+end
+function Base.:+(C::UniformScaling{S}, K::LazyKernelMatrix{T}) where {T,S<:Real}
+    if isequal(K.x, K.y)
+        R = typeof(zero(T) + zero(S))
+        return LazyKernelMatrix{R}(C.λ * WhiteKernel() + K.kernel, K.x, K.y)
+    else
+        return C + Matrix(K)
+    end
+end
+function Base.:+(K1::LazyKernelMatrix, K2::LazyKernelMatrix)
+    if isequal(K1.x, K2.x) && isequal(K1.y, K2.y)
+        return LazyKernelMatrix(K1.kernel + K2.kernel, K1.x, K1.y)
+    else
+        return Matrix(K1) + Matrix(K2)
+    end
+end

From 70ddd0355a66c6f8bda3b2944b682f73a8460f98 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Fri, 2 Jun 2023 14:54:39 +0200
Subject: [PATCH 2/4] Update API docs

---
 docs/src/api.md | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/docs/src/api.md b/docs/src/api.md
index 9fb241fa4..1def1fcdb 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -6,7 +6,9 @@ CurrentModule = KernelFunctions
 
 ## Functions
 
-The KernelFunctions API comprises the following four functions.
+The KernelFunctions API comprises the following functions.
+
+The first set eagerly construct all or part of a kernel matrix
 ```@docs
 kernelmatrix
 kernelmatrix!
@@ -14,6 +16,12 @@ kernelmatrix_diag
 kernelmatrix_diag!
 ```
 
+It is also possible to lazily construct the same matrix
+```@docs
+lazykernelmatrix
+LazyKernelMatrix
+```
+
 ## Input Types
 
 The above API operates on collections of inputs.

From 32f17d691d4ecc2c619afa1610213e3af1a3aff6 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Fri, 2 Jun 2023 14:55:27 +0200
Subject: [PATCH 3/4] Update kernelmatrix docstring

---
 src/matrix/kernelmatrix.jl | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/src/matrix/kernelmatrix.jl b/src/matrix/kernelmatrix.jl
index e778f79eb..77ee9c007 100644
--- a/src/matrix/kernelmatrix.jl
+++ b/src/matrix/kernelmatrix.jl
@@ -30,12 +30,16 @@ Compute the kernel `κ` for each pair of inputs in `x`.
 Returns a matrix of size `(length(x), length(x))` satisfying
 `kernelmatrix(κ, x)[p, q] == κ(x[p], x[q])`.
 
+If `x` is large, consider using [`lazykernelmatrix`](@ref) instead.
+
     kernelmatrix(κ::Kernel, x::AbstractVector, y::AbstractVector)
 
 Compute the kernel `κ` for each pair of inputs in `x` and `y`.
 Returns a matrix of size `(length(x), length(y))` satisfying
 `kernelmatrix(κ, x, y)[p, q] == κ(x[p], y[q])`.
 
+If `x` and `y` are large, consider using [`lazykernelmatrix`](@ref) instead.
+
     kernelmatrix(κ::Kernel, X::AbstractMatrix; obsdim)
     kernelmatrix(κ::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim)
 

From 4d7d0b2fa6e51e7613fcecc319efeacbbb47d929 Mon Sep 17 00:00:00 2001
From: Seth Axen <seth@sethaxen.com>
Date: Fri, 2 Jun 2023 15:01:22 +0200
Subject: [PATCH 4/4] Update more docs

---
 docs/src/api.md           | 2 +-
 docs/src/create_kernel.md | 2 +-
 docs/src/userguide.md     | 9 ++++++++-
 3 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/docs/src/api.md b/docs/src/api.md
index 1def1fcdb..13042cab0 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -16,7 +16,7 @@ kernelmatrix_diag
 kernelmatrix_diag!
 ```
 
-It is also possible to lazily construct the same matrix
+It is also possible to lazily construct the same matrix, which is recommended when the kernel matrix might be too large to store in memory
 ```@docs
 lazykernelmatrix
 LazyKernelMatrix
diff --git a/docs/src/create_kernel.md b/docs/src/create_kernel.md
index c7234b41f..870792374 100644
--- a/docs/src/create_kernel.md
+++ b/docs/src/create_kernel.md
@@ -38,7 +38,7 @@ Finally there are additional functions you can define to bring in more features:
  - `KernelFunctions.iskroncompatible(k::MyKernel)`: if your kernel factorizes in dimensions, you can declare your kernel as `iskroncompatible(k) = true` to use Kronecker methods.
  - `KernelFunctions.dim(x::MyDataType)`: by default the dimension of the inputs will only be checked for vectors of type `AbstractVector{<:Real}`. If you want to check the dimensionality of your inputs, dispatch the `dim` function on your datatype. Note that `0` is the default.
  - `dim` is called within `KernelFunctions.validate_inputs(x::MyDataType, y::MyDataType)`, which can instead be directly overloaded if you want to run special checks for your input types.
- - `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` functions to eventually optimize the computations.
+ - `kernelmatrix(k::MyKernel, ...)`: you can redefine the diverse `kernelmatrix` and `lazykernelmatrix` functions to eventually optimize the computations.
  - `Base.print(io::IO, k::MyKernel)`: if you want to specialize the printing of your kernel.
 
 KernelFunctions uses [Functors.jl](https://github.com/FluxML/Functors.jl) for specifying trainable kernel parameters
diff --git a/docs/src/userguide.md b/docs/src/userguide.md
index d3e8789f7..269c979cc 100644
--- a/docs/src/userguide.md
+++ b/docs/src/userguide.md
@@ -61,7 +61,7 @@ k(x1, x2)
 
 ## Creating a Kernel Matrix
 
-Kernel matrices can be created via the `kernelmatrix` function or `kernelmatrix_diag` for only the diagonal.
+Kernel matrices can be eagerly created via the `kernelmatrix` function or `kernelmatrix_diag` for only the diagonal.
 For example, for a collection of 10 `Real`-valued inputs:
 ```julia
 k = SqExponentialKernel()
@@ -90,6 +90,13 @@ kernelmatrix(k, X; obsdim=2) # same as ColVecs(X)
 ```
 This is similar to the convention used in [Distances.jl](https://github.com/JuliaStats/Distances.jl).
 
+When data is large, it may not be possible to store the kernel matrix in memory.
+Then it is recommended to use `lazykernelmatrix`:
+```julia
+lazykernelmatrix(k, RowVecs(X))
+lazykernelmatrix(k, ColVecs(X))
+```
+
 ### So what type should I use to represent a collection of inputs?
 The central assumption made by KernelFunctions.jl is that all collections of `N` inputs are represented by `AbstractVector`s of length `N`.
 Abstraction is then used to ensure that efficiency is retained, `ColVecs` and `RowVecs`