57. The Income Fluctuation Problem I: Discretization and VFI#

57.1. Overview#

In this lecture, we study an optimal savings problem for an infinitely lived consumer—the “common ancestor” described in [Ljungqvist and Sargent, 2018], section 1.3.

This savings problem is often called an income fluctuation problem or a household problem.

It is an essential sub-problem for many representative macroeconomic models

It is related to the decision problem in Optimal Savings III: Stochastic Returns but differs in significant ways.

For example,

  1. The choice problem for the agent includes an additive income term that leads to an occasionally binding constraint.

  2. Shocks affecting the budget constraint are correlated, forcing us to track an extra state variable.

We will begin by working with a relatively basic version of the model and solving it via old-fashioned discretization + value function iteration.

Although this approach is not the fastest or the most efficient, it is very robust and flexible.

For example, if we suddenly decided to add Epstein–Zin preferences, or modify ordinary conditional expectations to quantiles, the technique would continue to work well.

Note

The same is not true of some other methods we will deploy, such as the endogenous grid method.

This is a general rule of computation and analysis — while we can often come up with faster algorithms by exploiting structure, these new algorithms are typically less robust.

They are less robust precisely because they exploit more structure — which implies that they are, inevitably, more vulnerable to change.

In addition to Anaconda, this lecture will need the following libraries:

!pip install quantecon jax

Hide code cell output

Collecting quantecon
  Downloading quantecon-0.10.1-py3-none-any.whl.metadata (5.3 kB)
Requirement already satisfied: jax in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (0.8.1)
Requirement already satisfied: numba>=0.49.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (0.61.0)
Requirement already satisfied: numpy>=1.17.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.1.3)
Requirement already satisfied: requests in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (2.32.3)
Requirement already satisfied: scipy>=1.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.15.3)
Requirement already satisfied: sympy in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from quantecon) (1.13.3)
Requirement already satisfied: jaxlib<=0.8.1,>=0.8.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.8.1)
Requirement already satisfied: ml_dtypes>=0.5.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (0.5.4)
Requirement already satisfied: opt_einsum in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from jax) (3.4.0)
Requirement already satisfied: llvmlite<0.45,>=0.44.0dev0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from numba>=0.49.0->quantecon) (0.44.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from requests->quantecon) (2025.4.26)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/runner/miniconda3/envs/quantecon/lib/python3.13/site-packages (from sympy->quantecon) (1.3.0)
Downloading quantecon-0.10.1-py3-none-any.whl (325 kB)
Installing collected packages: quantecon
Successfully installed quantecon-0.10.1

We will use the following imports:

import quantecon as qe
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from typing import NamedTuple
from time import time
Matplotlib is building the font cache; this may take a moment.

We’ll use 64 bit floats to gain extra precision.

jax.config.update("jax_enable_x64", True)

57.2. Set Up#

We study a household that chooses a state-contingent consumption plan \(\{c_t\}_{t \geq 0}\) to maximize

\[ \mathbb{E} \, \sum_{t=0}^{\infty} \beta^t u(c_t) \]

subject to

\[ a_{t+1} + c_t \leq R a_t + y_t \]

Here

  • \(c_t\) is consumption and \(c_t \geq 0\),

  • \(a_t\) is assets and \(a_t \geq 0\),

  • \(R = 1 + r\) is a gross rate of return, and

  • \((y_t)_{t \geq 0}\) is labor income, taking values in some finite set \(\mathsf Y\).

We assume below that labor income dynamics follow a discretized AR(1) process.

We set \(\mathsf S := \mathbb{R}_+ \times \mathsf Y\), which represents the state space.

The value function \(V \colon \mathsf S \to \mathbb{R}\) is defined by

(57.1)#\[V(a, y) := \max \, \mathbb{E} \left\{ \sum_{t=0}^{\infty} \beta^t u(c_t) \right\}\]

where the maximization is over all feasible consumption sequences given \((a_0, y_0) = (a, y)\).

The Bellman equation is

