diff --git a/lectures/_static/downloads/amss_environment.yml b/lectures/_static/downloads/amss_environment.yml index aa81ff5f..38a213e8 100644 --- a/lectures/_static/downloads/amss_environment.yml +++ b/lectures/_static/downloads/amss_environment.yml @@ -16,7 +16,6 @@ dependencies: - matplotlib - networkx - sphinx=2.4.4 - - interpolation - seaborn - pip: - sphinxcontrib-jupyter diff --git a/lectures/_static/lecture_specific/amss/recursive_allocation.py b/lectures/_static/lecture_specific/amss/recursive_allocation.py index f5495f1f..1dd083d1 100644 --- a/lectures/_static/lecture_specific/amss/recursive_allocation.py +++ b/lectures/_static/lecture_specific/amss/recursive_allocation.py @@ -1,3 +1,64 @@ +import numpy as np +from numba import njit, prange +from quantecon import optimize + +@njit +def get_grid_nodes(grid): + """ + Get the actual grid points from a grid tuple. + """ + x_min, x_max, x_num = grid + return np.linspace(x_min, x_max, x_num) + +@njit +def linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_val): + """Helper function for scalar interpolation""" + x_nodes = np.linspace(x_min, x_max, x_num) + + # Extrapolation with linear extension + if x_val <= x_nodes[0]: + # Linear extrapolation using first two points + if x_num >= 2: + slope = (y_values[1] - y_values[0]) / (x_nodes[1] - x_nodes[0]) + return y_values[0] + slope * (x_val - x_nodes[0]) + else: + return y_values[0] + + if x_val >= x_nodes[-1]: + # Linear extrapolation using last two points + if x_num >= 2: + slope = (y_values[-1] - y_values[-2]) / (x_nodes[-1] - x_nodes[-2]) + return y_values[-1] + slope * (x_val - x_nodes[-1]) + else: + return y_values[-1] + + # Binary search for the right interval + left = 0 + right = x_num - 1 + while right - left > 1: + mid = (left + right) // 2 + if x_nodes[mid] <= x_val: + left = mid + else: + right = mid + + # Linear interpolation + x_left = x_nodes[left] + x_right = x_nodes[right] + y_left = y_values[left] + y_right = y_values[right] + + weight = (x_val - x_left) / (x_right - x_left) + return y_left * (1 - weight) + y_right * weight + +@njit +def linear_interp_1d(x_grid, y_values, x_query): + """ + Perform 1D linear interpolation. + """ + x_min, x_max, x_num = x_grid + return linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_query[0]) + class AMSS: # WARNING: THE CODE IS EXTREMELY SENSITIVE TO CHOCIES OF PARAMETERS. # DO NOT CHANGE THE PARAMETERS AND EXPECT IT TO WORK @@ -78,6 +139,10 @@ def simulate(self, s_hist, b_0): pref = self.pref x_grid, g, β, S = self.x_grid, self.g, self.β, self.S σ_v_star, σ_w_star = self.σ_v_star, self.σ_w_star + Π = self.Π + + # Extract the grid tuple from the list + grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid T = len(s_hist) s_0 = s_hist[0] @@ -111,8 +176,8 @@ def simulate(self, s_hist, b_0): T = np.zeros(S) for s in range(S): x_arr = np.array([x_]) - l[s] = eval_linear(x_grid, σ_v_star[s_, :, s], x_arr) - T[s] = eval_linear(x_grid, σ_v_star[s_, :, S+s], x_arr) + l[s] = linear_interp_1d(grid_tuple, σ_v_star[s_, :, s], x_arr) + T[s] = linear_interp_1d(grid_tuple, σ_v_star[s_, :, S+s], x_arr) c = (1 - l) - g u_c = pref.Uc(c, l) @@ -135,6 +200,8 @@ def simulate(self, s_hist, b_0): def obj_factory(Π, β, x_grid, g): S = len(Π) + # Extract the grid tuple from the list + grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid @njit def obj_V(σ, state, V, pref): @@ -152,7 +219,7 @@ def obj_V(σ, state, V, pref): V_next = np.zeros(S) for s in range(S): - V_next[s] = eval_linear(x_grid, V[s], np.array([x[s]])) + V_next[s] = linear_interp_1d(grid_tuple, V[s], np.array([x[s]])) out = Π[s_] @ (pref.U(c, l) + β * V_next) @@ -167,7 +234,7 @@ def obj_W(σ, state, V, pref): c = (1 - l) - g[s_] x = -pref.Uc(c, l) * (c - T - b_0) + pref.Ul(c, l) * (1 - l) - V_next = eval_linear(x_grid, V[s_], np.array([x])) + V_next = linear_interp_1d(grid_tuple, V[s_], np.array([x])) out = pref.U(c, l) + β * V_next @@ -178,9 +245,11 @@ def obj_W(σ, state, V, pref): def bellman_operator_factory(Π, β, x_grid, g, bounds_v): obj_V, obj_W = obj_factory(Π, β, x_grid, g) - n = x_grid[0][2] + # Extract the grid tuple from the list + grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid + n = grid_tuple[2] S = len(Π) - x_nodes = nodes(x_grid) + x_nodes = get_grid_nodes(grid_tuple) @njit(parallel=True) def T_v(V, V_new, σ_star, pref): diff --git a/lectures/amss.md b/lectures/amss.md index 0bb65d74..196a7c36 100644 --- a/lectures/amss.md +++ b/lectures/amss.md @@ -27,7 +27,6 @@ In addition to what's in Anaconda, this lecture will need the following librarie tags: [hide-output] --- !pip install --upgrade quantecon -!pip install interpolation ``` ## Overview @@ -38,12 +37,74 @@ Let's start with following imports: import numpy as np import matplotlib.pyplot as plt from scipy.optimize import root -from interpolation.splines import eval_linear, UCGrid, nodes from quantecon import optimize, MarkovChain from numba import njit, prange, float64 from numba.experimental import jitclass ``` +Now let's define numba-compatible interpolation functions for this lecture. + +We will soon use the following interpolation functions to interpolate the value function and the policy functions + +```{code-cell} ipython +@njit +def get_grid_nodes(grid): + """ + Get the actual grid points from a grid tuple. + """ + x_min, x_max, x_num = grid + return np.linspace(x_min, x_max, x_num) + +@njit +def linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_val): + """Helper function for scalar interpolation""" + x_nodes = np.linspace(x_min, x_max, x_num) + + # Extrapolation with linear extension + if x_val <= x_nodes[0]: + # Linear extrapolation using first two points + if x_num >= 2: + slope = (y_values[1] - y_values[0]) / (x_nodes[1] - x_nodes[0]) + return y_values[0] + slope * (x_val - x_nodes[0]) + else: + return y_values[0] + + if x_val >= x_nodes[-1]: + # Linear extrapolation using last two points + if x_num >= 2: + slope = (y_values[-1] - y_values[-2]) / (x_nodes[-1] - x_nodes[-2]) + return y_values[-1] + slope * (x_val - x_nodes[-1]) + else: + return y_values[-1] + + # Binary search for the right interval + left = 0 + right = x_num - 1 + while right - left > 1: + mid = (left + right) // 2 + if x_nodes[mid] <= x_val: + left = mid + else: + right = mid + + # Linear interpolation + x_left = x_nodes[left] + x_right = x_nodes[right] + y_left = y_values[left] + y_right = y_values[right] + + weight = (x_val - x_left) / (x_right - x_left) + return y_left * (1 - weight) + y_right * weight + +@njit +def linear_interp_1d(x_grid, y_values, x_query): + """ + Perform 1D linear interpolation. + """ + x_min, x_max, x_num = x_grid + return linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_query[0]) +``` + In {doc}`an earlier lecture `, we described a model of optimal taxation with state-contingent debt due to Robert E. Lucas, Jr., and Nancy Stokey {cite}`LucasStokey1983`. @@ -774,7 +835,7 @@ x_min = -1.5555 x_max = 17.339 x_num = 300 -x_grid = UCGrid((x_min, x_max, x_num)) +x_grid = [(x_min, x_max, x_num)] crra_pref = CRRAutility(β=β, σ=σ, γ=γ) @@ -788,7 +849,7 @@ amss_model = AMSS(crra_pref, β, Π, g, x_grid, bounds_v) ```{code-cell} python3 # WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS V = np.zeros((len(Π), x_num)) -V[:] = -nodes(x_grid).T ** 2 +V[:] = -get_grid_nodes(x_grid[0]) ** 2 σ_v_star = np.ones((S, x_num, S * 2)) σ_v_star[:, :, :S] = 0.0 @@ -914,14 +975,14 @@ x_min = -3.4107 x_max = 3.709 x_num = 300 -x_grid = UCGrid((x_min, x_max, x_num)) +x_grid = [(x_min, x_max, x_num)] log_pref = LogUtility(β=β, ψ=ψ) S = len(Π) bounds_v = np.vstack([np.zeros(2 * S), np.hstack([1 - g, np.ones(S)]) ]).T V = np.zeros((len(Π), x_num)) -V[:] = -(nodes(x_grid).T + x_max) ** 2 / 14 +V[:] = -(get_grid_nodes(x_grid[0]) + x_max) ** 2 / 14 σ_v_star = 1 - np.full((S, x_num, S * 2), 0.55)