Skip to content

Commit 62077db

Browse files
authored
Merge branch 'master' into kx/general-relativistic
2 parents f446f0e + 47b212a commit 62077db

11 files changed

+131
-61
lines changed

Diff for: .github/dependabot.yml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
2+
version: 2
3+
updates:
4+
- package-ecosystem: "github-actions"
5+
directory: "/" # Location of package manifests
6+
schedule:
7+
interval: "monthly"

Diff for: .github/workflows/CI.yml

+18-19
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,41 @@ on:
66
- master
77
pull_request:
88

9+
concurrency:
10+
# Skip intermediate builds: always.
11+
# Cancel intermediate builds: only if it is a pull request build.
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
14+
915
jobs:
1016
test:
17+
# needed to allow julia-actions/cache to delete old caches that it has created
18+
permissions:
19+
actions: write
20+
contents: read
1121
runs-on: ${{ matrix.os }}
12-
continue-on-error: ${{ matrix.version == 'nightly' }}
22+
continue-on-error: ${{ matrix.version == 'pre' }}
1323
strategy:
1424
matrix:
1525
version:
16-
- '1.6'
26+
- 'min'
1727
- '1'
18-
- 'nightly'
28+
- 'pre'
1929
os:
2030
- ubuntu-latest
2131
- macOS-latest
2232
- windows-latest
2333
arch:
24-
- x86
2534
- x64
26-
exclude:
27-
- os: ubuntu-latest
28-
arch: x86
29-
- os: macOS-latest
30-
arch: x86
31-
- os: windows-latest
32-
arch: x86
33-
# GitHub Action seems to have issue of running julia-nightly with windows-latest
34-
# TODO Revisit in the future
35-
- version: 'nightly'
36-
os: windows-latest
3735
steps:
38-
- uses: actions/checkout@v2
39-
- uses: julia-actions/setup-julia@v1
36+
- uses: actions/checkout@v4
37+
- uses: julia-actions/setup-julia@v2
4038
with:
4139
version: ${{ matrix.version }}
4240
arch: ${{ matrix.arch }}
43-
- uses: julia-actions/julia-buildpkg@latest
41+
- uses: julia-actions/cache@v2
42+
- uses: julia-actions/julia-buildpkg@v1
4443
- name: Run tests
45-
uses: julia-actions/julia-runtest@latest
44+
uses: julia-actions/julia-runtest@v1
4645
env:
4746
AHMC_TEST_GROUP: AdvancedHMC

Diff for: .github/workflows/ExperimentalTests.yml

+15-12
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@ on:
66
- master
77
pull_request:
88

9+
concurrency:
10+
# Skip intermediate builds: always.
11+
# Cancel intermediate builds: only if it is a pull request build.
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
14+
915
jobs:
1016
test:
17+
# needed to allow julia-actions/cache to delete old caches that it has created
18+
permissions:
19+
actions: write
20+
contents: read
1121
runs-on: ${{ matrix.os }}
1222
strategy:
1323
matrix:
@@ -18,23 +28,16 @@ jobs:
1828
- macOS-latest
1929
- windows-latest
2030
arch:
21-
- x86
2231
- x64
23-
exclude:
24-
- os: ubuntu-latest
25-
arch: x86
26-
- os: macOS-latest
27-
arch: x86
28-
- os: windows-latest
29-
arch: x86
3032
steps:
31-
- uses: actions/checkout@v2
32-
- uses: julia-actions/setup-julia@v1
33+
- uses: actions/checkout@v4
34+
- uses: julia-actions/setup-julia@v2
3335
with:
3436
version: ${{ matrix.version }}
3537
arch: ${{ matrix.arch }}
36-
- uses: julia-actions/julia-buildpkg@latest
38+
- uses: julia-actions/cache@v2
39+
- uses: julia-actions/julia-buildpkg@v1
3740
- name: Run integration tests
38-
uses: julia-actions/julia-runtest@latest
41+
uses: julia-actions/julia-runtest@v1
3942
env:
4043
AHMC_TEST_GROUP: Experimental

Diff for: .github/workflows/Format.yml

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ on:
55
push:
66
branches:
77
- master
8-
- main
98