\[ v(a, y) = \max_{0 \leq a' \leq Ra + y} \left\{ u(Ra + y - a') + β \sum_{y'} v(a', y') Q(y, y') \right\} \]

where

\[ u(c) = \frac{c^{1-\gamma}}{1-\gamma} \]

In the code we use the function

\[ B((a, y), a', v) = u(Ra + y - a') + β \sum_{y'} v(a', y') Q(y, y'). \]

the encapsulate the right hand side of the Bellman equation.

57.3. Code#

The following code defines a NamedTuple to store the model parameters and grids.

class Model(NamedTuple):
    β: float              # Discount factor
    R: float              # Gross interest rate
    γ: float              # CRRA parameter
    a_grid: jnp.ndarray   # Asset grid
    y_grid: jnp.ndarray   # Income grid
    Q: jnp.ndarray        # Markov matrix for income


def create_consumption_model(
        R=1.01,                    # Gross interest rate
        β=0.98,                    # Discount factor
        γ=2,                       # CRRA parameter
        a_min=0.01,                # Min assets
        a_max=5.0,                 # Max assets
        a_size=150,                # Grid size
        ρ=0.9, ν=0.1, y_size=100   # Income parameters
    ):
    """
    Creates an instance of the consumption-savings model.

    """
    a_grid = jnp.linspace(a_min, a_max, a_size)
    mc = qe.tauchen(n=y_size, rho=ρ, sigma=ν)
    y_grid, Q = jnp.exp(mc.state_values), jax.device_put(mc.P)
    return Model(β, R, γ, a_grid, y_grid, Q)

Now we define the right hand side of the Bellman equation.

We’ll use a vectorized coding style reminiscent of Matlab and NumPy (avoiding all loops).

Your are invited to explore an alternative style based around jax.vmap in the Exercises.

@jax.jit
def B(v, model):
    """
    A vectorized version of the right-hand side of the Bellman equation
    (before maximization), which is a 3D array representing

        B(a, y, a′) = u(Ra + y - a′) + β Σ_y′ v(a′, y′) Q(y, y′)

    for all (a, y, a′).
    """

    # Unpack
    β, R, γ, a_grid, y_grid, Q = model
    a_size, y_size = len(a_grid), len(y_grid)

    # Compute current rewards r(a, y, ap) as array r[i, j, ip]
    a  = jnp.reshape(a_grid, (a_size, 1, 1))    # a[i]   ->  a[i, j, ip]
    y  = jnp.reshape(y_grid, (1, y_size, 1))    # z[j]   ->  z[i, j, ip]
    ap = jnp.reshape(a_grid, (1, 1, a_size))    # ap[ip] -> ap[i, j, ip]
    c = R * a + y - ap

    # Calculate continuation rewards at all combinations of (a, y, ap)
    v = jnp.reshape(v, (1, 1, a_size, y_size))  # v[ip, jp] -> v[i, j, ip, jp]
    Q = jnp.reshape(Q, (1, y_size, 1, y_size))  # Q[j, jp]  -> Q[i, j, ip, jp]
    EV = jnp.sum(v * Q, axis=3)                 # sum over last index jp

    # Compute the right-hand side of the Bellman equation
    return jnp.where(c > 0, c**(1-γ)/(1-γ) + β * EV, -jnp.inf)

Some readers might be concerned that we are creating high dimensional arrays, leading to inefficiency.

Could they be avoided by more careful vectorization?

In fact this is not necessary: this function will be JIT-compiled by JAX, and the JIT compiler will optimize compiled code to minimize memory use.

The Bellman operator \(T\) can be implemented by

@jax.jit
def T(v, model):
    "The Bellman operator."
    return jnp.max(B(v, model), axis=2)

The next function computes a \(v\)-greedy policy given \(v\) (i.e., the policy that maximizes the right-hand side of the Bellman equation.)

@jax.jit
def get_greedy(v, model):
    "Computes a v-greedy policy, returned as a set of indices."
    return jnp.argmax(B(v, model), axis=2)

57.3.1. Value function iteration#

Now we define a solver that implements VFI.

