Skip to content
256 changes: 131 additions & 125 deletions lectures/mccall_fitted_vfi.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,18 @@ We will use the following imports:

```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np
from numba import jit, float64
from numba.experimental import jitclass
import jax
import jax.numpy as jnp
from typing import NamedTuple
import quantecon as qe

# Set JAX to use CPU
jax.config.update('jax_platform_name', 'cpu')
```

## The Algorithm
## The algorithm

The model is the same as the McCall model with job separation we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.
The model is the same as the McCall model with job separation that we {doc}`studied before <mccall_model_with_separation>`, except that the wage offer distribution is continuous.

We are going to start with the two Bellman equations we obtained for the model with job separation after {ref}`a simplifying transformation <ast_mcm>`.

Expand All @@ -82,16 +86,16 @@ v(w) = u(w) + \beta

The unknowns here are the function $v$ and the scalar $d$.

The difference between these and the pair of Bellman equations we previously worked on are
The differences between these and the pair of Bellman equations we previously worked on are

1. in {eq}`bell1mcmc`, what used to be a sum over a finite number of wage values is an integral over an infinite set.
1. In {eq}`bell1mcmc`, what used to be a sum over a finite number of wage values is an integral over an infinite set.
1. The function $v$ in {eq}`bell2mcmc` is defined over all $w \in \mathbb R_+$.

The function $q$ in {eq}`bell1mcmc` is the density of the wage offer distribution.

Its support is taken as equal to $\mathbb R_+$.

### Value Function Iteration
### Value function iteration

In theory, we should now proceed as follows:

Expand All @@ -106,12 +110,12 @@ The iterates of the value function can neither be calculated exactly nor stored

To see the issue, consider {eq}`bell2mcmc`.

Even if $v$ is a known function, the only way to store its update $v'$
Even if $v$ is a known function, the only way to store its update $v'$
is to record its value $v'(w)$ for every $w \in \mathbb R_+$.

Clearly, this is impossible.

### Fitted Value Function Iteration
### Fitted value function iteration

What we will do instead is use **fitted value function iteration**.

Expand Down Expand Up @@ -141,25 +145,25 @@ One good choice from both respects is continuous piecewise linear interpolation.

This method