109
concurrency:
1110
# Skip intermediate builds: always.
@@ -17,8 +16,8 @@ jobs:
1716
format:
1817
runs-on: ubuntu-latest
1918
steps:
20-
- uses: actions/checkout@v2
21-
- uses: julia-actions/setup-julia@latest
19+
- uses: actions/checkout@v4
20+
- uses: julia-actions/setup-julia@v2
2221
with:
2322
version: 1
2423
- name: Format code

Diff for: .github/workflows/IntegrationTests.yml

+15-12
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,18 @@ on:
66
- master
77
pull_request:
88

9+
concurrency:
10+
# Skip intermediate builds: always.
11+
# Cancel intermediate builds: only if it is a pull request build.
12+
group: ${{ github.workflow }}-${{ github.ref }}
13+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
14+
915
jobs:
1016
test:
17+
# needed to allow julia-actions/cache to delete old caches that it has created
18+
permissions:
19+
actions: write
20+
contents: read
1121
runs-on: ${{ matrix.os }}
1222
strategy:
1323
matrix:
@@ -18,23 +28,16 @@ jobs:
1828
- macOS-latest
1929
- windows-latest
2030
arch:
21-
- x86
2231
- x64
23-
exclude:
24-
- os: ubuntu-latest
25-
arch: x86
26-
- os: macOS-latest
27-
arch: x86
28-
- os: windows-latest
29-
arch: x86
3032
steps:
31-
- uses: actions/checkout@v2
32-
- uses: julia-actions/setup-julia@v1
33+
- uses: actions/checkout@v4
34+
- uses: julia-actions/setup-julia@v2
3335
with:
3436
version: ${{ matrix.version }}
3537
arch: ${{ matrix.arch }}
36-
- uses: julia-actions/julia-buildpkg@latest
38+
- uses: julia-actions/cache@v2
39+
- uses: julia-actions/julia-buildpkg@v1
3740
- name: Run integration tests
38-
uses: julia-actions/julia-runtest@latest
41+
uses: julia-actions/julia-runtest@v1
3942
env:
4043
AHMC_TEST_GROUP: Downstream

Diff for: .github/workflows/documentation.yml

+5-4
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@ on:
99

1010
jobs:
1111
build:
12+
permissions:
13+
statuses: write # Used to report documentation build statuses
1214
runs-on: ubuntu-latest
1315
steps:
14-
- uses: actions/checkout@v2
15-
- uses: julia-actions/setup-julia@latest
16+
- uses: actions/checkout@v4
17+
- uses: julia-actions/setup-julia@v2
1618
with:
17-
version: '1.7'
19+
version: '1'
1820
- name: Install dependencies
1921
run: julia --project=docs/ -e '
2022
using Pkg;
@@ -24,4 +26,3 @@ jobs:
2426
env:
2527
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # If authenticating with SSH deploy key
2628
run: julia --project=docs/ docs/make.jl
27-

Diff for: Project.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.6.1"
3+
version = "0.6.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -31,7 +31,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains"
3131
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3232

3333
[compat]
34-
AbstractMCMC = "4.2, 5"
34+
AbstractMCMC = "5"
3535
AdaptiveRejectionSampling = "0.1.1"
3636
ArgCheck = "1, 2"
3737
CUDA = "3, 4, 5"
@@ -50,7 +50,9 @@ SimpleUnPack = "1.1"
5050
Statistics = "1.6"
5151
StatsBase = "0.31, 0.32, 0.33, 0.34"
5252
StatsFuns = "0.8, 0.9, 1"
53-
julia = "1.6"
53+
LinearAlgebra = "1.6"
54+
Random = "1.6"
55+
julia = "1.6.7"
5456

