Skip to content

Commit 2fcf61a

Browse files
committed
add initial state inputs to diff tables
1 parent daf6a3a commit 2fcf61a

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

ext/StatsPlotsExt.jl

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,8 @@ function plot_irf(𝓂::ℳ;
636636

637637
reference_steady_state, NSSS, SSS_delta = get_relevant_steady_states(𝓂, algorithm, opts = opts)
638638

639+
initial_state_input = copy(initial_state)
640+
639641
unspecified_initial_state = initial_state == [0.0]
640642

641643
if unspecified_initial_state
@@ -854,7 +856,7 @@ function plot_irf(𝓂::ℳ;
854856
:shock_size => shock_size,
855857
:negative_shock => negative_shock,
856858
:generalised_irf => generalised_irf,
857-
:initial_state => initial_state,
859+
:initial_state => initial_state_input,
858860
:ignore_obc => ignore_obc,
859861
:tol => tol,
860862
:quadratic_matrix_equation_algorithm => quadratic_matrix_equation_algorithm,
@@ -1119,6 +1121,8 @@ function plot_irf!(𝓂::ℳ;
11191121

11201122
reference_steady_state, NSSS, SSS_delta = get_relevant_steady_states(𝓂, algorithm, opts = opts)
11211123

1124+
initial_state_input = copy(initial_state)
1125+
11221126
unspecified_initial_state = initial_state == [0.0]
11231127

11241128
if unspecified_initial_state
@@ -1321,7 +1325,7 @@ function plot_irf!(𝓂::ℳ;
13211325
:shock_size => shock_size,
13221326
:negative_shock => negative_shock,
13231327
:generalised_irf => generalised_irf,
1324-
:initial_state => initial_state,
1328+
:initial_state => initial_state_input,
13251329
:ignore_obc => ignore_obc,
13261330
:tol => tol,
13271331
:quadratic_matrix_equation_algorithm => quadratic_matrix_equation_algorithm,
@@ -1348,6 +1352,7 @@ function plot_irf!(𝓂::ℳ;
13481352
grouped_by_model = Dict{Any, Vector{Dict}}()
13491353

13501354
for d in irf_active_plot_container
1355+
# println(d[:initial_state])
13511356
model = d[:model_name]
13521357
d_sub = Dict(k => d[k] for k in setdiff(keys(args_and_kwargs),keys(args_and_kwargs_names)) if haskey(d, k))
13531358
push!(get!(grouped_by_model, model, Vector{Dict}()), d_sub)
@@ -1361,15 +1366,15 @@ function plot_irf!(𝓂::ℳ;
13611366

13621367
model_names = unique(model_names)
13631368

1364-
# for (i,d) in enumerate(irf_active_plot_container)
13651369
for model in model_names
13661370
if length(grouped_by_model[model]) > 1
13671371
diffdict_grouped = compare_args_and_kwargs(grouped_by_model[model])
13681372
diffdict = merge_by_runid(diffdict, diffdict_grouped)
13691373
end
13701374
end
13711375

1372-
@assert haskey(diffdict, :parameters) || haskey(diffdict, :shock_names) || any(haskey.(Ref(diffdict), keys(args_and_kwargs_names))) "New plot must be different from previous plot. Use the version without ! to plot."
1376+
@assert haskey(diffdict, :parameters) || haskey(diffdict, :shock_names) || haskey(diffdict, :initial_state) ||
1377+
any(haskey.(Ref(diffdict), keys(args_and_kwargs_names))) "New plot must be different from previous plot. Use the version without ! to plot."
13731378

13741379
annotate_ss = Vector{Pair{String, Any}}[]
13751380

@@ -1394,16 +1399,35 @@ function plot_irf!(𝓂::ℳ;
13941399
push!(annotate_diff_input, "Shock" => diffdict[:shocks])
13951400
end
13961401
end
1402+
1403+
if haskey(diffdict, :initial_state)
1404+
unique_initial_state = unique(diffdict[:initial_state])
1405+
1406+
initial_state_idx = Int[]
1407+
1408+
for init in diffdict[:initial_state]
1409+
for (i,u) in enumerate(unique_initial_state)
1410+
if u == init
1411+
push!(initial_state_idx,i)
1412+
continue
1413+
end
1414+
end
1415+
end
1416+
1417+
push!(annotate_diff_input, "Initial state" => ["#$i" for i in initial_state_idx])
1418+
end
13971419

13981420
same_shock_direction = true
13991421

14001422
for k in setdiff(keys(args_and_kwargs),
1401-
[
1402-
:run_id, :parameters, :initial_state, :plot_data, :tol, :reference_steady_state,
1403-
:shocks, :shock_names, :shock_idx,
1404-
:variables, :variable_names, :var_idx,
1405-
# :periods, :quadratic_matrix_equation_algorithm, :sylvester_algorithm, :lyapunov_algorithm,
1406-
])
1423+
[
1424+
:run_id, :parameters, :plot_data, :tol, :reference_steady_state, :initial_state,
1425+
:shocks, :shock_names, :shock_idx,
1426+
:variables, :variable_names, :var_idx,
1427+
# :periods, :quadratic_matrix_equation_algorithm, :sylvester_algorithm, :lyapunov_algorithm,
1428+
]
1429+
)
1430+
14071431
if haskey(diffdict, k)
14081432
push!(annotate_diff_input, args_and_kwargs_names[k] => reduce(vcat,diffdict[k]))
14091433

0 commit comments

Comments
 (0)