|
102 | 102 |
|
103 | 103 | # Param-style wrappers
|
104 | 104 |
|
| 105 | +""" |
| 106 | + Params([A, B]) |
| 107 | +
|
| 108 | +Container for implicit parameters, used when differentiating |
| 109 | +a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. |
| 110 | +""" |
| 111 | +struct Params |
| 112 | + order::Buffer # {Any, Vector{Any}} |
| 113 | + params::IdSet{Any} # TODO store ids only |
| 114 | +end |
| 115 | + |
| 116 | +Params() = Params(Buffer([], false), IdSet()) |
| 117 | +Params(xs) = Params(Buffer(xs, false), IdSet(xs)) |
| 118 | +Params(ps::Params) = ps |
| 119 | +Params(xs::Tuple) = Params(collect(xs)) |
| 120 | + |
| 121 | +@forward Params.order Base.iterate, Base.length, Base.getindex |
| 122 | +@forward Params.params Base.in |
| 123 | + |
105 | 124 | """
|
106 | 125 | gradient(() -> loss(), ps::Params) -> Grads
|
107 | 126 |
|
@@ -135,25 +154,6 @@ function withgradient(f, ps::Params)
|
135 | 154 | (val = y, grad = back(sensitivity(y)))
|
136 | 155 | end
|
137 | 156 |
|
138 |
| -""" |
139 |
| - Params([A, B]) |
140 |
| -
|
141 |
| -Container for implicit parameters, used when differentiating |
142 |
| -a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`. |
143 |
| -""" |
144 |
| -struct Params |
145 |
| - order::Buffer # {Any, Vector{Any}} |
146 |
| - params::IdSet{Any} # TODO store ids only |
147 |
| -end |
148 |
| - |
149 |
| -Params() = Params(Buffer([], false), IdSet()) |
150 |
| -Params(xs) = Params(Buffer(xs, false), IdSet(xs)) |
151 |
| -Params(ps::Params) = ps |
152 |
| -Params(xs::Tuple) = Params(collect(xs)) |
153 |
| - |
154 |
| -@forward Params.order Base.iterate, Base.length, Base.getindex |
155 |
| -@forward Params.params Base.in |
156 |
| - |
157 | 157 | function Base.union!(ps::Params, itrs...)
|
158 | 158 | foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
|
159 | 159 | return ps
|
|
0 commit comments