Skip to content

Remove interpolation package dependency from AMSS lecture and replace with NumPy alternatives #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lectures/_static/downloads/amss_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dependencies:
- matplotlib
- networkx
- sphinx=2.4.4
- interpolation
- seaborn
- pip:
- sphinxcontrib-jupyter
Expand Down
81 changes: 75 additions & 6 deletions lectures/_static/lecture_specific/amss/recursive_allocation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,64 @@
import numpy as np
from numba import njit, prange
from quantecon import optimize

@njit
def get_grid_nodes(grid):
"""
Get the actual grid points from a grid tuple.
"""
x_min, x_max, x_num = grid
return np.linspace(x_min, x_max, x_num)

@njit
def linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_val):
"""Helper function for scalar interpolation"""
x_nodes = np.linspace(x_min, x_max, x_num)

# Extrapolation with linear extension
if x_val <= x_nodes[0]:
# Linear extrapolation using first two points
if x_num >= 2:
slope = (y_values[1] - y_values[0]) / (x_nodes[1] - x_nodes[0])
return y_values[0] + slope * (x_val - x_nodes[0])
else:
return y_values[0]

if x_val >= x_nodes[-1]:
# Linear extrapolation using last two points
if x_num >= 2:
slope = (y_values[-1] - y_values[-2]) / (x_nodes[-1] - x_nodes[-2])
return y_values[-1] + slope * (x_val - x_nodes[-1])
else:
return y_values[-1]

# Binary search for the right interval
left = 0
right = x_num - 1
while right - left > 1:
mid = (left + right) // 2
if x_nodes[mid] <= x_val:
left = mid
else:
right = mid

# Linear interpolation
x_left = x_nodes[left]
x_right = x_nodes[right]
y_left = y_values[left]
y_right = y_values[right]

weight = (x_val - x_left) / (x_right - x_left)
return y_left * (1 - weight) + y_right * weight

@njit
def linear_interp_1d(x_grid, y_values, x_query):
"""
Perform 1D linear interpolation.
"""
x_min, x_max, x_num = x_grid
return linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_query[0])

class AMSS:
# WARNING: THE CODE IS EXTREMELY SENSITIVE TO CHOCIES OF PARAMETERS.
# DO NOT CHANGE THE PARAMETERS AND EXPECT IT TO WORK
Expand Down Expand Up @@ -78,6 +139,10 @@ def simulate(self, s_hist, b_0):
pref = self.pref
x_grid, g, β, S = self.x_grid, self.g, self.β, self.S
σ_v_star, σ_w_star = self.σ_v_star, self.σ_w_star
Π = self.Π

# Extract the grid tuple from the list
grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid

T = len(s_hist)
s_0 = s_hist[0]
Expand Down Expand Up @@ -111,8 +176,8 @@ def simulate(self, s_hist, b_0):
T = np.zeros(S)
for s in range(S):
x_arr = np.array([x_])
l[s] = eval_linear(x_grid, σ_v_star[s_, :, s], x_arr)
T[s] = eval_linear(x_grid, σ_v_star[s_, :, S+s], x_arr)
l[s] = linear_interp_1d(grid_tuple, σ_v_star[s_, :, s], x_arr)
T[s] = linear_interp_1d(grid_tuple, σ_v_star[s_, :, S+s], x_arr)

c = (1 - l) - g
u_c = pref.Uc(c, l)
Expand All @@ -135,6 +200,8 @@ def simulate(self, s_hist, b_0):

def obj_factory(Π, β, x_grid, g):
S = len(Π)
# Extract the grid tuple from the list
grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid

@njit
def obj_V(σ, state, V, pref):
Expand All @@ -152,7 +219,7 @@ def obj_V(σ, state, V, pref):
V_next = np.zeros(S)

for s in range(S):
V_next[s] = eval_linear(x_grid, V[s], np.array([x[s]]))
V_next[s] = linear_interp_1d(grid_tuple, V[s], np.array([x[s]]))

out = Π[s_] @ (pref.U(c, l) + β * V_next)

Expand All @@ -167,7 +234,7 @@ def obj_W(σ, state, V, pref):
c = (1 - l) - g[s_]
x = -pref.Uc(c, l) * (c - T - b_0) + pref.Ul(c, l) * (1 - l)

V_next = eval_linear(x_grid, V[s_], np.array([x]))
V_next = linear_interp_1d(grid_tuple, V[s_], np.array([x]))

out = pref.U(c, l) + β * V_next

Expand All @@ -178,9 +245,11 @@ def obj_W(σ, state, V, pref):

