Skip to content

more NF examples #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 123 commits into
base: main
Choose a base branch
from
Open

more NF examples #11

wants to merge 123 commits into from

Conversation

zuhengxu
Copy link
Member

@zuhengxu zuhengxu commented Jul 13, 2023

as per issue #15

@zuhengxu zuhengxu self-assigned this Jul 13, 2023
@zuhengxu zuhengxu added documentation Improvements or additions to documentation enhancement New feature or request labels Jul 13, 2023
@zuhengxu zuhengxu changed the title more NF examples+ more proper synthetic examples + documentations more NF examples Jul 13, 2023
@zuhengxu zuhengxu marked this pull request as ready for review April 10, 2025 21:46
@yebai
Copy link
Member

yebai commented Apr 10, 2025

Could we add a script to test whether all the example models run, e.g. on buildkite GPU?

@zuhengxu
Copy link
Member Author

Yes, let me clean up the Ham VI demo and then add the test script. Currently I can confirm that all the examples run on CPUs by doing

include("demo_..._fllow.jl")

@zuhengxu
Copy link
Member Author

zuhengxu commented Apr 12, 2025

Could we add a script to test whether all the example models run, e.g. on buildkite GPU?

@yebai Currently all the examples are only tested on CPUs. Do you have examples of such script (for testing runs) that I can refer to? (I did do include(...jl) manually for all the demos though.)

In terms of GPU, shall we first merge this PR and do GPU examples/test (#49) on a separate PR? the purpose this PR is just to demonstrate how to define and customize flow layers (at least for those popular flows). I think that to properly take advantage of GPU, we might want some interface adjustment.

@zuhengxu zuhengxu requested a review from sunxd3 April 12, 2025 01:24
@yebai
Copy link
Member

yebai commented Apr 14, 2025

Currently all the examples are only tested on CPUs. Do you have examples of such a script (for testing runs) that I can refer to? (I did do include(...jl) manually for all the demos though.)

It is probably okay to create a new CI group called NF Examples and include all examples there. If running all examples is time-consuming, we could run only a few iterations and not wait for the convergence. This is mainly to ensure all examples can still run with new releases or PRs.

In terms of GPU, shall we first merge this PR and do GPU examples/test (#49) on a separate PR?

No problem -- it sounds good!

Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General, this is great effort!

Some minor comments, I didn't run the code to see the results.

In the future, probably all of these should be in a documentation page that automatically run and render results. But this PR is fine as it is.
(Worth it to add an issue as TODO?)

@@ -8,7 +8,9 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for adding these dependencies?

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also this)

Comment on lines +13 to +25
function load_model(name::String)
if name == "Banana"
return Banana(2, 1.0, 10.0)
elseif name == "Cross"
return Cross()
elseif name == "Funnel"
return Funnel(2)
elseif name == "WarpedGaussian"
return WarpedGauss()
else
error("Model not defined")
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is slightly ad hoc (it required a jump to know what this function will return at the call site), maybe use model constructor directly?

logp_ρ = sum(logpdf(Normal(), ρ))
return logp_x + logp_ρ
end
∇logp = Base.Fix1(score, target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

score is only defined for banana, worth mention a bit, I think

Comment on lines +67 to +98
# function check_trained_flow(
# flow_trained::Bijectors.MultivariateTransformed,
# true_dist::ContinuousMultivariateDistribution,
# n_samples::Int;
# kwargs...,
# )
# samples_trained = rand_batch(flow_trained, n_samples)
# samples_true = rand(true_dist, n_samples)

# p = Plots.scatter(
# samples_true[1, :],
# samples_true[2, :];
# label="True Distribution",
# color=:green,
# markersize=2,
# alpha=0.5,
# )
# Plots.scatter!(
# p,
# samples_trained[1, :],
# samples_trained[2, :];
# label="Trained Flow",
# color=:red,
# markersize=2,
# alpha=0.5,
# )
# Plots.plot!(; kwargs...)

# Plots.title!(p, "Trained HamFlow")

# return p
# end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required anymore?

@@ -2,16 +2,22 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might not be required anymore?

Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants