56. Optimal Savings VI: EGM with JAX#

56.1. Overview#

In this lecture, we’ll implement the endogenous grid method (EGM) using JAX.

This lecture builds on Optimal Savings V: The Endogenous Grid Method, which introduced EGM using NumPy.

By converting to JAX, we can leverage fast linear algebra, hardware accelerators, and JIT compilation for improved performance.

We’ll also use JAX’s vmap function to fully vectorize the Coleman-Reffett operator.

Let’s start with some standard imports:

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import quantecon as qe
from typing import NamedTuple

56.2. Implementation#

For details on the savings problem and the endogenous grid method (EGM), please see Optimal Savings V: The Endogenous Grid Method.

Here we focus on the JAX implementation of EGM.

We use the same setting as in Optimal Savings V: The Endogenous Grid Method:

  • \(u(c) = \ln c\),

  • production is Cobb-Douglas, and

  • the shocks are lognormal.

Here are the analytical solutions for comparison.

def v_star(x, α, β, μ):
    """
    True value function
    """
    c1 = jnp.log(1 - α * β) / (1 - β)
    c2 = (μ + α * jnp.log(α * β)) / (1 - α)
    c3 = 1 / (1 - β)
    c4 = 1 / (1 - α * β)
    return c1 + c2 * (c3 - c4) + c4 * jnp.log(x)

def σ_star(x, α, β):
    """
    True optimal policy
    """
    return (1 - α * β) * x

The Model class stores only the data (grids, shocks, and parameters).

Utility and production functions will be defined globally to work with JAX’s JIT compiler.

class Model(NamedTuple):
    β: float              # discount factor
    μ: float              # shock location parameter
    s: float              # shock scale parameter
    s_grid: jnp.ndarray   # exogenous savings grid
    shocks: jnp.ndarray   # shock draws
    α: float              # production function parameter


def create_model(β: float = 0.96,
                 μ: float = 0.0,
                 s: float = 0.1,
                 grid_max: float = 4.0,
                 grid_size: int = 120,
                 shock_size: int = 250,
                 seed: int = 1234,
                 α: float = 0.4) -> Model:
    """
    Creates an instance of the optimal savings model.
    """
    # Set up exogenous savings grid
    s_grid = jnp.linspace(1e-4, grid_max, grid_size)

    # Store shocks (with a seed, so results are reproducible)
    key = jax.random.PRNGKey(seed)
    shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,)))

    return Model(β=β, μ=μ, s=s, s_grid=s_grid, shocks=shocks, α=α)

Here’s the Coleman-Reffett operator using EGM.

The key JAX feature here is vmap, which vectorizes the computation over the grid points.

def K(
        c_in: jnp.ndarray,  # Consumption values on the endogenous grid
        x_in: jnp.ndarray,  # Current endogenous grid
        model: Model        # Model specification
    ):
    """
    The Coleman-Reffett operator using EGM

    """

    # Simplify names
    β, α = model.β, model.α
    s_grid, shocks = model.s_grid, model.shocks

    # Linear interpolation of policy using endogenous grid
    σ = lambda x_val: jnp.interp(x_val, x_in, c_in)

    # Define function to compute consumption at a single grid point
    def compute_c(s):
        vals = u_prime(σ(f(s, α) * shocks)) * f_prime(s, α) * shocks
        return u_prime_inv(β * jnp.mean(vals))

    # Vectorize over grid using vmap
    compute_c_vectorized = jax.vmap(compute_c)
    c_out = compute_c_vectorized(s_grid)

    # Determine corresponding endogenous grid
    x_out = s_grid + c_out  # x_i = s_i + c_i

    return c_out, x_out

We define utility and production functions globally.

Note that f and f_prime take α as an explicit argument, allowing them to work with JAX’s functional programming model.

# Define utility and production functions with derivatives
u = lambda c: jnp.log(c)
u_prime = lambda c: 1 / c
u_prime_inv = lambda x: 1 / x
f = lambda k, α: k**α
f_prime = lambda k, α: α * k**(α - 1)

Now we create a model instance.

model = create_model()
s_grid = model.s_grid
W1125 05:16:10.195111    2548 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1125 05:16:10.198719    2503 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.

The solver uses JAX’s jax.lax.while_loop for the iteration and is JIT-compiled for speed.

@jax.jit
def solve_model_time_iter(model: Model,
                          c_init: jnp.ndarray,
                          x_init: jnp.ndarray,
                          tol: float = 1e-5,
                          max_iter: int = 1000):
    """
    Solve the model using time iteration with EGM.
    """

    def condition(loop_state):
        i, c, x, error = loop_state
        return (error > tol) & (i < max_iter)

    def body(loop_state):
        i, c, x, error = loop_state
        c_new, x_new = K(c, x, model)
        error = jnp.max(jnp.abs(c_new - c))
        return i + 1, c_new, x_new, error

    # Initialize loop state
    initial_state = (0, c_init, x_init, tol + 1)

    # Run the loop
    i, c, x, error = jax.lax.while_loop(condition, body, initial_state)

    return c, x

We solve the model starting from an initial guess.

c_init = jnp.copy(s_grid)
x_init = s_grid + c_init
c, x = solve_model_time_iter(model, c_init, x_init)

Let’s plot the resulting policy against the analytical solution.

fig, ax = plt.subplots()

ax.plot(x, c, lw=2,
        alpha=0.8, label='approximate policy function')

ax.plot(x, σ_star(x, model.α, model.β), 'k--',
        lw=2, alpha=0.8, label='true policy function')

ax.legend()
plt.show()
_images/f09a6eadc140b53fd6190cbdf5ed23dc2388ae1dd734dce4878c63da6fd3c601.png

The fit is very good.

max_dev = jnp.max(jnp.abs(c - σ_star(x, model.α, model.β)))
print(f"Maximum absolute deviation: {max_dev:.7}")
Maximum absolute deviation: 1.430511e-06

The JAX implementation is very fast thanks to JIT compilation and vectorization.

with qe.Timer(precision=8):
    c, x = solve_model_time_iter(model, c_init, x_init)
    jax.block_until_ready(c)
0.00293183 seconds elapsed

This speed comes from:

  • JIT compilation of the entire solver

  • Vectorization via vmap in the Coleman-Reffett operator

  • Use of jax.lax.while_loop instead of a Python loop

  • Efficient JAX array operations throughout

56.3. Exercises#

Exercise 56.1

Solve the optimal savings problem with CRRA utility

\[ u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma} \]

Compare the optimal policies for values of \(\gamma\) approaching 1 from above (e.g., 1.05, 1.1, 1.2).

Show that as \(\gamma \to 1\), the optimal policy converges to the policy obtained with log utility (\(\gamma = 1\)).

Hint: Use values of \(\gamma\) close to 1 to ensure the endogenous grids have similar coverage and make visual comparison easier.