Skip to content

Commit 4a018d8

Browse files
zuhengxuyebaisunxd3devmotiongdalle
authored
Change to DifferentiationInterface (#46)
* switch to differentiationinterface from diffresults * rename train.jl to optimize.jl * fix some compat issue and bump version * update tests to new interface * add Moonkcake to extras * rm all ext for now * rm enzyme test, and import mooncake for test * fixing compat and test with mooncake * fixing test bug * fix _value_and_grad wrapper bug * fix AutoReverseDiff argument typo * minor ed * minor ed * fixing test * minor ed * rm test for mooncake * fix doc * chagne CI * update CI * streamline project toml * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * add enzyme to tests * add Enzyme to using list * fixing enzyme readonly error by wrapping loss in Const * mv enzyme related edits to ext/ and fix tests * fixing extension loading error * Update Project.toml Co-authored-by: Guillaume Dalle <[email protected]> * remove Requires Co-authored-by: Guillaume Dalle <[email protected]> * remove explit load ext Co-authored-by: Guillaume Dalle <[email protected]> * Update src/objectives/loglikelihood.jl Co-authored-by: Guillaume Dalle <[email protected]> * make ext dep explicit * rm empty argument specialization for _prep_grad and _value_grad * signal empty rng arg * drop Requires * drop Requires * update test to include mooncake * rm unnecessary EnzymeCoreExt * minor update of readme * typo fix in readme * Update src/NormalizingFlows.jl Co-authored-by: David Widmann <[email protected]> * Update src/NormalizingFlows.jl Co-authored-by: David Widmann <[email protected]> * rm time_elapsed from train_flow * Update docs/src/api.md Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: Xianda Sun <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Guillaume Dalle <[email protected]>
1 parent 79ebfb2 commit 4a018d8

20 files changed

+301
-382
lines changed

Diff for: .github/workflows/CI.yml

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
name: CI
2+
23
on:
34
push:
45
branches:
56
- main
67
tags: ['*']
78
pull_request:
9+
810
concurrency:
911
# Skip intermediate builds: always.
1012
# Cancel intermediate builds: only if it is a pull request build.
1113
group: ${{ github.workflow }}-${{ github.ref }}
1214
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
15+
1316
jobs:
1417
test:
1518
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
@@ -19,17 +22,17 @@ jobs:
1922
matrix:
2023
version:
2124
- '1'
22-
- '1.6'
25+
- 'min'
2326
os:
2427
- ubuntu-latest
2528
arch:
2629
- x64
2730
steps:
28-
- uses: actions/checkout@v3
29-
- uses: julia-actions/setup-julia@v1
31+
- uses: actions/checkout@v4
32+
- uses: julia-actions/setup-julia@v2
3033
with:
3134
version: ${{ matrix.version }}
3235
arch: ${{ matrix.arch }}
33-
- uses: julia-actions/cache@v1
36+
- uses: julia-actions/cache@v2
3437
- uses: julia-actions/julia-buildpkg@v1
3538
- uses: julia-actions/julia-runtest@v1

Diff for: Project.toml

+7-31
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,26 @@
11
name = "NormalizingFlows"
22
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
3-
version = "0.1.1"
3+
version = "0.2.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
8-
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
8+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1313
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1615
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1716

18-
[weakdeps]
19-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
20-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
21-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
22-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
23-
24-
[extensions]
25-
NormalizingFlowsEnzymeExt = "Enzyme"
26-
NormalizingFlowsForwardDiffExt = "ForwardDiff"
27-
NormalizingFlowsReverseDiffExt = "ReverseDiff"
28-
NormalizingFlowsZygoteExt = "Zygote"
29-
3017
[compat]
31-
ADTypes = "0.1, 0.2, 1"
32-
Bijectors = "0.12.6, 0.13, 0.14"
33-
DiffResults = "1"
18+
ADTypes = "1"
19+
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
20+
DifferentiationInterface = "0.6.42"
3421
Distributions = "0.25"
3522
DocStringExtensions = "0.9"
36-
Enzyme = "0.11, 0.12, 0.13"
37-
ForwardDiff = "0.10.25"
38-
Optimisers = "0.2.16, 0.3"
23+
Optimisers = "0.2.16, 0.3, 0.4"
3924
ProgressMeter = "1.0.0"
40-
Requires = "1"
41-
ReverseDiff = "1.14"
4225
StatsBase = "0.33, 0.34"
43-
Zygote = "0.6"
44-
julia = "1.6"
45-
46-
[extras]
47-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
48-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
49-
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
50-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
julia = "1.10"

Diff for: README.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
[![Build Status](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/TuringLang/NormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
55

66

7-
**Last updated: 2023-Aug-23**
7+
**Last updated: 2025-Mar-04**
88

99
A normalizing flow library for Julia.
1010

@@ -21,16 +21,16 @@ See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for mor
2121
To install the package, run the following command in the Julia REPL:
2222
```julia
2323
] # enter Pkg mode
24-
(@v1.9) pkg> add git@github.com:TuringLang/NormalizingFlows.jl.git
24+
(@v1.11) pkg> add NormalizingFlows
2525
```
2626
Then simply run the following command to use the package:
2727
```julia
2828
using NormalizingFlows
2929
```
3030

3131
## Quick recap of normalizing flows
32-
Normalizing flows transform a simple reference distribution $q_0$ (sometimes known as base distribution) to
33-
a complex distribution $q$ using invertible functions.
32+
Normalizing flows transform a simple reference distribution $q_0$ (sometimes referred to as the base distribution)
33+
to a complex distribution $q$ using invertible functions.
3434

3535
In more details, given the base distribution, usually a standard Gaussian distribution, i.e., $q_0 = \mathcal{N}(0, I)$,
3636
we apply a series of parameterized invertible transformations (called flow layers), $T_{1, \theta_1}, \cdots, T_{N, \theta_k}$, yielding that
@@ -56,7 +56,7 @@ Given the feasibility of i.i.d. sampling and density evaluation, normalizing flo
5656
\text{Reverse KL:}\quad
5757
&\arg\min _{\theta} \mathbb{E}_{q_{\theta}}\left[\log q_{\theta}(Z)-\log p(Z)\right] \\
5858
&= \arg\min _{\theta} \mathbb{E}_{q_0}\left[\log \frac{q_\theta(T_N\circ \cdots \circ T_1(Z_0))}{p(T_N\circ \cdots \circ T_1(Z_0))}\right] \\
59-
&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(X)+\sum_{n=1}^N \log J_n\left(F_n \circ \cdots \circ F_1(X)\right)\right]
59+
&= \arg\max _{\theta} \mathbb{E}_{q_0}\left[ \log p\left(T_N \circ \cdots \circ T_1(Z_0)\right)-\log q_0(Z_0)+\sum_{n=1}^N \log J_n\left(T_n \circ \cdots \circ T_1(Z_0)\right)\right]
6060
\end{aligned}
6161
```
6262
and
@@ -76,10 +76,12 @@ normalizing constant.
7676
In contrast, forward KL minimization is typically used for **generative modeling**,
7777
where one wants to learn the underlying distribution of some data.
7878

79-
## Current status and TODOs
79+
## Current status and to-dos
8080

8181
- [x] general interface development
8282
- [x] documentation
83+
- [ ] integrating [Lux.jl](https://lux.csail.mit.edu/stable/tutorials/intermediate/7_RealNVP) and [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl).
84+
This could potentially solve the GPU compatibility issue as well.
8385
- [ ] including more NF examples/Tutorials
8486
- WIP: [PR#11](https://github.com/TuringLang/NormalizingFlows.jl/pull/11)
8587
- [ ] GPU compatibility

Diff for: docs/src/api.md

+1-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ For example of Gaussian VI, we can construct the flow as follows:
1515
```@julia
1616
using Distributions, Bijectors
1717
T= Float32
18+
@leaf MvNormal # to prevent params in q₀ from being optimized
1819
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
1920
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))
2021
```
@@ -83,11 +84,3 @@ NormalizingFlows.loglikelihood
8384
```@docs
8485
NormalizingFlows.optimize
8586
```
86-
87-
88-
## Utility Functions for Taking Gradient
89-
```@docs
90-
NormalizingFlows.grad!
91-
NormalizingFlows.value_and_gradient!
92-
```
93-

Diff for: docs/src/example.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Here we used the `PlanarLayer()` from `Bijectors.jl` to construct a
3636

3737
```julia
3838
using Bijectors, FunctionChains
39+
using Functors
3940

