-
Notifications
You must be signed in to change notification settings - Fork 5
more NF examples #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
more NF examples #11
Conversation
This reverts commit 1c1c88a.
Could we add a script to test whether all the example models run, e.g. on |
Yes, let me clean up the Ham VI demo and then add the test script. Currently I can confirm that all the examples run on CPUs by doing
|
@yebai Currently all the examples are only tested on CPUs. Do you have examples of such script (for testing runs) that I can refer to? (I did do In terms of GPU, shall we first merge this PR and do GPU examples/test (#49) on a separate PR? the purpose this PR is just to demonstrate how to define and customize flow layers (at least for those popular flows). I think that to properly take advantage of GPU, we might want some interface adjustment. |
It is probably okay to create a new CI group called
No problem -- it sounds good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General, this is great effort!
Some minor comments, I didn't run the code to see the results.
In the future, probably all of these should be in a documentation page that automatically run and render results. But this PR is fine as it is.
(Worth it to add an issue as TODO?)
@@ -8,7 +8,9 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | |||
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | |||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | |||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | |||
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for adding these dependencies?
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(also this)
function load_model(name::String) | ||
if name == "Banana" | ||
return Banana(2, 1.0, 10.0) | ||
elseif name == "Cross" | ||
return Cross() | ||
elseif name == "Funnel" | ||
return Funnel(2) | ||
elseif name == "WarpedGaussian" | ||
return WarpedGauss() | ||
else | ||
error("Model not defined") | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is slightly ad hoc (it required a jump to know what this function will return at the call site), maybe use model constructor directly?
logp_ρ = sum(logpdf(Normal(), ρ)) | ||
return logp_x + logp_ρ | ||
end | ||
∇logp = Base.Fix1(score, target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
score
is only defined for banana, worth mention a bit, I think
# function check_trained_flow( | ||
# flow_trained::Bijectors.MultivariateTransformed, | ||
# true_dist::ContinuousMultivariateDistribution, | ||
# n_samples::Int; | ||
# kwargs..., | ||
# ) | ||
# samples_trained = rand_batch(flow_trained, n_samples) | ||
# samples_true = rand(true_dist, n_samples) | ||
|
||
# p = Plots.scatter( | ||
# samples_true[1, :], | ||
# samples_true[2, :]; | ||
# label="True Distribution", | ||
# color=:green, | ||
# markersize=2, | ||
# alpha=0.5, | ||
# ) | ||
# Plots.scatter!( | ||
# p, | ||
# samples_trained[1, :], | ||
# samples_trained[2, :]; | ||
# label="Trained Flow", | ||
# color=:red, | ||
# markersize=2, | ||
# alpha=0.5, | ||
# ) | ||
# Plots.plot!(; kwargs...) | ||
|
||
# Plots.title!(p, "Trained HamFlow") | ||
|
||
# return p | ||
# end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this required anymore?
@@ -2,16 +2,22 @@ | |||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | |||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | |||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | |||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | |||
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might not be required anymore?
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SimpleUnPack can be replaced with destructing syntax (like (; a, b) = x
) https://github.com/devmotion/SimpleUnPack.jl/blob/5943a14470a08f919f82809536ac63aed67bb28f/src/SimpleUnPack.jl#L22
as per issue #15
NormalizingFlows.jl
+Bijectors.jl
(a tentative list)