1. combines well with value function iteration (see., e.g.,
1. combines well with value function iteration (see, e.g.,
{cite}`gordon1995stable` or {cite}`stachurski2008continuous`) and
1. preserves useful shape properties such as monotonicity and concavity/convexity.

Linear interpolation will be implemented using [numpy.interp](https://numpy.org/doc/stable/reference/generated/numpy.interp.html).
Linear interpolation will be implemented using JAX's interpolation function `jnp.interp`.

The next figure illustrates piecewise linear interpolation of an arbitrary
function on grid points $0, 0.2, 0.4, 0.6, 0.8, 1$.

```{code-cell} python3
def f(x):
y1 = 2 * np.cos(6 * x) + np.sin(14 * x)
y1 = 2 * jnp.cos(6 * x) + jnp.sin(14 * x)
return y1 + 2.5

c_grid = np.linspace(0, 1, 6)
f_grid = np.linspace(0, 1, 150)
c_grid = jnp.linspace(0, 1, 6)
f_grid = jnp.linspace(0, 1, 150)

def Af(x):
return np.interp(x, c_grid, f(c_grid))
return jnp.interp(x, c_grid, f(c_grid))

fig, ax = plt.subplots()

Expand All @@ -175,123 +179,126 @@ plt.show()

## Implementation

The first step is to build a jitted class for the McCall model with separation and
a continuous wage offer distribution.
The first step is to build a JAX-compatible structure for the McCall model with separation and a continuous wage offer distribution.

We will take the utility function to be the log function for this application, with $u(c) = \ln c$.

We will adopt the lognormal distribution for wages, with $w = \exp(\mu + \sigma z)$
when $z$ is standard normal and $\mu, \sigma$ are parameters.

```{code-cell} python3
@jit
def lognormal_draws(n=1000, μ=2.5, σ=0.5, seed=1234):
np.random.seed(seed)
z = np.random.randn(n)
w_draws = np.exp(μ + σ * z)
key = jax.random.PRNGKey(seed)
z = jax.random.normal(key, (n,))
w_draws = jnp.exp(μ + σ * z)
return w_draws
```

Here's our class.
Here's our model structure using a NamedTuple.

```{code-cell} python3
mccall_data_continuous = [
('c', float64), # unemployment compensation
('α', float64), # job separation rate
('β', float64), # discount factor
('w_grid', float64[:]), # grid of points for fitted VFI
('w_draws', float64[:]) # draws of wages for Monte Carlo
]

@jitclass(mccall_data_continuous)
class McCallModelContinuous:

def __init__(self,
c=1,
α=0.1,
β=0.96,
grid_min=1e-10,
grid_max=5,
grid_size=100,
w_draws=lognormal_draws()):

self.c, self.α, self.β = c, α, β

self.w_grid = np.linspace(grid_min, grid_max, grid_size)
self.w_draws = w_draws

def update(self, v, d):

# Simplify names
c, α, β = self.c, self.α, self.β
w = self.w_grid
u = lambda x: np.log(x)

# Interpolate array represented value function
vf = lambda x: np.interp(x, w, v)

# Update d using Monte Carlo to evaluate integral
d_new = np.mean(np.maximum(vf(self.w_draws), u(c) + β * d))

# Update v
v_new = u(w) + β * ((1 - α) * v + α * d)

return v_new, d_new
class McCallModelContinuous(NamedTuple):
c: float # unemployment compensation
α: float # job separation rate
β: float # discount factor
w_grid: jnp.ndarray # grid of points for fitted VFI
w_draws: jnp.ndarray # draws of wages for Monte Carlo

def create_mccall_model(c=1,
α=0.1,
β=0.96,
grid_min=1e-10,
grid_max=5,
grid_size=100,
μ=2.5,
σ=0.5,
mc_size=1000,
seed=1234,
w_draws=None):
"""Factory function to create a McCall model instance."""
if w_draws is None:
# Generate wage draws if not provided
w_draws = lognormal_draws(n=mc_size, μ=μ, σ=σ, seed=seed)

w_grid = jnp.linspace(grid_min, grid_max, grid_size)
return McCallModelContinuous(c=c, α=α, β=β, w_grid=w_grid, w_draws=w_draws)

@jax.jit
def update(model, v, d):
"""Update value function and continuation value."""
# Unpack model parameters
c, α, β, w_grid, w_draws = model
u = jnp.log

# Interpolate array represented value function
vf = lambda x: jnp.interp(x, w_grid, v)

# Update d using Monte Carlo to evaluate integral
d_new = jnp.mean(jnp.maximum(vf(w_draws), u(c) + β * d))

# Update v
v_new = u(w_grid) + β * ((1 - α) * v + α * d)

return v_new, d_new
```

We then return the current iterate as an approximate solution.

```{code-cell} python3
@jit
def solve_model(mcm, tol=1e-5, max_iter=2000):
@jax.jit
def solve_model(model, tol=1e-5, max_iter=2000):
"""
Iterates to convergence on the Bellman equations

* mcm is an instance of McCallModel
* model is an instance of McCallModelContinuous
"""

v = np.ones_like(mcm.w_grid) # Initial guess of v
d = 1 # Initial guess of d
i = 0
error = tol + 1

while error > tol and i < max_iter:
v_new, d_new = mcm.update(v, d)
error_1 = np.max(np.abs(v_new - v))
error_2 = np.abs(d_new - d)
error = max(error_1, error_2)
v = v_new
d = d_new
i += 1

return v, d

# Initial guesses
v = jnp.ones_like(model.w_grid)
d = 1.0

def body_fun(state):
v, d, i, error = state
v_new, d_new = update(model, v, d)
error_1 = jnp.max(jnp.abs(v_new - v))
error_2 = jnp.abs(d_new - d)
error = jnp.maximum(error_1, error_2)
return v_new, d_new, i + 1, error

def cond_fun(state):
_, _, i, error = state
return (error > tol) & (i < max_iter)

initial_state = (v, d, 0, tol + 1)
v_final, d_final, _, _ = jax.lax.while_loop(cond_fun, body_fun, initial_state)

return v_final, d_final
```

Here's a function `compute_reservation_wage` that takes an instance of `McCallModelContinuous`
and returns the associated reservation wage.

If $v(w) < h$ for all $w$, then the function returns np.inf
If $v(w) < h$ for all $w$, then the function returns `jnp.inf`

```{code-cell} python3
@jit
def compute_reservation_wage(mcm):
@jax.jit
def compute_reservation_wage(model):
"""
Computes the reservation wage of an instance of the McCall model
by finding the smallest w such that v(w) >= h.

If no such w exists, then w_bar is set to np.inf.
If no such w exists, then w_bar is set to inf.
"""
u = lambda x: np.log(x)

v, d = solve_model(mcm)
h = u(mcm.c) + mcm.β * d

w_bar = np.inf
for i, wage in enumerate(mcm.w_grid):
if v[i] > h:
w_bar = wage
break

c, α, β, w_grid, w_draws = model
u = jnp.log

v, d = solve_model(model)
h = u(c) + β * d

# Find the first wage where v(w) >= h
indices = jnp.where(v >= h, size=1, fill_value=-1)
w_bar = jnp.where(indices[0] >= 0, w_grid[indices[0]], jnp.inf)

return w_bar
```

Expand All @@ -305,7 +312,7 @@ The exercises ask you to explore the solution and how it changes with parameters
Use the code above to explore what happens to the reservation wage when the wage parameter $\mu$
changes.

Use the default parameters and $\mu$ in `mu_vals = np.linspace(0.0, 2.0, 15)`.
Use the default parameters and $\mu$ in `μ_vals = jnp.linspace(0.0, 2.0, 15)`.

Is the impact on the reservation wage as you expected?
```
Expand All @@ -317,21 +324,18 @@ Is the impact on the reservation wage as you expected?
Here is one solution

```{code-cell} python3
mcm = McCallModelContinuous()
mu_vals = np.linspace(0.0, 2.0, 15)
w_bar_vals = np.empty_like(mu_vals)

fig, ax = plt.subplots()
def compute_res_wage_given_μ(μ):
model = create_mccall_model(μ=μ)
w_bar = compute_reservation_wage(model)
return w_bar

for i, m in enumerate(mu_vals):
mcm.w_draws = lognormal_draws(μ=m)
w_bar = compute_reservation_wage(mcm)
w_bar_vals[i] = w_bar
μ_vals = jnp.linspace(0.0, 2.0, 15)
w_bar_vals = jax.vmap(compute_res_wage_given_μ)(μ_vals)

fig, ax = plt.subplots()
ax.set(xlabel='mean', ylabel='reservation wage')
ax.plot(mu_vals, w_bar_vals, label=r'$\bar w$ as a function of $\mu$')
ax.plot(μ_vals, w_bar_vals, label=r'$\bar w$ as a function of $\mu$')
ax.legend()

plt.show()
```

Expand All @@ -354,11 +358,11 @@ support.

(This is a form of *mean-preserving spread*.)

Use `s_vals = np.linspace(1.0, 2.0, 15)` and `m = 2.0`.
Use `s_vals = jnp.linspace(1.0, 2.0, 15)` and `m = 2.0`.

State how you expect the reservation wage to vary with $s$.

Now compute it. Is this as you expected?
Now compute it - is this as you expected?
```

```{solution-start} mfv_ex2
Expand All @@ -368,23 +372,25 @@ Now compute it. Is this as you expected?
Here is one solution

```{code-cell} python3
mcm = McCallModelContinuous()
s_vals = np.linspace(1.0, 2.0, 15)
m = 2.0
w_bar_vals = np.empty_like(s_vals)

fig, ax = plt.subplots()

for i, s in enumerate(s_vals):
def compute_res_wage_given_s(s, m=2.0, seed=1234):
a, b = m - s, m + s
mcm.w_draws = np.random.uniform(low=a, high=b, size=10_000)
w_bar = compute_reservation_wage(mcm)
w_bar_vals[i] = w_bar
key = jax.random.PRNGKey(seed)
uniform_draws = jax.random.uniform(key, shape=(10_000,), minval=a, maxval=b)
# Create model with default parameters but replace wage draws
model = create_mccall_model(w_draws=uniform_draws)
w_bar = compute_reservation_wage(model)
return w_bar

s_vals = jnp.linspace(1.0, 2.0, 15)
# Use vmap with different seeds for each s value
seeds = jnp.arange(len(s_vals))
compute_vectorized = jax.vmap(compute_res_wage_given_s, in_axes=(0, None, 0))
w_bar_vals = compute_vectorized(s_vals, 2.0, seeds)

fig, ax = plt.subplots()
ax.set(xlabel='volatility', ylabel='reservation wage')
ax.plot(s_vals, w_bar_vals, label=r'$\bar w$ as a function of wage volatility')
ax.legend()

plt.show()
```

Expand Down
Loading