Skip to content

Commit 7b86a9a

Browse files
committed
Refine Mooncake rules
1 parent 6f44671 commit 7b86a9a

File tree

1 file changed

+244
-21
lines changed

1 file changed

+244
-21
lines changed

ext/MooncakeExt.jl

Lines changed: 244 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,269 @@ import Mooncake
44
import MacroModelling
55
import SparseArrays
66
import LinearAlgebra as ℒ
7+
import ChainRulesCore
78

8-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.mul_reverse_AD!), Matrix{S}, AbstractMatrix{M}, AbstractMatrix{N}} where {S <: Real, M <: Real, N <: Real}
9+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, typeof(MacroModelling.mul_reverse_AD!), Matrix{S}, AbstractMatrix{M}, AbstractMatrix{N}} where {S <: Real, M <: Real, N <: Real}
910

10-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.sparse_preallocated!), Matrix{T}} where {T <: Real} true
11+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, MacroModelling.higher_order_caches{S,F}, typeof(MacroModelling.sparse_preallocated!), Matrix{S}} where {S <: Real, F <: AbstractFloat} true
1112

12-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_second_order_stochastic_steady_state), Val{:newton}, Matrix{Float64}, SparseArrays.AbstractSparseMatrix{Float64}, Vector{Float64}, MacroModelling.ℳ} true
13+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, AbstractFloat, typeof(MacroModelling.calculate_second_order_stochastic_steady_state), Val{:newton}, Matrix{Float64}, SparseArrays.AbstractSparseMatrix{Float64}, Vector{Float64}, MacroModelling.ℳ} true
14+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, AbstractFloat, typeof(MacroModelling.calculate_third_order_stochastic_steady_state), Val{:newton}, Matrix{Float64}, SparseArrays.AbstractSparseMatrix{Float64}, SparseArrays.AbstractSparseMatrix{Float64}, Vector{Float64}, MacroModelling.ℳ} true
1315

14-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_third_order_stochastic_steady_state), Val{:newton}, Matrix{Float64}, SparseArrays.AbstractSparseMatrix{Float64}, SparseArrays.AbstractSparseMatrix{Float64}, Vector{Float64}, MacroModelling.ℳ} true
16+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, MacroModelling.CalculationOptions, typeof(MacroModelling.get_NSSS_and_parameters), MacroModelling.ℳ, Vector{S}} where S <: AbstractFloat true
1517

16-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_jacobian), Vector{M}, Vector{N}, MacroModelling.ℳ} where {M <: AbstractFloat, N <: AbstractFloat}
18+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, typeof(MacroModelling.calculate_jacobian), Vector{M}, Vector{N}, MacroModelling.ℳ} where {M <: AbstractFloat, N <: AbstractFloat}
1719

18-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_hessian), Vector{M}, Vector{N}, MacroModelling.ℳ} where {M <: AbstractFloat, N <: AbstractFloat}
20+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, typeof(MacroModelling.calculate_hessian), Vector{M}, Vector{N}, MacroModelling.ℳ} where {M <: AbstractFloat, N <: AbstractFloat}
1921

20-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_third_order_derivatives), Vector{M}, Vector{N}, MacroModelling.ℳ} where {M <: AbstractFloat, N <: AbstractFloat}
22+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, typeof(MacroModelling.calculate_third_order_derivatives), Vector{M}, Vector{N}, MacroModelling.ℳ} where {M <: AbstractFloat, N <: AbstractFloat}
2123

22-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.get_NSSS_and_parameters), MacroModelling.ℳ, Vector{S} where S <: AbstractFloat} true
24+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, MacroModelling.timings, MacroModelling.CalculationOptions, AbstractMatrix{R}, typeof(MacroModelling.calculate_first_order_solution), Matrix{R}} where R <: AbstractFloat true
2325

