-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathtransform.jl
39 lines (31 loc) · 1.38 KB
/
transform.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
Transform
Abstract type defining a transformation of the input.
"""
abstract type Transform end
Base.map(t::Transform, x::AbstractVector) = t.(x)
"""
IdentityTransform()
Transformation that returns exactly the input.
"""
struct IdentityTransform <: Transform end
(t::IdentityTransform)(x) = x
Base.map(::IdentityTransform, x::AbstractVector) = x
### TODO Maybe defining adjoints could help but so far it's not working
# @adjoint function ScaleTransform(s::T) where {T<:Real}
# @check_args(ScaleTransform, s, s > zero(T), "s > 0")
# ScaleTransform{T}(s),Δ->ScaleTransform{T}(Δ)
# end
#
# @adjoint function ScaleTransform(s::A) where {A<:AbstractVector{<:Real}}
# @check_args(ScaleTransform, s, all(s.>zero(eltype(A))), "s > 0")
# ScaleTransform{A}(s),Δ->begin; @show Δ,size(Δ); ScaleTransform{A}(Δ); end
# end
# @adjoint transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = transform(t,x),Δ->(ScaleTransform(nothing),t.s.*Δ)
#
# @adjoint transform(t::ARDTransform{<:Real},X::AbstractMatrix{<:Real},obsdim::Int) = transform(t,X,obsdim),Δ->begin
# @show Δ,size(Δ);
# return (obsdim == 1 ? ARD()Δ'.*X : ScaleTransform()Δ.*X,transform(t,Δ,obsdim),nothing)
# end
#
# @adjoint transform(t::ScaleTransform{T},x::AbstractVecOrMat,obsdim::Int) where {T<:Real} = transform(t,x), Δ->(ScaleTransform(one(T)),t.s.*Δ,nothing)