Skip to content

Commit

Permalink
Fix bug in Stacked (#184)
Browse files Browse the repository at this point in the history
* make Stacked fall back to type-unstable impl properly

* version bump

* updated logabsdetjac impl too

* added tests for Stacked with mixed types

* fixed impl of Stacked with only one bijector

* fixed tests for Stacked

* removed excessive return statement
  • Loading branch information
torfjelde authored Jun 9, 2021
1 parent f3bb749 commit a64750b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.9.4"
version = "0.9.5"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
25 changes: 13 additions & 12 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ function _transform(x, rs::NTuple{1, UnitRange{Int}}, b::Bijector)
@assert rs[1] == 1:length(x)
return b(x)
end
function (sb::Stacked{<:Tuple})(x::AbstractVector{<:Real})
function (sb::Stacked{<:Tuple,<:Tuple})(x::AbstractVector{<:Real})
y = _transform(x, sb.ranges, sb.bs...)
@assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))"
return y
end
# The Stacked{<:AbstractArray} version is not TrackedArray friendly
function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real})
function (sb::Stacked)(x::AbstractVector{<:Real})
N = length(sb.bs)
N == 1 && return sb.bs[1](x[sb.ranges[1]])

Expand All @@ -91,6 +91,7 @@ function (sb::Stacked{<:AbstractArray})(x::AbstractVector{<:Real})
end

(sb::Stacked)(x::AbstractMatrix{<:Real}) = eachcolmaphcat(sb, x)

function logabsdetjac(
b::Stacked,
x::AbstractVector{<:Real}
Expand All @@ -108,18 +109,18 @@ function logabsdetjac(
end

function logabsdetjac(
b::Stacked{<:Tuple{Vararg{<:Any, N}}},
b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}},
x::AbstractVector{<:Real}
) where {N}
init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
init + sum(2:N) do i
sum(logabsdetjac(b.bs[i], x[b.ranges[i]]))
end
end

# Handle the case of just one bijector
function logabsdetjac(b::Stacked{<:Tuple{<:Bijector}}, x::AbstractVector{<:Real})
return sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
return if N == 1
init
else
init + sum(2:N) do i
sum(logabsdetjac(b.bs[i], x[b.ranges[i]]))
end
end
end

function logabsdetjac(b::Stacked, x::AbstractMatrix{<:Real})
Expand All @@ -137,7 +138,7 @@ end
# logjac += sum(_logjac)
# return (rv = vcat(y_1, y_2), logabsdetjac = logjac)
# end
@generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N}
@generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N}
expr = Expr(:block)
y_names = []

Expand All @@ -159,7 +160,7 @@ end
return expr
end

function forward(sb::Stacked{<:AbstractArray}, x::AbstractVector)
function forward(sb::Stacked, x::AbstractVector)
N = length(sb.bs)
yinit, linit = forward(sb.bs[1], x[sb.ranges[1]])
logjac = sum(linit)
Expand Down
32 changes: 30 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,34 @@ end
x = ones(4) ./ 4.0
@test_throws AssertionError sb(x)

# Mixed versions
# Tuple, Array
sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], (1:1, 2:3))
x = ones(3) ./ 3.0
res = forward(sb, x)
@test sb(param(x)) isa TrackedArray
@test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...]
@test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...]
@test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2])
@test res.logabsdetjac == logabsdetjac(sb, x)

x = ones(4) ./ 4.0
@test_throws AssertionError sb(x)

# Array, Tuple
sb = Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3])
x = ones(3) ./ 3.0
res = forward(sb, x)
@test sb(param(x)) isa TrackedArray
@test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...]
@test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...]
@test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2])
@test res.logabsdetjac == logabsdetjac(sb, x)

x = ones(4) ./ 4.0
@test_throws AssertionError sb(x)


@testset "Stacked: ADVI with MvNormal" begin
# MvNormal test
dists = [
Expand Down Expand Up @@ -798,8 +826,8 @@ end
SimplexBijector(),
Stacked((Exp{0}(), Log{0}())),
Stacked((Log{0}(), Exp{0}())),
# Stacked([Exp{0}(), Log{0}()]),
# Stacked([Log{0}(), Exp{0}()]),
Stacked([Exp{0}(), Log{0}()]),
Stacked([Log{0}(), Exp{0}()]),
Composed((Exp{0}(), Log{0}())),
Composed((Log{0}(), Exp{0}())),
# Composed([Exp{0}(), Log{0}()]),
Expand Down

2 comments on commit a64750b

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/38475

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.5 -m "<description of version>" a64750b774ede7572ff3fa0d087b7b4b84d5e3c4
git push origin v0.9.5

Please sign in to comment.