24-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_first_order_solution), Matrix{R}} where R <: AbstractFloat true
26+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, MacroModelling.timings, AbstractMatrix{R}, MacroModelling.CalculationOptions, typeof(MacroModelling.calculate_second_order_solution), AbstractMatrix{R}, SparseArrays.SparseMatrixCSC{R}, AbstractMatrix{R}, MacroModelling.second_order_auxiliary_matrices, MacroModelling.caches} where R <: AbstractFloat true
2527

26-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_second_order_solution), AbstractMatrix{R}, SparseArrays.SparseMatrixCSC{R}, AbstractMatrix{R}, MacroModelling.second_order_auxiliary_matrices, MacroModelling.caches} where R <: AbstractFloat true
28+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, MacroModelling.timings, AbstractMatrix{R}, MacroModelling.CalculationOptions, typeof(MacroModelling.calculate_second_order_solution), AbstractMatrix{R}, SparseArrays.SparseMatrixCSC{R}, SparseArrays.SparseMatrixCSC{R}, AbstractMatrix{R}, SparseArrays.SparseMatrixCSC{R}, MacroModelling.second_order_auxiliary_matrices, MacroModelling.third_order_auxiliary_matrices, MacroModelling.caches} where R <: AbstractFloat true
2729

28-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_second_order_solution), AbstractMatrix{R}, SparseArrays.SparseMatrixCSC{R}, SparseArrays.SparseMatrixCSC{R}, AbstractMatrix{R}, SparseArrays.SparseMatrixCSC{R}, MacroModelling.second_order_auxiliary_matrices, MacroModelling.third_order_auxiliary_matrices, MacroModelling.caches} where R <: AbstractFloat true
2930

30-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.solve_lyapunov_equation), AbstractMatrix{R}, AbstractMatrix{R}} where R <: AbstractFloat true
3131

32-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.solve_sylvester_equation), AbstractMatrix{R}, AbstractMatrix{R}, AbstractMatrix{R}} where R <: AbstractFloat true
32+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, Int, Float64, typeof(MacroModelling.find_shocks), Val{:LagrangeNewton}, Vector{Float64}, Vector{Float64}, AbstractMatrix{Float64}, ℒ.Diagonal{Bool, Vector{Bool}}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, Vector{Float64}} true
3333

34-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.find_shocks), Val{:LagrangeNewton}, Vector{Float64}, Vector{Float64}, AbstractMatrix{Float64}, ℒ.Diagonal{Bool, Vector{Bool}}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, Vector{Float64}} true
34+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, Int, Float64, typeof(MacroModelling.find_shocks), Val{:LagrangeNewton}, Vector{Float64}, Vector{Float64}, Vector{Float64}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, .Diagonal{Bool, Vector{Bool}}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, Vector{Float64}} true
3535

36-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.find_shocks), Val{:LagrangeNewton}, Vector{Float64}, Vector{Float64}, Vector{Float64}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, AbstractMatrix{Float64},
37-
.Diagonal{Bool, Vector{Bool}}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, AbstractMatrix{Float64}, Vector{Float64}} true
36+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, Int, Int, MacroModelling.CalculationOptions, Symbol, typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:first_order}, Vector{Vector{Float64}}, Matrix{Float64}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
37+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, Int, Int, MacroModelling.CalculationOptions, Symbol, typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:pruned_second_order}, Vector{Vector{Float64}}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
38+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, Int, Int, MacroModelling.CalculationOptions, Symbol, typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:second_order}, Vector{Float64}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
39+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, Int, Int, MacroModelling.CalculationOptions, Symbol, typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:pruned_third_order}, Vector{Vector{Float64}}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
40+
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{Any, Int, Int, MacroModelling.CalculationOptions, Symbol, typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:third_order}, Vector{Float64}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
3841

39-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:first_order}, Vector{Vector{Float64}}, Matrix{Float64}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
4042

