diff --git a/lectures/ifp_advanced.md b/lectures/ifp_advanced.md index f6df09b27..6c0fbd919 100644 --- a/lectures/ifp_advanced.md +++ b/lectures/ifp_advanced.md @@ -487,15 +487,15 @@ a_init = σ_init.copy() Let's generate an approximation solution with JAX: ```{code-cell} ipython3 -a_star, σ_star = solve_model(ifp, a_init, σ_init) +σ_star, a_star = solve_model(ifp, σ_init, a_init) ``` Let's try it again with a timer. ```{code-cell} python3 with qe.Timer(precision=8): - a_star, σ_star = solve_model(ifp, a_init, σ_init) - a_star.block_until_ready() + σ_star, a_star = solve_model(ifp, σ_init, a_init) + σ_star.block_until_ready() ``` ## Simulation @@ -642,7 +642,7 @@ s_grid = ifp.s_grid n_z = len(ifp.P) a_init = s_grid[:, None] * jnp.ones(n_z) c_init = a_init -a_vec, c_vec = solve_model(ifp, a_init, c_init) +c_vec, a_vec = solve_model(ifp, c_init, a_init) assets = compute_asset_stationary(c_vec, a_vec, ifp, num_households=200_000) # Compute Gini coefficient for the plot @@ -734,8 +734,8 @@ for a_r in a_r_vals: n_z_temp = len(ifp_temp.P) a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp) c_init_temp = a_init_temp - a_vec_temp, c_vec_temp = solve_model( - ifp_temp, a_init_temp, c_init_temp + c_vec_temp, a_vec_temp = solve_model( + ifp_temp, c_init_temp, a_init_temp ) # Simulate households @@ -811,8 +811,8 @@ for a_y in a_y_vals: n_z_temp = len(ifp_temp.P) a_init_temp = s_grid_temp[:, None] * jnp.ones(n_z_temp) c_init_temp = a_init_temp - a_vec_temp, c_vec_temp = solve_model( - ifp_temp, a_init_temp, c_init_temp + c_vec_temp, a_vec_temp = solve_model( + ifp_temp, c_init_temp, a_init_temp ) # Simulate households