4041
function create_planar_flow(n_layers::Int, q₀)
4142
d = length(q₀)
@@ -45,7 +46,9 @@ function create_planar_flow(n_layers::Int, q₀)
4546
end
4647

4748
# create a 20-layer planar flow
48-
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I))
49+
@leaf MvNormal # to prevent params in q₀ from being optimized
50+
q₀ = MvNormal(zeros(Float32, 2), I)
51+
flow = create_planar_flow(20, q₀)
4952
flow_untrained = deepcopy(flow) # keep a copy of the untrained flow for comparison
5053
```
5154
*Notice that here the flow layers are chained together using `fchain` function from [`FunctionChains.jl`](https://github.com/oschulz/FunctionChains.jl).
@@ -116,4 +119,4 @@ plot!(title = "Comparison of Trained and Untrained Flow", xlabel = "X", ylabel=
116119

117120
## Reference
118121

119-
- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
122+
- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning

Diff for: ext/NormalizingFlowsEnzymeExt.jl

-25
This file was deleted.

Diff for: ext/NormalizingFlowsForwardDiffExt.jl

-28
This file was deleted.

Diff for: ext/NormalizingFlowsReverseDiffExt.jl

-22
This file was deleted.

Diff for: ext/NormalizingFlowsZygoteExt.jl

-23
This file was deleted.

Diff for: src/NormalizingFlows.jl

+15-31
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ using Bijectors
44
using Optimisers
55
using LinearAlgebra, Random, Distributions, StatsBase
66
using ProgressMeter
7-
using ADTypes, DiffResults
7+
using ADTypes
8+
import DifferentiationInterface as DI
89

910
using DocStringExtensions
1011

11-
export train_flow, elbo, loglikelihood, value_and_gradient!
12-
13-
using ADTypes
14-
using DiffResults
12+
export train_flow, elbo, loglikelihood
1513

1614
"""
1715
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
@@ -30,7 +28,13 @@ Train the given normalizing flow `flow` by calling `optimize`.
3028
- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps
3129
- `ADbackend::ADTypes.AbstractADType=ADTypes.AutoZygote()`:
3230
automatic differentiation backend, currently supports
33-
`ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, and `ADTypes.ReverseDiff()`.
31+
`ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`,
32+
`ADTypes.AutoMooncake()` and
33+
`ADTypes.AutoEnzyme(;
34+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
35+
function_annotation=Enzyme.Const,
36+
)`.
37+
If user wants to use `AutoEnzyme`, please make sure to include the `set_runtime_activity` and `function_annotation` as shown above.
3438
- `kwargs...`: additional keyword arguments for `optimize` (See [`optimize`](@ref) for details)
3539
3640
# Returns
@@ -57,13 +61,15 @@ function train_flow(
5761
# otherwise the compilation time for destructure will be too long
5862
θ_flat, re = Optimisers.destructure(flow)
5963

64+
loss(θ, rng, args...) = -vo(rng, re(θ), args...)
65+
6066
# Normalizing flow training loop
6167
θ_flat_trained, opt_stats, st = optimize(
62-
rng,
6368
ADbackend,
64-
vo,
69+
loss,
6570
θ_flat,
6671
re,
72+
rng,
6773
args...;
6874
max_iters=max_iters,
6975
optimiser=optimiser,
@@ -74,29 +80,7 @@ function train_flow(
7480
return flow_trained, opt_stats, st
7581
end
7682

77-
include("train.jl")
83+
include("optimize.jl")
7884
include("objectives.jl")
7985

80-
# optional dependencies
81-
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
82-
using Requires
83-
end
84-
85-
# Question: should Exts be loaded here or in train.jl?
86-
function __init__()
87-
@static if !isdefined(Base, :get_extension)
88-
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
89-
"../ext/NormalizingFlowsForwardDiffExt.jl"
90-
)
91-
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
92-
"../ext/NormalizingFlowsReverseDiffExt.jl"
93-
)
94-
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(
95-
"../ext/NormalizingFlowsEnzymeExt.jl"
96-
)
97-
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
98-
"../ext/NormalizingFlowsZygoteExt.jl"
99-
)
100-
end
101-
end
10286
end

Diff for: src/objectives.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
include("objectives/elbo.jl")
2-
include("objectives/loglikelihood.jl")
2+
include("objectives/loglikelihood.jl") # not fully tested

Diff for: src/objectives/elbo.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ end
4242

4343
function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
4444
return elbo(Random.default_rng(), flow, logp, n_samples)
45-
end
45+
end

0 commit comments

Comments
 (0)