41-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:pruned_second_order},Vector{Vector{Float64}}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
43+
function ChainRulesCore.rrule(func_ir::Any,
44+
::MacroModelling.higher_order_caches{T,F},
45+
::typeof(MacroModelling.sparse_preallocated!),
46+
::Matrix{T}) where {T<:Real,F<:AbstractFloat}
47+
ChainRulesCore.rrule(MacroModelling.sparse_preallocated!, Ŝ; ℂ=ℂ)
48+
end
4249

43-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:second_order},Vector{Float64}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
50+
function ChainRulesCore.rrule(func_ir::Any,
51+
::typeof(MacroModelling.mul_reverse_AD!),
52+
C::Matrix{S},
53+
A::AbstractMatrix{M},
54+
B::AbstractMatrix{N}) where {S<:Real,M<:Real,N<:Real}
55+
ChainRulesCore.rrule(MacroModelling.mul_reverse_AD!, C, A, B)
56+
end
4457

45-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:pruned_third_order},Vector{Vector{Float64}}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
58+
function ChainRulesCore.rrule(func_ir::Any,
59+
::typeof(MacroModelling.calculate_jacobian),
60+
parameters::Vector{M},
61+
SS_and_pars::Vector{N},
62+
m::MacroModelling.ℳ) where {M<:AbstractFloat,N<:AbstractFloat}
63+
ChainRulesCore.rrule(MacroModelling.calculate_jacobian, parameters, SS_and_pars, m)
64+
end
4665

