From df63ed43c03274813d302a45188678bb7236d9cd Mon Sep 17 00:00:00 2001 From: Matt Peddie Date: Wed, 8 Feb 2023 10:25:55 +1000 Subject: [PATCH] Add rational quadratic covariance --- .../src/covariance_functions/radial.hpp | 55 +++++++++++++++++++ tests/test_radial.cc | 13 +++++ 2 files changed, 68 insertions(+) diff --git a/include/albatross/src/covariance_functions/radial.hpp b/include/albatross/src/covariance_functions/radial.hpp index 3a8ae9bc..ee87b2a9 100644 --- a/include/albatross/src/covariance_functions/radial.hpp +++ b/include/albatross/src/covariance_functions/radial.hpp @@ -15,6 +15,7 @@ constexpr double default_length_scale = 100000.; constexpr double default_radial_sigma = 10.; +constexpr double default_scale_mixture = 1.; namespace albatross { @@ -232,5 +233,59 @@ class Matern52 : public CovarianceFunction> { DistanceMetricType distance_metric_; }; +inline double rational_quadratic_covariance(double distance, + double length_scale, + double sigma = 1., + double alpha = 1.) { + if (length_scale <= 0.) { + return 0.; + } + return sigma * sigma * + std::pow(1 + distance * distance / + (2 * alpha * length_scale * length_scale), + -alpha); +} + +template +class RationalQuadratic + : public CovarianceFunction> { + public: + // I don't know whether the rational-quadratic function is positive + // definite when the distance is an angular (or great circle) + // distance. + static_assert( + !std::is_base_of::value, + "RationalQuadratic covariance with AngularDistance is not PSD."); + + ALBATROSS_DECLARE_PARAMS(rational_quadratic_length_scale, + sigma_rational_quadratic, alpha_rational_quadratic); + + RationalQuadratic(double length_scale_ = default_length_scale, + double sigma_rational_quadratic_ = default_radial_sigma, + double alpha_rational_quadratic_ = default_scale_mixture) + : distance_metric_() { + rational_quadratic_length_scale = {length_scale_, PositivePrior()}; + sigma_rational_quadratic = {sigma_rational_quadratic_, NonNegativePrior()}; + alpha_rational_quadratic = {alpha_rational_quadratic_, PositivePrior()}; + }; + + std::string name() const { + return "rational_quadratic[" + this->distance_metric_.get_name() + "]"; + } + + template ::value, + int>::type = 0> + double _call_impl(const X &x, const X &y) const { + double distance = this->distance_metric_(x, y); + return rational_quadratic_covariance( + distance, rational_quadratic_length_scale.value, + sigma_rational_quadratic.value, alpha_rational_quadratic.value); + } + + DistanceMetricType distance_metric_; +}; + } // namespace albatross #endif diff --git a/tests/test_radial.cc b/tests/test_radial.cc index 5010c49c..0bebb026 100644 --- a/tests/test_radial.cc +++ b/tests/test_radial.cc @@ -386,4 +386,17 @@ TEST(test_radial, test_matern_32_oracle) { } } +TEST(test_radial, test_rq_peak) { + constexpr std::size_t test_iters = 10000; + std::mt19937 gen{22}; + std::normal_distribution<> d{0., 10.}; + for (std::size_t iter = 0; iter < test_iters; ++iter) { + const double x = d(gen); + const double length = 1e-6 + fabs(d(gen)); + const double scale_mixture = 1e-6 + fabs(d(gen)); + const RationalQuadratic cov(length, 1., scale_mixture); + EXPECT_EQ(cov(x, x), 1.0); + } +} + } // namespace albatross