5557
[extras]
5658
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ samples = AbstractMCMC.sample(
128128
model,
129129
sampler,
130130
n_adapts + n_samples;
131-
nadapts = n_adapts,
131+
n_adapts = n_adapts,
132132
initial_params = initial_θ,
133133
)
134134
```

Diff for: src/abstractmcmc.jl

+38-6
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ function AbstractMCMC.sample(
4747
callback = nothing,
4848
kwargs...,
4949
)
50+
if haskey(kwargs, :nadapts)
51+
throw(
52+
ArgumentError(
53+
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
54+
),
55+
)
56+
end
57+
5058
if callback === nothing
5159
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
5260
progress = false # don't use AMCMC's progress-funtionality
@@ -78,6 +86,13 @@ function AbstractMCMC.sample(
7886
callback = nothing,
7987
kwargs...,
8088
)
89+
if haskey(kwargs, :nadapts)
90+
throw(
91+
ArgumentError(
92+
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
93+
),
94+
)
95+
end
8196

8297
if callback === nothing
8398
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
@@ -141,8 +156,17 @@ function AbstractMCMC.step(
141156
model::AbstractMCMC.LogDensityModel,
142157
spl::AbstractHMCSampler,
143158
state::HMCState;
159+
n_adapts::Int = 0,
144160
kwargs...,
145161
)
162+
if haskey(kwargs, :nadapts)
163+
throw(
164+
ArgumentError(
165+
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
166+
),
167+
)
168+
end
169+
146170
# Compute transition.
147171
i = state.i + 1
148172
t_old = state.transition
@@ -158,7 +182,6 @@ function AbstractMCMC.step(
158182

159183
# Adapt h and spl.
160184
tstat = stat(t)
161-
n_adapts = kwargs[:n_adapts]
162185
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate)
163186
tstat = merge(tstat, (is_adapt = isadapted,))
164187

@@ -189,8 +212,8 @@ struct HMCProgressCallback{P}
189212
"If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation."
190213
verbose::Bool
191214
"Number of divergent transitions fo far."
192-
num_divergent_transitions::Ref{Int}
193-
num_divergent_transitions_during_adaption::Ref{Int}
215+
num_divergent_transitions::Base.RefValue{Int}
216+
num_divergent_transitions_during_adaption::Base.RefValue{Int}
194217
end
195218

196219
function HMCProgressCallback(n_samples; progress = true, verbose = false)
@@ -200,7 +223,16 @@ function HMCProgressCallback(n_samples; progress = true, verbose = false)
200223
HMCProgressCallback(pm, progress, verbose, Ref(0), Ref(0))
201224
end
202225

203-
function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kwargs...)
226+
function (cb::HMCProgressCallback)(
227+
rng,
228+
model,
229+
spl,
230+
t,
231+
state,
232+
i;
233+
n_adapts::Int = 0,
234+
kwargs...,
235+
)
204236
progress = cb.progress
205237
verbose = cb.verbose
206238
pm = cb.pm
@@ -243,8 +275,8 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw
243275
),
244276
)
245277
# Report finish of adapation
246-
elseif verbose && isadapted && i == nadapts
247-
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric
278+
elseif verbose && isadapted && i == n_adapts
279+
@info "Finished $(n_adapts) adapation steps" adaptor κ.τ.integrator metric
248280
end
249281
end
250282

Diff for: test/abstractmcmc.jl

+24
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,30 @@ using Statistics: mean
3232
verbose = false,
3333
)
3434

35+
# Error if keyword argument `nadapts` is used
36+
@test_throws ArgumentError AbstractMCMC.sample(
37+
rng,
38+
model,
39+
nuts,
40+
n_adapts + n_samples;
41+
nadapts = n_adapts,
42+
initial_params = θ_init,
43+
progress = false,
44+
verbose = false,
45+
)
46+
@test_throws ArgumentError AbstractMCMC.sample(
47+
rng,
48+
model,
49+
nuts,
50+
MCMCThreads(),
51+
n_adapts + n_samples,
52+
2;
53+
nadapts = n_adapts,
54+
initial_params = θ_init,
55+
progress = false,
56+
verbose = false,
57+
)
58+
3559
# Transform back to original space.
3660
# NOTE: We're not correcting for the `logabsdetjac` here since, but
3761
# we're only interested in the mean it doesn't matter.

Diff for: test/mcmcchains.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Statistics: mean
2323
model,
2424
sampler,
2525
n_adapts + n_samples;
26-
nadapts = n_adapts,
26+
n_adapts = n_adapts,
2727
initial_params = θ_init,
2828
chain_type = Chains,
2929
progress = false,

0 commit comments

Comments
 (0)