def bellman_operator_factory(Π, β, x_grid, g, bounds_v):
obj_V, obj_W = obj_factory(Π, β, x_grid, g)
n = x_grid[0][2]
# Extract the grid tuple from the list
grid_tuple = x_grid[0] if isinstance(x_grid, list) else x_grid
n = grid_tuple[2]
S = len(Π)
x_nodes = nodes(x_grid)
x_nodes = get_grid_nodes(grid_tuple)

@njit(parallel=True)
def T_v(V, V_new, σ_star, pref):
Expand Down
73 changes: 67 additions & 6 deletions lectures/amss.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ In addition to what's in Anaconda, this lecture will need the following librarie
tags: [hide-output]
---
!pip install --upgrade quantecon
!pip install interpolation
```

## Overview
Expand All @@ -38,12 +37,74 @@ Let's start with following imports:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import root
from interpolation.splines import eval_linear, UCGrid, nodes
from quantecon import optimize, MarkovChain
from numba import njit, prange, float64
from numba.experimental import jitclass
```

Now let's define numba-compatible interpolation functions for this lecture.

We will soon use the following interpolation functions to interpolate the value function and the policy functions

```{code-cell} ipython
@njit
def get_grid_nodes(grid):
"""
Get the actual grid points from a grid tuple.
"""
x_min, x_max, x_num = grid
return np.linspace(x_min, x_max, x_num)

@njit
def linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_val):
"""Helper function for scalar interpolation"""
x_nodes = np.linspace(x_min, x_max, x_num)

# Extrapolation with linear extension
if x_val <= x_nodes[0]:
# Linear extrapolation using first two points
if x_num >= 2:
slope = (y_values[1] - y_values[0]) / (x_nodes[1] - x_nodes[0])
return y_values[0] + slope * (x_val - x_nodes[0])
else:
return y_values[0]

if x_val >= x_nodes[-1]:
# Linear extrapolation using last two points
if x_num >= 2:
slope = (y_values[-1] - y_values[-2]) / (x_nodes[-1] - x_nodes[-2])
return y_values[-1] + slope * (x_val - x_nodes[-1])
else:
return y_values[-1]

# Binary search for the right interval
left = 0
right = x_num - 1
while right - left > 1:
mid = (left + right) // 2
if x_nodes[mid] <= x_val:
left = mid
else:
right = mid

# Linear interpolation
x_left = x_nodes[left]
x_right = x_nodes[right]
y_left = y_values[left]
y_right = y_values[right]

weight = (x_val - x_left) / (x_right - x_left)
return y_left * (1 - weight) + y_right * weight

@njit
def linear_interp_1d(x_grid, y_values, x_query):
"""
Perform 1D linear interpolation.
"""
x_min, x_max, x_num = x_grid
return linear_interp_1d_scalar(x_min, x_max, x_num, y_values, x_query[0])
```

In {doc}`an earlier lecture <opt_tax_recur>`, we described a model of
optimal taxation with state-contingent debt due to
Robert E. Lucas, Jr., and Nancy Stokey {cite}`LucasStokey1983`.
Expand Down Expand Up @@ -774,7 +835,7 @@ x_min = -1.5555
x_max = 17.339
x_num = 300

x_grid = UCGrid((x_min, x_max, x_num))
x_grid = [(x_min, x_max, x_num)]

crra_pref = CRRAutility(β=β, σ=σ, γ=γ)

Expand All @@ -788,7 +849,7 @@ amss_model = AMSS(crra_pref, β, Π, g, x_grid, bounds_v)
```{code-cell} python3
# WARNING: DO NOT EXPECT THE CODE TO WORK IF YOU CHANGE PARAMETERS
V = np.zeros((len(Π), x_num))
V[:] = -nodes(x_grid).T ** 2
V[:] = -get_grid_nodes(x_grid[0]) ** 2

σ_v_star = np.ones((S, x_num, S * 2))
σ_v_star[:, :, :S] = 0.0
Expand Down Expand Up @@ -914,14 +975,14 @@ x_min = -3.4107
x_max = 3.709
x_num = 300

x_grid = UCGrid((x_min, x_max, x_num))
x_grid = [(x_min, x_max, x_num)]
log_pref = LogUtility(β=β, ψ=ψ)

S = len(Π)
bounds_v = np.vstack([np.zeros(2 * S), np.hstack([1 - g, np.ones(S)]) ]).T

V = np.zeros((len(Π), x_num))
V[:] = -(nodes(x_grid).T + x_max) ** 2 / 14
V[:] = -(get_grid_nodes(x_grid[0]) + x_max) ** 2 / 14

σ_v_star = 1 - np.full((S, x_num, S * 2), 0.55)

Expand Down
Loading