@@ -28,6 +28,26 @@ inline double squared_exponential_covariance(double distance,
28
28
return sigma * sigma * exp (-pow (distance / length_scale, 2 ));
29
29
}
30
30
31
+ inline double squared_exponential_covariance_derivative (double distance,
32
+ double length_scale,
33
+ double sigma = 1 .) {
34
+ if (length_scale <= 0 .) {
35
+ return 0 .;
36
+ }
37
+ return -2 * distance / (length_scale * length_scale) *
38
+ squared_exponential_covariance (distance, length_scale, sigma);
39
+ }
40
+
41
+ inline double squared_exponential_covariance_second_derivative (
42
+ double distance, double length_scale, double sigma = 1 .) {
43
+ if (length_scale <= 0 .) {
44
+ return 0 .;
45
+ }
46
+ const auto ratio = distance / length_scale;
47
+ return (4 . * ratio * ratio - 2 .) / (length_scale * length_scale) *
48
+ squared_exponential_covariance (distance, length_scale, sigma);
49
+ }
50
+
31
51
/*
32
52
* SquaredExponential distance
33
53
* covariance(d) = sigma^2 exp(-(d/length_scale)^2)
@@ -83,6 +103,72 @@ class SquaredExponential
83
103
sigma_squared_exponential.value );
84
104
}
85
105
106
+ // This operator is only defined when the distance metric is also defined.
107
+ template <typename X,
108
+ typename std::enable_if<
109
+ has_call_operator<DistanceMetricType, X &, X &>::value,
110
+ int >::type = 0 >
111
+ double _call_impl (const Derivative<X> &x, const X &y) const {
112
+ double distance = this ->distance_metric_ (x.value , y);
113
+ double distance_derivative = this ->distance_metric_ .derivative (x.value , y);
114
+ return distance_derivative * squared_exponential_covariance_derivative (
115
+ distance,
116
+ squared_exponential_length_scale.value ,
117
+ sigma_squared_exponential.value );
118
+ }
119
+
120
+ template <typename X,
121
+ typename std::enable_if<
122
+ has_call_operator<DistanceMetricType, X &, X &>::value,
123
+ int >::type = 0 >
124
+ double _call_impl (const Derivative<X> &x, const Derivative<X> &y) const {
125
+ const double distance = this ->distance_metric_ (x.value , y.value );
126
+ const double d_x = this ->distance_metric_ .derivative (x.value , y.value );
127
+ const double d_y = this ->distance_metric_ .derivative (y.value , x.value );
128
+ const double d_xy =
129
+ this ->distance_metric_ .second_derivative (x.value , y.value );
130
+
131
+ const double f_d = squared_exponential_covariance_derivative (
132
+ distance, squared_exponential_length_scale.value ,
133
+ sigma_squared_exponential.value );
134
+
135
+ const double f_dd = squared_exponential_covariance_second_derivative (
136
+ distance, squared_exponential_length_scale.value ,
137
+ sigma_squared_exponential.value );
138
+
139
+ std::cout << x.value << " " << y.value << " " << d_xy << " , " << f_d
140
+ << " , " << d_x << " , " << d_y << " , " << f_dd << std::endl;
141
+ return d_xy * f_d + d_x * d_y * f_dd;
142
+ }
143
+
144
+ // This operator is only defined when the distance metric is also defined.
145
+ template <typename X,
146
+ typename std::enable_if<
147
+ has_call_operator<DistanceMetricType, X &, X &>::value,
148
+ int >::type = 0 >
149
+ double _call_impl (const SecondDerivative<X> &x, const X &y) const {
150
+ double d = this ->distance_metric_ (x.value , y);
151
+ double d_1 = this ->distance_metric_ .derivative (x.value , y);
152
+ double d_2 = this ->distance_metric_ .second_derivative (x.value , y);
153
+ double f_1 = squared_exponential_covariance_derivative (
154
+ d, squared_exponential_length_scale.value ,
155
+ sigma_squared_exponential.value );
156
+ double f_2 = squared_exponential_covariance_second_derivative (
157
+ d, squared_exponential_length_scale.value ,
158
+ sigma_squared_exponential.value );
159
+ return d_2 * f_1 + d_1 * d_1 * f_2;
160
+ }
161
+
162
+ // This operator is only defined when the distance metric is also defined.
163
+ template <typename X,
164
+ typename std::enable_if<
165
+ has_call_operator<DistanceMetricType, X &, X &>::value,
166
+ int >::type = 0 >
167
+ double _call_impl (const SecondDerivative<X> &x,
168
+ const SecondDerivative<X> &y) const {
169
+ return NAN;
170
+ }
171
+
86
172
DistanceMetricType distance_metric_;
87
173
};
88
174
0 commit comments