56. Cake Eating VI: EGM with JAX#

56.1. Overview#

In this lecture, we’ll implement the endogenous grid method (EGM) using JAX.

This lecture builds on Cake Eating V: The Endogenous Grid Method, which introduced EGM using NumPy.

By converting to JAX, we can leverage fast linear algebra, hardware accelerators, and JIT compilation for improved performance.

We’ll also use JAX’s vmap function to fully vectorize the Coleman-Reffett operator.

Let’s start with some standard imports:

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import quantecon as qe
from typing import NamedTuple

56.2. Implementation#

For details on the savings problem and the endogenous grid method (EGM), please see Cake Eating V: The Endogenous Grid Method.

Here we focus on the JAX implementation of EGM.

We use the same setting as in Cake Eating V: The Endogenous Grid Method:

  • \(u(c) = \ln c\),

  • production is Cobb-Douglas, and

  • the shocks are lognormal.

Here are the analytical solutions for comparison.

def v_star(x, α, β, μ):
    """
    True value function
    """
    c1 = jnp.log(1 - α * β) / (1 - β)
    c2 = (μ + α * jnp.log(α * β)) / (1 - α)
    c3 = 1 / (1 - β)
    c4 = 1 / (1 - α * β)
    return c1 + c2 * (c3 - c4) + c4 * jnp.log(x)

def σ_star(x, α, β):
    """
    True optimal policy
    """
    return (1 - α * β) * x

The Model class stores only the data (grids, shocks, and parameters).

Utility and production functions will be defined globally to work with JAX’s JIT compiler.

class Model(NamedTuple):
    β: float              # discount factor
    μ: float              # shock location parameter
    s: float              # shock scale parameter
    grid: jnp.ndarray     # state grid
    shocks: jnp.ndarray   # shock draws
    α: float              # production function parameter


def create_model(β: float = 0.96,
                 μ: float = 0.0,
                 s: float = 0.1,
                 grid_max: float = 4.0,
                 grid_size: int = 120,
                 shock_size: int = 250,
                 seed: int = 1234,
                 α: float = 0.4) -> Model:
    """
    Creates an instance of the cake eating model.
    """
    # Set up grid
    grid = jnp.linspace(1e-4, grid_max, grid_size)

    # Store shocks (with a seed, so results are reproducible)
    key = jax.random.PRNGKey(seed)
    shocks = jnp.exp(μ + s * jax.random.normal(key, shape=(shock_size,)))

    return Model(β=β, μ=μ, s=s, grid=grid, shocks=shocks, α=α)

Here’s the Coleman-Reffett operator using EGM.

The key JAX feature here is vmap, which vectorizes the computation over the grid points.

def K(σ_array: jnp.ndarray, model: Model) -> jnp.ndarray:
    """
    The Coleman-Reffett operator using EGM

    """

    # Simplify names
    β, α = model.β, model.α
    grid, shocks = model.grid, model.shocks

    # Determine endogenous grid
    x = grid + σ_array  # x_i = k_i + c_i

    # Linear interpolation of policy using endogenous grid
    σ = lambda x_val: jnp.interp(x_val, x, σ_array)

    # Define function to compute consumption at a single grid point
    def compute_c(k):
        vals = u_prime(σ(f(k, α) * shocks)) * f_prime(k, α) * shocks
        return u_prime_inv(β * jnp.mean(vals))

    # Vectorize over grid using vmap
    compute_c_vectorized = jax.vmap(compute_c)
    c = compute_c_vectorized(grid)

    return c

We define utility and production functions globally.

Note that f and f_prime take α as an explicit argument, allowing them to work with JAX’s functional programming model.

# Define utility and production functions with derivatives
u = lambda c: jnp.log(c)
u_prime = lambda c: 1 / c
u_prime_inv = lambda x: 1 / x
f = lambda k, α: k**α
f_prime = lambda k, α: α * k**(α - 1)

Now we create a model instance.

α = 0.4

model = create_model(α=α)
grid = model.grid

The solver uses JAX’s jax.lax.while_loop for the iteration and is JIT-compiled for speed.

@jax.jit
def solve_model_time_iter(model: Model,
                          σ_init: jnp.ndarray,
                          tol: float = 1e-5,
                          max_iter: int = 1000) -> jnp.ndarray:
    """
    Solve the model using time iteration with EGM.
    """

    def condition(loop_state):
        i, σ, error = loop_state
        return (error > tol) & (i < max_iter)

    def body(loop_state):
        i, σ, error = loop_state
        σ_new = K(σ, model)
        error = jnp.max(jnp.abs(σ_new - σ))
        return i + 1, σ_new, error

    # Initialize loop state
    initial_state = (0, σ_init, tol + 1)

    # Run the loop
    i, σ, error = jax.lax.while_loop(condition, body, initial_state)

    return σ

We solve the model starting from an initial guess.

σ_init = jnp.copy(grid)
σ = solve_model_time_iter(model, σ_init)

Let’s plot the resulting policy against the analytical solution.

x = grid + σ  # x_i = k_i + c_i

fig, ax = plt.subplots()

ax.plot(x, σ, lw=2,
        alpha=0.8, label='approximate policy function')

ax.plot(x, σ_star(x, model.α, model.β), 'k--',
        lw=2, alpha=0.8, label='true policy function')

ax.legend()
plt.show()
_images/f09a6eadc140b53fd6190cbdf5ed23dc2388ae1dd734dce4878c63da6fd3c601.png

The fit is very good.

max_dev = jnp.max(jnp.abs(σ - σ_star(x, model.α, model.β)))
print(f"Maximum absolute deviation: {max_dev:.7}")
Maximum absolute deviation: 1.430511e-06

The JAX implementation is very fast thanks to JIT compilation and vectorization.

with qe.Timer(precision=8):
    σ = solve_model_time_iter(model, σ_init).block_until_ready()
0.00272655 seconds elapsed

This speed comes from:

  • JIT compilation of the entire solver

  • Vectorization via vmap in the Coleman-Reffett operator

  • Use of jax.lax.while_loop instead of a Python loop

  • Efficient JAX array operations throughout

56.3. Exercises#

Exercise 56.1

Solve the stochastic cake eating problem with CRRA utility

\[ u(c) = \frac{c^{1 - \gamma} - 1}{1 - \gamma} \]

Compare the optimal policies for values of \(\gamma\) approaching 1 from above (e.g., 1.05, 1.1, 1.2).

Show that as \(\gamma \to 1\), the optimal policy converges to the policy obtained with log utility (\(\gamma = 1\)).

Hint: Use values of \(\gamma\) close to 1 to ensure the endogenous grids have similar coverage and make visual comparison easier.