-
Notifications
You must be signed in to change notification settings - Fork 17
BreadCrumbs interface: an easier way to feed pigeons #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| """ | ||
| A struct that provides a basic, user-friendly interface to Pigeons. Only two inputs | ||
| are required, in positional order: | ||
| $FIELDS | ||
| !!! note | ||
| The PT state is initialized using a random sample from the reference. | ||
| """ | ||
| struct BreadCrumbs{TRefDist <: Distributions.Distribution, TTarget} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use @auto instead? Otherwise the standard out gets cluttered with pages of mostly useless type information when showing the stack trace. Also ref should be more general. We don't want Distributions to only work in BreadCrumbs. It should be seamless with the Conversely, things that can be fed into |
||
| """A function that evaluates the target log potential""" | ||
| target_log_potential::TTarget | ||
| """A Distributions.jl distribution used as reference""" | ||
| reference_distribution::TRefDist | ||
| end | ||
|
|
||
| # Target for a BreadCrumbs input | ||
| struct BreadCrumbsTarget{TBC <: BreadCrumbs} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems superfluous?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I.e. instead can we write the dispatches on Distributions types directly?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah! I think you're right! At least for the reference. I don't think people are thinking of passing the loglikelihood as a Distribution on the data given the parameter. That would also be very restrictive given that Distributions.jl only has simple models. |
||
| bc::TBC | ||
| end | ||
| function (bct::BreadCrumbsTarget)(x) | ||
| return if insupport(bct.bc.reference_distribution, x) | ||
| bct.bc.target_log_potential(x) | ||
| else | ||
| eltype(bct.bc.reference_distribution)(-Inf) | ||
| end | ||
| end | ||
| default_explorer(::BreadCrumbsTarget) = SliceSampler() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SliceSampler() is already the global default, no need to have that line. |
||
|
|
||
| # initialization | ||
| # general case | ||
| function initialization(bct::BreadCrumbsTarget, rng::AbstractRNG, ::Int) | ||
| rand(rng, bct.bc.reference_distribution) | ||
| end | ||
| # univariate case: need to wrap in vector to make the state mutable | ||
| function initialization( | ||
| bct::TBCT, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this equivalent to |
||
| rng::AbstractRNG, | ||
| ::Int | ||
| ) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCT<:BreadCrumbsTarget{TBC}} | ||
| [rand(rng, bct.bc.reference_distribution)] | ||
| end | ||
|
|
||
| # reference for a BreadCrumbs input | ||
| struct BreadCrumbsReference{TBC <: BreadCrumbs} | ||
| bc::TBC | ||
| end | ||
| (bcr::BreadCrumbsReference)(x) = logpdf(bcr.bc.reference_distribution, x) | ||
| default_reference(bct::BreadCrumbsTarget) = BreadCrumbsReference(bct.bc) | ||
|
|
||
| # sampling from the reference | ||
| # general case | ||
| sample_iid!(bcr::BreadCrumbsReference, replica, shared) = | ||
| rand!(replica.rng, bcr.bc.reference_distribution, replica.state) | ||
| # univariate case | ||
| function sample_iid!( | ||
| bcr::TBCR, | ||
| replica, | ||
| shared | ||
| ) where {TRD<:Distributions.UnivariateDistribution, TBC<:BreadCrumbs{TRD}, TBCR<:BreadCrumbsReference{TBC}} | ||
| replica.state[] = rand(rng, bcr.bc.reference_distribution) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| using MCMCChains | ||
|
|
||
| @testset "Multivariate BreadCrumbs" begin | ||
| function unid_log_potential(x; n_trials=100, n_successes=50) | ||
| p = prod(x) | ||
| return n_successes*log(p) + (n_trials-n_successes)*log1p(-p) | ||
| end | ||
| ref_dist = product_distribution(Uniform(), Uniform()) | ||
| pt = pigeons( | ||
| BreadCrumbs(unid_log_potential, ref_dist), | ||
| n_rounds = 12, | ||
| record = [traces] | ||
| ) | ||
|
|
||
| # collect the statistics and convert to MCMCChains' Chains | ||
| samples = Chains(sample_array(pt), variable_names(pt)) | ||
| end |
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexandrebouchard unrelated to this PR but I had to change this limit so that the test would pass. Are you ok with lowering this limit? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to export it, but then the PR should contain documentation (in code and in the website).