47-
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(MacroModelling.calculate_inversion_filter_loglikelihood), Val{:third_order},Vector{Float64}, Vector{AbstractMatrix{Float64}}, Matrix{Float64}, Union{Vector{String}, Vector{Symbol}}, MacroModelling.timings} true
66+
function ChainRulesCore.rrule(func_ir::Any,
67+
::typeof(MacroModelling.calculate_hessian),
68+
parameters::Vector{M},
69+
SS_and_pars::Vector{N},
70+
m::MacroModelling.ℳ) where {M<:AbstractFloat,N<:AbstractFloat}
71+
ChainRulesCore.rrule(MacroModelling.calculate_hessian, parameters, SS_and_pars, m)
72+
end
73+
74+
function ChainRulesCore.rrule(func_ir::Any,
75+
::typeof(MacroModelling.calculate_third_order_derivatives),
76+
parameters::Vector{M},
77+
SS_and_pars::Vector{N},
78+
m::MacroModelling.ℳ) where {M<:AbstractFloat,N<:AbstractFloat}
79+
ChainRulesCore.rrule(MacroModelling.calculate_third_order_derivatives, parameters, SS_and_pars, m)
80+
end
81+
82+
function ChainRulesCore.rrule(func_ir::Any,
83+
tol::AbstractFloat,
84+
::typeof(MacroModelling.calculate_second_order_stochastic_steady_state),
85+
::Val{:newton},
86+
S1::Matrix{Float64},
87+
S2::SparseArrays.AbstractSparseMatrix{Float64},
88+
x::Vector{Float64},
89+
m::MacroModelling.ℳ)
90+
ChainRulesCore.rrule(MacroModelling.calculate_second_order_stochastic_steady_state,
91+
Val(:newton), S1, S2, x, m; tol=tol)
92+
end
93+
94+
function ChainRulesCore.rrule(func_ir::Any,
95+
tol::AbstractFloat,
96+
::typeof(MacroModelling.calculate_third_order_stochastic_steady_state),
97+
::Val{:newton},
98+
S1::Matrix{Float64},
99+
S2::SparseArrays.AbstractSparseMatrix{Float64},
100+
S3::SparseArrays.AbstractSparseMatrix{Float64},
101+
x::Vector{Float64},
102+
m::MacroModelling.ℳ)
103+
ChainRulesCore.rrule(MacroModelling.calculate_third_order_stochastic_steady_state,
104+
Val(:newton), S1, S2, S3, x, m; tol=tol)
105+
end
106+
107+
function ChainRulesCore.rrule(func_ir::Any,
108+
opts::MacroModelling.CalculationOptions,
109+
::typeof(MacroModelling.get_NSSS_and_parameters),
110+
m::MacroModelling.ℳ,
111+
x::Vector{S}) where {S<:AbstractFloat}
112+
ChainRulesCore.rrule(MacroModelling.get_NSSS_and_parameters, m, x; opts=opts)
113+
end
114+
115+
function ChainRulesCore.rrule(func_ir::Any,
116+
T::MacroModelling.timings,
117+
opts::MacroModelling.CalculationOptions,
118+
initial_guess::AbstractMatrix{R},
119+
::typeof(MacroModelling.calculate_first_order_solution),
120+
∇₁::Matrix{R}) where {R<:AbstractFloat}
121+
ChainRulesCore.rrule(MacroModelling.calculate_first_order_solution,
122+
∇₁; T=T, opts=opts, initial_guess=initial_guess)
123+
end
124+
125+
function ChainRulesCore.rrule(func_ir::Any,
126+
T::MacroModelling.timings,
127+
initial_guess::AbstractMatrix{R},
128+
opts::MacroModelling.CalculationOptions,
129+
::typeof(MacroModelling.calculate_second_order_solution),
130+
∇₁::AbstractMatrix{R},
131+
∇₂::SparseArrays.SparseMatrixCSC{R},
132+
𝑺₁::AbstractMatrix{R},
133+
M₂::MacroModelling.second_order_auxiliary_matrices,
134+
ℂC::MacroModelling.caches) where {R<:AbstractFloat}
135+
ChainRulesCore.rrule(MacroModelling.calculate_second_order_solution,
136+
∇₁, ∇₂, 𝑺₁, M₂, ℂC; T=T, initial_guess=initial_guess, opts=opts)
137+
end
138+
139+
function ChainRulesCore.rrule(func_ir::Any,
140+
T::MacroModelling.timings,
141+
initial_guess::AbstractMatrix{R},
142+
opts::MacroModelling.CalculationOptions,
143+
::typeof(MacroModelling.calculate_second_order_solution),
144+
∇₁::AbstractMatrix{R},
145+
∇₂::SparseArrays.SparseMatrixCSC{R},
146+
∇₃::SparseArrays.SparseMatrixCSC{R},
147+
𝑺₁::AbstractMatrix{R},
148+
𝐒₂::SparseArrays.SparseMatrixCSC{R},
149+
M₂::MacroModelling.second_order_auxiliary_matrices,
150+
M₃::MacroModelling.third_order_auxiliary_matrices,
151+
ℂC::MacroModelling.caches) where {R<:AbstractFloat}
152+
ChainRulesCore.rrule(MacroModelling.calculate_second_order_solution,
153+
∇₁, ∇₂, ∇₃, 𝑺₁, 𝐒₂, M₂, M₃, ℂC;
154+
T=T, initial_guess=initial_guess, opts=opts)
155+
end
156+
157+
function ChainRulesCore.rrule(func_ir::Any,
158+
alg::Symbol,
159+
tol::AbstractFloat,
160+
acc_tol::AbstractFloat,
161+
verbose::Bool,
162+
::typeof(MacroModelling.solve_lyapunov_equation),
163+
A::AbstractMatrix{R},
164+
C::AbstractMatrix{R}) where {R<:AbstractFloat}
165+
ChainRulesCore.rrule(MacroModelling.solve_lyapunov_equation,
166+
A, C;
167+
lyapunov_algorithm=alg,
168+
tol=tol,
169+
acceptance_tol=acc_tol,
170+
verbose=verbose)
171+
end
172+
173+
function ChainRulesCore.rrule(func_ir::Any,
174+
initial_guess::AbstractMatrix{<:AbstractFloat},
175+
syl_alg::Symbol,
176+
acc_tol::AbstractFloat,
177+
tol::AbstractFloat,
178+
𝕊ℂ::MacroModelling.sylvester_caches,
179+
verbose::Bool,
180+
::typeof(MacroModelling.solve_sylvester_equation),
181+
A::AbstractMatrix{R},
182+
B::AbstractMatrix{R},
183+
C::AbstractMatrix{R}) where {R<:AbstractFloat}
184+
ChainRulesCore.rrule(MacroModelling.solve_sylvester_equation,
185+
A, B, C;
186+
initial_guess=initial_guess,
187+
sylvester_algorithm=syl_alg,
188+
acceptance_tol=acc_tol,
189+
tol=tol,
190+
𝕊ℂ=𝕊ℂ,
191+
verbose=verbose)
192+
end
193+
194+
function ChainRulesCore.rrule(func_ir::Any,
195+
max_iter::Int,
196+
tol::Float64,
197+
::typeof(MacroModelling.find_shocks),
198+
::Val{:LagrangeNewton},
199+
initial_guess::Vector{Float64},
200+
kron_buffer::Vector{Float64},
201+
kron_buffer2::AbstractMatrix{Float64},
202+
J::ℒ.Diagonal{Bool,Vector{Bool}},
203+
S_i::AbstractMatrix{Float64},
204+
S_i2e::AbstractMatrix{Float64},
205+
shock_independent::Vector{Float64})
206+
ChainRulesCore.rrule(MacroModelling.find_shocks,
207+
Val(:LagrangeNewton),
208+
initial_guess,
209+
kron_buffer,
210+
kron_buffer2,
211+
J,
212+
S_i,
213+
S_i2e,
214+
shock_independent;
215+
max_iter=max_iter,
216+
tol=tol)
217+
end
218+
219+
function ChainRulesCore.rrule(func_ir::Any,
220+
max_iter::Int,
221+
tol::Float64,
222+
::typeof(MacroModelling.find_shocks),
223+
::Val{:LagrangeNewton},
224+
initial_guess::Vector{Float64},
225+
kron_buffer::Vector{Float64},
226+
kron_buffer2::Vector{Float64},
227+
kron_buffer3::AbstractMatrix{Float64},
228+
kron_buffer4::AbstractMatrix{Float64},
229+
kron_buffer5::AbstractMatrix{Float64},
230+
J::ℒ.Diagonal{Bool,Vector{Bool}},
231+
S_i::AbstractMatrix{Float64},
232+
S_i2e::AbstractMatrix{Float64},
233+
S_i3e::AbstractMatrix{Float64},
234+
shock_independent::Vector{Float64})
235+
ChainRulesCore.rrule(MacroModelling.find_shocks,
236+
Val(:LagrangeNewton),
237+
initial_guess,
238+
kron_buffer,
239+
kron_buffer2,
240+
kron_buffer3,
241+
kron_buffer4,
242+
kron_buffer5,
243+
J,
244+
S_i,
245+
S_i2e,
246+
S_i3e,
247+
shock_independent;
248+
max_iter=max_iter,
249+
tol=tol)
250+
end
251+
252+
function ChainRulesCore.rrule(func_ir::Any,
253+
warm_iters::Int,
254+
presample::Int,
255+
opts::MacroModelling.CalculationOptions,
256+
filt_alg::Symbol,
257+
::typeof(MacroModelling.calculate_inversion_filter_loglikelihood),
258+
alg::Val{A},
259+
state,
260+
S,
261+
data,
262+
observables,
263+
T::MacroModelling.timings) where A
264+
ChainRulesCore.rrule(MacroModelling.calculate_inversion_filter_loglikelihood,
265+
alg, state, S, data, observables, T;
266+
warmup_iterations=warm_iters,
267+
presample_periods=presample,
268+
opts=opts,
269+
filter_algorithm=filt_alg)
270+
end
48271

49272
end # module

0 commit comments

Comments
 (0)