Skip to content

Commit dd40e69

Browse files
committed
Stepsize no longer reset to 1 if term_buffer = 0 (Issue #3023)
1 parent b14c402 commit dd40e69

6 files changed

+78
-14
lines changed

src/stan/mcmc/covar_adaptation.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ class covar_adaptation : public windowed_adaptation {
1414
explicit covar_adaptation(int n)
1515
: windowed_adaptation("covariance"), estimator_(n) {}
1616

17+
/**
18+
* Return true if covariance was updated and adaptation is not finished
19+
*
20+
* @param covar Covariance
21+
* @param q Last draw
22+
*/
1723
bool learn_covariance(Eigen::MatrixXd& covar, const Eigen::VectorXd& q) {
1824
if (adaptation_window())
1925
estimator_.add_sample(q);
@@ -30,11 +36,11 @@ class covar_adaptation : public windowed_adaptation {
3036

3137
estimator_.restart();
3238

33-
++adapt_window_counter_;
34-
return true;
39+
increment_window_counter();
40+
return true && !finished();
3541
}
3642

37-
++adapt_window_counter_;
43+
increment_window_counter();
3844
return false;
3945
}
4046

src/stan/mcmc/var_adaptation.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ class var_adaptation : public windowed_adaptation {
1414
explicit var_adaptation(int n)
1515
: windowed_adaptation("variance"), estimator_(n) {}
1616

17+
/**
18+
* Return true if variance was updated and adaptation is not finished
19+
*
20+
* @param var Diagonal covariance
21+
* @param q Last draw
22+
*/
1723
bool learn_variance(Eigen::VectorXd& var, const Eigen::VectorXd& q) {
1824
if (adaptation_window())
1925
estimator_.add_sample(q);
@@ -29,11 +35,11 @@ class var_adaptation : public windowed_adaptation {
2935

3036
estimator_.restart();
3137

32-
++adapt_window_counter_;
33-
return true;
38+
increment_window_counter();
39+
return true && !finished();
3440
}
3541

36-
++adapt_window_counter_;
42+
increment_window_counter();
3743
return false;
3844
}
3945

src/stan/mcmc/windowed_adaptation.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <stan/callbacks/logger.hpp>
55
#include <stan/mcmc/base_adaptation.hpp>
6-
#include <ostream>
6+
#include <iostream>
77
#include <string>
88

99
namespace stan {
@@ -109,6 +109,25 @@ class windowed_adaptation : public base_adaptation {
109109
}
110110
}
111111

112+
/**
113+
* Check if there is any more warmup left to do
114+
*/
115+
bool finished() {
116+
if(adapt_window_counter_ + 1 >= num_warmup_) {
117+
return true;
118+
} else {
119+
return false;
120+
}
121+
}
122+
123+
/**
124+
* Increment the window counter and return the new value
125+
*/
126+
unsigned int increment_window_counter() {
127+
adapt_window_counter_ += 1;
128+
return adapt_window_counter_;
129+
}
130+
112131
protected:
113132
std::string estimator_name_;
114133

src/test/unit/mcmc/covar_adaptation_test.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,25 @@ TEST(McmcCovarAdaptation, learn_covariance) {
1515
target_covar *= 1e-3 * 5.0 / (n_learn + 5.0);
1616

1717
stan::mcmc::covar_adaptation adapter(n);
18-
adapter.set_window_params(50, 0, 0, n_learn, logger);
19-
20-
for (int i = 0; i < n_learn; ++i)
21-
adapter.learn_covariance(covar, q);
18+
adapter.set_window_params(30, 0, 0, n_learn, logger);
2219

20+
for (int i = 0; i < n_learn - 1; ++i) {
21+
EXPECT_FALSE(adapter.learn_covariance(covar, q));
22+
}
23+
// Learn variance should return true at end of first window
24+
EXPECT_TRUE(adapter.learn_covariance(covar, q));
25+
2326
for (int i = 0; i < n; ++i) {
2427
for (int j = 0; j < n; ++j) {
2528
EXPECT_EQ(target_covar(i, j), covar(i, j));
2629
}
2730
}
31+
32+
// Make sure learn_variance doesn't return true after second window (adaptation finished)
33+
for (int i = 0; i < 2 * n_learn ; ++i) {
34+
EXPECT_FALSE(adapter.learn_covariance(covar, q));
35+
}
36+
EXPECT_TRUE(adapter.finished());
37+
2838
EXPECT_EQ(0, logger.call_count());
2939
}

src/test/unit/mcmc/var_adaptation_test.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,22 @@ TEST(McmcVarAdaptation, learn_variance) {
1515
target_var *= 1e-3 * 5.0 / (n_learn + 5.0);
1616

1717
stan::mcmc::var_adaptation adapter(n);
18-
adapter.set_window_params(50, 0, 0, n_learn, logger);
18+
adapter.set_window_params(30, 0, 0, n_learn, logger);
1919

20-
for (int i = 0; i < n_learn; ++i)
21-
adapter.learn_variance(var, q);
20+
for (int i = 0; i < n_learn - 1; ++i) {
21+
EXPECT_FALSE(adapter.learn_variance(var, q));
22+
}
23+
// Learn variance should return true at end of first window
24+
EXPECT_TRUE(adapter.learn_variance(var, q));
2225

2326
for (int i = 0; i < n; ++i)
2427
EXPECT_EQ(target_var(i), var(i));
2528

29+
// Make sure learn_variance doesn't return true after second window (adaptation finished)
30+
for (int i = 0; i < 2 * n_learn ; ++i) {
31+
EXPECT_FALSE(adapter.learn_variance(var, q));
32+
}
33+
EXPECT_TRUE(adapter.finished());
34+
2635
EXPECT_EQ(0, logger.call_count());
2736
}

src/test/unit/mcmc/windowed_adaptation_test.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,17 @@ TEST(McmcWindowedAdaptation, set_window_params3) {
4646
ASSERT_EQ(0, logger.call_count());
4747
ASSERT_EQ(0, logger.call_count_info());
4848
}
49+
50+
TEST(McmcWindowedAdaptation, finished) {
51+
stan::test::unit::instrumented_logger logger;
52+
53+
stan::mcmc::windowed_adaptation adapter("test");
54+
55+
adapter.set_window_params(1000, 75, 50, 25, logger);
56+
57+
for(size_t i = 0; i < 999; i++) {
58+
EXPECT_FALSE(adapter.finished());
59+
adapter.increment_window_counter();
60+
}
61+
EXPECT_TRUE(adapter.finished());
62+
}

0 commit comments

Comments
 (0)