diff --git a/Project.toml b/Project.toml index 3fca586d..fe79753d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.9.6" +version = "0.9.7" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl new file mode 100644 index 00000000..d1bfd8f0 --- /dev/null +++ b/src/bijectors/ordered.jl @@ -0,0 +1,78 @@ +""" + OrderedBijector() + +A bijector mapping ordered vectors in ℝᵈ to unordered vectors in ℝᵈ. + +## See also +- [Stan's documentation](https://mc-stan.org/docs/2_27/reference-manual/ordered-vector.html) + - Note that this transformation and its inverse are the _opposite_ of in this reference. +""" +struct OrderedBijector <: Bijector{1} end + +""" + ordered(d::Distribution) + +Return a `Distribution` whose support are ordered vectors, i.e., vectors with increasingly ordered elements. +""" +ordered(d::ContinuousMultivariateDistribution) = Bijectors.transformed(d, OrderedBijector()) + +(b::OrderedBijector)(y::AbstractVecOrMat) = _transform_ordered(y) + +function _transform_ordered(y::AbstractVector) + x = similar(y) + @assert !isempty(y) + + @inbounds x[1] = y[1] + @inbounds for i = 2:length(x) + x[i] = x[i - 1] + exp(y[i]) + end + + return x +end + +function _transform_ordered(y::AbstractMatrix) + x = similar(y) + @assert !isempty(y) + + @inbounds for j = 1:size(x, 2), i = 1:size(x, 1) + if i == 1 + x[i, j] = y[i, j] + else + x[i, j] = x[i - 1, j] + exp(y[i, j]) + end + end + + return x +end + +(ib::Inverse{<:OrderedBijector})(x::AbstractVecOrMat) = _transform_inverse_ordered(x) + +function _transform_inverse_ordered(x::AbstractVector) + y = similar(x) + @assert !isempty(y) + + @inbounds y[1] = x[1] + @inbounds for i = 2:length(y) + y[i] = log(x[i] - x[i - 1]) + end + + return y +end + +function _transform_inverse_ordered(x::AbstractMatrix) + y = similar(x) + @assert !isempty(y) + + @inbounds for j = 1:size(y, 2), i = 1:size(y, 1) + if i == 1 + y[i, j] = x[i, j] + else + y[i, j] = log(x[i, j] - x[i - 1, j]) + end + end + + return y +end + +logabsdetjac(b::OrderedBijector, x::AbstractVector) = sum(@view(x[2:end])) +logabsdetjac(b::OrderedBijector, x::AbstractMatrix) = vec(sum(@view(x[2:end, :]); dims = 1)) diff --git a/src/chainrules.jl b/src/chainrules.jl index b7f2b51a..6e2e296a 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -6,3 +6,123 @@ ChainRulesCore.@scalar_rule( ), (x, - tanh(Ω + b) * x, x - 1), ) + +# `OrderedBijector` +function ChainRulesCore.rrule(::typeof(_transform_ordered), y::AbstractVector) + function _transform_ordered_adjoint(Δ) + Δ_new = similar(y) + n = length(Δ) + @assert n == length(Δ_new) + + s = sum(Δ) + Δ_new[1] = s + @inbounds for i in 2:n + # Equivalent to + # + # Δ_new[i] = sum(Δ[i:end]) * yexp[i - 1] + # + s -= Δ[i - 1] + Δ_new[i] = s * exp(y[i]) + end + + # Using `NO_FIELDS` to be backwards-compatible. + return (ChainRulesCore.NO_FIELDS, Δ_new) + end + + return _transform_ordered(y), _transform_ordered_adjoint +end + +function ChainRulesCore.rrule(::typeof(_transform_ordered), y::AbstractMatrix) + function _transform_ordered_adjoint(Δ) + Δ_new = similar(y) + n = size(Δ, 1) + @assert size(Δ) == size(Δ_new) + + s = vec(sum(Δ; dims=1)) + Δ_new[1, :] .= s + @inbounds for i in 2:n + # Equivalent to + # + # Δ_new[i] = sum(Δ[i:end]) * yexp[i - 1] + # + s -= Δ[i - 1, :] + Δ_new[i, :] = s .* exp.(y[i, :]) + end + + return (ChainRulesCore.NO_FIELDS, Δ_new) + end + + return _transform_ordered(y), _transform_ordered_adjoint +end + +function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractVector) + r = similar(x) + @inbounds for i = 1:length(r) + if i == 1 + r[i] = 1 + else + r[i] = x[i] - x[i - 1] + end + end + + function _transform_inverse_ordered_adjoint(Δ) + Δ_new = similar(x) + @assert length(Δ_new) == length(Δ) + + n = length(Δ_new) + @inbounds for j = 1:n - 1 + Δ_new[j] = (Δ[j] / r[j]) - (Δ[j + 1] / r[j + 1]) + end + @inbounds Δ_new[n] = Δ[n] / r[n] + + return (ChainRulesCore.NO_FIELDS, Δ_new) + end + + y = similar(x) + @inbounds y[1] = x[1] + @inbounds for i = 2:length(x) + y[i] = log(r[i]) + end + + return y, _transform_inverse_ordered_adjoint +end + +function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractMatrix) + r = similar(x) + @inbounds for j = 1:size(x, 2), i = 1:size(x, 1) + if i == 1 + r[i, j] = 1 + else + r[i, j] = x[i, j] - x[i - 1, j] + end + end + + function _transform_inverse_ordered_adjoint(Δ) + Δ_new = similar(x) + n = size(Δ, 1) + @assert size(Δ) == size(Δ_new) + + @inbounds for j = 1:size(Δ_new, 2), i = 1:n - 1 + Δ_new[i, j] = (Δ[i, j] / r[i, j]) - (Δ[i + 1, j] / r[i + 1, j]) + end + + @inbounds for j = 1:size(Δ_new, 2) + Δ_new[n, j] = Δ[n, j] / r[n, j] + end + + return (ChainRulesCore.NO_FIELDS, Δ_new) + end + + # Compute primal here so we can make use of the already + # computed `r`. + y = similar(x) + @inbounds for j = 1:size(x, 2), i = 1:size(x, 1) + if i == 1 + y[i, j] = x[i, j] + else + y[i, j] = log(r[i, j]) + end + end + + return y, _transform_inverse_ordered_adjoint +end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 6dbc3d4e..62c0691d 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -9,7 +9,9 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian, ReverseDiffAD, Inverse import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, _simplex_inv_bijector, replace_diag, jacobian, getpd, lower, - _inv_link_chol_lkj, _link_chol_lkj + _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered + +import ChainRulesCore using Compat: eachcol using Distributions: LocationScale @@ -195,4 +197,26 @@ end return α, find_alpha_pullback end +# `OrderedBijector` +function _transform_ordered(y::Union{TrackedVector, TrackedMatrix}) + return track(_transform_ordered, y) +end +@grad function _transform_ordered(y::AbstractVecOrMat) + x, dx = ChainRulesCore.rrule(_transform_ordered, value(y)) + return x, (wrap_chainrules_output ∘ Base.tail ∘ dx) +end + +function _transform_inverse_ordered(x::Union{TrackedVector, TrackedMatrix}) + return track(_transform_inverse_ordered, x) +end +@grad function _transform_inverse_ordered(x::AbstractVecOrMat) + y, dy = ChainRulesCore.rrule(_transform_inverse_ordered, value(x)) + return y, (wrap_chainrules_output ∘ Base.tail ∘ dy) +end + +# NOTE: Probably doesn't work in complete generality. +wrap_chainrules_output(x) = x +wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing +wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) + end diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 576b3788..885c1d11 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -1,19 +1,28 @@ -using .Tracker: Tracker, - TrackedReal, - TrackedVector, - TrackedMatrix, - TrackedArray, - TrackedVecOrMat, - @grad, - track, - data, - param +module TrackerCompat + +using ..Tracker: Tracker, + TrackedReal, + TrackedVector, + TrackedMatrix, + TrackedArray, + TrackedVecOrMat, + @grad, + track, + data, + param + +import ..Bijectors +using ..Bijectors: Log, SimplexBijector, ADBijector, + TrackerAD, Inverse, Stacked, Exp + +import ChainRulesCore using Compat: eachcol using LinearAlgebra +using Distributions: LocationScale -maporbroadcast(f, x::TrackedArray...) = f.(x...) -function maporbroadcast( +Bijectors.maporbroadcast(f, x::TrackedArray...) = f.(x...) +function Bijectors.maporbroadcast( f, x1::TrackedArray{T, N}, x::AbstractArray{<:TrackedReal}..., @@ -21,7 +30,7 @@ function maporbroadcast( return f.(convert(Array{TrackedReal{T}, N}, x1), x...) end -_eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) +Bijectors._eps(::Type{<:TrackedReal{T}}) where {T} = Bijectors._eps(T) function Base.minimum(d::LocationScale{<:TrackedReal}) m = minimum(d.ρ) if isfinite(m) @@ -40,13 +49,13 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end # AD implementations -function jacobian( +function Bijectors.jacobian( b::Union{<:ADBijector{<:TrackerAD}, Inverse{<:ADBijector{<:TrackerAD}}}, x::Real ) return data(Tracker.gradient(b, x)[1]) end -function jacobian( +function Bijectors.jacobian( b::Union{<:ADBijector{<:TrackerAD}, Inverse{<:ADBijector{<:TrackerAD}}}, x::AbstractVector{<:Real} ) @@ -55,20 +64,20 @@ function jacobian( end # implementations for Shift bijector -function _logabsdetjac_shift(a::TrackedReal, x::Real, ::Val{0}) +function Bijectors._logabsdetjac_shift(a::TrackedReal, x::Real, ::Val{0}) return tracker_shift_logabsdetjac(a, x, Val(0)) end -function _logabsdetjac_shift(a::TrackedReal, x::AbstractVector{<:Real}, ::Val{0}) +function Bijectors._logabsdetjac_shift(a::TrackedReal, x::AbstractVector{<:Real}, ::Val{0}) return tracker_shift_logabsdetjac(a, x, Val(0)) end -function _logabsdetjac_shift( +function Bijectors._logabsdetjac_shift( a::Union{TrackedReal, TrackedVector{<:Real}}, x::AbstractVector{<:Real}, ::Val{1} ) return tracker_shift_logabsdetjac(a, x, Val(1)) end -function _logabsdetjac_shift( +function Bijectors._logabsdetjac_shift( a::Union{TrackedReal, TrackedVector{<:Real}}, x::AbstractMatrix{<:Real}, ::Val{1} @@ -76,66 +85,66 @@ function _logabsdetjac_shift( return tracker_shift_logabsdetjac(a, x, Val(1)) end function tracker_shift_logabsdetjac(a, x, ::Val{N}) where {N} - return param(_logabsdetjac_shift(data(a), data(x), Val(N))) + return param(Bijectors._logabsdetjac_shift(data(a), data(x), Val(N))) end # Log bijector -@grad function logabsdetjac(b::Log{1}, x::AbstractVector) +@grad function Bijectors.logabsdetjac(b::Log{1}, x::AbstractVector) return -sum(log, data(x)), Δ -> (nothing, -Δ ./ data(x)) end -@grad function logabsdetjac(b::Log{1}, x::AbstractMatrix) +@grad function Bijectors.logabsdetjac(b::Log{1}, x::AbstractMatrix) return -vec(sum(log, data(x); dims = 1)), Δ -> (nothing, .- Δ' ./ data(x)) end -@grad function logabsdetjac(b::Log{2}, x::AbstractMatrix) +@grad function Bijectors.logabsdetjac(b::Log{2}, x::AbstractMatrix) return -sum(log, data(x)), Δ -> (nothing, -Δ ./ data(x)) end # implementations for Scale bijector # Adjoints for 0-dim and 1-dim `Scale` using `Real` -function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) - return track(_logabsdetjac_scale, a, data(x), Val(0)) +function Bijectors._logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) + return track(Bijectors._logabsdetjac_scale, a, data(x), Val(0)) end -@grad function _logabsdetjac_scale(a::Real, x::Real, ::Val{0}) - return _logabsdetjac_scale(data(a), data(x), Val(0)), Δ -> (inv(data(a)) .* Δ, nothing, nothing) +@grad function Bijectors._logabsdetjac_scale(a::Real, x::Real, ::Val{0}) + return Bijectors._logabsdetjac_scale(data(a), data(x), Val(0)), Δ -> (inv(data(a)) .* Δ, nothing, nothing) end # Need to treat `AbstractVector` and `AbstractMatrix` separately due to ambiguity errors -function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) - return track(_logabsdetjac_scale, a, data(x), Val(0)) +function Bijectors._logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) + return track(Bijectors._logabsdetjac_scale, a, data(x), Val(0)) end -@grad function _logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) +@grad function Bijectors._logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) da = data(a) J = fill(inv.(da), length(x)) - return _logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) end -function _logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Val{0}) - return track(_logabsdetjac_scale, a, data(x), Val(0)) +function Bijectors._logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Val{0}) + return track(Bijectors._logabsdetjac_scale, a, data(x), Val(0)) end -@grad function _logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{0}) +@grad function Bijectors._logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{0}) da = data(a) J = fill(size(x, 1) / da, size(x, 2)) - return _logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) end # adjoints for 1-dim and 2-dim `Scale` using `AbstractVector` -function _logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1}) - return track(_logabsdetjac_scale, a, data(x), Val(1)) +function Bijectors._logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1}) + return track(Bijectors._logabsdetjac_scale, a, data(x), Val(1)) end -@grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1}) +@grad function Bijectors._logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1}) # ∂ᵢ (∑ⱼ log|aⱼ|) = ∑ⱼ δᵢⱼ ∂ᵢ log|aⱼ| # = ∂ᵢ log |aᵢ| # = (1 / aᵢ) ∂ᵢ aᵢ # = (1 / aᵢ) da = data(a) J = inv.(da) - return _logabsdetjac_scale(da, data(x), Val(1)), Δ -> (J .* Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(1)), Δ -> (J .* Δ, nothing, nothing) end -function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1}) - return track(_logabsdetjac_scale, a, data(x), Val(1)) +function Bijectors._logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1}) + return track(Bijectors._logabsdetjac_scale, a, data(x), Val(1)) end -@grad function _logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1}) +@grad function Bijectors._logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1}) da = data(a) Jᵀ = repeat(inv.(da), 1, size(x, 2)) - return _logabsdetjac_scale(da, data(x), Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing) end # TODO: implement analytical gradient for scaling a vector using a matrix # function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Val{1}) @@ -145,50 +154,50 @@ end # throw # end # implementations for Stacked bijector -function logabsdetjac(b::Stacked, x::TrackedMatrix{<:Real}) +function Bijectors.logabsdetjac(b::Stacked, x::TrackedMatrix{<:Real}) return map(eachcol(x)) do c - logabsdetjac(b, c) + Bijectors.logabsdetjac(b, c) end end # TODO: implement custom adjoint since we can exploit block-diagonal nature of `Stacked` function (sb::Stacked)(x::TrackedMatrix{<:Real}) - return eachcolmaphcat(sb, x) + return Bijectors.eachcolmaphcat(sb, x) end # Simplex adjoints -function _simplex_bijector(X::TrackedVecOrMat, b::SimplexBijector{1}) - return track(_simplex_bijector, X, b) +function Bijectors._simplex_bijector(X::TrackedVecOrMat, b::SimplexBijector{1}) + return track(Bijectors._simplex_bijector, X, b) end -function _simplex_inv_bijector(Y::TrackedVecOrMat, b::SimplexBijector{1}) - return track(_simplex_inv_bijector, Y, b) +function Bijectors._simplex_inv_bijector(Y::TrackedVecOrMat, b::SimplexBijector{1}) + return track(Bijectors._simplex_inv_bijector, Y, b) end -@grad function _simplex_bijector(X::AbstractVector, b::SimplexBijector{1}) +@grad function Bijectors._simplex_bijector(X::AbstractVector, b::SimplexBijector{1}) Xd = data(X) - return _simplex_bijector(Xd, b), Δ -> (simplex_link_jacobian(Xd)' * Δ, nothing) + return Bijectors._simplex_bijector(Xd, b), Δ -> (Bijectors.simplex_link_jacobian(Xd)' * Δ, nothing) end -@grad function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) +@grad function Bijectors._simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector{1}) Yd = data(Y) - return _simplex_inv_bijector(Yd, b), Δ -> (simplex_invlink_jacobian(Yd)' * Δ, nothing) + return Bijectors._simplex_inv_bijector(Yd, b), Δ -> (Bijectors.simplex_invlink_jacobian(Yd)' * Δ, nothing) end -@grad function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) +@grad function Bijectors._simplex_bijector(X::AbstractMatrix, b::SimplexBijector{1}) Xd = data(X) - return _simplex_bijector(Xd, b), Δ -> begin - maphcat(eachcol(Xd), eachcol(Δ)) do c1, c2 - simplex_link_jacobian(c1)' * c2 + return Bijectors._simplex_bijector(Xd, b), Δ -> begin + Bijectors.maphcat(eachcol(Xd), eachcol(Δ)) do c1, c2 + Bijectors.simplex_link_jacobian(c1)' * c2 end, nothing end end -@grad function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) +@grad function Bijectors._simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) Yd = data(Y) - return _simplex_inv_bijector(Yd, b), Δ -> begin - maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 - simplex_invlink_jacobian(c1)' * c2 + return Bijectors._simplex_inv_bijector(Yd, b), Δ -> begin + Bijectors.maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 + Bijectors.simplex_invlink_jacobian(c1)' * c2 end, nothing end end -replace_diag(::typeof(log), X::TrackedMatrix) = track(replace_diag, log, X) -@grad function replace_diag(::typeof(log), X) +Bijectors.replace_diag(::typeof(log), X::TrackedMatrix) = track(Bijectors.replace_diag, log, X) +@grad function Bijectors.replace_diag(::typeof(log), X) Xd = data(X) f(i, j) = i == j ? log(Xd[i, j]) : Xd[i, j] out = f.(1:size(Xd, 1), (1:size(Xd, 2))') @@ -198,8 +207,8 @@ replace_diag(::typeof(log), X::TrackedMatrix) = track(replace_diag, log, X) end end -replace_diag(::typeof(exp), X::TrackedMatrix) = track(replace_diag, exp, X) -@grad function replace_diag(::typeof(exp), X) +Bijectors.replace_diag(::typeof(exp), X::TrackedMatrix) = track(Bijectors.replace_diag, exp, X) +@grad function Bijectors.replace_diag(::typeof(exp), X) Xd = data(X) f(i, j) = ifelse(i == j, exp(Xd[i, j]), Xd[i, j]) out = f.(1:size(Xd, 1), (1:size(Xd, 2))') @@ -209,18 +218,18 @@ replace_diag(::typeof(exp), X::TrackedMatrix) = track(replace_diag, exp, X) end end -logabsdetjac(b::SimplexBijector{1}, x::TrackedVecOrMat) = track(logabsdetjac, b, x) -@grad function logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) +Bijectors.logabsdetjac(b::SimplexBijector{1}, x::TrackedVecOrMat) = track(Bijectors.logabsdetjac, b, x) +@grad function Bijectors.logabsdetjac(b::SimplexBijector{1}, x::AbstractVector) xd = data(x) - return logabsdetjac(b, xd), Δ -> begin - (nothing, simplex_logabsdetjac_gradient(xd) * Δ) + return Bijectors.logabsdetjac(b, xd), Δ -> begin + (nothing, Bijectors.simplex_logabsdetjac_gradient(xd) * Δ) end end -@grad function logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) +@grad function Bijectors.logabsdetjac(b::SimplexBijector{1}, x::AbstractMatrix) xd = data(x) - return logabsdetjac(b, xd), Δ -> begin - (nothing, maphcat(eachcol(xd), Δ) do c, g - simplex_logabsdetjac_gradient(c) * g + return Bijectors.logabsdetjac(b, xd), Δ -> begin + (nothing, Bijectors.maphcat(eachcol(xd), Δ) do c, g + Bijectors.simplex_logabsdetjac_gradient(c) * g end) end end @@ -240,9 +249,9 @@ for header in [ (:(α_::TrackedReal), :(β::TrackedReal), :(z_0::TrackedVector), :(z::TrackedVector)), ] @eval begin - function _radial_transform($(header...)) - α = softplus(α_) # from A.2 - β_hat = -α + softplus(β) # from A.2 + function Bijectors._radial_transform($(header...)) + α = Bijectors.softplus(α_) # from A.2 + β_hat = -α + Bijectors.softplus(β) # from A.2 if β_hat isa TrackedReal TV = vectorof(typeof(β_hat)) T = vectorof(typeof(β_hat)) @@ -275,9 +284,9 @@ for header in [ (:(α_::TrackedReal), :(β::TrackedReal), :(z_0::TrackedVector), :(z::TrackedMatrix)), ] @eval begin - function _radial_transform($(header...)) - α = softplus(α_) # from A.2 - β_hat = -α + softplus(β) # from A.2 + function Bijectors._radial_transform($(header...)) + α = Bijectors.softplus(α_) # from A.2 + β_hat = -α + Bijectors.softplus(β) # from A.2 if β_hat isa TrackedReal TV = vectorof(typeof(β_hat)) T = matrixof(TV) @@ -327,26 +336,26 @@ end (b::Log{1})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) (b::Log{2})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) -logabsdetjac(b::Log{0}, x::TrackedVector) = .-log.(x)::vectorof(float(eltype(x))) -logabsdetjac(b::Log{1}, x::TrackedMatrix) = - vec(sum(log.(x); dims = 1)) +Bijectors.logabsdetjac(b::Log{0}, x::TrackedVector) = .-log.(x)::vectorof(float(eltype(x))) +Bijectors.logabsdetjac(b::Log{1}, x::TrackedMatrix) = - vec(sum(log.(x); dims = 1)) -getpd(X::TrackedMatrix) = track(getpd, X) -@grad function getpd(X::AbstractMatrix) +Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) +@grad function Bijectors.getpd(X::AbstractMatrix) Xd = data(X) - return LowerTriangular(Xd) * LowerTriangular(Xd)', Δ -> begin - Xl = LowerTriangular(Xd) - return (LowerTriangular(Δ' * Xl + Δ * Xl),) + return Bijectors.LowerTriangular(Xd) * Bijectors.LowerTriangular(Xd)', Δ -> begin + Xl = Bijectors.LowerTriangular(Xd) + return (Bijectors.LowerTriangular(Δ' * Xl + Δ * Xl),) end end -lower(A::TrackedMatrix) = track(lower, A) -@grad function lower(A::AbstractMatrix) +Bijectors.lower(A::TrackedMatrix) = track(Bijectors.lower, A) +@grad function Bijectors.lower(A::AbstractMatrix) Ad = data(A) - return lower(Ad), Δ -> (lower(Δ),) + return Bijectors.lower(Ad), Δ -> (Bijectors.lower(Δ),) end -_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y) -@grad function _inv_link_chol_lkj(y_tracked) +Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) +@grad function Bijectors._inv_link_chol_lkj(y_tracked) y = data(y_tracked) K = LinearAlgebra.checksquare(y) @@ -393,8 +402,8 @@ _inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y) return w, pullback_inv_link_chol_lkj end -_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w) -@grad function _link_chol_lkj(w_tracked) +Bijectors._link_chol_lkj(w::TrackedMatrix) = track(Bijectors._link_chol_lkj, w) +@grad function Bijectors._link_chol_lkj(w_tracked) w = data(w_tracked) K = LinearAlgebra.checksquare(w) @@ -447,11 +456,11 @@ _link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w) return z, pullback_link_chol_lkj end -function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal} - return track(find_alpha, wt_y, wt_u_hat, b) +function Bijectors.find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal} + return track(Bijectors.find_alpha, wt_y, wt_u_hat, b) end -@grad function find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal) - α = find_alpha(data(wt_y), data(wt_u_hat), data(b)) +@grad function Bijectors.find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal) + α = Bijectors.find_alpha(data(wt_y), data(wt_u_hat), data(b)) ∂wt_y = inv(1 + wt_u_hat * sech(α + b)^2) ∂wt_u_hat = - tanh(α + b) * ∂wt_y @@ -460,3 +469,25 @@ end return α, find_alpha_pullback end + +# `OrderedBijector` +Bijectors._transform_ordered(y::Union{TrackedVector,TrackedMatrix}) = track(Bijectors._transform_ordered, y) +@grad function Bijectors._transform_ordered(y::AbstractVecOrMat) + x, dx = ChainRulesCore.rrule(Bijectors._transform_ordered, data(y)) + return x, (wrap_chainrules_output ∘ Base.tail ∘ dx) +end + +function Bijectors._transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix}) + return track(Bijectors._transform_inverse_ordered, x) +end +@grad function Bijectors._transform_inverse_ordered(x::AbstractVecOrMat) + y, dy = ChainRulesCore.rrule(Bijectors._transform_inverse_ordered, data(x)) + return y, (wrap_chainrules_output ∘ Base.tail ∘ dy) +end + +# NOTE: Probably doesn't work in complete generality. +wrap_chainrules_output(x) = x +wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing +wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) + +end diff --git a/src/interface.jl b/src/interface.jl index 9454956b..c1ca115c 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -156,6 +156,7 @@ include("bijectors/pd.jl") include("bijectors/corr.jl") include("bijectors/truncated.jl") include("bijectors/named_bijector.jl") +include("bijectors/ordered.jl") # Normalizing flow related include("bijectors/planar_layer.jl") diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl new file mode 100644 index 00000000..88fec6d3 --- /dev/null +++ b/test/bijectors/ordered.jl @@ -0,0 +1,28 @@ +import Bijectors: OrderedBijector + +@testset "OrderedBijector" begin + b = OrderedBijector() + + # Length 1 + x = randn(1) + y = b(x) + test_bijector(b, hcat(x, x), hcat(y, y), zeros(2)) + + # Larger + x = randn(5) + xs = hcat(x, x) + test_bijector(b, xs) + + y = b(x) + @test sort(y) == y + + ys = b(xs) + @test sort(ys[:, 1]) == ys[:, 1] + @test sort(ys[:, 2]) == ys[:, 2] + + # `ChainRules` + test_rrule(Bijectors._transform_ordered, x) + test_rrule(Bijectors._transform_ordered, xs) + test_rrule(Bijectors._transform_inverse_ordered, y) + test_rrule(Bijectors._transform_inverse_ordered, ys) +end diff --git a/test/runtests.jl b/test/runtests.jl index f9cfbcc3..c1765f73 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ if GROUP == "All" || GROUP == "Interface" include("bijectors/named_bijector.jl") include("bijectors/leaky_relu.jl") include("bijectors/coupling.jl") + include("bijectors/ordered.jl") end if GROUP == "All" || GROUP == "AD"