First we write a simple version using a standard Python loop.

def value_function_iteration_python(model, tol=1e-5, max_iter=10_000):
    """
    Implements VFI using successive approximation with a Python loop.
    """
    v = jnp.zeros((len(model.a_grid), len(model.y_grid)))
    error = tol + 1
    k = 0

    while error > tol and k < max_iter:
        v_new = T(v, model)
        error = jnp.max(jnp.abs(v_new - v))
        v = v_new
        k += 1

    return v, get_greedy(v, model)

Next we write a version that uses jax.lax.while_loop.

@jax.jit
def value_function_iteration(model, tol=1e-5, max_iter=10_000):
    """
    Implements VFI using successive approximation.
    """
    def body_fun(k_v_err):
        k, v, error = k_v_err
        v_new = T(v, model)
        error = jnp.max(jnp.abs(v_new - v))
        return k + 1, v_new, error

    def cond_fun(k_v_err):
        k, v, error = k_v_err
        return jnp.logical_and(error > tol, k < max_iter)

    v_init = jnp.zeros((len(model.a_grid), len(model.y_grid)))
    k, v_star, error = jax.lax.while_loop(cond_fun, body_fun,
                                          (1, v_init, tol + 1))
    return v_star, get_greedy(v_star, model)

57.3.2. Timing#

Let’s create an instance and compare the two implementations.

model = create_consumption_model()
W1125 05:15:31.482491    2108 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1125 05:15:31.486009    2040 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.

First let’s time the Python version.

print("Starting VFI using Python loop.")
start = time()
v_star_python, σ_star_python = value_function_iteration_python(model)
python_time = time() - start
print(f"VFI completed in {python_time} seconds.")
Starting VFI using Python loop.
VFI completed in 2.6871509552001953 seconds.

Now let’s time the jax.lax.while_loop version.

print("Starting VFI using jax.lax.while_loop.")
start = time()
v_star_jax, σ_star_jax = value_function_iteration(model)
v_star_jax.block_until_ready()
jax_with_compile = time() - start
print(f"VFI completed in {jax_with_compile} seconds.")
Starting VFI using jax.lax.while_loop.
VFI completed in 1.7547013759613037 seconds.

Let’s run it again to eliminate compile time.

start = time()
v_star_jax, σ_star_jax = value_function_iteration(model)
v_star_jax.block_until_ready()
jax_without_compile = time() - start
print(f"VFI completed in {jax_without_compile} seconds.")
VFI completed in 1.3669328689575195 seconds.

Let’s check that the two implementations produce the same result.

print(f"Values match: {jnp.allclose(v_star_python, v_star_jax)}")
print(f"Policies match: {jnp.allclose(σ_star_python, σ_star_jax)}")
Values match: True
Policies match: True

Here’s the speedup from using jax.lax.while_loop.

print(f"Relative speed = {python_time / jax_without_compile:.2f}")
Relative speed = 1.97

57.4. Exercises#

Exercise 57.1

In this exercise, we explore an alternative approach to implementing value function iteration using jax.vmap.

For this simple optimal savings problem, direct vectorization is relatively easy.

In particular, it’s straightforward to express the right hand side of the Bellman equation as an array that stores evaluations of the function at every state and control.

However, for more complex models, direct vectorization can be much harder.

For this reason, it helps to have another approach to fast JAX implementations up our sleeves.

Your task is to implement a version that:

  1. writes the right hand side of the Bellman operator as a function of individual states and controls, and

  2. applies jax.vmap on the outside to achieve a parallelized solution.

Specifically:

  1. Rewrite B to take indices (i, j, ip) corresponding to (a, y, a′) and compute the Bellman equation for those specific indices.

  2. Use jax.vmap successively to vectorize over all indices (use staged vmap as shown in earlier examples).

  3. Implement T_vmap and get_greedy_vmap functions using the vectorized B.

  4. Implement value_iteration_vmap using jax.lax.while_loop.

  5. Test that your implementation produces the same results as the direct vectorization approach.

  6. Compare the execution times of both approaches.