diff --git a/Project.toml b/Project.toml index 63d48bee..9caf1d00 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.9.2" +version = "0.9.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -21,7 +21,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ArgCheck = "1, 2" -ChainRulesCore = "0.9" +ChainRulesCore = "0.9, 0.10" Compat = "3" Distributions = "0.23.3, 0.24, 0.25" Functors = "0.1, 0.2" diff --git a/test/Project.toml b/test/Project.toml index a8f231e0..5c10541d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,13 +14,13 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ChainRulesTestUtils = "0.6.3" +ChainRulesTestUtils = "0.6.3, 0.7" Combinatorics = "1.0.2" DistributionsAD = "0.6.3" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" Functors = "0.1, 0.2" -NNlib = "0.7" +NNlib = "0.7.18" ReverseDiff = "1.4.2" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index c1a8a5de..76846885 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -85,8 +85,8 @@ DistSpec(Poisson, (0.5,), 1), DistSpec(Poisson, (0.5,), [1, 1]), - DistSpec(Skellam, (1.0, 2.0), -2; broken=(:Zygote,)), - DistSpec(Skellam, (1.0, 2.0), [-2, -2]; broken=(:Zygote,)), + DistSpec(Skellam, (1.0, 2.0), -2), + DistSpec(Skellam, (1.0, 2.0), [-2, -2]), DistSpec(PoissonBinomial, ([0.5, 0.5],), 0), @@ -193,8 +193,9 @@ DistSpec(NormalCanon, (1.0, 2.0), 0.5), - DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5; broken=(:Zygote,)), + DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5), + DistSpec(Pareto, (), 1.5), DistSpec(Pareto, (1.0,), 1.5), DistSpec(Pareto, (1.0, 1.0), 1.5), @@ -245,11 +246,8 @@ DistSpec(VonMises, (1.0,), 1.0), DistSpec(VonMises, (1, 1), 1), - # Only some Zygote tests are broken and therefore this can not be checked - DistSpec(Pareto, (), 1.5; broken=(:Zygote,)), - # Some tests are broken on some Julia versions, therefore it can't be checked reliably - DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)), + DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)), ] # Tests that have a `broken` field can be executed but, according to FiniteDifferences, @@ -405,7 +403,7 @@ B, to_posdef, ), - DistSpec((eta) -> LKJ(10, eta), (1.), A_big, to_corr) + DistSpec(eta -> LKJ(10, eta), (1.,), A_big, to_corr) # AD for parameters of LKJ requires more DistributionsAD supports ] @@ -435,17 +433,21 @@ # Skellam only fails in these tests with ReverseDiff # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126 # PoissonBinomial fails with Zygote + # Matrix case does not work with Skellam: + # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 filldist_broken = if d.f(d.θ...) isa Skellam - (d.broken..., :ReverseDiff) + ((d.broken..., :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff)) elseif d.f(d.θ...) isa PoissonBinomial - (d.broken..., :Zygote) + ((d.broken..., :Zygote), (d.broken..., :Zygote)) else - d.broken + (d.broken, d.broken) end - arraydist_broken = if d.f(d.θ...) isa PoissonBinomial - (d.broken..., :Zygote) + arraydist_broken = if d.f(d.θ...) isa Skellam + (d.broken, (d.broken..., :Zygote)) + elseif d.f(d.θ...) isa PoissonBinomial + ((d.broken..., :Zygote), (d.broken..., :Zygote)) else - d.broken + (d.broken, d.broken) end # Create `filldist` distribution @@ -456,7 +458,7 @@ f_arraydist = (θ...,) -> arraydist([d.f(θ...) for _ in 1:n]) d_arraydist = f_arraydist(d.θ...) - for sz in ((n,), (n, 2)) + for (i, sz) in enumerate(((n,), (n, 2))) # Matrix case doesn't work for continuous distributions for some reason # now but not too important (?!) if length(sz) == 2 && Distributions.value_support(typeof(d)) === Continuous @@ -474,7 +476,7 @@ d.θ, x, d.xtrans; - broken=filldist_broken, + broken=filldist_broken[i], ) ) test_ad( @@ -484,7 +486,7 @@ d.θ, x, d.xtrans; - broken=arraydist_broken, + broken=arraydist_broken[i], ) ) end