Skip to content

Commit

Permalink
OrderedBijector (#186)
Browse files Browse the repository at this point in the history
* initial impl of OrderedBijector

* added tests for OrderedBijector

* added Tracker gradient for OrderedBijector and improved imports

* added ReverseDiff adjoint for OrderedBijector

* added docs

* updated impls of OrderedBijector for improved perf

* improved impls of adjoints too

* fixed typo in impl of inverse for OrderedBijector

* use views

* added tests using ChainRulesTestUtils

* up the lowerbound on ChainRulesTestUtils

* fixed backwards compat

* fixed the chainrules tests

* patch version bump

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
torfjelde and devmotion authored Jul 15, 2021
1 parent 3332c7f commit 0d9b8b4
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.9.6"
version = "0.9.7"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
78 changes: 78 additions & 0 deletions src/bijectors/ordered.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
OrderedBijector()
A bijector mapping ordered vectors in ℝᵈ to unordered vectors in ℝᵈ.
## See also
- [Stan's documentation](https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html)
- Note that this transformation and its inverse are the _opposite_ of in this reference.
"""
struct OrderedBijector <: Bijector{1} end

"""
ordered(d::Distribution)
Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements.
"""
ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector())

(b::OrderedBijector)(y::AbstractVecOrMat) = _transform_ordered(y)

function _transform_ordered(y::AbstractVector)
x = similar(y)
@assert !isempty(y)

@inbounds x[1] = y[1]
@inbounds for i = 2:length(x)
x[i] = x[i - 1] + exp(y[i])
end

return x
end

function _transform_ordered(y::AbstractMatrix)
x = similar(y)
@assert !isempty(y)

@inbounds for j = 1:size(x, 2), i = 1:size(x, 1)
if i == 1
x[i, j] = y[i, j]
else
x[i, j] = x[i - 1, j] + exp(y[i, j])
end
end

return x
end

(ib::Inverse{<:OrderedBijector})(x::AbstractVecOrMat) = _transform_inverse_ordered(x)

function _transform_inverse_ordered(x::AbstractVector)
y = similar(x)
@assert !isempty(y)

@inbounds y[1] = x[1]
@inbounds for i = 2:length(y)
y[i] = log(x[i] - x[i - 1])
end

return y
end

function _transform_inverse_ordered(x::AbstractMatrix)
y = similar(x)
@assert !isempty(y)

@inbounds for j = 1:size(y, 2), i = 1:size(y, 1)
if i == 1
y[i, j] = x[i, j]
else
y[i, j] = log(x[i, j] - x[i - 1, j])
end
end

return y
end

logabsdetjac(b::OrderedBijector, x::AbstractVector) = sum(@view(x[2:end]))
logabsdetjac(b::OrderedBijector, x::AbstractMatrix) = vec(sum(@view(x[2:end, :]); dims = 1))
120 changes: 120 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,123 @@ ChainRulesCore.@scalar_rule(
),
(x, - tanh+ b) * x, x - 1),
)

# `OrderedBijector`
function ChainRulesCore.rrule(::typeof(_transform_ordered), y::AbstractVector)
function _transform_ordered_adjoint(Δ)
Δ_new = similar(y)
n = length(Δ)
@assert n == length(Δ_new)

s = sum(Δ)
Δ_new[1] = s
@inbounds for i in 2:n
# Equivalent to
#
# Δ_new[i] = sum(Δ[i:end]) * yexp[i - 1]
#
s -= Δ[i - 1]
Δ_new[i] = s * exp(y[i])
end

# Using `NO_FIELDS` to be backwards-compatible.
return (ChainRulesCore.NO_FIELDS, Δ_new)
end

return _transform_ordered(y), _transform_ordered_adjoint
end

function ChainRulesCore.rrule(::typeof(_transform_ordered), y::AbstractMatrix)
function _transform_ordered_adjoint(Δ)
Δ_new = similar(y)
n = size(Δ, 1)
@assert size(Δ) == size(Δ_new)

s = vec(sum(Δ; dims=1))
Δ_new[1, :] .= s
@inbounds for i in 2:n
# Equivalent to
#
# Δ_new[i] = sum(Δ[i:end]) * yexp[i - 1]
#
s -= Δ[i - 1, :]
Δ_new[i, :] = s .* exp.(y[i, :])
end

return (ChainRulesCore.NO_FIELDS, Δ_new)
end

return _transform_ordered(y), _transform_ordered_adjoint
end

function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractVector)
r = similar(x)
@inbounds for i = 1:length(r)
if i == 1
r[i] = 1
else
r[i] = x[i] - x[i - 1]
end
end

function _transform_inverse_ordered_adjoint(Δ)
Δ_new = similar(x)
@assert length(Δ_new) == length(Δ)

n = length(Δ_new)
@inbounds for j = 1:n - 1
Δ_new[j] = (Δ[j] / r[j]) - (Δ[j + 1] / r[j + 1])
end
@inbounds Δ_new[n] = Δ[n] / r[n]

return (ChainRulesCore.NO_FIELDS, Δ_new)
end

y = similar(x)
@inbounds y[1] = x[1]
@inbounds for i = 2:length(x)
y[i] = log(r[i])
end

return y, _transform_inverse_ordered_adjoint
end

function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractMatrix)
r = similar(x)
@inbounds for j = 1:size(x, 2), i = 1:size(x, 1)
if i == 1
r[i, j] = 1
else
r[i, j] = x[i, j] - x[i - 1, j]
end
end

function _transform_inverse_ordered_adjoint(Δ)
Δ_new = similar(x)
n = size(Δ, 1)
@assert size(Δ) == size(Δ_new)

@inbounds for j = 1:size(Δ_new, 2), i = 1:n - 1
Δ_new[i, j] = (Δ[i, j] / r[i, j]) - (Δ[i + 1, j] / r[i + 1, j])
end

@inbounds for j = 1:size(Δ_new, 2)
Δ_new[n, j] = Δ[n, j] / r[n, j]
end

return (ChainRulesCore.NO_FIELDS, Δ_new)
end

# Compute primal here so we can make use of the already
# computed `r`.
y = similar(x)
@inbounds for j = 1:size(x, 2), i = 1:size(x, 1)
if i == 1
y[i, j] = x[i, j]
else
y[i, j] = log(r[i, j])
end
end

return y, _transform_inverse_ordered_adjoint
end
26 changes: 25 additions & 1 deletion src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
ReverseDiffAD, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_inv_link_chol_lkj, _link_chol_lkj
_inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered

import ChainRulesCore

using Compat: eachcol
using Distributions: LocationScale
Expand Down Expand Up @@ -195,4 +197,26 @@ end
return α, find_alpha_pullback
end

# `OrderedBijector`
function _transform_ordered(y::Union{TrackedVector, TrackedMatrix})
return track(_transform_ordered, y)
end
@grad function _transform_ordered(y::AbstractVecOrMat)
x, dx = ChainRulesCore.rrule(_transform_ordered, value(y))
return x, (wrap_chainrules_output Base.tail dx)
end

function _transform_inverse_ordered(x::Union{TrackedVector, TrackedMatrix})
return track(_transform_inverse_ordered, x)
end
@grad function _transform_inverse_ordered(x::AbstractVecOrMat)
y, dy = ChainRulesCore.rrule(_transform_inverse_ordered, value(x))
return y, (wrap_chainrules_output Base.tail dy)
end

# NOTE: Probably doesn't work in complete generality.
wrap_chainrules_output(x) = x
wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing
wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)

end
Loading

2 comments on commit 0d9b8b4

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/40957

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.7 -m "<description of version>" 0d9b8b478331f79c49b2a66da2c24a3a9042dcaf
git push origin v0.9.7

Please sign in to comment.